diff --git a/src/wallet/wallet2.cpp b/src/wallet/wallet2.cpp index 74b19df3c..04be12c13 100644 --- a/src/wallet/wallet2.cpp +++ b/src/wallet/wallet2.cpp @@ -1881,7 +1881,7 @@ void wallet2::scan_tx(const std::unordered_set &txids) // TODO: handle this sweep case detached_blockchain_data dbd; dbd.original_chain_size = m_blockchain.size(); - if (m_blockchain.size() > txs_to_scan.lowest_height) + if (txs_to_scan.highest_height > 0) { // When connected to an untrusted daemon, if we will need to re-process 1+ // tx that the user did not request to scan, then we fail out because @@ -1920,7 +1920,7 @@ void wallet2::scan_tx(const std::unordered_set &txids) if (skip_to_height > m_blockchain.size()) { m_skip_to_height = skip_to_height; - LOG_PRINT_L0("Skipping refresh to height " << skip_to_height); + LOG_PRINT_L0("Next refresh will skip to height " << skip_to_height); // update last block reward here because the refresh loop won't necessarily set it try @@ -1932,9 +1932,7 @@ void wallet2::scan_tx(const std::unordered_set &txids) } catch (...) { MERROR("Failed getting block header at height " << txs_to_scan.highest_height); } - // TODO: use fast_refresh instead of refresh to update m_blockchain. It needs refactoring to work correctly here. - // Or don't refresh at all, and let it update on the next refresh loop. - refresh(is_trusted_daemon()); + // The wallet's blockchain state will now sync from the expected height correctly on next refresh loop } } //---------------------------------------------------------------------------------------------------- @@ -4346,7 +4344,7 @@ wallet2::detached_blockchain_data wallet2::detach_blockchain(uint64_t height, st uint64_t blocks_detached = 0; dbd.original_chain_size = m_blockchain.size(); - if (height >= m_blockchain.offset()) + if (height <= m_blockchain.size() && height >= m_blockchain.offset()) { for (uint64_t i = height; i < m_blockchain.size(); ++i) dbd.detached_blockchain.push_back(m_blockchain[i]); diff --git a/tests/functional_tests/transfer.py b/tests/functional_tests/transfer.py index ef80dc739..03dfd0397 100755 --- a/tests/functional_tests/transfer.py +++ b/tests/functional_tests/transfer.py @@ -888,12 +888,16 @@ class TransferTest(): print('Testing scan_tx') + def restore_wallet(wallet, seed, restore_height = 0): + try: wallet.close_wallet() + except: pass + wallet.restore_deterministic_wallet(seed = seed, restore_height = restore_height) + wallet.auto_refresh(enable = False) + assert wallet.get_transfers() == {} + # set up sender_wallet sender_wallet = self.wallet[0] - try: sender_wallet.close_wallet() - except: pass - sender_wallet.restore_deterministic_wallet(seed = seeds[0]) - sender_wallet.auto_refresh(enable = False) + restore_wallet(sender_wallet, seeds[0]) sender_wallet.refresh() res = sender_wallet.get_transfers() out_len = 0 if 'out' not in res else len(res.out) @@ -903,10 +907,7 @@ class TransferTest(): # set up receiver_wallet receiver_wallet = self.wallet[1] - try: receiver_wallet.close_wallet() - except: pass - receiver_wallet.restore_deterministic_wallet(seed = seeds[1]) - receiver_wallet.auto_refresh(enable = False) + restore_wallet(receiver_wallet, seeds[1]) receiver_wallet.refresh() res = receiver_wallet.get_transfers() in_len = 0 if 'in' not in res else len(res['in']) @@ -971,6 +972,7 @@ class TransferTest(): print('Checking scan_tx on outgoing tx before refresh') sender_wallet.scan_tx([txid]) + sender_wallet.refresh() res = sender_wallet.get_transfers() assert 'pending' not in res or len(res.pending) == 0 assert 'pool' not in res or len (res.pool) == 0 @@ -1011,9 +1013,7 @@ class TransferTest(): all_txs = out_txids + in_txids for test_type in ["all txs", "incoming first", "duplicates within", "duplicates across"]: print(test + ' (' + test_type + ')') - sender_wallet.close_wallet() - sender_wallet.restore_deterministic_wallet(seed = seeds[0], restore_height = height) - assert sender_wallet.get_transfers() == {} + restore_wallet(sender_wallet, seeds[0], height) if test_type == "all txs": sender_wallet.scan_tx(all_txs) elif test_type == "incoming first": @@ -1027,18 +1027,19 @@ class TransferTest(): sender_wallet.scan_tx(all_txs) else: assert True == False - diff_transfers(sender_wallet.get_transfers(), res) assert sender_wallet.get_balance().balance == expected_sender_balance + sender_wallet.refresh() + diff_transfers(sender_wallet.get_transfers(), res) print('Sanity check against outgoing wallet restored at height 0') - sender_wallet.close_wallet() - sender_wallet.restore_deterministic_wallet(seed = seeds[0], restore_height = 0) + restore_wallet(sender_wallet, seeds[0], 0) sender_wallet.refresh() diff_transfers(sender_wallet.get_transfers(), res) assert sender_wallet.get_balance().balance == expected_sender_balance print('Checking scan_tx on incoming txs before refresh') receiver_wallet.scan_tx([txid, miner_txid]) + receiver_wallet.refresh() res = receiver_wallet.get_transfers() assert 'pending' not in res or len(res.pending) == 0 assert 'pool' not in res or len (res.pool) == 0 @@ -1071,20 +1072,18 @@ class TransferTest(): txids = [x.txid for x in res['in']] if 'out' in res: txids = txids + [x.txid for x in res.out] - receiver_wallet.close_wallet() - receiver_wallet.restore_deterministic_wallet(seed = seeds[1], restore_height = height) - assert receiver_wallet.get_transfers() == {} + restore_wallet(receiver_wallet, seeds[1], height) receiver_wallet.scan_tx(txids) if 'out' in res: for i, out_tx in enumerate(res.out): if 'destinations' in out_tx: del res.out[i]['destinations'] # destinations are not expected after wallet restore - diff_transfers(receiver_wallet.get_transfers(), res) assert receiver_wallet.get_balance().balance == expected_receiver_balance + receiver_wallet.refresh() + diff_transfers(receiver_wallet.get_transfers(), res) print('Sanity check against incoming wallet restored at height 0') - receiver_wallet.close_wallet() - receiver_wallet.restore_deterministic_wallet(seed = seeds[1], restore_height = 0) + restore_wallet(receiver_wallet, seeds[1], 0) receiver_wallet.refresh() diff_transfers(receiver_wallet.get_transfers(), res) assert receiver_wallet.get_balance().balance == expected_receiver_balance