From 4a7ecc51dfa02e7356f026b26e5bcb162a74caf9 Mon Sep 17 00:00:00 2001 From: Ross Lawley Date: Mon, 2 Feb 2026 14:42:24 +0000 Subject: [PATCH] ByteBuf leak fixes - Ensure Default Server Monitor calls close on resources before interrupt - Update ByteBufferBsonOutput documentation - Improve ReplyHeader testing and ensure resources are closed - Improve ServerSessionPool testing - Ensure reactive client session closing is idempotent - Added System.gc to unified test cleanup. Should cause more gc when testing. JAVA-6081 --- bson/src/main/org/bson/BsonDocument.java | 4 +- config/spotbugs/exclude.xml | 6 + .../connection/ByteBufferBsonOutput.java | 117 ++++++- .../connection/DefaultServerMonitor.java | 122 +++++-- .../ReplyHeaderSpecification.groovy | 201 ----------- .../internal/connection/ReplyHeaderTest.java | 199 +++++++++++ .../connection/DefaultServerMonitorTest.java | 51 +++ .../ServerSessionPoolSpecification.groovy | 229 ------------- .../session/ServerSessionPoolTest.java | 319 ++++++++++++++++++ .../main/com/mongodb/DBDecoderAdapter.java | 9 +- .../internal/ClientSessionPublisherImpl.java | 20 +- .../reactivestreams/client/Fixture.java | 18 +- .../client/AbstractSessionsProseTest.java | 4 +- .../mongodb/client/unified/UnifiedTest.java | 1 + 14 files changed, 824 insertions(+), 476 deletions(-) delete mode 100644 driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy create mode 100644 driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java delete mode 100644 driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy create mode 100644 driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java diff --git a/bson/src/main/org/bson/BsonDocument.java b/bson/src/main/org/bson/BsonDocument.java index 87625de8dbd..f52a1f25f7d 100644 --- a/bson/src/main/org/bson/BsonDocument.java +++ b/bson/src/main/org/bson/BsonDocument.java @@ -921,10 +921,12 @@ private static class SerializationProxy implements Serializable { new BsonDocumentCodec().encode(new BsonBinaryWriter(buffer), document, EncoderContext.builder().build()); this.bytes = new byte[buffer.size()]; int curPos = 0; - for (ByteBuf cur : buffer.getByteBuffers()) { + List byteBuffers = buffer.getByteBuffers(); + for (ByteBuf cur : byteBuffers) { System.arraycopy(cur.array(), cur.position(), bytes, curPos, cur.limit()); curPos += cur.position(); } + byteBuffers.forEach(ByteBuf::release); } private Object readResolve() { diff --git a/config/spotbugs/exclude.xml b/config/spotbugs/exclude.xml index 20684680865..d1647b1f149 100644 --- a/config/spotbugs/exclude.xml +++ b/config/spotbugs/exclude.xml @@ -290,4 +290,10 @@ + + + + + + diff --git a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java index 1edbb0f4c2f..0988679f28d 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java +++ b/driver-core/src/main/com/mongodb/internal/connection/ByteBufferBsonOutput.java @@ -16,6 +16,9 @@ package com.mongodb.internal.connection; +import com.mongodb.annotations.Sealed; +import com.mongodb.internal.ResourceUtil; +import com.mongodb.internal.VisibleForTesting; import org.bson.BsonSerializationException; import org.bson.ByteBuf; import org.bson.io.OutputBuffer; @@ -28,11 +31,28 @@ import static com.mongodb.assertions.Assertions.assertTrue; import static com.mongodb.assertions.Assertions.notNull; +import static com.mongodb.internal.VisibleForTesting.AccessModifier.PRIVATE; import static java.lang.String.format; /** + * A BSON output implementation that uses pooled {@link ByteBuf} instances for efficient memory management. + * + *

ByteBuf Ownership and Lifecycle

+ *

This class manages the lifecycle of {@link ByteBuf} instances obtained from the {@link BufferProvider}. + * The ownership model is as follows:

+ *
    + *
  • Internal buffers are owned by this output and released when {@link #close()} is called or + * when {@link #truncateToPosition(int)} removes them.
  • + *
  • Methods that return {@link ByteBuf} instances (e.g., {@link #getByteBuffers()}) return + * duplicates with their own reference counts. Callers are responsible for releasing + * these buffers to prevent memory leaks.
  • + *
  • The {@link Branch} subclass merges its buffers into the parent on close, transferring + * ownership by retaining buffers before the branch releases them.
  • + *
+ * *

This class is not part of the public API and may be removed or changed at any time

*/ +@Sealed public class ByteBufferBsonOutput extends OutputBuffer { private static final int MAX_SHIFT = 31; @@ -50,6 +70,9 @@ public class ByteBufferBsonOutput extends OutputBuffer { /** * Construct an instance that uses the given buffer provider to allocate byte buffers as needs as it grows. * + *

The buffer provider is used to allocate new {@link ByteBuf} instances as the output grows. + * All allocated buffers are owned by this output and will be released when {@link #close()} is called.

+ * * @param bufferProvider the non-null buffer provider */ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { @@ -63,6 +86,10 @@ public ByteBufferBsonOutput(final BufferProvider bufferProvider) { * If multiple branches are created, they are merged in the order they are {@linkplain ByteBufferBsonOutput.Branch#close() closed}. * {@linkplain #close() Closing} this {@link ByteBufferBsonOutput} does not {@linkplain ByteBufferBsonOutput.Branch#close() close} the branch. * + *

ByteBuf Ownership: The branch allocates its own buffers. When the branch is closed, + * ownership of these buffers is transferred to the parent by retaining them before the branch releases + * its references. The parent then becomes responsible for releasing these buffers when it is closed.

+ * * @return A new {@link ByteBufferBsonOutput.Branch}. */ public ByteBufferBsonOutput.Branch branch() { @@ -223,10 +250,28 @@ protected void write(final int absolutePosition, final int value) { byteBuffer.put(bufferPositionPair.position++, (byte) value); } + /** + * Returns a list of duplicated byte buffers containing the written data, flipped for reading. + * + *

ByteBuf Ownership: The returned buffers are duplicates with their own + * reference counts (each starts with a reference count of 1). The caller is responsible + * for releasing each buffer when done to prevent memory leaks. Example usage:

+ *
{@code
+     * List buffers = output.getByteBuffers();
+     * try {
+     *     // use buffers
+     * } finally {
+     *     ResourceUtil.release(buffers);
+     * }
+     * }
+ *

Note: These buffers must be released before this {@code ByteBufferBsonOutput} is closed. + * Otherwise there is a risk of the buffers being released back to the bufferProvider and data corruption.

+ * + * @return a list of duplicated buffers, flipped for reading + */ @Override public List getByteBuffers() { ensureOpen(); - List buffers = new ArrayList<>(bufferList.size()); for (final ByteBuf cur : bufferList) { buffers.add(cur.duplicate().order(ByteOrder.LITTLE_ENDIAN).flip()); @@ -234,6 +279,17 @@ public List getByteBuffers() { return buffers; } + /** + * Returns a list of duplicated byte buffers without flipping them. + * + *

ByteBuf Ownership: The returned buffers are duplicates with their own + * reference counts (each starts with a reference count of 1). The caller is responsible + * for releasing each buffer when done to prevent memory leaks.

+ * + * @return a list of duplicated buffers + * @see #getByteBuffers() + */ + @VisibleForTesting(otherwise = PRIVATE) public List getDuplicateByteBuffers() { ensureOpen(); @@ -245,6 +301,13 @@ public List getDuplicateByteBuffers() { } + /** + * {@inheritDoc} + * + *

ByteBuf Management: This method obtains duplicated buffers via + * {@link #getByteBuffers()} and releases them after writing to the output stream, + * ensuring no buffer leaks occur.

+ */ @Override public int pipe(final OutputStream out) throws IOException { ensureOpen(); @@ -263,11 +326,20 @@ public int pipe(final OutputStream out) throws IOException { total += cur.limit(); } } finally { - byteBuffers.forEach(ByteBuf::release); + ResourceUtil.release(byteBuffers); } return total; } + /** + * Truncates this output to the specified position, releasing any buffers that are no longer needed. + * + *

ByteBuf Management: Any buffers beyond the new position are removed from + * the internal buffer list and released. This ensures no memory leaks when truncating.

+ * + * @param newPosition the new position to truncate to + * @throws IllegalArgumentException if newPosition is negative or greater than the current position + */ @Override public void truncateToPosition(final int newPosition) { ensureOpen(); @@ -306,13 +378,15 @@ public final void flush() throws IOException { * {@inheritDoc} *

* Idempotent.

+ * + *

ByteBuf Management: Releases internal buffers and clears the buffer list. + * After this method returns, all buffers that were allocated by this output will have been fully released + * back to the buffer provider.

*/ @Override public void close() { if (isOpen()) { - for (final ByteBuf cur : bufferList) { - cur.release(); - } + ResourceUtil.release(bufferList); currentByteBuffer = null; bufferList.clear(); closed = true; @@ -345,7 +419,14 @@ boolean isOpen() { } /** - * @see #branch() + * Merges a branch's buffers into this output. + * + *

ByteBuf Ownership: This method retains each buffer from the branch before + * adding it to this output's buffer list. This is necessary because the branch will release its + * references when it closes. The retain ensures the buffers remain valid and are now owned by + * this output.

+ * + * @param branch the branch to merge */ private void merge(final ByteBufferBsonOutput branch) { assertTrue(branch instanceof ByteBufferBsonOutput.Branch); @@ -356,6 +437,20 @@ private void merge(final ByteBufferBsonOutput branch) { currentByteBuffer = null; } + /** + * A branch of a {@link ByteBufferBsonOutput} that can be merged back into its parent. + * + *

ByteBuf Ownership: A branch allocates its own buffers independently. + * When {@link #close()} is called:

+ *
    + *
  1. The parent's {@link ByteBufferBsonOutput#merge(ByteBufferBsonOutput)} method is called, + * which retains all buffers in this branch.
  2. + *
  3. Then {@code super.close()} is called, which releases the branch's references to the buffers.
  4. + *
+ *

The retain/release sequence ensures buffers are safely transferred to the parent without leaks.

+ * + * @see #branch() + */ public static final class Branch extends ByteBufferBsonOutput { private final ByteBufferBsonOutput parent; @@ -365,6 +460,16 @@ private Branch(final ByteBufferBsonOutput parent) { } /** + * Closes this branch and merges its data into the parent output. + * + *

ByteBuf Ownership: On close, this branch's buffers are transferred + * to the parent. The parent retains the buffers (incrementing reference counts), and then + * this branch releases only its own single reference. The parent + * becomes the sole owner of the buffers and is responsible for releasing them.

+ * + *

Idempotent. If already closed, this method does nothing.

+ * + * @throws AssertionError if the parent has been closed before this branch * @see #branch() */ @Override diff --git a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java index bb97517d315..6eeea6b030b 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java +++ b/driver-core/src/main/com/mongodb/internal/connection/DefaultServerMonitor.java @@ -187,11 +187,17 @@ class ServerMonitor extends Thread implements AutoCloseable { @Override public void close() { - interrupt(); - InternalConnection connection = this.connection; - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; + connection = null; + return result; + }); + + if (localConnection != null) { + localConnection.close(); } + + interrupt(); } @Override @@ -293,24 +299,36 @@ private ServerDescription setupNewConnectionAndGetInitialDescription(final boole */ private ServerDescription doHeartbeat(final ServerDescription currentServerDescription, final boolean shouldStreamResponses) { + // Check if monitor was closed or connection is unusable + InternalConnection localConnection = withLock(lock, () -> { + if (isClosed || connection == null || connection.isClosed()) { + return null; + } + return connection; + }); + + if (localConnection == null) { + throw new MongoSocketException("Monitor closed", serverId.getAddress()); + } + try { OperationContext operationContext = operationContextFactory.create(); - if (!connection.hasMoreToCome()) { + if (!localConnection.hasMoreToCome()) { BsonDocument helloDocument = new BsonDocument(getHandshakeCommandName(currentServerDescription), new BsonInt32(1)) .append("helloOk", BsonBoolean.TRUE); if (shouldStreamResponses) { helloDocument.append("topologyVersion", assertNotNull(currentServerDescription.getTopologyVersion()).asDocument()); helloDocument.append("maxAwaitTimeMS", new BsonInt64(serverSettings.getHeartbeatFrequency(MILLISECONDS))); } - connection.send(createCommandMessage(helloDocument, connection, currentServerDescription), new BsonDocumentCodec(), + localConnection.send(createCommandMessage(helloDocument, localConnection, currentServerDescription), new BsonDocumentCodec(), operationContext); } BsonDocument helloResult; if (shouldStreamResponses) { - helloResult = connection.receive(new BsonDocumentCodec(), operationContextWithAdditionalTimeout(operationContext)); + helloResult = localConnection.receive(new BsonDocumentCodec(), operationContextWithAdditionalTimeout(operationContext)); } else { - helloResult = connection.receive(new BsonDocumentCodec(), operationContext); + helloResult = localConnection.receive(new BsonDocumentCodec(), operationContext); } logAndNotifyHeartbeatSucceeded(shouldStreamResponses, helloResult); return createServerDescription(serverId.getAddress(), helloResult, roundTripTimeSampler.getAverage(), @@ -322,10 +340,18 @@ private ServerDescription doHeartbeat(final ServerDescription currentServerDescr } private void logAndNotifyHeartbeatStarted(final boolean shouldStreamResponses) { - alreadyLoggedHeartBeatStarted = true; - logHeartbeatStarted(serverId, connection.getDescription(), shouldStreamResponses); - serverMonitorListener.serverHearbeatStarted(new ServerHeartbeatStartedEvent( - connection.getDescription().getConnectionId(), shouldStreamResponses)); + ConnectionDescription description = connection.getDescription(); + if (description != null) { + alreadyLoggedHeartBeatStarted = true; + logHeartbeatStarted(serverId, description, shouldStreamResponses); + serverMonitorListener.serverHearbeatStarted(new ServerHeartbeatStartedEvent( + description.getConnectionId(), shouldStreamResponses)); + } else { + // Connection not fully established yet - skip logging for this heartbeat + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Skipping heartbeat started event for %s - connection description not available", serverId)); + } + } } private void logAndNotifyHeartbeatSucceeded(final boolean shouldStreamResponses, final BsonDocument helloResult) { @@ -343,12 +369,20 @@ private void logAndNotifyHeartbeatSucceeded(final boolean shouldStreamResponses, private void logAndNotifyHeartbeatFailed(final boolean shouldStreamResponses, final Exception e) { alreadyLoggedHeartBeatStarted = false; long elapsedTimeNanos = getElapsedTimeNanos(); - logHeartbeatFailed(serverId, connection.getDescription(), shouldStreamResponses, elapsedTimeNanos, e); - serverMonitorListener.serverHeartbeatFailed( - new ServerHeartbeatFailedEvent(connection.getDescription().getConnectionId(), elapsedTimeNanos, - shouldStreamResponses, e)); - } + ConnectionDescription description = connection != null ? connection.getDescription() : null; + if (description != null) { + logHeartbeatFailed(serverId, description, shouldStreamResponses, elapsedTimeNanos, e); + serverMonitorListener.serverHeartbeatFailed( + new ServerHeartbeatFailedEvent(description.getConnectionId(), elapsedTimeNanos, + shouldStreamResponses, e)); + } else { + // Log failure without connection details + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Heartbeat failed for %s but connection description not available", serverId), e); + } + } + } private long getElapsedTimeNanos() { return System.nanoTime() - lookupStartTimeNanos; } @@ -514,11 +548,17 @@ private class RoundTripTimeMonitor extends Thread implements AutoCloseable { @Override public void close() { - interrupt(); - InternalConnection connection = this.connection; - if (connection != null) { - connection.close(); + InternalConnection localConnection = withLock(lock, () -> { + InternalConnection result = connection; + connection = null; + return result; + }); + + if (localConnection != null) { + localConnection.close(); } + + interrupt(); } @Override @@ -552,13 +592,45 @@ public void run() { } private void initialize() { - connection = null; - connection = internalConnectionFactory.create(serverId); - connection.open(operationContextFactory.create()); - roundTripTimeSampler.addSample(connection.getInitialServerDescription().getRoundTripTimeNanos()); + boolean shouldProceed = withLock(lock, () -> !isClosed); + + if (!shouldProceed) { + return; + } + + InternalConnection newConnection = internalConnectionFactory.create(serverId); + newConnection.open(operationContextFactory.create()); + + // Check again after the potentially long open() operation + boolean stillValid = withLock(lock, () -> { + if (!isClosed) { + connection = newConnection; + return true; + } + return false; + }); + + if (stillValid) { + roundTripTimeSampler.addSample(newConnection.getInitialServerDescription().getRoundTripTimeNanos()); + } else { + // Monitor was closed during open(), clean up the connection + newConnection.close(); + } } private void pingServer(final InternalConnection connection) { + // Atomically check if monitor was closed and connection is still valid + boolean shouldProceed = withLock(lock, () -> + !isClosed && this.connection == connection + ); + + if (!shouldProceed) { + if (LOGGER.isDebugEnabled()) { + LOGGER.debug(format("Skipping ping for %s - monitor closed or connection changed", serverId)); + } + return; // Monitor closed or connection changed, skip ping + } + long start = System.nanoTime(); OperationContext operationContext = operationContextFactory.create(); executeCommand("admin", diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy deleted file mode 100644 index 0407baeca8a..00000000000 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderSpecification.groovy +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - - -package com.mongodb.internal.connection - -import com.mongodb.MongoInternalException -import org.bson.io.BasicOutputBuffer -import spock.lang.Specification - -import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize - -class ReplyHeaderSpecification extends Specification { - - def 'should parse reply header'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(responseFlags) - writeLong(9000) - writeInt(4) - writeInt(1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - def replyHeader = new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - replyHeader.messageLength == 186 - replyHeader.requestId == 45 - replyHeader.responseTo == 23 - - where: - responseFlags << [0, 1, 2, 3] - cursorNotFound << [false, true, false, true] - queryFailure << [false, false, true, true] - } - - def 'should parse reply header with compressed header'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(2012) - writeInt(1) - writeInt(258) - writeByte(2) - writeInt(responseFlags) - writeLong(9000) - writeInt(4) - writeInt(1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - def compressedHeader = new CompressedHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - when: - def replyHeader = new ReplyHeader(byteBuf, compressedHeader) - - then: - replyHeader.messageLength == 274 - replyHeader.requestId == 45 - replyHeader.responseTo == 23 - - where: - responseFlags << [0, 1, 2, 3] - cursorNotFound << [false, true, false, true] - queryFailure << [false, false, true, true] - } - - def 'should throw MongoInternalException on incorrect opCode'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(36) - writeInt(45) - writeInt(23) - writeInt(2) - writeInt(0) - writeLong(2) - writeInt(0) - writeInt(0) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'Unexpected reply message opCode 2' - } - - def 'should throw MongoInternalException on message size < 36'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(35) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(0) - writeLong(2) - writeInt(0) - writeInt(0) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message length 35 is less than the minimum message length 36' - } - - def 'should throw MongoInternalException on message size > max message size'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(400) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(0) - writeLong(2) - writeInt(0) - writeInt(0) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, 399)) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message length 400 is greater than the maximum message length 399' - } - - def 'should throw MongoInternalException on num documents < 0'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(1) - writeInt(1) - writeLong(9000) - writeInt(4) - writeInt(-1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - - when: - new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1' - } - - def 'should throw MongoInternalException on num documents < 0 with compressed header'() { - def outputBuffer = new BasicOutputBuffer() - outputBuffer.with { - writeInt(186) - writeInt(45) - writeInt(23) - writeInt(2012) - writeInt(1) - writeInt(258) - writeByte(2) - writeInt(1) - writeLong(9000) - writeInt(4) - writeInt(-1) - } - def byteBuf = outputBuffer.byteBuffers.get(0) - def compressedHeader = new CompressedHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())) - - when: - new ReplyHeader(byteBuf, compressedHeader) - - then: - def ex = thrown(MongoInternalException) - ex.getMessage() == 'The reply message number of returned documents, -1, is expected to be 1' - } -} diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java new file mode 100644 index 00000000000..ffe46a9354e --- /dev/null +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/ReplyHeaderTest.java @@ -0,0 +1,199 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.connection; + +import com.mongodb.MongoInternalException; +import org.bson.ByteBuf; +import org.bson.io.BasicOutputBuffer; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static com.mongodb.connection.ConnectionDescription.getDefaultMaxMessageSize; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@DisplayName("ReplyHeader") +class ReplyHeaderTest { + + @ParameterizedTest(name = "with responseFlags {0}") + @ValueSource(ints = {0, 1, 2, 3}) + @DisplayName("should parse reply header with various response flags") + void testParseReplyHeader(final int responseFlags) { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(responseFlags); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(1); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + ReplyHeader replyHeader = new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize())); + + assertEquals(186, replyHeader.getMessageLength()); + assertEquals(45, replyHeader.getRequestId()); + assertEquals(23, replyHeader.getResponseTo()); + } + } + + @ParameterizedTest(name = "with responseFlags {0}") + @ValueSource(ints = {0, 1, 2, 3}) + @DisplayName("should parse reply header with compressed header and various response flags") + void testParseReplyHeaderWithCompressedHeader(final int responseFlags) { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(2012); + outputBuffer.writeInt(1); + outputBuffer.writeInt(258); + outputBuffer.writeByte(2); + outputBuffer.writeInt(responseFlags); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(1); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + CompressedHeader compressedHeader = new CompressedHeader(byteBuf, + new MessageHeader(byteBuf, getDefaultMaxMessageSize())); + ReplyHeader replyHeader = new ReplyHeader(byteBuf, compressedHeader); + + assertEquals(274, replyHeader.getMessageLength()); + assertEquals(45, replyHeader.getRequestId()); + assertEquals(23, replyHeader.getResponseTo()); + } + } + + @Test + @DisplayName("should throw MongoInternalException on incorrect opCode") + void testThrowExceptionOnIncorrectOpCode() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(36); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(2); + outputBuffer.writeInt(0); + outputBuffer.writeLong(2); + outputBuffer.writeInt(0); + outputBuffer.writeInt(0); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))); + + assertEquals("Unexpected reply message opCode 2", ex.getMessage()); + } + } + + @Test + @DisplayName("should throw MongoInternalException on message size less than 36 bytes") + void testThrowExceptionOnMessageSizeLessThan36() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(35); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(0); + outputBuffer.writeLong(2); + outputBuffer.writeInt(0); + outputBuffer.writeInt(0); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))); + + assertEquals("The reply message length 35 is less than the minimum message length 36", ex.getMessage()); + } + } + + @Test + @DisplayName("should throw MongoInternalException on message size exceeding max message size") + void testThrowExceptionOnMessageSizeExceedingMax() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(400); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(0); + outputBuffer.writeLong(2); + outputBuffer.writeInt(0); + outputBuffer.writeInt(0); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, 399))); + + assertEquals("The reply message length 400 is greater than the maximum message length 399", ex.getMessage()); + } + } + + @Test + @DisplayName("should throw MongoInternalException on negative number of returned documents") + void testThrowExceptionOnNegativeNumberOfDocuments() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(1); + outputBuffer.writeInt(1); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(-1); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, new MessageHeader(byteBuf, getDefaultMaxMessageSize()))); + + assertEquals("The reply message number of returned documents, -1, is expected to be 1", ex.getMessage()); + } + } + + @Test + @DisplayName("should throw MongoInternalException on negative number of documents with compressed header") + void testThrowExceptionOnNegativeNumberOfDocumentsWithCompressedHeader() { + try (BasicOutputBuffer outputBuffer = new BasicOutputBuffer()) { + outputBuffer.writeInt(186); + outputBuffer.writeInt(45); + outputBuffer.writeInt(23); + outputBuffer.writeInt(2012); + outputBuffer.writeInt(1); + outputBuffer.writeInt(258); + outputBuffer.writeByte(2); + outputBuffer.writeInt(1); + outputBuffer.writeLong(9000); + outputBuffer.writeInt(4); + outputBuffer.writeInt(-1); + + ByteBuf byteBuf = outputBuffer.getByteBuffers().get(0); + CompressedHeader compressedHeader = new CompressedHeader(byteBuf, + new MessageHeader(byteBuf, getDefaultMaxMessageSize())); + + MongoInternalException ex = assertThrows(MongoInternalException.class, + () -> new ReplyHeader(byteBuf, compressedHeader)); + + assertEquals("The reply message number of returned documents, -1, is expected to be 1", ex.getMessage()); + } + } +} diff --git a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java index 3aff244ea1e..0176b8e9ad3 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java +++ b/driver-core/src/test/unit/com/mongodb/internal/connection/DefaultServerMonitorTest.java @@ -56,8 +56,10 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.atLeast; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -254,6 +256,55 @@ public void serverHeartbeatFailed(final ServerHeartbeatFailedEvent event) { assertEquals(expectedEvents, events); } + @Test + void closeDuringConnectionShouldNotLeakBuffers() throws Exception { + CountDownLatch connectionStarted = new CountDownLatch(1); + CountDownLatch proceedWithOpen = new CountDownLatch(1); + + InternalConnection mockConnection = mock(InternalConnection.class); + doAnswer(invocation -> { + connectionStarted.countDown(); + assertTrue(proceedWithOpen.await(5, TimeUnit.SECONDS)); + return null; + }).when(mockConnection).open(any()); + + when(mockConnection.getInitialServerDescription()) + .thenReturn(createDefaultServerDescription()); + + InternalConnectionFactory factory = createConnectionFactory(mockConnection); + + monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class)); + + // Wait for connection to start opening + assertTrue(connectionStarted.await(5, TimeUnit.SECONDS)); + + // Close monitor while connection is opening + monitor.close(); + + // Allow connection to complete + proceedWithOpen.countDown(); + + // Verify no leaks by checking connection was properly closed + monitor.getServerMonitor().join(5000); + } + + @Test + void heartbeatWithNullConnectionDescriptionShouldNotCrash() throws Exception { + InternalConnection mockConnection = mock(InternalConnection.class); + when(mockConnection.getDescription()).thenReturn(null); + when(mockConnection.getInitialServerDescription()) + .thenReturn(createDefaultServerDescription()); + when(mockConnection.isClosed()).thenReturn(false); + + InternalConnectionFactory factory = createConnectionFactory(mockConnection); + monitor = createAndStartMonitor(factory, mock(ServerMonitorListener.class)); + + // Wait a bit for the monitor to run + Thread.sleep(500); + + // Monitor should handle null description gracefully + verify(mockConnection, atLeast(1)).open(any()); + } private InternalConnectionFactory createConnectionFactory(final InternalConnection connection) { InternalConnectionFactory factory = mock(InternalConnectionFactory.class); diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy deleted file mode 100644 index 19bfa994200..00000000000 --- a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolSpecification.groovy +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Copyright 2008-present MongoDB, Inc. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.mongodb.internal.session - -import com.mongodb.ServerAddress -import com.mongodb.connection.ClusterDescription -import com.mongodb.connection.ClusterSettings -import com.mongodb.connection.ServerDescription -import com.mongodb.connection.ServerSettings -import com.mongodb.internal.connection.Cluster -import com.mongodb.internal.connection.Connection -import com.mongodb.internal.connection.Server -import com.mongodb.internal.connection.ServerTuple -import com.mongodb.internal.validator.NoOpFieldNameValidator -import org.bson.BsonArray -import org.bson.BsonBinarySubType -import org.bson.BsonDocument -import org.bson.codecs.BsonDocumentCodec -import spock.lang.Specification - -import static com.mongodb.ClusterFixture.OPERATION_CONTEXT -import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS -import static com.mongodb.ClusterFixture.getServerApi -import static com.mongodb.ReadPreference.primaryPreferred -import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE -import static com.mongodb.connection.ClusterType.REPLICA_SET -import static com.mongodb.connection.ServerConnectionState.CONNECTED -import static com.mongodb.connection.ServerConnectionState.CONNECTING -import static com.mongodb.connection.ServerType.REPLICA_SET_PRIMARY -import static com.mongodb.connection.ServerType.UNKNOWN -import static java.util.concurrent.TimeUnit.MINUTES - -class ServerSessionPoolSpecification extends Specification { - - def connectedDescription = new ClusterDescription(MULTIPLE, REPLICA_SET, - [ - ServerDescription.builder().ok(true) - .state(CONNECTED) - .address(new ServerAddress()) - .type(REPLICA_SET_PRIMARY) - .logicalSessionTimeoutMinutes(30) - .build() - ], ClusterSettings.builder().hosts([new ServerAddress()]).build(), ServerSettings.builder().build()) - - def unconnectedDescription = new ClusterDescription(MULTIPLE, REPLICA_SET, - [ - ServerDescription.builder().ok(true) - .state(CONNECTING) - .address(new ServerAddress()) - .type(UNKNOWN) - .logicalSessionTimeoutMinutes(null) - .build() - ], ClusterSettings.builder().hosts([new ServerAddress()]).build(), ServerSettings.builder().build()) - - def 'should get session'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - - when: - def session = pool.get() - - then: - session != null - } - - def 'should throw IllegalStateException if pool is closed'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - pool.close() - - when: - pool.get() - - then: - thrown(IllegalStateException) - } - - def 'should pool session'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - def session = pool.get() - - when: - pool.release(session) - def pooledSession = pool.get() - - then: - session == pooledSession - } - - def 'should prune sessions when getting'() { - given: - def cluster = Mock(Cluster) { - getCurrentDescription() >> connectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >>> [0, MINUTES.toMillis(29) + 1, - ] - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - def sessionOne = pool.get() - - when: - pool.release(sessionOne) - - then: - !sessionOne.closed - - when: - def sessionTwo = pool.get() - - then: - sessionTwo != sessionOne - sessionOne.closed - 0 * cluster.selectServer(_) - } - - def 'should not prune session when timeout is null'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> unconnectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >>> [0, 0, 0] - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - def session = pool.get() - - when: - pool.release(session) - def newSession = pool.get() - - then: - session == newSession - } - - def 'should initialize session'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >> 42 - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - - when: - def session = pool.get() as ServerSessionPool.ServerSessionImpl - - then: - session.lastUsedAtMillis == 42 - session.transactionNumber == 0 - def uuid = session.identifier.getBinary('id') - uuid != null - uuid.type == BsonBinarySubType.UUID_STANDARD.value - uuid.data.length == 16 - } - - def 'should advance transaction number'() { - given: - def cluster = Stub(Cluster) { - getCurrentDescription() >> connectedDescription - } - def clock = Stub(ServerSessionPool.Clock) { - millis() >> 42 - } - def pool = new ServerSessionPool(cluster, OPERATION_CONTEXT, clock) - - when: - def session = pool.get() as ServerSessionPool.ServerSessionImpl - - then: - session.transactionNumber == 0 - session.advanceTransactionNumber() == 1 - session.transactionNumber == 1 - } - - def 'should end pooled sessions when pool is closed'() { - given: - def connection = Mock(Connection) - def server = Stub(Server) { - getConnection(_) >> connection - } - def cluster = Mock(Cluster) { - getCurrentDescription() >> connectedDescription - } - def pool = new ServerSessionPool(cluster, TIMEOUT_SETTINGS, getServerApi()) - def sessions = [] - 10.times { sessions.add(pool.get()) } - - for (def cur : sessions) { - pool.release(cur) - } - - when: - pool.close() - - then: - 1 * cluster.selectServer(_, _) >> new ServerTuple(server, connectedDescription.serverDescriptions[0]) - 1 * connection.command('admin', - new BsonDocument('endSessions', new BsonArray(sessions*.getIdentifier())), - { it instanceof NoOpFieldNameValidator }, primaryPreferred(), - { it instanceof BsonDocumentCodec }, _) >> new BsonDocument() - 1 * connection.release() - } -} diff --git a/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java new file mode 100644 index 00000000000..f0af97a1276 --- /dev/null +++ b/driver-core/src/test/unit/com/mongodb/internal/session/ServerSessionPoolTest.java @@ -0,0 +1,319 @@ +/* + * Copyright 2008-present MongoDB, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.mongodb.internal.session; + +import com.mongodb.MongoException; +import com.mongodb.ServerAddress; +import com.mongodb.connection.ClusterDescription; +import com.mongodb.connection.ClusterSettings; +import com.mongodb.connection.ServerDescription; +import com.mongodb.connection.ServerSettings; +import com.mongodb.internal.connection.Cluster; +import com.mongodb.internal.connection.Connection; +import com.mongodb.internal.connection.Server; +import com.mongodb.internal.connection.ServerTuple; +import com.mongodb.session.ServerSession; +import org.bson.BsonArray; +import org.bson.BsonDocument; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.ArgumentMatcher; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; + +import java.util.ArrayList; +import java.util.List; + +import static com.mongodb.ClusterFixture.OPERATION_CONTEXT; +import static com.mongodb.ClusterFixture.TIMEOUT_SETTINGS; +import static com.mongodb.ClusterFixture.getServerApi; +import static com.mongodb.connection.ClusterConnectionMode.MULTIPLE; +import static com.mongodb.connection.ClusterType.REPLICA_SET; +import static com.mongodb.connection.ServerConnectionState.CONNECTED; +import static com.mongodb.connection.ServerConnectionState.CONNECTING; +import static com.mongodb.connection.ServerType.REPLICA_SET_PRIMARY; +import static com.mongodb.connection.ServerType.UNKNOWN; +import static java.util.Collections.singletonList; +import static java.util.concurrent.TimeUnit.MINUTES; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +@DisplayName("ServerSessionPool") +@ExtendWith(MockitoExtension.class) +class ServerSessionPoolTest { + + private ClusterDescription connectedDescription; + private ClusterDescription unconnectedDescription; + + @Mock + private Cluster clusterMock; + + @BeforeEach + void setUp() { + connectedDescription = new ClusterDescription( + MULTIPLE, + REPLICA_SET, + singletonList( + ServerDescription.builder() + .ok(true) + .state(CONNECTED) + .address(new ServerAddress()) + .type(REPLICA_SET_PRIMARY) + .logicalSessionTimeoutMinutes(30) + .build() + ), + ClusterSettings.builder().hosts(singletonList(new ServerAddress())).build(), + ServerSettings.builder().build() + ); + + unconnectedDescription = new ClusterDescription( + MULTIPLE, + REPLICA_SET, + singletonList( + ServerDescription.builder() + .ok(true) + .state(CONNECTING) + .address(new ServerAddress()) + .type(UNKNOWN) + .logicalSessionTimeoutMinutes(null) + .build() + ), + ClusterSettings.builder().hosts(singletonList(new ServerAddress())).build(), + ServerSettings.builder().build() + ); + } + + @Test + @DisplayName("should get session from pool") + void testGetSession() { + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + + ServerSession session = pool.get(); + + assertNotNull(session); + } + + @Test + @DisplayName("should throw IllegalStateException when pool is closed") + void testThrowExceptionIfPoolClosed() { + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + pool.close(); + + assertThrows(IllegalStateException.class, pool::get); + } + + @Test + @DisplayName("should reuse released session from pool") + void testPoolSession() { + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + + ServerSession session = pool.get(); + pool.release(session); + ServerSession pooledSession = pool.get(); + + assertEquals(session, pooledSession); + } + + @Test + @DisplayName("should prune expired sessions when getting new session") + void testPruneSessionsWhenGetting() { + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(0L, MINUTES.toMillis(29) + 1); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession sessionOne = pool.get(); + + pool.release(sessionOne); + assertFalse(sessionOne.isClosed()); + + ServerSession sessionTwo = pool.get(); + + assertNotEquals(sessionTwo, sessionOne); + // Note: Actual closed state verification depends on implementation details + } + + @Test + @DisplayName("should not prune session when timeout is null") + void testNotPruneSessionWhenTimeoutIsNull() { + when(clusterMock.getCurrentDescription()).thenReturn(unconnectedDescription); + + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(0L, 0L, 0L); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession session = pool.get(); + + pool.release(session); + ServerSession newSession = pool.get(); + + assertEquals(session, newSession); + } + + @Test + @DisplayName("should initialize session with correct properties") + void testInitializeSession() { + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(42L); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession session = pool.get(); + + ServerSessionPool.ServerSessionImpl sessionImpl = (ServerSessionPool.ServerSessionImpl) session; + assertEquals(42L, sessionImpl.getLastUsedAtMillis()); + assertEquals(0L, sessionImpl.getTransactionNumber()); + + BsonDocument identifier = sessionImpl.getIdentifier(); + assertNotNull(identifier); + byte[] uuid = identifier.getBinary("id").getData(); + assertNotNull(uuid); + assertEquals(16, uuid.length); + } + + @Test + @DisplayName("should advance transaction number") + void testAdvanceTransactionNumber() { + ServerSessionPool.Clock clock = mock(ServerSessionPool.Clock.class); + when(clock.millis()).thenReturn(42L); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, OPERATION_CONTEXT, clock); + ServerSession session = pool.get(); + + ServerSessionPool.ServerSessionImpl sessionImpl = (ServerSessionPool.ServerSessionImpl) session; + assertEquals(0L, sessionImpl.getTransactionNumber()); + assertEquals(1L, sessionImpl.advanceTransactionNumber()); + assertEquals(1L, sessionImpl.getTransactionNumber()); + } + + @Test + @DisplayName("should end pooled sessions when pool is closed") + void testEndPooledSessionsWhenPoolClosed() { + Connection connection = mock(Connection.class); + Server server = mock(Server.class); + when(server.getConnection(any())).thenReturn(connection); + + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + when(clusterMock.selectServer(any(), any())) + .thenReturn(new ServerTuple(server, connectedDescription.getServerDescriptions().get(0))); + + when(connection.command( + any(String.class), + any(BsonDocument.class), + any(), + any(), + any(), + any() + )).thenReturn(new BsonDocument()); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + List sessions = new ArrayList<>(); + for (int i = 0; i < 10; i++) { + sessions.add(pool.get()); + } + + for (ServerSession session : sessions) { + pool.release(session); + } + + pool.close(); + + verify(clusterMock, times(1)).selectServer(any(), any()); + verify(connection, times(1)).command( + any(String.class), + argThat(endSessionsDocMatcher(sessions)), + any(), + any(), + any(), + any() + ); + verify(connection, times(1)).release(); + } + + @Test + @DisplayName("should handle MongoException during endSessions without leaking resources") + void testHandleMongoExceptionDuringEndSessionsWithoutLeakingResources() { + Connection connection = mock(Connection.class); + Server server = mock(Server.class); + when(server.getConnection(any())).thenReturn(connection); + + when(clusterMock.getCurrentDescription()).thenReturn(connectedDescription); + when(clusterMock.selectServer(any(), any())) + .thenReturn(new ServerTuple(server, connectedDescription.getServerDescriptions().get(0))); + + when(connection.command( + any(String.class), + any(BsonDocument.class), + any(), + any(), + any(), + any() + )).thenThrow(new MongoException("Simulated error")); + + ServerSessionPool pool = new ServerSessionPool(clusterMock, TIMEOUT_SETTINGS, getServerApi()); + List sessions = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + sessions.add(pool.get()); + } + + for (ServerSession session : sessions) { + pool.release(session); + } + + // Should not throw - exception is handled internally + pool.close(); + + verify(clusterMock, times(1)).selectServer(any(), any()); + verify(connection, times(1)).release(); + } + + /** + * Matcher to verify the endSessions document contains the correct session identifiers. + */ + private ArgumentMatcher endSessionsDocMatcher(List sessions) { + return doc -> { + if (!doc.containsKey("endSessions")) { + return false; + } + BsonArray endSessionsArray = doc.getArray("endSessions"); + if (endSessionsArray.size() != sessions.size()) { + return false; + } + for (int i = 0; i < sessions.size(); i++) { + ServerSession session = sessions.get(i); + BsonDocument sessionIdentifier = session.getIdentifier(); + BsonDocument arrayElement = endSessionsArray.get(i).asDocument(); + if (!sessionIdentifier.equals(arrayElement)) { + return false; + } + } + return true; + }; + } +} diff --git a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java index dd761234df9..75a60ca382f 100644 --- a/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java +++ b/driver-legacy/src/main/com/mongodb/DBDecoderAdapter.java @@ -39,9 +39,9 @@ class DBDecoderAdapter implements Decoder { @Override public DBObject decode(final BsonReader reader, final DecoderContext decoderContext) { - ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider); - BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput); - try { + + try (ByteBufferBsonOutput bsonOutput = new ByteBufferBsonOutput(bufferProvider); + BsonBinaryWriter binaryWriter = new BsonBinaryWriter(bsonOutput)) { binaryWriter.pipe(reader); BufferExposingByteArrayOutputStream byteArrayOutputStream = new BufferExposingByteArrayOutputStream(binaryWriter.getBsonOutput().getSize()); @@ -50,9 +50,6 @@ public DBObject decode(final BsonReader reader, final DecoderContext decoderCont } catch (IOException e) { // impossible with a byte array output stream throw new MongoInternalException("An unlikely IOException thrown.", e); - } finally { - binaryWriter.close(); - bsonOutput.close(); } } diff --git a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java index 5cf0ea103bd..7a0b016cc98 100644 --- a/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java +++ b/driver-reactive-streams/src/main/com/mongodb/reactivestreams/client/internal/ClientSessionPublisherImpl.java @@ -37,6 +37,8 @@ import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.util.concurrent.atomic.AtomicBoolean; + import static com.mongodb.MongoException.TRANSIENT_TRANSACTION_ERROR_LABEL; import static com.mongodb.MongoException.UNKNOWN_TRANSACTION_COMMIT_RESULT_LABEL; import static com.mongodb.assertions.Assertions.assertNotNull; @@ -52,7 +54,7 @@ final class ClientSessionPublisherImpl extends BaseClientSessionImpl implements private boolean messageSentInCurrentTransaction; private boolean commitInProgress; private TransactionOptions transactionOptions; - + private AtomicBoolean closed = new AtomicBoolean(); ClientSessionPublisherImpl(final ServerSessionPool serverSessionPool, final MongoClientImpl mongoClient, final ClientSessionOptions options, final OperationExecutor executor) { @@ -221,10 +223,18 @@ private void clearTransactionContextOnError(final MongoException e) { @Override public void close() { - if (transactionState == TransactionState.IN) { - Mono.from(abortTransaction()).doFinally(it -> super.close()).subscribe(); - } else { - super.close(); + if (closed.compareAndSet(false, true)) { + if (transactionState == TransactionState.IN) { + Mono.from(abortTransaction()) + .doFinally(it -> { + clearTransactionContext(); + super.close(); + }) + .subscribe(); + } else { + clearTransactionContext(); + super.close(); + } } } diff --git a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java index 2881b47e38e..05ca89dd048 100644 --- a/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java +++ b/driver-reactive-streams/src/test/functional/com/mongodb/reactivestreams/client/Fixture.java @@ -24,6 +24,7 @@ import com.mongodb.MongoTimeoutException; import com.mongodb.connection.ClusterType; import com.mongodb.connection.ServerVersion; +import com.mongodb.connection.TransportSettings; import com.mongodb.reactivestreams.client.internal.MongoClientImpl; import org.bson.Document; import org.bson.conversions.Bson; @@ -33,6 +34,7 @@ import java.util.List; import static com.mongodb.ClusterFixture.TIMEOUT_DURATION; +import static com.mongodb.ClusterFixture.getOverriddenTransportSettings; import static com.mongodb.ClusterFixture.getServerApi; import static com.mongodb.internal.thread.InterruptionUtil.interruptAndCreateMongoInterruptedException; import static java.lang.Thread.sleep; @@ -67,11 +69,18 @@ public static MongoClientSettings.Builder getMongoClientSettingsBuilder() { } public static MongoClientSettings.Builder getMongoClientSettingsBuilder(final ConnectionString connectionString) { - MongoClientSettings.Builder builder = MongoClientSettings.builder(); + MongoClientSettings.Builder builder = MongoClientSettings.builder() + .applyConnectionString(connectionString); + + TransportSettings overriddenTransportSettings = getOverriddenTransportSettings(); + if (overriddenTransportSettings != null) { + builder.transportSettings(overriddenTransportSettings); + } + if (getServerApi() != null) { builder.serverApi(getServerApi()); } - return builder.applyConnectionString(connectionString); + return builder; } public static String getDefaultDatabaseName() { @@ -164,6 +173,11 @@ public static synchronized ConnectionString getConnectionString() { public static MongoClientSettings.Builder getMongoClientBuilderFromConnectionString() { MongoClientSettings.Builder builder = MongoClientSettings.builder() .applyConnectionString(getConnectionString()); + + TransportSettings overriddenTransportSettings = getOverriddenTransportSettings(); + if (overriddenTransportSettings != null) { + builder.transportSettings(overriddenTransportSettings); + } if (getServerApi() != null) { builder.serverApi(getServerApi()); } diff --git a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java index 3682bd64ff0..764f3dbaa47 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/AbstractSessionsProseTest.java @@ -93,7 +93,9 @@ public void shouldCreateServerSessionOnlyAfterConnectionCheckout() throws Interr .addCommandListener(new CommandListener() { @Override public void commandStarted(final CommandStartedEvent event) { - lsidSet.add(event.getCommand().getDocument("lsid")); + if (event.getCommand().containsKey("lsid")) { + lsidSet.add(event.getCommand().getDocument("lsid").toBsonDocument()); + } } }) .build())) { diff --git a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java index cf003078f04..1fc2d18a1fb 100644 --- a/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java +++ b/driver-sync/src/test/functional/com/mongodb/client/unified/UnifiedTest.java @@ -311,6 +311,7 @@ public void cleanUp() { if (testDef != null) { postCleanUp(testDef); } + System.gc(); } protected void postCleanUp(final TestDef testDef) {