From 3a0451a8be43154f0da06dd9c693ed6f0c657042 Mon Sep 17 00:00:00 2001
From: Sarang Noether <32460187+SarangNoether@users.noreply.github.com>
Date: Tue, 27 Aug 2019 16:22:44 -0400
Subject: [PATCH] MLSAG speedup and additional checks

---
 src/ringct/bulletproofs.cc          |  5 +-
 src/ringct/rctOps.cpp               | 44 +++--------------
 src/ringct/rctOps.h                 |  5 +-
 src/ringct/rctSigs.cpp              | 75 +++++++++++++++--------------
 tests/performance_tests/main.cpp    | 20 +-------
 tests/performance_tests/rct_mlsag.h | 12 ++---
 tests/unit_tests/ringct.cpp         | 15 +++++-
 7 files changed, 75 insertions(+), 101 deletions(-)

diff --git a/src/ringct/bulletproofs.cc b/src/ringct/bulletproofs.cc
index 6270d4d14..ff6fee95c 100644
--- a/src/ringct/bulletproofs.cc
+++ b/src/ringct/bulletproofs.cc
@@ -101,7 +101,10 @@ static rct::key get_exponent(const rct::key &base, size_t idx)
 {
   static const std::string salt("bulletproof");
   std::string hashed = std::string((const char*)base.bytes, sizeof(base)) + salt + tools::get_varint_data(idx);
-  const rct::key e = rct::hashToPoint(rct::hash2rct(crypto::cn_fast_hash(hashed.data(), hashed.size())));
+  rct::key e;
+  ge_p3 e_p3;
+  rct::hash_to_p3(e_p3, rct::hash2rct(crypto::cn_fast_hash(hashed.data(), hashed.size())));
+  ge_p3_tobytes(e.bytes, &e_p3);
   CHECK_AND_ASSERT_THROW_MES(!(e == rct::identity()), "Exponent is point at infinity");
   return e;
 }
diff --git a/src/ringct/rctOps.cpp b/src/ringct/rctOps.cpp
index b5499262f..6e4d063df 100644
--- a/src/ringct/rctOps.cpp
+++ b/src/ringct/rctOps.cpp
@@ -620,45 +620,17 @@ namespace rct {
        sc_reduce32(rv.bytes);
        return rv;
    }
-
-    key hashToPointSimple(const key & hh) {
-        key pointk;
-        ge_p1p1 point2;
-        ge_p2 point;
-        ge_p3 res;
-        key h = cn_fast_hash(hh); 
-        CHECK_AND_ASSERT_THROW_MES_L1(ge_frombytes_vartime(&res, h.bytes) == 0, "ge_frombytes_vartime failed at "+boost::lexical_cast<std::string>(__LINE__));
-        ge_p3_to_p2(&point, &res);
-        ge_mul8(&point2, &point);
-        ge_p1p1_to_p3(&res, &point2);
-        ge_p3_tobytes(pointk.bytes, &res);
-        return pointk;
-    }    
     
-    key hashToPoint(const key & hh) {
-        key pointk;
-        ge_p2 point;
-        ge_p1p1 point2;
-        ge_p3 res;
-        key h = cn_fast_hash(hh); 
-        ge_fromfe_frombytes_vartime(&point, h.bytes);
-        ge_mul8(&point2, &point);
-        ge_p1p1_to_p3(&res, &point2);        
-        ge_p3_tobytes(pointk.bytes, &res);
-        return pointk;
+    // Hash a key to p3 representation
+    void hash_to_p3(ge_p3 &hash8_p3, const key &k) {
+      key hash_key = cn_fast_hash(k);
+      ge_p2 hash_p2;
+      ge_fromfe_frombytes_vartime(&hash_p2, hash_key.bytes);
+      ge_p1p1 hash8_p1p1;
+      ge_mul8(&hash8_p1p1, &hash_p2);
+      ge_p1p1_to_p3(&hash8_p3, &hash8_p1p1);
     }
 
-    void hashToPoint(key & pointk, const key & hh) {
-        ge_p2 point;
-        ge_p1p1 point2;
-        ge_p3 res;
-        key h = cn_fast_hash(hh); 
-        ge_fromfe_frombytes_vartime(&point, h.bytes);
-        ge_mul8(&point2, &point);
-        ge_p1p1_to_p3(&res, &point2);        
-        ge_p3_tobytes(pointk.bytes, &res);
-    }    
-
     //sums a vector of curve points (for scalars use sc_add)
     void sumKeys(key & Csum, const keyV &  Cis) {
         identity(Csum);
diff --git a/src/ringct/rctOps.h b/src/ringct/rctOps.h
index dd6d87593..c24d48e9a 100644
--- a/src/ringct/rctOps.h
+++ b/src/ringct/rctOps.h
@@ -172,10 +172,7 @@ namespace rct {
     key cn_fast_hash(const key64 keys);
     key hash_to_scalar(const key64 keys);
 
-    //returns hashToPoint as described in https://github.com/ShenNoether/ge_fromfe_writeup 
-    key hashToPointSimple(const key &in);
-    key hashToPoint(const key &in);
-    void hashToPoint(key &out, const key &in);
+    void hash_to_p3(ge_p3 &hash8_p3, const key &k);
 
     //sums a vector of curve points (for scalars use sc_add)
     void sumKeys(key & Csum, const key &Cis);
diff --git a/src/ringct/rctSigs.cpp b/src/ringct/rctSigs.cpp
index ff2a81d43..a7b265d63 100644
--- a/src/ringct/rctSigs.cpp
+++ b/src/ringct/rctSigs.cpp
@@ -163,14 +163,11 @@ namespace rct {
       return verifyBorromean(bb, P1_p3, P2_p3);
     }
 
-    //Multilayered Spontaneous Anonymous Group Signatures (MLSAG signatures)
-    //This is a just slghtly more efficient version than the ones described below
-    //(will be explained in more detail in Ring Multisig paper
-    //These are aka MG signatutes in earlier drafts of the ring ct paper
-    // c.f. https://eprint.iacr.org/2015/1098 section 2. 
-    // Gen creates a signature which proves that for some column in the keymatrix "pk"
-    //   the signer knows a secret key for each row in that column
-    // Ver verifies that the MG sig was created correctly        
+    // MLSAG signatures
+    // See paper by Noether (https://eprint.iacr.org/2015/1098)
+    // This generalization allows for some dimensions not to require linkability;
+    //   this is used in practice for commitment data within signatures
+    // Note that using more than one linkable dimension is not recommended.
     mgSig MLSAG_Gen(const key &message, const keyM & pk, const keyV & xx, const multisig_kLRki *kLRki, key *mscout, const unsigned int index, size_t dsRows, hw::device &hwdev) {
         mgSig rv;
         size_t cols = pk.size();
@@ -188,6 +185,7 @@ namespace rct {
 
         size_t i = 0, j = 0, ii = 0;
         key c, c_old, L, R, Hi;
+        ge_p3 Hi_p3;
         sc_0(c_old.bytes);
         vector<geDsmp> Ip(dsRows);
         rv.II = keyV(dsRows);
@@ -208,7 +206,8 @@ namespace rct {
               rv.II[i] = kLRki->ki;
             }
             else {
-              Hi = hashToPoint(pk[index][i]);
+              hash_to_p3(Hi_p3, pk[index][i]);
+              ge_p3_tobytes(Hi.bytes, &Hi_p3);
               hwdev.mlsag_prepare(Hi, xx[i], alpha[i] , aG[i] , aHP[i] , rv.II[i]);
               toHash[3 * i + 2] = aG[i];
               toHash[3 * i + 3] = aHP[i];
@@ -235,7 +234,8 @@ namespace rct {
             sc_0(c.bytes);
             for (j = 0; j < dsRows; j++) {
                 addKeys2(L, rv.ss[i][j], c_old, pk[i][j]);
-                hashToPoint(Hi, pk[i][j]);
+                hash_to_p3(Hi_p3, pk[i][j]);
+                ge_p3_tobytes(Hi.bytes, &Hi_p3);
                 addKeys3(R, rv.ss[i][j], Hi, c_old, Ip[j].k);
                 toHash[3 * j + 1] = pk[i][j];
                 toHash[3 * j + 2] = L; 
@@ -260,43 +260,42 @@ namespace rct {
         return rv;
     }
     
-    //Multilayered Spontaneous Anonymous Group Signatures (MLSAG signatures)
-    //This is a just slghtly more efficient version than the ones described below
-    //(will be explained in more detail in Ring Multisig paper
-    //These are aka MG signatutes in earlier drafts of the ring ct paper
-    // c.f. https://eprint.iacr.org/2015/1098 section 2. 
-    // Gen creates a signature which proves that for some column in the keymatrix "pk"
-    //   the signer knows a secret key for each row in that column
-    // Ver verifies that the MG sig was created correctly            
+    // MLSAG signatures
+    // See paper by Noether (https://eprint.iacr.org/2015/1098)
+    // This generalization allows for some dimensions not to require linkability;
+    //   this is used in practice for commitment data within signatures
+    // Note that using more than one linkable dimension is not recommended.
     bool MLSAG_Ver(const key &message, const keyM & pk, const mgSig & rv, size_t dsRows) {
-
         size_t cols = pk.size();
-        CHECK_AND_ASSERT_MES(cols >= 2, false, "Error! What is c if cols = 1!");
+        CHECK_AND_ASSERT_MES(cols >= 2, false, "Signature must contain more than one public key");
         size_t rows = pk[0].size();
-        CHECK_AND_ASSERT_MES(rows >= 1, false, "Empty pk");
+        CHECK_AND_ASSERT_MES(rows >= 1, false, "Bad total row number");
         for (size_t i = 1; i < cols; ++i) {
-          CHECK_AND_ASSERT_MES(pk[i].size() == rows, false, "pk is not rectangular");
+          CHECK_AND_ASSERT_MES(pk[i].size() == rows, false, "Bad public key matrix dimensions");
         }
-        CHECK_AND_ASSERT_MES(rv.II.size() == dsRows, false, "Bad II size");
-        CHECK_AND_ASSERT_MES(rv.ss.size() == cols, false, "Bad rv.ss size");
+        CHECK_AND_ASSERT_MES(rv.II.size() == dsRows, false, "Wrong number of key images present");
+        CHECK_AND_ASSERT_MES(rv.ss.size() == cols, false, "Bad scalar matrix dimensions");
         for (size_t i = 0; i < cols; ++i) {
-          CHECK_AND_ASSERT_MES(rv.ss[i].size() == rows, false, "rv.ss is not rectangular");
+          CHECK_AND_ASSERT_MES(rv.ss[i].size() == rows, false, "Bad scalar matrix dimensions");
         }
-        CHECK_AND_ASSERT_MES(dsRows <= rows, false, "Bad dsRows value");
+        CHECK_AND_ASSERT_MES(dsRows <= rows, false, "Non-double-spend rows cannot exceed total rows");
 
-        for (size_t i = 0; i < rv.ss.size(); ++i)
-          for (size_t j = 0; j < rv.ss[i].size(); ++j)
-            CHECK_AND_ASSERT_MES(sc_check(rv.ss[i][j].bytes) == 0, false, "Bad ss slot");
-        CHECK_AND_ASSERT_MES(sc_check(rv.cc.bytes) == 0, false, "Bad cc");
+        for (size_t i = 0; i < rv.ss.size(); ++i) {
+          for (size_t j = 0; j < rv.ss[i].size(); ++j) {
+            CHECK_AND_ASSERT_MES(sc_check(rv.ss[i][j].bytes) == 0, false, "Bad signature scalar");
+          }
+        }
+        CHECK_AND_ASSERT_MES(sc_check(rv.cc.bytes) == 0, false, "Bad initial signature hash");
 
         size_t i = 0, j = 0, ii = 0;
-        key c,  L, R, Hi;
+        key c,  L, R;
         key c_old = copy(rv.cc);
         vector<geDsmp> Ip(dsRows);
         for (i = 0 ; i < dsRows ; i++) {
+            CHECK_AND_ASSERT_MES(!(rv.II[i] == rct::identity()), false, "Bad key image");
             precomp(Ip[i].k, rv.II[i]);
         }
-        size_t ndsRows = 3 * dsRows; //non Double Spendable Rows (see identity chains paper
+        size_t ndsRows = 3 * dsRows; // number of dimensions not requiring linkability
         keyV toHash(1 + 3 * dsRows + 2 * (rows - dsRows));
         toHash[0] = message;
         i = 0;
@@ -304,9 +303,14 @@ namespace rct {
             sc_0(c.bytes);
             for (j = 0; j < dsRows; j++) {
                 addKeys2(L, rv.ss[i][j], c_old, pk[i][j]);
-                hashToPoint(Hi, pk[i][j]);
-                CHECK_AND_ASSERT_MES(!(Hi == rct::identity()), false, "Data hashed to point at infinity");
-                addKeys3(R, rv.ss[i][j], Hi, c_old, Ip[j].k);
+
+                // Compute R directly
+                ge_p3 hash8_p3;
+                hash_to_p3(hash8_p3, pk[i][j]);
+                ge_p2 R_p2;
+                ge_double_scalarmult_precomp_vartime(&R_p2, rv.ss[i][j].bytes, &hash8_p3, c_old.bytes, Ip[j].k);
+                ge_tobytes(R.bytes, &R_p2);
+
                 toHash[3 * j + 1] = pk[i][j];
                 toHash[3 * j + 2] = L; 
                 toHash[3 * j + 3] = R;
@@ -317,6 +321,7 @@ namespace rct {
                 toHash[ndsRows + 2 * ii + 2] = L;
             }
             c = hash_to_scalar(toHash);
+            CHECK_AND_ASSERT_MES(!(c == rct::zero()), false, "Bad signature hash");
             copy(c_old, c);
             i = (i + 1);
         }
diff --git a/tests/performance_tests/main.cpp b/tests/performance_tests/main.cpp
index c32e0df20..bd7414c59 100644
--- a/tests/performance_tests/main.cpp
+++ b/tests/performance_tests/main.cpp
@@ -57,7 +57,6 @@
 #include "rct_mlsag.h"
 #include "equality.h"
 #include "range_proof.h"
-#include "rct_mlsag.h"
 #include "bulletproof.h"
 #include "crypto_ops.h"
 #include "multiexp.h"
@@ -214,14 +213,8 @@ int main(int argc, char** argv)
   TEST_PERFORMANCE1(filter, p, test_cn_fast_hash, 32);
   TEST_PERFORMANCE1(filter, p, test_cn_fast_hash, 16384);
 
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 3, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 5, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 10, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 100, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 3, true);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 5, true);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 10, true);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 100, true);
+  TEST_PERFORMANCE2(filter, p, test_ringct_mlsag, 11, false);
+  TEST_PERFORMANCE2(filter, p, test_ringct_mlsag, 11, true);
 
   TEST_PERFORMANCE2(filter, p, test_equality, memcmp32, true);
   TEST_PERFORMANCE2(filter, p, test_equality, memcmp32, false);
@@ -251,15 +244,6 @@ int main(int argc, char** argv)
   TEST_PERFORMANCE6(filter, p, test_aggregated_bulletproof, false, 2, 1, 1, 0, 64);
   TEST_PERFORMANCE6(filter, p, test_aggregated_bulletproof, true, 2, 1, 1, 0, 64); // 64 proof, each with 2 amounts
 
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 3, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 5, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 10, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 100, false);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 3, true);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 5, true);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 10, true);
-  TEST_PERFORMANCE3(filter, p, test_ringct_mlsag, 1, 100, true);
-
   TEST_PERFORMANCE1(filter, p, test_crypto_ops, op_sc_add);
   TEST_PERFORMANCE1(filter, p, test_crypto_ops, op_sc_sub);
   TEST_PERFORMANCE1(filter, p, test_crypto_ops, op_sc_mul);
diff --git a/tests/performance_tests/rct_mlsag.h b/tests/performance_tests/rct_mlsag.h
index 0141710f7..59eae074e 100644
--- a/tests/performance_tests/rct_mlsag.h
+++ b/tests/performance_tests/rct_mlsag.h
@@ -35,13 +35,13 @@
 
 #include "single_tx_test_base.h"
 
-template<size_t inputs, size_t ring_size, bool ver>
+template<size_t ring_size, bool ver>
 class test_ringct_mlsag : public single_tx_test_base
 {
 public:
   static const size_t cols = ring_size;
-  static const size_t rows = inputs;
-  static const size_t loop_count = 100;
+  static const size_t rows = 2; // single spend and commitment data
+  static const size_t loop_count = 1000;
 
   bool init()
   {
@@ -65,7 +65,7 @@ public:
     {
         sk[j] = xm[ind][j];
     }
-    IIccss = MLSAG_Gen(rct::identity(), P, sk, NULL, NULL, ind, rows, hw::get_device("default"));
+    IIccss = MLSAG_Gen(rct::identity(), P, sk, NULL, NULL, ind, rows-1, hw::get_device("default"));
 
     return true;
   }
@@ -73,9 +73,9 @@ public:
   bool test()
   {
     if (ver)
-      MLSAG_Ver(rct::identity(), P, IIccss, rows);
+      MLSAG_Ver(rct::identity(), P, IIccss, rows-1);
     else
-      MLSAG_Gen(rct::identity(), P, sk, NULL, NULL, ind, rows, hw::get_device("default"));
+      MLSAG_Gen(rct::identity(), P, sk, NULL, NULL, ind, rows-1, hw::get_device("default"));
     return true;
   }
 
diff --git a/tests/unit_tests/ringct.cpp b/tests/unit_tests/ringct.cpp
index 4d51ec434..8788dba8d 100644
--- a/tests/unit_tests/ringct.cpp
+++ b/tests/unit_tests/ringct.cpp
@@ -788,7 +788,20 @@ TEST(ringct, HPow2)
 {
   key G = scalarmultBase(d2h(1));
 
-  key H = hashToPointSimple(G);
+  // Note that H is computed differently than standard hashing
+  // This method is not guaranteed to return a curvepoint for all inputs
+  // Don't use it elsewhere
+  key H = cn_fast_hash(G);
+  ge_p3 H_p3;
+  int decode = ge_frombytes_vartime(&H_p3, H.bytes);
+  ASSERT_EQ(decode, 0); // this is known to pass for the particular value G
+  ge_p2 H_p2;
+  ge_p3_to_p2(&H_p2, &H_p3);
+  ge_p1p1 H8_p1p1;
+  ge_mul8(&H8_p1p1, &H_p2);
+  ge_p1p1_to_p3(&H_p3, &H8_p1p1);
+  ge_p3_tobytes(H.bytes, &H_p3);
+
   for (int j = 0 ; j < ATOMS ; j++) {
     ASSERT_TRUE(equalKeys(H, H2[j]));
     addKeys(H, H, H);