Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flushing Proxy Channels at CPU side upon reaching the Inflight Request Limit #415

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
42 changes: 34 additions & 8 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,11 @@ class Endpoint {
/// @return The transport used.
Transport transport();

/// Get the maximum write queue size.
///
/// @return The maximum number of write requests that can be queued.
int maxWriteQueueSize();

/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
Expand Down Expand Up @@ -416,6 +421,10 @@ class Endpoint {
/// Represents a connection between two processes.
class Connection {
public:
/// Constructor.
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){};

virtual ~Connection() = default;

/// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory.
Expand Down Expand Up @@ -454,10 +463,16 @@ class Connection {
/// @return name of @ref transport() -> @ref remoteTransport()
std::string getTransportName();

/// Get the maximum write queue size
///
/// @return The maximum number of write requests that can be queued.
int getMaxWriteQueueSize();

protected:
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
int maxWriteQueueSize;
};

/// Used to configure an endpoint.
Expand All @@ -468,18 +483,29 @@ struct EndpointConfig {
static const int DefaultMaxWrPerSend = 64;

Transport transport;
int ibMaxCqSize = DefaultMaxCqSize;
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}
int ibMaxCqSize;
int ibMaxCqPollNum;
int ibMaxSendWr;
int ibMaxWrPerSend;
int maxWriteQueueSize;

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}
/// @param ibMaxCqSize The maximum completion queue size.
/// @param ibMaxCqPollNum The maximum completion queue poll number.
/// @param ibMaxSendWr The maximum send work requests.
/// @param ibMaxWrPerSend The maximum work requests per send.
/// @param maxWriteQueueSize The maximum write queue size.
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
: transport(transport),
ibMaxCqSize(ibMaxCqSize),
ibMaxCqPollNum(ibMaxCqPollNum),
ibMaxSendWr(ibMaxSendWr),
ibMaxWrPerSend(ibMaxWrPerSend),
maxWriteQueueSize(maxWriteQueueSize) {}
};

/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ProxyService : public BaseProxyService {
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
int deviceNumaNode;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;

void bindThread();

Expand Down
13 changes: 10 additions & 3 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,12 @@ std::string Connection::getTransportName() {
TransportNames[static_cast<int>(this->remoteTransport())];
}

int Connection::getMaxWriteQueueSize() { return maxWriteQueueSize; }

// CudaIpcConnection

CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream)
: stream_(stream) {
: Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
}
Expand Down Expand Up @@ -119,7 +121,9 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
// IBConnection

IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context)
: transport_(localEndpoint.transport()),
: Connection(localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
: EndpointConfig::DefaultMaxCqSize),
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
qp = getImpl(localEndpoint)->ibQp_;
Expand Down Expand Up @@ -231,7 +235,10 @@ void IBConnection::flush(int64_t timeoutUsec) {

EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
uint64_t recvBufferSize)
: abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) {
: Connection(localEndpoint.maxWriteQueueSize()),
abortFlag_(0),
sendBufferSize_(sendBufferSize),
recvBufferSize_(recvBufferSize) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
Expand Down
4 changes: 3 additions & 1 deletion src/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace mscclpp {

Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
: transport_(config.transport), hostHash_(getHostHash()) {
: transport_(config.transport), hostHash_(getHostHash()), maxWriteQueueSize_(config.maxWriteQueueSize) {
if (AllIBTransports.has(transport_)) {
ibLocal_ = true;
ibQp_ = contextImpl.getIbContext(transport_)
Expand All @@ -34,6 +34,8 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)

MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }

MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() { return pimpl_->maxWriteQueueSize_; }

MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
std::vector<char> data;
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));
Expand Down
1 change: 1 addition & 0 deletions src/include/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct Endpoint::Impl {

Transport transport_;
uint64_t hostHash_;
int maxWriteQueueSize_;

// The following are only used for IB and are undefined for other transports.
bool ibLocal_;
Expand Down
7 changes: 6 additions & 1 deletion src/proxy_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,21 +70,26 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.chanId];

auto result = ProxyHandlerResult::Continue;
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();

if (trigger->fields.type & TriggerData) {
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields.size);
inflightRequests[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerFlag) {
semaphore->signal();
inflightRequests[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerSync) {
if (trigger->fields.type & TriggerSync ||
(maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) {
semaphore->connection()->flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
inflightRequests[semaphore->connection()] = 0;
}

return result;
Expand Down
Loading