From 98decee8accd219527e710439353f5245348371d Mon Sep 17 00:00:00 2001
From: j-berman <justinberman@protonmail.com>
Date: Sat, 29 Jun 2024 10:31:51 -0700
Subject: [PATCH] wallet2: remove refresh() from scan_tx

Fixes #9354
---
 src/wallet/wallet2.cpp             | 10 +++-----
 tests/functional_tests/transfer.py | 39 +++++++++++++++---------------
 2 files changed, 23 insertions(+), 26 deletions(-)

diff --git a/src/wallet/wallet2.cpp b/src/wallet/wallet2.cpp
index 74b19df3c..04be12c13 100644
--- a/src/wallet/wallet2.cpp
+++ b/src/wallet/wallet2.cpp
@@ -1881,7 +1881,7 @@ void wallet2::scan_tx(const std::unordered_set<crypto::hash> &txids)
   // TODO: handle this sweep case
   detached_blockchain_data dbd;
   dbd.original_chain_size = m_blockchain.size();
-  if (m_blockchain.size() > txs_to_scan.lowest_height)
+  if (txs_to_scan.highest_height > 0)
   {
     // When connected to an untrusted daemon, if we will need to re-process 1+
     // tx that the user did not request to scan, then we fail out because
@@ -1920,7 +1920,7 @@ void wallet2::scan_tx(const std::unordered_set<crypto::hash> &txids)
   if (skip_to_height > m_blockchain.size())
   {
     m_skip_to_height = skip_to_height;
-    LOG_PRINT_L0("Skipping refresh to height " << skip_to_height);
+    LOG_PRINT_L0("Next refresh will skip to height " << skip_to_height);
 
     // update last block reward here because the refresh loop won't necessarily set it
     try
@@ -1932,9 +1932,7 @@ void wallet2::scan_tx(const std::unordered_set<crypto::hash> &txids)
     }
     catch (...) { MERROR("Failed getting block header at height " << txs_to_scan.highest_height); }
 
-    // TODO: use fast_refresh instead of refresh to update m_blockchain. It needs refactoring to work correctly here.
-    // Or don't refresh at all, and let it update on the next refresh loop.
-    refresh(is_trusted_daemon());
+    // The wallet's blockchain state will now sync from the expected height correctly on next refresh loop
   }
 }
 //----------------------------------------------------------------------------------------------------
@@ -4346,7 +4344,7 @@ wallet2::detached_blockchain_data wallet2::detach_blockchain(uint64_t height, st
 
   uint64_t blocks_detached = 0;
   dbd.original_chain_size = m_blockchain.size();
-  if (height >= m_blockchain.offset())
+  if (height <= m_blockchain.size() && height >= m_blockchain.offset())
   {
     for (uint64_t i = height; i < m_blockchain.size(); ++i)
       dbd.detached_blockchain.push_back(m_blockchain[i]);
diff --git a/tests/functional_tests/transfer.py b/tests/functional_tests/transfer.py
index ef80dc739..03dfd0397 100755
--- a/tests/functional_tests/transfer.py
+++ b/tests/functional_tests/transfer.py
@@ -888,12 +888,16 @@ class TransferTest():
 
         print('Testing scan_tx')
 
+        def restore_wallet(wallet, seed, restore_height = 0):
+            try: wallet.close_wallet()
+            except: pass
+            wallet.restore_deterministic_wallet(seed = seed, restore_height = restore_height)
+            wallet.auto_refresh(enable = False)
+            assert wallet.get_transfers() == {}
+
         # set up sender_wallet
         sender_wallet = self.wallet[0]
-        try: sender_wallet.close_wallet()
-        except: pass
-        sender_wallet.restore_deterministic_wallet(seed = seeds[0])
-        sender_wallet.auto_refresh(enable = False)
+        restore_wallet(sender_wallet, seeds[0])
         sender_wallet.refresh()
         res = sender_wallet.get_transfers()
         out_len = 0 if 'out' not in res else len(res.out)
@@ -903,10 +907,7 @@ class TransferTest():
 
         # set up receiver_wallet
         receiver_wallet = self.wallet[1]
-        try: receiver_wallet.close_wallet()
-        except: pass
-        receiver_wallet.restore_deterministic_wallet(seed = seeds[1])
-        receiver_wallet.auto_refresh(enable = False)
+        restore_wallet(receiver_wallet, seeds[1])
         receiver_wallet.refresh()
         res = receiver_wallet.get_transfers()
         in_len = 0 if 'in' not in res else len(res['in'])
@@ -971,6 +972,7 @@ class TransferTest():
 
         print('Checking scan_tx on outgoing tx before refresh')
         sender_wallet.scan_tx([txid])
+        sender_wallet.refresh()
         res = sender_wallet.get_transfers()
         assert 'pending' not in res or len(res.pending) == 0
         assert 'pool' not in res or len (res.pool) == 0
@@ -1011,9 +1013,7 @@ class TransferTest():
         all_txs = out_txids + in_txids
         for test_type in ["all txs", "incoming first", "duplicates within", "duplicates across"]:
             print(test + ' (' + test_type + ')')
-            sender_wallet.close_wallet()
-            sender_wallet.restore_deterministic_wallet(seed = seeds[0], restore_height = height)
-            assert sender_wallet.get_transfers() == {}
+            restore_wallet(sender_wallet, seeds[0], height)
             if test_type == "all txs":
                 sender_wallet.scan_tx(all_txs)
             elif test_type == "incoming first":
@@ -1027,18 +1027,19 @@ class TransferTest():
                 sender_wallet.scan_tx(all_txs)
             else:
                 assert True == False
-            diff_transfers(sender_wallet.get_transfers(), res)
             assert sender_wallet.get_balance().balance == expected_sender_balance
+            sender_wallet.refresh()
+            diff_transfers(sender_wallet.get_transfers(), res)
 
         print('Sanity check against outgoing wallet restored at height 0')
-        sender_wallet.close_wallet()
-        sender_wallet.restore_deterministic_wallet(seed = seeds[0], restore_height = 0)
+        restore_wallet(sender_wallet, seeds[0], 0)
         sender_wallet.refresh()
         diff_transfers(sender_wallet.get_transfers(), res)
         assert sender_wallet.get_balance().balance == expected_sender_balance
 
         print('Checking scan_tx on incoming txs before refresh')
         receiver_wallet.scan_tx([txid, miner_txid])
+        receiver_wallet.refresh()
         res = receiver_wallet.get_transfers()
         assert 'pending' not in res or len(res.pending) == 0
         assert 'pool' not in res or len (res.pool) == 0
@@ -1071,20 +1072,18 @@ class TransferTest():
         txids = [x.txid for x in res['in']]
         if 'out' in res:
             txids = txids + [x.txid for x in res.out]
-        receiver_wallet.close_wallet()
-        receiver_wallet.restore_deterministic_wallet(seed = seeds[1], restore_height = height)
-        assert receiver_wallet.get_transfers() == {}
+        restore_wallet(receiver_wallet, seeds[1], height)
         receiver_wallet.scan_tx(txids)
         if 'out' in res:
             for i, out_tx in enumerate(res.out):
                 if 'destinations' in out_tx:
                     del res.out[i]['destinations'] # destinations are not expected after wallet restore
-        diff_transfers(receiver_wallet.get_transfers(), res)
         assert receiver_wallet.get_balance().balance == expected_receiver_balance
+        receiver_wallet.refresh()
+        diff_transfers(receiver_wallet.get_transfers(), res)
 
         print('Sanity check against incoming wallet restored at height 0')
-        receiver_wallet.close_wallet()
-        receiver_wallet.restore_deterministic_wallet(seed = seeds[1], restore_height = 0)
+        restore_wallet(receiver_wallet, seeds[1], 0)
         receiver_wallet.refresh()
         diff_transfers(receiver_wallet.get_transfers(), res)
         assert receiver_wallet.get_balance().balance == expected_receiver_balance