Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flushing Proxy Channels at CPU side upon reaching the Inflight Request Limit #415

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
21 changes: 16 additions & 5 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ std::string version();
/// Base class for bootstraps.
class Bootstrap {
public:
Bootstrap(){};
Bootstrap() {};
virtual ~Bootstrap() = default;
virtual int getRank() = 0;
virtual int getNranks() = 0;
Expand Down Expand Up @@ -388,6 +388,11 @@ class Endpoint {
/// @return The transport used.
Transport transport();

/// Get the max inflight requests.
///
/// @return max inflight requests.
int maxInflightRequests();
chhwang marked this conversation as resolved.
Show resolved Hide resolved

/// Serialize the Endpoint object to a vector of characters.
///
/// @return A vector of characters representing the serialized Endpoint object.
Expand Down Expand Up @@ -416,6 +421,10 @@ class Endpoint {
/// Represents a connection between two processes.
class Connection {
public:
/// Constructor.
/// @param maxInflightRequests The maximum number of inflight requests.
Connection(int maxInflightRequests) : maxInflightRequests(maxInflightRequests) {};

virtual ~Connection() = default;

/// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory.
Expand Down Expand Up @@ -454,10 +463,13 @@ class Connection {
/// @return name of @ref transport() -> @ref remoteTransport()
std::string getTransportName();

int getMaxInflightRequest();

protected:
// Internal methods for getting implementation pointers.
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
int maxInflightRequests;
};

/// Used to configure an endpoint.
Expand All @@ -472,14 +484,13 @@ struct EndpointConfig {
int ibMaxCqPollNum = DefaultMaxCqPollNum;
int ibMaxSendWr = DefaultMaxSendWr;
int ibMaxWrPerSend = DefaultMaxWrPerSend;

/// Default constructor. Sets transport to Transport::Unknown.
EndpointConfig() : transport(Transport::Unknown) {}
int maxInflightRequests;

/// Constructor that takes a transport and sets the other fields to their default values.
///
/// @param transport The transport to use.
EndpointConfig(Transport transport) : transport(transport) {}
EndpointConfig(Transport transport = Transport::Unknown, int maxInflightRequests = -1)
chhwang marked this conversation as resolved.
Show resolved Hide resolved
: transport(transport), maxInflightRequests(maxInflightRequests) {}
};

/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases
Expand Down
1 change: 1 addition & 0 deletions include/mscclpp/proxy_channel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class ProxyService : public BaseProxyService {
std::vector<RegisteredMemory> memories_;
std::shared_ptr<Proxy> proxy_;
int deviceNumaNode;
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;

void bindThread();

Expand Down
14 changes: 11 additions & 3 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,13 @@ std::string Connection::getTransportName() {
TransportNames[static_cast<int>(this->remoteTransport())];
}

int Connection::getMaxInflightRequest() { return maxInflightRequests; }

// CudaIpcConnection

CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream)
: stream_(stream) {
: Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX),
caiomcbr marked this conversation as resolved.
Show resolved Hide resolved
stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
}
Expand Down Expand Up @@ -119,7 +122,9 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
// IBConnection

IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context)
: transport_(localEndpoint.transport()),
: Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests()
: EndpointConfig::DefaultMaxCqPollNum),
transport_(localEndpoint.transport()),
remoteTransport_(remoteEndpoint.transport()),
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
qp = getImpl(localEndpoint)->ibQp_;
Expand Down Expand Up @@ -231,7 +236,10 @@ void IBConnection::flush(int64_t timeoutUsec) {

EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
uint64_t recvBufferSize)
: abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) {
: Connection(localEndpoint.maxInflightRequests() != -1 ? localEndpoint.maxInflightRequests() : INT_MAX),
abortFlag_(0),
sendBufferSize_(sendBufferSize),
recvBufferSize_(recvBufferSize) {
// Validating Transport Protocol
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);
Expand Down
4 changes: 3 additions & 1 deletion src/endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace mscclpp {

Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
: transport_(config.transport), hostHash_(getHostHash()) {
: transport_(config.transport), hostHash_(getHostHash()), maxInflightRequests_(config.maxInflightRequests) {
if (AllIBTransports.has(transport_)) {
ibLocal_ = true;
ibQp_ = contextImpl.getIbContext(transport_)
Expand All @@ -34,6 +34,8 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)

MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }

MSCCLPP_API_CPP int Endpoint::maxInflightRequests() { return pimpl_->maxInflightRequests_; }

MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
std::vector<char> data;
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));
Expand Down
1 change: 1 addition & 0 deletions src/include/endpoint.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ struct Endpoint::Impl {

Transport transport_;
uint64_t hostHash_;
int maxInflightRequests_;

// The following are only used for IB and are undefined for other transports.
bool ibLocal_;
Expand Down
6 changes: 5 additions & 1 deletion src/proxy_channel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,19 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
trigger->fields.size);
inflightRequests[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerFlag) {
semaphore->signal();
inflightRequests[semaphore->connection()]++;
}

if (trigger->fields.type & TriggerSync) {
if (trigger->fields.type & TriggerSync ||
inflightRequests[semaphore->connection()] > semaphore->connection()->getMaxInflightRequest()) {
semaphore->connection()->flush();
result = ProxyHandlerResult::FlushFifoTailAndContinue;
inflightRequests[semaphore->connection()] = 0;
}

return result;
Expand Down
Loading