update ProtoOutputStream from upstream

This commit is contained in:
woodser 2024-01-01 20:20:13 -05:00
parent e7371d1299
commit 20b55ed9dd
3 changed files with 75 additions and 100 deletions

View file

@ -170,11 +170,11 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
Connection(Socket socket, Connection(Socket socket,
MessageListener messageListener, MessageListener messageListener,
ConnectionListener connectionListener, ConnectionListener connectionListener,
@Nullable NodeAddress peersNodeAddress, @Nullable NodeAddress peersNodeAddress,
NetworkProtoResolver networkProtoResolver, NetworkProtoResolver networkProtoResolver,
@Nullable BanFilter banFilter) { @Nullable BanFilter banFilter) {
this.socket = socket; this.socket = socket;
this.connectionListener = connectionListener; this.connectionListener = connectionListener;
this.banFilter = banFilter; this.banFilter = banFilter;
@ -200,7 +200,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
// When you construct an ObjectInputStream, in the constructor the class attempts to read a header that // When you construct an ObjectInputStream, in the constructor the class attempts to read a header that
// the associated ObjectOutputStream on the other end of the connection has written. // the associated ObjectOutputStream on the other end of the connection has written.
// It will not return until that header has been read. // It will not return until that header has been read.
protoOutputStream = new SynchronizedProtoOutputStream(socket.getOutputStream(), statistic); protoOutputStream = new ProtoOutputStream(socket.getOutputStream(), statistic);
protoInputStream = socket.getInputStream(); protoInputStream = socket.getInputStream();
// We create a thread for handling inputStream data // We create a thread for handling inputStream data
executorService.submit(this); executorService.submit(this);
@ -239,8 +239,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
if (banFilter != null && if (banFilter != null &&
peersNodeAddressOptional.isPresent() && peersNodeAddressOptional.isPresent() &&
banFilter.isPeerBanned(peersNodeAddressOptional.get())) { banFilter.isPeerBanned(peersNodeAddressOptional.get())) {
log.warn("We tried to send a message to a banned peer. message={}", log.warn("We tried to send a message to a banned peer. message={}", networkEnvelope.getClass().getSimpleName());
networkEnvelope.getClass().getSimpleName());
reportInvalidRequest(RuleViolation.PEER_BANNED); reportInvalidRequest(RuleViolation.PEER_BANNED);
return; return;
} }
@ -256,7 +255,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
long elapsed = now - lastSendTimeStamp; long elapsed = now - lastSendTimeStamp;
if (elapsed < getSendMsgThrottleTrigger()) { if (elapsed < getSendMsgThrottleTrigger()) {
log.debug("We got 2 sendMessage requests in less than {} ms. We set the thread to sleep " + log.debug("We got 2 sendMessage requests in less than {} ms. We set the thread to sleep " +
"for {} ms to avoid flooding our peer. lastSendTimeStamp={}, now={}, elapsed={}, networkEnvelope={}", "for {} ms to avoid flooding our peer. lastSendTimeStamp={}, now={}, elapsed={}, networkEnvelope={}",
getSendMsgThrottleTrigger(), getSendMsgThrottleSleep(), lastSendTimeStamp, now, elapsed, getSendMsgThrottleTrigger(), getSendMsgThrottleSleep(), lastSendTimeStamp, now, elapsed,
networkEnvelope.getClass().getSimpleName()); networkEnvelope.getClass().getSimpleName());
@ -437,6 +436,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
messageListeners.forEach(listener -> listener.onMessage(envelope, connection)))); messageListeners.forEach(listener -> listener.onMessage(envelope, connection))));
} }
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
// Setters // Setters
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
@ -456,6 +456,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
peersNodeAddressProperty.set(peerNodeAddress); peersNodeAddressProperty.set(peerNodeAddress);
} }
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
// Getters // Getters
/////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////
@ -498,7 +499,8 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
Uninterruptibles.sleepUninterruptibly(200, TimeUnit.MILLISECONDS); Uninterruptibles.sleepUninterruptibly(200, TimeUnit.MILLISECONDS);
} catch (Throwable t) { } catch (Throwable t) {
handleException(t); log.error(t.getMessage());
t.printStackTrace();
} finally { } finally {
stopped = true; stopped = true;
EXECUTOR.execute(() -> doShutDown(closeConnectionReason, shutDownCompleteHandler)); EXECUTOR.execute(() -> doShutDown(closeConnectionReason, shutDownCompleteHandler));
@ -611,8 +613,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
"connection with address{} and uid {}", ruleViolations, peersNodeAddressProperty, uid); "connection with address{} and uid {}", ruleViolations, peersNodeAddressProperty, uid);
this.ruleViolation = ruleViolation; this.ruleViolation = ruleViolation;
if (ruleViolation == RuleViolation.PEER_BANNED) { if (ruleViolation == RuleViolation.PEER_BANNED) {
log.debug("We close connection due RuleViolation.PEER_BANNED. peersNodeAddress={}", log.debug("We close connection due RuleViolation.PEER_BANNED. peersNodeAddress={}", getPeersNodeAddressOptional());
getPeersNodeAddressOptional());
shutDown(CloseConnectionReason.PEER_BANNED); shutDown(CloseConnectionReason.PEER_BANNED);
} else if (ruleViolation == RuleViolation.INVALID_CLASS) { } else if (ruleViolation == RuleViolation.INVALID_CLASS) {
log.warn("We close connection due RuleViolation.INVALID_CLASS"); log.warn("We close connection due RuleViolation.INVALID_CLASS");
@ -655,8 +656,8 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
// TODO sometimes we get StreamCorruptedException, OptionalDataException, IllegalStateException // TODO sometimes we get StreamCorruptedException, OptionalDataException, IllegalStateException
closeConnectionReason = CloseConnectionReason.UNKNOWN_EXCEPTION; closeConnectionReason = CloseConnectionReason.UNKNOWN_EXCEPTION;
log.warn("Unknown reason for exception at socket: {}\n\t" + log.warn("Unknown reason for exception at socket: {}\n\t" +
"peer={}\n\t" + "peer={}\n\t" +
"Exception={}", "Exception={}",
socket.toString(), socket.toString(),
this.peersNodeAddressOptional, this.peersNodeAddressOptional,
e.toString()); e.toString());
@ -756,7 +757,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
long elapsed = now - lastReadTimeStamp; long elapsed = now - lastReadTimeStamp;
if (elapsed < 10) { if (elapsed < 10) {
log.debug("We got 2 network messages received in less than 10 ms. We set the thread to sleep " + log.debug("We got 2 network messages received in less than 10 ms. We set the thread to sleep " +
"for 20 ms to avoid getting flooded by our peer. lastReadTimeStamp={}, now={}, elapsed={}", "for 20 ms to avoid getting flooded by our peer. lastReadTimeStamp={}, now={}, elapsed={}",
lastReadTimeStamp, now, elapsed); lastReadTimeStamp, now, elapsed);
Thread.sleep(20); Thread.sleep(20);
} }
@ -803,7 +804,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
if (!proto.getMessageVersion().equals(Version.getP2PMessageVersion()) if (!proto.getMessageVersion().equals(Version.getP2PMessageVersion())
&& reportInvalidRequest(RuleViolation.WRONG_NETWORK_ID)) { && reportInvalidRequest(RuleViolation.WRONG_NETWORK_ID)) {
log.warn("RuleViolation.WRONG_NETWORK_ID. version of message={}, app version={}, " + log.warn("RuleViolation.WRONG_NETWORK_ID. version of message={}, app version={}, " +
"proto.toTruncatedString={}", proto.getMessageVersion(), "proto.toTruncatedString={}", proto.getMessageVersion(),
Version.getP2PMessageVersion(), Version.getP2PMessageVersion(),
Utilities.toTruncatedString(proto.toString())); Utilities.toTruncatedString(proto.toString()));
return; return;
@ -821,8 +822,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
if (CloseConnectionReason.PEER_BANNED.name().equals(proto.getCloseConnectionMessage().getReason())) { if (CloseConnectionReason.PEER_BANNED.name().equals(proto.getCloseConnectionMessage().getReason())) {
log.warn("We got shut down because we are banned by the other peer. " + log.warn("We got shut down because we are banned by the other peer. " +
"(InputHandler.run CloseConnectionMessage). Peer: {}", "(InputHandler.run CloseConnectionMessage). Peer: {}", getPeersNodeAddressOptional());
getPeersNodeAddressOptional());
} }
shutDown(CloseConnectionReason.CLOSE_REQUESTED_BY_PEER); shutDown(CloseConnectionReason.CLOSE_REQUESTED_BY_PEER);
return; return;
@ -841,8 +841,7 @@ public class Connection implements HasCapabilities, Runnable, MessageListener {
} }
if (!(networkEnvelope instanceof SendersNodeAddressMessage) && peersNodeAddressOptional.isEmpty()) { if (!(networkEnvelope instanceof SendersNodeAddressMessage) && peersNodeAddressOptional.isEmpty()) {
log.info("We got a {} from a peer with yet unknown address on connection with uid={}", log.info("We got a {} from a peer with yet unknown address on connection with uid={}", networkEnvelope.getClass().getSimpleName(), uid);
networkEnvelope.getClass().getSimpleName(), uid);
} }
EXECUTOR.execute(() -> onMessage(networkEnvelope, this)); EXECUTOR.execute(() -> onMessage(networkEnvelope, this));

View file

@ -17,49 +17,84 @@
package haveno.network.p2p.network; package haveno.network.p2p.network;
import haveno.common.proto.network.NetworkEnvelope;
import haveno.network.p2p.peers.keepalive.messages.KeepAliveMessage; import haveno.network.p2p.peers.keepalive.messages.KeepAliveMessage;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.NotThreadSafe; import haveno.common.proto.network.NetworkEnvelope;
import java.io.IOException; import java.io.IOException;
import java.io.OutputStream; import java.io.OutputStream;
@NotThreadSafe import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.ThreadSafe;
@ThreadSafe
class ProtoOutputStream { class ProtoOutputStream {
private static final Logger log = LoggerFactory.getLogger(ProtoOutputStream.class); private static final Logger log = LoggerFactory.getLogger(ProtoOutputStream.class);
private final OutputStream delegate; private final OutputStream outputStream;
private final Statistic statistic; private final Statistic statistic;
ProtoOutputStream(OutputStream delegate, Statistic statistic) { private final AtomicBoolean isConnectionActive = new AtomicBoolean(true);
this.delegate = delegate; private final Lock lock = new ReentrantLock();
ProtoOutputStream(OutputStream outputStream, Statistic statistic) {
this.outputStream = outputStream;
this.statistic = statistic; this.statistic = statistic;
} }
void writeEnvelope(NetworkEnvelope envelope) { void writeEnvelope(NetworkEnvelope envelope) {
lock.lock();
try { try {
writeEnvelopeOrThrow(envelope); writeEnvelopeOrThrow(envelope);
} catch (IOException e) { } catch (IOException e) {
if (!isConnectionActive.get()) {
// Connection was closed by us.
return;
}
log.error("Failed to write envelope", e); log.error("Failed to write envelope", e);
throw new HavenoRuntimeException("Failed to write envelope", e); throw new HavenoRuntimeException("Failed to write envelope", e);
} finally {
lock.unlock();
} }
} }
void onConnectionShutdown() { void onConnectionShutdown() {
isConnectionActive.set(false);
boolean acquiredLock = tryToAcquireLock();
if (!acquiredLock) {
return;
}
try { try {
delegate.close(); outputStream.close();
} catch (Throwable t) { } catch (Throwable t) {
log.error("Failed to close connection", t); log.error("Failed to close connection", t);
} finally {
lock.unlock();
} }
} }
private void writeEnvelopeOrThrow(NetworkEnvelope envelope) throws IOException { private void writeEnvelopeOrThrow(NetworkEnvelope envelope) throws IOException {
long ts = System.currentTimeMillis();
protobuf.NetworkEnvelope proto = envelope.toProtoNetworkEnvelope(); protobuf.NetworkEnvelope proto = envelope.toProtoNetworkEnvelope();
proto.writeDelimitedTo(delegate); proto.writeDelimitedTo(outputStream);
delegate.flush(); outputStream.flush();
long duration = System.currentTimeMillis() - ts;
if (duration > 10000) {
log.info("Sending {} to peer took {} sec.", envelope.getClass().getSimpleName(), duration / 1000d);
}
statistic.addSentBytes(proto.getSerializedSize()); statistic.addSentBytes(proto.getSerializedSize());
statistic.addSentMessage(envelope); statistic.addSentMessage(envelope);
@ -67,4 +102,13 @@ class ProtoOutputStream {
statistic.updateLastActivityTimestamp(); statistic.updateLastActivityTimestamp();
} }
} }
private boolean tryToAcquireLock() {
long shutdownTimeout = Connection.getShutdownTimeout();
try {
return lock.tryLock(shutdownTimeout, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
return false;
}
}
} }

View file

@ -1,68 +0,0 @@
/*
* This file is part of Haveno.
*
* Haveno is free software: you can redistribute it and/or modify it
* under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or (at
* your option) any later version.
*
* Haveno is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU Affero General Public
* License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Haveno. If not, see <http://www.gnu.org/licenses/>.
*/
package haveno.network.p2p.network;
import haveno.common.proto.network.NetworkEnvelope;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import javax.annotation.concurrent.ThreadSafe;
import java.io.OutputStream;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
@ThreadSafe
class SynchronizedProtoOutputStream extends ProtoOutputStream {
private static final Logger log = LoggerFactory.getLogger(SynchronizedProtoOutputStream.class);
private final ExecutorService executorService;
SynchronizedProtoOutputStream(OutputStream delegate, Statistic statistic) {
super(delegate, statistic);
this.executorService = Executors.newSingleThreadExecutor();
}
@Override
void writeEnvelope(NetworkEnvelope envelope) {
Future<?> future = executorService.submit(() -> super.writeEnvelope(envelope));
try {
future.get();
} catch (InterruptedException e) {
Thread currentThread = Thread.currentThread();
currentThread.interrupt();
String msg = "Thread " + currentThread + " was interrupted. InterruptedException=" + e;
log.error(msg);
throw new HavenoRuntimeException(msg, e);
} catch (ExecutionException e) {
String msg = "Failed to write envelope. ExecutionException " + e;
log.error(msg);
throw new HavenoRuntimeException(msg, e);
}
}
void onConnectionShutdown() {
try {
executorService.shutdownNow();
super.onConnectionShutdown();
} catch (Throwable t) {
log.error("Failed to handle connection shutdown. Throwable={}", t.toString());
}
}
}