From 99d946e6191056f747225d36a2408085624b516e Mon Sep 17 00:00:00 2001
From: moneromooo-monero <moneromooo-monero@users.noreply.github.com>
Date: Sun, 6 Jan 2019 19:49:52 +0000
Subject: [PATCH] ringct: encode 8 byte amount, saving 24 bytes per output

Found by knaccc
---
 .../cryptonote_boost_serialization.h          | 12 +++++++-
 src/device/device.hpp                         |  4 +--
 src/device/device_default.cpp                 |  8 +++---
 src/device/device_default.hpp                 |  4 +--
 src/device/device_ledger.cpp                  |  8 +++---
 src/device/device_ledger.hpp                  |  4 +--
 src/ringct/rctOps.cpp                         | 28 ++++++++++++++++---
 src/ringct/rctOps.h                           |  4 +--
 src/ringct/rctSigs.cpp                        |  8 +++---
 src/ringct/rctTypes.h                         | 15 +++++++++-
 src/wallet/wallet2.cpp                        |  4 +--
 tests/core_tests/multisig.cpp                 |  2 +-
 tests/unit_tests/device.cpp                   | 22 +++++++++++++--
 tests/unit_tests/ringct.cpp                   |  4 +--
 14 files changed, 93 insertions(+), 34 deletions(-)

diff --git a/src/cryptonote_basic/cryptonote_boost_serialization.h b/src/cryptonote_basic/cryptonote_boost_serialization.h
index e3d0ec18f..6f26d8756 100644
--- a/src/cryptonote_basic/cryptonote_boost_serialization.h
+++ b/src/cryptonote_basic/cryptonote_boost_serialization.h
@@ -45,6 +45,8 @@
 #include "ringct/rctTypes.h"
 #include "ringct/rctOps.h"
 
+BOOST_CLASS_VERSION(rct::ecdhTuple, 1)
+
 //namespace cryptonote {
 namespace boost
 {
@@ -248,7 +250,15 @@ namespace boost
   inline void serialize(Archive &a, rct::ecdhTuple &x, const boost::serialization::version_type ver)
   {
     a & x.mask;
-    a & x.amount;
+    if (ver < 1)
+    {
+      a & x.amount;
+      return;
+    }
+    crypto::hash8 &amount = (crypto::hash8&)x.amount;
+    if (!Archive::is_saving::value)
+      memset(&x.amount, 0, sizeof(x.amount));
+    a & amount;
     // a & x.senderPk; // not serialized, as we do not use it in monero currently
   }
 
diff --git a/src/device/device.hpp b/src/device/device.hpp
index 399648f01..bdb608907 100644
--- a/src/device/device.hpp
+++ b/src/device/device.hpp
@@ -208,8 +208,8 @@ namespace hw {
             return encrypt_payment_id(payment_id, public_key, secret_key);
         }
 
-        virtual bool  ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec) = 0;
-        virtual bool  ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec) = 0;
+        virtual bool  ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec, bool short_amount) = 0;
+        virtual bool  ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec, bool short_amount) = 0;
 
         virtual bool  add_output_key_mapping(const crypto::public_key &Aout, const crypto::public_key &Bout, const bool is_subaddress, const size_t real_output_index,
                                              const rct::key &amount_key,  const crypto::public_key &out_eph_public_key) = 0;
diff --git a/src/device/device_default.cpp b/src/device/device_default.cpp
index 2286998a4..cb2f4e266 100644
--- a/src/device/device_default.cpp
+++ b/src/device/device_default.cpp
@@ -302,13 +302,13 @@ namespace hw {
             return true;
         }
 
-        bool  device_default::ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec) {
-            rct::ecdhEncode(unmasked, sharedSec);
+        bool  device_default::ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec, bool short_amount) {
+            rct::ecdhEncode(unmasked, sharedSec, short_amount);
             return true;
         }
 
-        bool  device_default::ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec) {
-            rct::ecdhDecode(masked, sharedSec);
+        bool  device_default::ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec, bool short_amount) {
+            rct::ecdhDecode(masked, sharedSec, short_amount);
             return true;
         }
 
diff --git a/src/device/device_default.hpp b/src/device/device_default.hpp
index 5c59a9066..54d159b11 100644
--- a/src/device/device_default.hpp
+++ b/src/device/device_default.hpp
@@ -111,8 +111,8 @@ namespace hw {
 
             bool  encrypt_payment_id(crypto::hash8 &payment_id, const crypto::public_key &public_key, const crypto::secret_key &secret_key) override;
 
-            bool  ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec) override;
-            bool  ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec) override;
+            bool  ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec, bool short_amount) override;
+            bool  ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec, bool short_amount) override;
 
             bool  add_output_key_mapping(const crypto::public_key &Aout, const crypto::public_key &Bout, const bool is_subaddress, const size_t real_output_index,
                                          const rct::key &amount_key,  const crypto::public_key &out_eph_public_key) override;
diff --git a/src/device/device_ledger.cpp b/src/device/device_ledger.cpp
index bfb41bbe4..afe3bea62 100644
--- a/src/device/device_ledger.cpp
+++ b/src/device/device_ledger.cpp
@@ -1140,13 +1140,13 @@ namespace hw {
         return true;
     }
 
-    bool  device_ledger::ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & AKout) {
+    bool  device_ledger::ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & AKout, bool short_amount) {
         AUTO_LOCK_CMD();
 
         #ifdef DEBUG_HWDEVICE
         const rct::key AKout_x =   hw::ledger::decrypt(AKout);
         rct::ecdhTuple unmasked_x = unmasked;
-        this->controle_device->ecdhEncode(unmasked_x, AKout_x);
+        this->controle_device->ecdhEncode(unmasked_x, AKout_x, short_amount);
         #endif
 
         int offset = set_command_header_noopt(INS_BLIND);
@@ -1177,13 +1177,13 @@ namespace hw {
         return true;
     }
 
-    bool  device_ledger::ecdhDecode(rct::ecdhTuple & masked, const rct::key & AKout) {
+    bool  device_ledger::ecdhDecode(rct::ecdhTuple & masked, const rct::key & AKout, bool short_amount) {
         AUTO_LOCK_CMD();
 
         #ifdef DEBUG_HWDEVICE
         const rct::key AKout_x =   hw::ledger::decrypt(AKout);
         rct::ecdhTuple masked_x = masked;
-        this->controle_device->ecdhDecode(masked_x, AKout_x);
+        this->controle_device->ecdhDecode(masked_x, AKout_x, short_amount);
         #endif
 
         int offset = set_command_header_noopt(INS_UNBLIND);
diff --git a/src/device/device_ledger.hpp b/src/device/device_ledger.hpp
index 2f5beb044..88c9419a8 100644
--- a/src/device/device_ledger.hpp
+++ b/src/device/device_ledger.hpp
@@ -191,8 +191,8 @@ namespace hw {
 
         bool  encrypt_payment_id(crypto::hash8 &payment_id, const crypto::public_key &public_key, const crypto::secret_key &secret_key) override;
 
-        bool  ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec) override;
-        bool  ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec) override;
+        bool  ecdhEncode(rct::ecdhTuple & unmasked, const rct::key & sharedSec, bool short_format) override;
+        bool  ecdhDecode(rct::ecdhTuple & masked, const rct::key & sharedSec, bool short_format) override;
 
         bool  add_output_key_mapping(const crypto::public_key &Aout, const crypto::public_key &Bout, const bool is_subaddress, const size_t real_output_index,
                                      const rct::key &amount_key,  const crypto::public_key &out_eph_public_key) override;
diff --git a/src/ringct/rctOps.cpp b/src/ringct/rctOps.cpp
index 0ec654af6..b28aa4fe6 100644
--- a/src/ringct/rctOps.cpp
+++ b/src/ringct/rctOps.cpp
@@ -670,18 +670,38 @@ namespace rct {
 
     //Elliptic Curve Diffie Helman: encodes and decodes the amount b and mask a
     // where C= aG + bH
-    void ecdhEncode(ecdhTuple & unmasked, const key & sharedSec) {
+    static key ecdhHash(const key &k)
+    {
+      char data[38];
+      rct::key hash;
+      memcpy(data, "amount", 6);
+      memcpy(data + 6, &k, sizeof(k));
+      cn_fast_hash(hash, data, sizeof(data));
+      return hash;
+    }
+    static void xor8(key &v, const key &k)
+    {
+      for (int i = 0; i < 8; ++i)
+        v.bytes[i] ^= k.bytes[i];
+    }
+    void ecdhEncode(ecdhTuple & unmasked, const key & sharedSec, bool short_amount) {
         key sharedSec1 = hash_to_scalar(sharedSec);
         key sharedSec2 = hash_to_scalar(sharedSec1);
         //encode
         sc_add(unmasked.mask.bytes, unmasked.mask.bytes, sharedSec1.bytes);
-        sc_add(unmasked.amount.bytes, unmasked.amount.bytes, sharedSec2.bytes);
+        if (short_amount)
+          xor8(unmasked.amount, ecdhHash(sharedSec));
+        else
+          sc_add(unmasked.amount.bytes, unmasked.amount.bytes, sharedSec2.bytes);
     }
-    void ecdhDecode(ecdhTuple & masked, const key & sharedSec) {
+    void ecdhDecode(ecdhTuple & masked, const key & sharedSec, bool short_amount) {
         key sharedSec1 = hash_to_scalar(sharedSec);
         key sharedSec2 = hash_to_scalar(sharedSec1);
         //decode
         sc_sub(masked.mask.bytes, masked.mask.bytes, sharedSec1.bytes);
-        sc_sub(masked.amount.bytes, masked.amount.bytes, sharedSec2.bytes);
+        if (short_amount)
+          xor8(masked.amount, ecdhHash(sharedSec));
+        else
+          sc_sub(masked.amount.bytes, masked.amount.bytes, sharedSec2.bytes);
     }
 }
diff --git a/src/ringct/rctOps.h b/src/ringct/rctOps.h
index 60e920b3a..01cdd6fd7 100644
--- a/src/ringct/rctOps.h
+++ b/src/ringct/rctOps.h
@@ -182,7 +182,7 @@ namespace rct {
 
     //Elliptic Curve Diffie Helman: encodes and decodes the amount b and mask a
     // where C= aG + bH
-    void ecdhEncode(ecdhTuple & unmasked, const key & sharedSec);
-    void ecdhDecode(ecdhTuple & masked, const key & sharedSec);
+    void ecdhEncode(ecdhTuple & unmasked, const key & sharedSec, bool short_amount);
+    void ecdhDecode(ecdhTuple & masked, const key & sharedSec, bool short_amount);
 }
 #endif  /* RCTOPS_H */
diff --git a/src/ringct/rctSigs.cpp b/src/ringct/rctSigs.cpp
index 298afd0d9..6687c91cd 100644
--- a/src/ringct/rctSigs.cpp
+++ b/src/ringct/rctSigs.cpp
@@ -716,7 +716,7 @@ namespace rct {
             //mask amount and mask
             rv.ecdhInfo[i].mask = copy(outSk[i].mask);
             rv.ecdhInfo[i].amount = d2h(amounts[i]);
-            hwdev.ecdhEncode(rv.ecdhInfo[i], amount_keys[i]);
+            hwdev.ecdhEncode(rv.ecdhInfo[i], amount_keys[i], rv.type == RCTTypeBulletproof2);
         }
 
         //set txn fee
@@ -853,7 +853,7 @@ namespace rct {
             //mask amount and mask
             rv.ecdhInfo[i].mask = copy(outSk[i].mask);
             rv.ecdhInfo[i].amount = d2h(outamounts[i]);
-            hwdev.ecdhEncode(rv.ecdhInfo[i], amount_keys[i]);
+            hwdev.ecdhEncode(rv.ecdhInfo[i], amount_keys[i], rv.type == RCTTypeBulletproof2);
         }
             
         //set txn fee
@@ -1151,7 +1151,7 @@ namespace rct {
 
         //mask amount and mask
         ecdhTuple ecdh_info = rv.ecdhInfo[i];
-        hwdev.ecdhDecode(ecdh_info, sk);
+        hwdev.ecdhDecode(ecdh_info, sk, rv.type == RCTTypeBulletproof2);
         mask = ecdh_info.mask;
         key amount = ecdh_info.amount;
         key C = rv.outPk[i].mask;
@@ -1181,7 +1181,7 @@ namespace rct {
 
         //mask amount and mask
         ecdhTuple ecdh_info = rv.ecdhInfo[i];
-        hwdev.ecdhDecode(ecdh_info, sk);
+        hwdev.ecdhDecode(ecdh_info, sk, rv.type == RCTTypeBulletproof2);
         mask = ecdh_info.mask;
         key amount = ecdh_info.amount;
         key C = rv.outPk[i].mask;
diff --git a/src/ringct/rctTypes.h b/src/ringct/rctTypes.h
index 5578a51dc..54fca1d05 100644
--- a/src/ringct/rctTypes.h
+++ b/src/ringct/rctTypes.h
@@ -283,7 +283,20 @@ namespace rct {
             return false;
           for (size_t i = 0; i < outputs; ++i)
           {
-            FIELDS(ecdhInfo[i])
+            if (type == RCTTypeBulletproof2)
+            {
+              ar.begin_object();
+              FIELD_N("mask", ecdhInfo[i].mask);
+              if (!typename Archive<W>::is_saving())
+                memset(ecdhInfo[i].amount.bytes, 0, sizeof(ecdhInfo[i].amount.bytes));
+              crypto::hash8 &amount = (crypto::hash8&)ecdhInfo[i].amount;
+              FIELD(amount);
+              ar.end_object();
+            }
+            else
+            {
+              FIELDS(ecdhInfo[i])
+            }
             if (outputs - i > 1)
               ar.delimit_array();
           }
diff --git a/src/wallet/wallet2.cpp b/src/wallet/wallet2.cpp
index 32bf4370a..7517c1d99 100644
--- a/src/wallet/wallet2.cpp
+++ b/src/wallet/wallet2.cpp
@@ -10143,7 +10143,7 @@ void wallet2::check_tx_key_helper(const crypto::hash &txid, const crypto::key_de
         crypto::secret_key scalar1;
         hwdev.derivation_to_scalar(found_derivation, n, scalar1);
         rct::ecdhTuple ecdh_info = tx.rct_signatures.ecdhInfo[n];
-        hwdev.ecdhDecode(ecdh_info, rct::sk2rct(scalar1));
+        hwdev.ecdhDecode(ecdh_info, rct::sk2rct(scalar1), tx.rct_signatures.type == rct::RCTTypeBulletproof2);
         const rct::key C = tx.rct_signatures.outPk[n].mask;
         rct::key Ctmp;
         THROW_WALLET_EXCEPTION_IF(sc_check(ecdh_info.mask.bytes) != 0, error::wallet_internal_error, "Bad ECDH input mask");
@@ -10648,7 +10648,7 @@ bool wallet2::check_reserve_proof(const cryptonote::account_public_address &addr
       crypto::secret_key shared_secret;
       crypto::derivation_to_scalar(derivation, proof.index_in_tx, shared_secret);
       rct::ecdhTuple ecdh_info = tx.rct_signatures.ecdhInfo[proof.index_in_tx];
-      rct::ecdhDecode(ecdh_info, rct::sk2rct(shared_secret));
+      rct::ecdhDecode(ecdh_info, rct::sk2rct(shared_secret), tx.rct_signatures.type == rct::RCTTypeBulletproof2);
       amount = rct::h2d(ecdh_info.amount);
     }
     total += amount;
diff --git a/tests/core_tests/multisig.cpp b/tests/core_tests/multisig.cpp
index 7dd6a89e2..37fda6643 100644
--- a/tests/core_tests/multisig.cpp
+++ b/tests/core_tests/multisig.cpp
@@ -455,7 +455,7 @@ bool gen_multisig_tx_validation_base::generate_with(std::vector<test_event_entry
       crypto::secret_key scalar1;
       crypto::derivation_to_scalar(derivation, n, scalar1);
       rct::ecdhTuple ecdh_info = tx.rct_signatures.ecdhInfo[n];
-      rct::ecdhDecode(ecdh_info, rct::sk2rct(scalar1));
+      rct::ecdhDecode(ecdh_info, rct::sk2rct(scalar1), tx.rct_signatures.type == rct::RCTTypeBulletproof2);
       rct::key C = tx.rct_signatures.outPk[n].mask;
       rct::addKeys2(Ctmp, ecdh_info.mask, ecdh_info.amount, rct::H);
       CHECK_AND_ASSERT_MES(rct::equalKeys(C, Ctmp), false, "Failed to decode amount");
diff --git a/tests/unit_tests/device.cpp b/tests/unit_tests/device.cpp
index 50ccec9fa..3ae748145 100644
--- a/tests/unit_tests/device.cpp
+++ b/tests/unit_tests/device.cpp
@@ -114,7 +114,7 @@ TEST(device, ops)
   ASSERT_EQ(ki0, ki1);
 }
 
-TEST(device, ecdh)
+TEST(device, ecdh32)
 {
   hw::core::device_default dev;
   rct::ecdhTuple tuple, tuple2;
@@ -123,8 +123,24 @@ TEST(device, ecdh)
   tuple.amount = rct::skGen();
   tuple.senderPk = rct::pkGen();
   tuple2 = tuple;
-  dev.ecdhEncode(tuple, key);
-  dev.ecdhDecode(tuple, key);
+  dev.ecdhEncode(tuple, key, false);
+  dev.ecdhDecode(tuple, key, false);
+  ASSERT_EQ(tuple2.mask, tuple.mask);
+  ASSERT_EQ(tuple2.amount, tuple.amount);
+  ASSERT_EQ(tuple2.senderPk, tuple.senderPk);
+}
+
+TEST(device, ecdh8)
+{
+  hw::core::device_default dev;
+  rct::ecdhTuple tuple, tuple2;
+  rct::key key = rct::skGen();
+  tuple.mask = rct::skGen();
+  tuple.amount = rct::skGen();
+  tuple.senderPk = rct::pkGen();
+  tuple2 = tuple;
+  dev.ecdhEncode(tuple, key, true);
+  dev.ecdhDecode(tuple, key, true);
   ASSERT_EQ(tuple2.mask, tuple.mask);
   ASSERT_EQ(tuple2.amount, tuple.amount);
   ASSERT_EQ(tuple2.senderPk, tuple.senderPk);
diff --git a/tests/unit_tests/ringct.cpp b/tests/unit_tests/ringct.cpp
index ec704daa1..905b8471a 100644
--- a/tests/unit_tests/ringct.cpp
+++ b/tests/unit_tests/ringct.cpp
@@ -843,8 +843,8 @@ TEST(ringct, ecdh_roundtrip)
     t0.amount = d2h(amount);
 
     t1 = t0;
-    ecdhEncode(t1, k);
-    ecdhDecode(t1, k);
+    ecdhEncode(t1, k, true);
+    ecdhDecode(t1, k, true);
     ASSERT_TRUE(t0.mask == t1.mask);
     ASSERT_TRUE(equalKeys(t0.mask, t1.mask));
     ASSERT_TRUE(t0.amount == t1.amount);