diff --git a/.azure-pipelines/ut.yml b/.azure-pipelines/ut.yml index 9a1a96f65..5e950606e 100644 --- a/.azure-pipelines/ut.yml +++ b/.azure-pipelines/ut.yml @@ -35,7 +35,7 @@ jobs: cmake -DCMAKE_BUILD_TYPE=Release -DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON .. make -j workingDirectory: '$(System.DefaultWorkingDirectory)' - + - task: DownloadSecureFile@1 name: SshKeyFile displayName: Download key file diff --git a/apps/nccl/src/broadcast.hpp b/apps/nccl/src/broadcast.hpp index a453bcb2c..e9a9111f6 100644 --- a/apps/nccl/src/broadcast.hpp +++ b/apps/nccl/src/broadcast.hpp @@ -15,7 +15,7 @@ template __global__ void __launch_bounds__(1024, 1) broadcast6(void* sendbuff, void* scratchbuff, void* recvbuff, mscclpp::DeviceHandle* smChannels, - size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root, + [[maybe_unused]] size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root, size_t nRanksPerNode, size_t nelemsPerGPU) { const size_t nThread = blockDim.x * gridDim.x; const size_t nPeer = nRanksPerNode - 1; diff --git a/apps/nccl/src/nccl.cu b/apps/nccl/src/nccl.cu index 802d399e1..734ec6a04 100644 --- a/apps/nccl/src/nccl.cu +++ b/apps/nccl/src/nccl.cu @@ -176,9 +176,9 @@ static std::shared_ptr> setupSmChannel std::transform(smChannels.begin(), smChannels.end(), std::back_inserter(smChannelDeviceHandles), [](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); }); std::shared_ptr> ptr = - mscclpp::allocSharedCuda>(smChannelDeviceHandles.size()); - mscclpp::memcpyCuda>(ptr.get(), smChannelDeviceHandles.data(), - smChannelDeviceHandles.size(), cudaMemcpyHostToDevice); + mscclpp::detail::gpuCallocShared>(smChannelDeviceHandles.size()); + mscclpp::gpuMemcpy>(ptr.get(), smChannelDeviceHandles.data(), + smChannelDeviceHandles.size(), cudaMemcpyHostToDevice); return ptr; } @@ -360,7 +360,7 @@ static void ncclCommInitRankFallbackSingleNode(ncclComm* commPtr, std::shared_pt commPtr->smSemaphores = std::move(smSemaphores); commPtr->buffFlag = 0; commPtr->numScratchBuff = 2; - commPtr->scratchBuff = mscclpp::allocExtSharedCuda(SCRATCH_SIZE); + commPtr->scratchBuff = mscclpp::GpuBuffer(SCRATCH_SIZE).memory(); commPtr->remoteScratchRegMemories = setupRemoteMemories(commPtr->comm, rank, commPtr->scratchBuff.get(), SCRATCH_SIZE, mscclpp::Transport::CudaIpc); } @@ -624,7 +624,6 @@ NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t } int rank = comm->comm->bootstrap()->getRank(); - int nRank = comm->comm->bootstrap()->getNranks(); std::vector& plans = comm->executionPlans["broadcast"]; std::shared_ptr plan; @@ -817,18 +816,13 @@ NCCL_API ncclResult_t ncclCommDeregister(const ncclComm_t, void*) { } ncclResult_t ncclMemAlloc(void** ptr, size_t size) { - // Allocate memory using mscclpp::allocSharedPhysicalCuda if (ptr == nullptr || size == 0) { WARN("ptr is nullptr or size is 0"); return ncclInvalidArgument; } std::shared_ptr sharedPtr; try { - if (mscclpp::isNvlsSupported()) { - sharedPtr = mscclpp::allocSharedPhysicalCuda(size); - } else { - sharedPtr = mscclpp::allocExtSharedCuda(size); - } + sharedPtr = mscclpp::GpuBuffer(size).memory(); if (sharedPtr == nullptr) { INFO(MSCCLPP_ALLOC, "Failed to allocate memory"); return ncclSystemError; diff --git a/docs/getting-started/tutorials/python-api.md b/docs/getting-started/tutorials/python-api.md index fcc7eee2f..c2f26c23f 100644 --- a/docs/getting-started/tutorials/python-api.md +++ b/docs/getting-started/tutorials/python-api.md @@ -14,6 +14,7 @@ from mscclpp import ( ProxyService, Transport, ) +from mscclpp.utils import GpuBuffer import mscclpp.comm as mscclpp_comm def create_connection(group: mscclpp_comm.CommGroup, transport: str): @@ -32,7 +33,7 @@ if __name__ == "__main__": mscclpp_group = mscclpp_comm.CommGroup(MPI.COMM_WORLD) connections = create_connection(mscclpp_group, "NVLink") nelems = 1024 - memory = cp.zeros(nelem, dtype=cp.int32) + memory = GpuBuffer(nelem, dtype=cp.int32) proxy_service = ProxyService() simple_channels = group.make_proxy_channels(proxy_service, memory, connections) proxy_service.start_proxy() diff --git a/include/mscclpp/gpu_utils.hpp b/include/mscclpp/gpu_utils.hpp index 58a6e2556..8b7d8b19b 100644 --- a/include/mscclpp/gpu_utils.hpp +++ b/include/mscclpp/gpu_utils.hpp @@ -4,7 +4,6 @@ #ifndef MSCCLPP_GPU_UTILS_HPP_ #define MSCCLPP_GPU_UTILS_HPP_ -#include #include #include "errors.hpp" @@ -35,19 +34,6 @@ namespace mscclpp { -/// set memory access permission to read-write -/// @param base Base memory pointer. -/// @param size Size of the memory. -inline void setReadWriteMemoryAccess(void* base, size_t size) { - CUmemAccessDesc accessDesc = {}; - int deviceId; - MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); - accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - accessDesc.location.id = deviceId; - accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; - MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, size, &accessDesc, 1)); -} - /// A RAII guard that will cudaThreadExchangeStreamCaptureMode to cudaStreamCaptureModeRelaxed on construction and /// restore the previous mode on destruction. This is helpful when we want to avoid CUDA graph capture. struct AvoidCudaGraphCaptureGuard { @@ -64,101 +50,46 @@ struct CudaStreamWithFlags { cudaStream_t stream_; }; -template -struct CudaDeleter; - namespace detail { -/// A wrapper of cudaMalloc that sets the allocated memory to zero. -/// @tparam T Type of each element in the allocated memory. -/// @param nelem Number of elements to allocate. -/// @return A pointer to the allocated memory. -template -T* cudaCalloc(size_t nelem) { - AvoidCudaGraphCaptureGuard cgcGuard; - T* ptr; - CudaStreamWithFlags stream(cudaStreamNonBlocking); - MSCCLPP_CUDATHROW(cudaMalloc(&ptr, nelem * sizeof(T))); - MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, 0, nelem * sizeof(T), stream)); - MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); - return ptr; -} +void setReadWriteMemoryAccess(void* base, size_t size); +void* gpuCalloc(size_t bytes); +void* gpuCallocHost(size_t bytes); +#if defined(__HIP_PLATFORM_AMD__) +void* gpuCallocUncached(size_t bytes); +#endif // defined(__HIP_PLATFORM_AMD__) #if (CUDA_NVLS_SUPPORTED) -template -T* cudaPhysicalCalloc(size_t nelems, size_t gran) { - AvoidCudaGraphCaptureGuard cgcGuard; - int deviceId = -1; - CUdevice currentDevice; - MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); - MSCCLPP_CUTHROW(cuDeviceGet(¤tDevice, deviceId)); - - CUmemAllocationProp prop = {}; - prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; - prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; - prop.requestedHandleTypes = - (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); - prop.location.id = currentDevice; - - // allocate physical memory - CUmemGenericAllocationHandle memHandle; - size_t nbytes = (nelems * sizeof(T) + gran - 1) / gran * gran; - MSCCLPP_CUTHROW(cuMemCreate(&memHandle, nbytes, &prop, 0 /*flags*/)); - - T* devicePtr = nullptr; - MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, nbytes, gran, 0U, 0)); - MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, nbytes, 0, memHandle, 0)); - setReadWriteMemoryAccess(devicePtr, nbytes); - CudaStreamWithFlags stream(cudaStreamNonBlocking); - MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, nbytes, stream)); - MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); - - return devicePtr; -} -#endif +void* gpuCallocPhysical(size_t bytes, size_t gran = 0, size_t align = 0); +#endif // CUDA_NVLS_SUPPORTED -template -T* cudaExtCalloc(size_t nelem) { - AvoidCudaGraphCaptureGuard cgcGuard; - T* ptr; - CudaStreamWithFlags stream(cudaStreamNonBlocking); -#if defined(__HIP_PLATFORM_AMD__) - MSCCLPP_CUDATHROW(hipExtMallocWithFlags((void**)&ptr, nelem * sizeof(T), hipDeviceMallocUncached)); -#else - MSCCLPP_CUDATHROW(cudaMalloc(&ptr, nelem * sizeof(T))); -#endif - MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, 0, nelem * sizeof(T), stream)); - MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); - return ptr; -} +void gpuFree(void* ptr); +void gpuFreeHost(void* ptr); +#if (CUDA_NVLS_SUPPORTED) +void gpuFreePhysical(void* ptr); +#endif // CUDA_NVLS_SUPPORTED -/// A wrapper of cudaHostAlloc that sets the allocated memory to zero. -/// @tparam T Type of each element in the allocated memory. -/// @param nelem Number of elements to allocate. -/// @return A pointer to the allocated memory. -template -T* cudaHostCalloc(size_t nelem) { - AvoidCudaGraphCaptureGuard cgcGuard; - T* ptr; - MSCCLPP_CUDATHROW(cudaHostAlloc(&ptr, nelem * sizeof(T), cudaHostAllocMapped | cudaHostAllocWriteCombined)); - memset(ptr, 0, nelem * sizeof(T)); - return ptr; -} +void gpuMemcpyAsync(void* dst, const void* src, size_t bytes, cudaStream_t stream, + cudaMemcpyKind kind = cudaMemcpyDefault); +void gpuMemcpy(void* dst, const void* src, size_t bytes, cudaMemcpyKind kind = cudaMemcpyDefault); /// A template function that allocates memory while ensuring that the memory will be freed when the returned object is /// destroyed. /// @tparam T Type of each element in the allocated memory. -/// @tparam alloc A function that allocates memory. /// @tparam Deleter A deleter that will be used to free the allocated memory. /// @tparam Memory The type of the returned object. -/// @param nelem Number of elements to allocate. +/// @tparam Alloc A function type that allocates memory. +/// @tparam Args Input types of the @p alloc function variables. +/// @param alloc A function that allocates memory. +/// @param nelems Number of elements to allocate. +/// @param args Extra input variables for the @p alloc function. /// @return An object of type @p Memory that will free the allocated memory when destroyed. /// -template -Memory safeAlloc(size_t nelem) { +template +Memory safeAlloc(Alloc alloc, size_t nelems, Args&&... args) { T* ptr = nullptr; try { - ptr = alloc(nelem); + ptr = reinterpret_cast(alloc(nelems * sizeof(T), std::forward(args)...)); } catch (...) { if (ptr) { Deleter()(ptr); @@ -168,258 +99,164 @@ Memory safeAlloc(size_t nelem) { return Memory(ptr, Deleter()); } -template -Memory safeAlloc(size_t nelem, size_t gran) { - if ((nelem * sizeof(T)) % gran) { - throw Error("The request allocation size is not divisible by the required granularity:" + - std::to_string(nelem * sizeof(T)) + " vs " + std::to_string(gran), - ErrorCode::InvalidUsage); - } - T* ptr = nullptr; - try { - ptr = alloc(nelem, gran); - } catch (...) { - if (ptr) { - Deleter()(ptr); - } - throw; - } - return Memory(ptr, Deleter()); -} - -} // namespace detail +/// A deleter that calls gpuFree for use with std::unique_ptr or std::shared_ptr. +/// @tparam T Type of each element in the allocated memory. +template +struct GpuDeleter { + void operator()(void* ptr) { gpuFree(ptr); } +}; -/// A deleter that calls cudaFree for use with std::unique_ptr or std::shared_ptr. +/// A deleter that calls gpuFreeHost for use with std::unique_ptr or std::shared_ptr. /// @tparam T Type of each element in the allocated memory. -template -struct CudaDeleter { - using TPtrOrArray = std::conditional_t, T, T*>; - void operator()(TPtrOrArray ptr) { - AvoidCudaGraphCaptureGuard cgcGuard; - MSCCLPP_CUDATHROW(cudaFree(ptr)); - } +template +struct GpuHostDeleter { + void operator()(void* ptr) { gpuFreeHost(ptr); } }; -template -struct CudaPhysicalDeleter { - static_assert(!std::is_array_v, "T must not be an array"); - void operator()(T* ptr) { - AvoidCudaGraphCaptureGuard cgcGuard; - CUmemGenericAllocationHandle handle; - size_t size = 0; - MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr)); - MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); - MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size)); - MSCCLPP_CUTHROW(cuMemRelease(handle)); - MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, size)); - } +#if (CUDA_NVLS_SUPPORTED) +template +struct GpuPhysicalDeleter { + void operator()(void* ptr) { gpuFreePhysical(ptr); } }; +#endif // CUDA_NVLS_SUPPORTED -/// A deleter that calls cudaFreeHost for use with std::unique_ptr or std::shared_ptr. -/// @tparam T Type of each element in the allocated memory. template -struct CudaHostDeleter { - using TPtrOrArray = std::conditional_t, T, T*>; - void operator()(TPtrOrArray ptr) { - AvoidCudaGraphCaptureGuard cgcGuard; - MSCCLPP_CUDATHROW(cudaFreeHost(ptr)); - } -}; +using UniqueGpuPtr = std::unique_ptr>; -/// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @return A std::shared_ptr to the allocated memory. template -std::shared_ptr allocSharedCuda(size_t count = 1) { - return detail::safeAlloc, CudaDeleter, std::shared_ptr>(count); -} +using UniqueGpuHostPtr = std::unique_ptr>; -#if (CUDA_NVLS_SUPPORTED) -static inline size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag) { - size_t gran = 0; - int numDevices = 0; - MSCCLPP_CUDATHROW(cudaGetDeviceCount(&numDevices)); - - CUmulticastObjectProp prop = {}; - prop.size = size; - // This is a dummy value, it might affect the granularity in the future - prop.numDevices = numDevices; - prop.handleTypes = (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); - prop.flags = 0; - MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, granFlag)); - return gran; +template +auto gpuCallocShared(size_t nelems = 1) { + return detail::safeAlloc, std::shared_ptr>(detail::gpuCalloc, nelems); } -#endif -/// Allocates physical memory on the device and returns a std::shared_ptr to it. The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @param gran the granularity of the allocation. -/// @return A std::shared_ptr to the allocated memory. template -std::shared_ptr allocSharedPhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) { -#if (CUDA_NVLS_SUPPORTED) - if (!isNvlsSupported()) { - throw Error("Only support GPU with NVLS support", ErrorCode::InvalidUsage); - } - if (count == 0) { - return nullptr; - } - - if (gran == 0) { - gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); - } - size_t nelems = ((count * sizeof(T) + gran - 1) / gran * gran) / sizeof(T); - return detail::safeAlloc, CudaPhysicalDeleter, std::shared_ptr>(nelems, gran); -#else - throw Error("Only support GPU with Fabric support", ErrorCode::InvalidUsage); -#endif +auto gpuCallocUnique(size_t nelems = 1) { + return detail::safeAlloc, UniqueGpuPtr>(detail::gpuCalloc, nelems); } -/// Allocates memory on the device and returns a std::shared_ptr to it. The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @return A std::shared_ptr to the allocated memory. template -std::shared_ptr allocExtSharedCuda(size_t count = 1) { - return detail::safeAlloc, CudaDeleter, std::shared_ptr>(count); +auto gpuCallocHostShared(size_t nelems = 1) { + return detail::safeAlloc, std::shared_ptr>(detail::gpuCallocHost, nelems); } -/// Unique device pointer that will call cudaFree on destruction. -/// @tparam T Type of each element in the allocated memory. template -using UniqueCudaPtr = std::unique_ptr>; +auto gpuCallocHostUnique(size_t nelems = 1) { + return detail::safeAlloc, UniqueGpuHostPtr>(detail::gpuCallocHost, nelems); +} + +#if defined(__HIP_PLATFORM_AMD__) -/// Allocates memory on the device and returns a std::unique_ptr to it. The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @return A std::unique_ptr to the allocated memory. template -UniqueCudaPtr allocUniqueCuda(size_t count = 1) { - return detail::safeAlloc, CudaDeleter, UniqueCudaPtr>(count); +auto gpuCallocUncachedShared(size_t nelems = 1) { + return detail::safeAlloc, std::shared_ptr>(detail::gpuCallocUncached, nelems); } -/// Allocates memory on the device and returns a std::unique_ptr to it. The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @return A std::unique_ptr to the allocated memory. template -UniqueCudaPtr allocExtUniqueCuda(size_t count = 1) { - return detail::safeAlloc, CudaDeleter, UniqueCudaPtr>(count); +auto gpuCallocUncachedUnique(size_t nelems = 1) { + return detail::safeAlloc, UniqueGpuPtr>(detail::gpuCallocUncached, nelems); } -/// Allocates memory with cudaHostAlloc, constructs an object of type T in it and returns a std::shared_ptr to it. -/// @tparam T Type of the object to construct. -/// @tparam Args Types of the arguments to pass to the constructor. -/// @param args Arguments to pass to the constructor. -/// @return A std::shared_ptr to the allocated memory. -template -std::shared_ptr makeSharedCudaHost(Args&&... args) { - auto ptr = detail::safeAlloc, CudaHostDeleter, std::shared_ptr>(1); - new (ptr.get()) T(std::forward(args)...); - return ptr; -} +#endif // defined(__HIP_PLATFORM_AMD__) + +#if (CUDA_NVLS_SUPPORTED) -/// Allocates an array of objects of type T with cudaHostAlloc, default constructs each element and returns a -/// std::shared_ptr to it. -/// @tparam T Type of the object to construct. -/// @param count Number of elements to allocate. -/// @return A std::shared_ptr to the allocated memory. template -std::shared_ptr makeSharedCudaHost(size_t count) { - using TElem = std::remove_extent_t; - auto ptr = detail::safeAlloc, CudaHostDeleter, std::shared_ptr>(count); - for (size_t i = 0; i < count; ++i) { - new (&ptr[i]) TElem(); - } - return ptr; +using UniqueGpuPhysicalPtr = std::unique_ptr>; + +template +auto gpuCallocPhysicalShared(size_t nelems = 1, size_t gran = 0, size_t align = 0) { + return detail::safeAlloc, std::shared_ptr>(detail::gpuCallocPhysical, nelems, + gran, align); } -/// Unique CUDA host pointer that will call cudaFreeHost on destruction. -/// @tparam T Type of each element in the allocated memory. template -using UniqueCudaHostPtr = std::unique_ptr>; - -/// Allocates memory with cudaHostAlloc, constructs an object of type T in it and returns a std::unique_ptr to it. -/// @tparam T Type of the object to construct. -/// @tparam Args Types of the arguments to pass to the constructor. -/// @param args Arguments to pass to the constructor. -/// @return A std::unique_ptr to the allocated memory. -template , bool> = true> -UniqueCudaHostPtr makeUniqueCudaHost(Args&&... args) { - auto ptr = detail::safeAlloc, CudaHostDeleter, UniqueCudaHostPtr>(1); - new (ptr.get()) T(std::forward(args)...); - return ptr; +auto gpuCallocPhysicalUnique(size_t nelems = 1, size_t gran = 0, size_t align = 0) { + return detail::safeAlloc, UniqueGpuPhysicalPtr>(detail::gpuCallocPhysical, nelems, + gran, align); } -/// Allocates an array of objects of type T with cudaHostAlloc, default constructs each element and returns a -/// std::unique_ptr to it. -/// @tparam T Type of the object to construct. -/// @param count Number of elements to allocate. -/// @return A std::unique_ptr to the allocated memory. -template , bool> = true> -UniqueCudaHostPtr makeUniqueCudaHost(size_t count) { - using TElem = std::remove_extent_t; - auto ptr = detail::safeAlloc, CudaHostDeleter, UniqueCudaHostPtr>(count); - for (size_t i = 0; i < count; ++i) { - new (&ptr[i]) TElem(); - } - return ptr; +size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag); + +#endif // CUDA_NVLS_SUPPORTED + +} // namespace detail + +template +void gpuMemcpyAsync(T* dst, const T* src, size_t nelems, cudaStream_t stream, cudaMemcpyKind kind = cudaMemcpyDefault) { + detail::gpuMemcpyAsync(dst, src, nelems * sizeof(T), stream, kind); } -/// Allocated physical memory on the device and returns a memory handle along with a virtual memory handle for it. -/// The memory is zeroed out. -/// @tparam T Type of each element in the allocated memory. -/// @param count Number of elements to allocate. -/// @param gran the granularity of the allocation. -/// @return A std::unique_ptr to the allocated memory. -template -std::unique_ptr allocUniquePhysicalCuda([[maybe_unused]] size_t count, [[maybe_unused]] size_t gran = 0) { +template +void gpuMemcpy(T* dst, const T* src, size_t nelems, cudaMemcpyKind kind = cudaMemcpyDefault) { + detail::gpuMemcpy(dst, src, nelems * sizeof(T), kind); +} + +bool isNvlsSupported(); + +/// Allocates a GPU memory space specialized for communication. The memory is zeroed out. Get the device pointer by +/// `GpuBuffer::data()`. +/// +/// Use this function for communication buffers, i.e., only when other devices (CPU, GPU, NIC, etc.) may access this +/// memory space at the same time with the local device (GPU). Running heavy computation over this memory space +/// may perform bad and is not recommended in general. +/// +/// The allocated memory space is managed by the `memory_` object, not by the class instance. Which means, +/// the class destructor will NOT free the allocated memory if `memory_` is shared with and alive in other contexts. +/// +/// @tparam T Type of each element in the allocated memory. Default is `char`. +/// +template +class GpuBuffer { + public: + /// Constructs a GpuBuffer with the specified number of elements. + /// @param nelems Number of elements to allocate. If it is zero, `data()` will return a null pointer. + GpuBuffer(size_t nelems) : nelems_(nelems) { + if (nelems == 0) { + bytes_ = 0; + return; + } #if (CUDA_NVLS_SUPPORTED) - if (!isNvlsSupported()) { - throw Error("Only support GPU with NVLS support", ErrorCode::InvalidUsage); - } - if (count == 0) { - return nullptr; - } + if (isNvlsSupported()) { + size_t gran = detail::getMulticastGranularity(nelems * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); + bytes_ = (nelems * sizeof(T) + gran - 1) / gran * gran / sizeof(T) * sizeof(T); + memory_ = detail::gpuCallocPhysicalShared(nelems, gran); + return; + } +#endif // CUDA_NVLS_SUPPORTED - if (gran == 0) { - gran = getMulticastGranularity(count * sizeof(T), CU_MULTICAST_GRANULARITY_RECOMMENDED); + bytes_ = nelems * sizeof(T); +#if defined(__HIP_PLATFORM_AMD__) + memory_ = detail::gpuCallocUncachedShared(nelems); +#else // !defined(__HIP_PLATFORM_AMD__) + memory_ = detail::gpuCallocShared(nelems); +#endif // !defined(__HIP_PLATFORM_AMD__) } - return detail::safeAlloc, CudaPhysicalDeleter, - std::unique_ptr, CudaDeleter>>>(count, gran); -#else - throw Error("Only support GPU with Fabric support", ErrorCode::InvalidUsage); -#endif -} -/// Asynchronous cudaMemcpy without capture into a CUDA graph. -/// @tparam T Type of each element in the allocated memory. -/// @param dst Destination pointer. -/// @param src Source pointer. -/// @param count Number of elements to copy. -/// @param stream CUDA stream to use. -/// @param kind Type of cudaMemcpy to perform. -template -void memcpyCudaAsync(T* dst, const T* src, size_t count, cudaStream_t stream, cudaMemcpyKind kind = cudaMemcpyDefault) { - AvoidCudaGraphCaptureGuard cgcGuard; - MSCCLPP_CUDATHROW(cudaMemcpyAsync(dst, src, count * sizeof(T), kind, stream)); -} + /// Returns the number of elements in the allocated memory. + /// @return The number of elements. + size_t nelems() const { return nelems_; } -/// Synchronous cudaMemcpy without capture into a CUDA graph. -/// @tparam T Type of each element in the allocated memory. -/// @param dst Destination pointer. -/// @param src Source pointer. -/// @param count Number of elements to copy. -/// @param kind Type of cudaMemcpy to perform. -template -void memcpyCuda(T* dst, const T* src, size_t count, cudaMemcpyKind kind = cudaMemcpyDefault) { - AvoidCudaGraphCaptureGuard cgcGuard; - CudaStreamWithFlags stream(cudaStreamNonBlocking); - MSCCLPP_CUDATHROW(cudaMemcpyAsync(dst, src, count * sizeof(T), kind, stream)); - MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); -} + /// Returns the number of bytes that is actually allocated. This may be larger than `nelems() * sizeof(T)`. + /// @return The number of bytes. + size_t bytes() const { return bytes_; } + + /// Returns the shared pointer to the allocated memory. + /// If `nelems()` is zero, this function will return an empty shared pointer. + /// @return A `std::shared_ptr` to the allocated memory. + std::shared_ptr memory() { return memory_; } + + /// Returns the device pointer to the allocated memory. Equivalent to `memory().get()`. + /// If `nelems()` is zero, this function will return a null pointer. + /// @return A device pointer to the allocated memory. + T* data() { return memory_.get(); } + + private: + size_t nelems_; + size_t bytes_; + std::shared_ptr memory_; +}; } // namespace mscclpp diff --git a/include/mscclpp/npkit/npkit.hpp b/include/mscclpp/npkit/npkit.hpp index 24caf3603..803ddd444 100644 --- a/include/mscclpp/npkit/npkit.hpp +++ b/include/mscclpp/npkit/npkit.hpp @@ -83,18 +83,18 @@ class NpKit { // 64K * 2 (send/recv) * (1024/64) = 2M, 2M * 64 * 16B = 2GB per CPU static const uint64_t kMaxNumCpuEventsPerBuffer = 1ULL << 21; - static std::vector> gpu_event_buffers_; + static std::vector> gpu_event_buffers_; static std::vector> cpu_event_buffers_; - static mscclpp::UniqueCudaPtr gpu_collect_contexts_; + static mscclpp::detail::UniqueGpuPtr gpu_collect_contexts_; static std::unique_ptr cpu_collect_contexts_; static uint64_t rank_; #if defined(__HIP_PLATFORM_AMD__) - static mscclpp::UniqueCudaHostPtr cpu_timestamp_; + static mscclpp::detail::UniqueGpuHostPtr cpu_timestamp_; #else - static mscclpp::UniqueCudaHostPtr cpu_timestamp_; + static mscclpp::detail::UniqueGpuHostPtr cpu_timestamp_; #endif static std::unique_ptr cpu_timestamp_update_thread_; static volatile bool cpu_timestamp_update_thread_should_stop_; diff --git a/include/mscclpp/nvls.hpp b/include/mscclpp/nvls.hpp index 36ad614ba..90915cf74 100644 --- a/include/mscclpp/nvls.hpp +++ b/include/mscclpp/nvls.hpp @@ -16,9 +16,6 @@ class NvlsConnection { NvlsConnection() = delete; std::vector serialize(); - // the recommended buffer size for NVLS, returned by cuMulticastGetGranularity - static const int DefaultNvlsBufferSize; - // Everyone needs to synchronize after creating a NVLS connection before adding devices void addDevice(); void addDevice(int cudaDeviceId); @@ -39,10 +36,10 @@ class NvlsConnection { friend class NvlsConnection; }; - /// @brief bind the allocated memory via @ref mscclpp::allocSharedPhysicalCuda to the multicast handle. The behavior - /// is undefined if the devicePtr is not allocated by @ref mscclpp::allocSharedPhysicalCuda. - /// @param devicePtr - /// @param size + /// @brief bind the memory allocated via @ref mscclpp::GpuBuffer to the multicast handle. The behavior + /// is undefined if the devicePtr is not allocated by @ref mscclpp::GpuBuffer. + /// @param devicePtr The device pointer returned by `mscclpp::GpuBuffer::data()`. + /// @param size The bytes of the memory to bind to the multicast handle. /// @return DeviceMulticastPointer with devicePtr, mcPtr and bufferSize DeviceMulticastPointer bindAllocatedMemory(CUdeviceptr devicePtr, size_t size); @@ -65,7 +62,7 @@ class Communicator; /// @param config The configuration for the local endpoint. /// @return std::shared_ptr A shared pointer to the NVLS connection. std::shared_ptr connectNvlsCollective(std::shared_ptr comm, std::vector allRanks, - size_t bufferSize = NvlsConnection::DefaultNvlsBufferSize); + size_t bufferSize); } // namespace mscclpp diff --git a/include/mscclpp/semaphore.hpp b/include/mscclpp/semaphore.hpp index 5f1800990..b28373bdc 100644 --- a/include/mscclpp/semaphore.hpp +++ b/include/mscclpp/semaphore.hpp @@ -64,7 +64,7 @@ class BaseSemaphore { }; /// A semaphore for sending signals from the host to the device. -class Host2DeviceSemaphore : public BaseSemaphore { +class Host2DeviceSemaphore : public BaseSemaphore { private: std::shared_ptr connection_; @@ -117,7 +117,7 @@ class Host2HostSemaphore : public BaseSemaphore { +class SmDevice2DeviceSemaphore : public BaseSemaphore { public: /// Constructor. /// @param communicator The communicator. diff --git a/include/mscclpp/utils.hpp b/include/mscclpp/utils.hpp index 80b3bf39d..c8ef3d271 100644 --- a/include/mscclpp/utils.hpp +++ b/include/mscclpp/utils.hpp @@ -37,8 +37,6 @@ struct ScopedTimer : public Timer { std::string getHostName(int maxlen, const char delim); -bool isNvlsSupported(); - } // namespace mscclpp #endif // MSCCLPP_UTILS_HPP_ diff --git a/python/mscclpp/__init__.py b/python/mscclpp/__init__.py index 1c2567f42..410ad246e 100644 --- a/python/mscclpp/__init__.py +++ b/python/mscclpp/__init__.py @@ -24,9 +24,9 @@ Executor, ExecutionPlan, PacketType, + RawGpuBuffer, version, is_nvls_supported, - alloc_shared_physical_cuda, npkit, ) diff --git a/python/mscclpp/gpu_utils_py.cpp b/python/mscclpp/gpu_utils_py.cpp index 32c578fb7..db57e61cf 100644 --- a/python/mscclpp/gpu_utils_py.cpp +++ b/python/mscclpp/gpu_utils_py.cpp @@ -1,30 +1,18 @@ #include #include -// #include #include #include namespace nb = nanobind; using namespace mscclpp; -class PyCudaMemory { - public: - PyCudaMemory(size_t size) : size_(size) { ptr_ = allocSharedPhysicalCuda(size); } - - uintptr_t getPtr() const { return (uintptr_t)(ptr_.get()); } - size_t size() const { return size_; } - - private: - std::shared_ptr ptr_; - size_t size_; -}; - void register_gpu_utils(nb::module_& m) { - nb::class_(m, "PyCudaMemory") - .def(nb::init(), nb::arg("size")) - .def("get_ptr", &PyCudaMemory::getPtr, "Get the raw pointer") - .def("size", &PyCudaMemory::size, "Get the size of the allocated memory"); - m.def( - "alloc_shared_physical_cuda", [](size_t size) { return std::make_shared(size); }, nb::arg("size")); + m.def("is_nvls_supported", &isNvlsSupported); + + nb::class_>(m, "RawGpuBuffer") + .def(nb::init(), nb::arg("nelems")) + .def("nelems", &GpuBuffer::nelems) + .def("bytes", &GpuBuffer::bytes) + .def("data", [](GpuBuffer& self) { return reinterpret_cast(self.data()); }); } diff --git a/python/mscclpp/nvls_py.cpp b/python/mscclpp/nvls_py.cpp index 91b966bd8..e48587761 100644 --- a/python/mscclpp/nvls_py.cpp +++ b/python/mscclpp/nvls_py.cpp @@ -34,5 +34,5 @@ void register_nvls(nb::module_& m) { .def("get_multicast_min_granularity", &NvlsConnection::getMultiCastMinGranularity); m.def("connect_nvls_collective", &connectNvlsCollective, nb::arg("communicator"), nb::arg("allRanks"), - nb::arg("bufferSize") = NvlsConnection::DefaultNvlsBufferSize); + nb::arg("bufferSize")); } diff --git a/python/mscclpp/utils.py b/python/mscclpp/utils.py index eeac96d12..08e3d4a9b 100644 --- a/python/mscclpp/utils.py +++ b/python/mscclpp/utils.py @@ -6,10 +6,11 @@ import struct import subprocess import tempfile -from typing import Any, Type +from typing import Any, Type, Union, Tuple import cupy as cp import numpy as np +from ._mscclpp import RawGpuBuffer try: import torch @@ -36,7 +37,7 @@ def launch_kernel( nblocks: int, nthreads: int, shared: int, - stream: Type[cp.cuda.Stream] or Type[None], + stream: Union[cp.cuda.Stream, None], ): buffer = (ctypes.c_byte * len(params)).from_buffer_copy(params) buffer_size = ctypes.c_size_t(len(params)) @@ -137,6 +138,25 @@ def __del__(self): self._tempdir.cleanup() +class GpuBuffer(cp.ndarray): + def __new__( + cls, shape: Union[int, Tuple[int]], dtype: cp.dtype = float, strides: Tuple[int] = None, order: str = "C" + ): + # Check if `shape` is valid + if isinstance(shape, int): + shape = (shape,) + try: + shape = tuple(shape) + except TypeError: + raise ValueError("Shape must be a tuple-like or an integer.") + if any(s <= 0 for s in shape): + raise ValueError("Shape must be positive.") + # Create the buffer + buffer = RawGpuBuffer(np.prod(shape) * np.dtype(dtype).itemsize) + memptr = cp.cuda.MemoryPointer(cp.cuda.UnownedMemory(buffer.data(), buffer.bytes(), buffer), 0) + return cp.ndarray(shape, dtype=dtype, strides=strides, order=order, memptr=memptr) + + def pack(*args): res = b"" for arg in list(args): diff --git a/python/mscclpp/utils_py.cpp b/python/mscclpp/utils_py.cpp index e9e847ee8..16800a752 100644 --- a/python/mscclpp/utils_py.cpp +++ b/python/mscclpp/utils_py.cpp @@ -20,5 +20,4 @@ void register_utils(nb::module_& m) { nb::class_(m, "ScopedTimer").def(nb::init(), nb::arg("name")); m.def("get_host_name", &getHostName, nb::arg("maxlen"), nb::arg("delim")); - m.def("is_nvls_supported", &isNvlsSupported); } diff --git a/python/mscclpp_benchmark/allreduce_bench.py b/python/mscclpp_benchmark/allreduce_bench.py index e93c0479e..cf679a887 100644 --- a/python/mscclpp_benchmark/allreduce_bench.py +++ b/python/mscclpp_benchmark/allreduce_bench.py @@ -15,6 +15,7 @@ import cupy.cuda.nccl as nccl import mscclpp.comm as mscclpp_comm from mscclpp import ProxyService, is_nvls_supported +from mscclpp.utils import GpuBuffer from prettytable import PrettyTable import netifaces as ni import ipaddress @@ -162,8 +163,8 @@ def find_best_config(mscclpp_call, niter): def run_benchmark( mscclpp_group: mscclpp_comm.CommGroup, nccl_op: nccl.NcclCommunicator, table: PrettyTable, niter: int, nelem: int ): - memory = cp.zeros(nelem, dtype=data_type) - memory_out = cp.zeros(nelem, dtype=data_type) + memory = GpuBuffer(nelem, dtype=data_type) + memory_out = GpuBuffer(nelem, dtype=data_type) cp.cuda.runtime.deviceSynchronize() proxy_service = ProxyService() diff --git a/python/mscclpp_benchmark/mscclpp_op.py b/python/mscclpp_benchmark/mscclpp_op.py index 88840a743..c2af7a4fc 100644 --- a/python/mscclpp_benchmark/mscclpp_op.py +++ b/python/mscclpp_benchmark/mscclpp_op.py @@ -1,9 +1,9 @@ import os import cupy as cp import ctypes -from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore, alloc_shared_physical_cuda +from mscclpp import Transport, ProxyService, SmDevice2DeviceSemaphore import mscclpp.comm as mscclpp_comm -from mscclpp.utils import KernelBuilder, pack +from mscclpp.utils import KernelBuilder, GpuBuffer, pack IB_TRANSPORTS = [ @@ -115,7 +115,7 @@ def __init__( self.connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc) type_str = type_to_str(memory.dtype) - self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype) + self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype) # create a sm_channel for each remote neighbor self.sm_channels = self.group.make_sm_channels_with_scratch(self.memory, self.scratch, self.connections) file_dir = os.path.dirname(os.path.abspath(__file__)) @@ -179,7 +179,7 @@ def __init__( type_str = type_to_str(memory.dtype) self.proxy_service = proxy_service - self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype) + self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype) # create a sm_channel for each remote neighbor self.fst_round_proxy_chans = self.group.make_proxy_channels_with_scratch( @@ -259,7 +259,7 @@ def __init__( type_str = type_to_str(memory.dtype) self.proxy_service = proxy_service - self.scratch = cp.zeros(self.memory.size, dtype=self.memory.dtype) + self.scratch = GpuBuffer(self.memory.size, dtype=self.memory.dtype) same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)} # create a sm_channel for each remote neighbor self.sm_channels = self.group.make_sm_channels(self.memory, same_node_connections) @@ -362,8 +362,8 @@ def __init__( type_str = type_to_str(memory.dtype) self.proxy_service = proxy_service - self.scratch = cp.zeros(self.memory.size * 8, dtype=self.memory.dtype) - self.put_buff = cp.zeros(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype) + self.scratch = GpuBuffer(self.memory.size * 8, dtype=self.memory.dtype) + self.put_buff = GpuBuffer(self.memory.size * 8 // nranks_per_node, dtype=self.memory.dtype) same_node_connections = {rank: conn for rank, conn in self.connections.items() if in_same_node(rank)} across_node_connections = {rank: conn for rank, conn in self.connections.items() if not in_same_node(rank)} # create a sm_channel for each remote neighbor @@ -441,18 +441,10 @@ def __init__( # create a connection for each remote neighbor self.nvlink_connections = self.group.make_connection(remote_nghrs, Transport.CudaIpc) self.nvls_connection = group.make_connection(all_ranks, Transport.Nvls) - min_gran = self.nvls_connection.get_multicast_min_granularity() - aligned_buffer_size = int(((buffer_size + min_gran - 1) // min_gran) * min_gran) - buffer_raw = alloc_shared_physical_cuda(aligned_buffer_size) + self.memory = GpuBuffer(nelem, memory_dtype) self.nvls_mem_handle = self.nvls_connection.bind_allocated_memory( - buffer_raw.get_ptr(), aligned_buffer_size - ) # just using recommended size for now - self.memory_ptr = self.nvls_mem_handle.get_device_ptr() - - self.cp_memory_ptr = cp.cuda.MemoryPointer( - cp.cuda.UnownedMemory(self.memory_ptr, aligned_buffer_size, buffer_raw), 0 + self.memory.data.ptr, self.memory.data.mem.size ) - self.memory = cp.ndarray(nelem, memory_dtype, self.cp_memory_ptr) # create a sm_channel for each remote neighbor self.semaphores = group.make_semaphore(self.nvlink_connections, SmDevice2DeviceSemaphore) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 67e9929f1..c973ae0e9 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -8,11 +8,9 @@ ExecutionPlan, PacketType, npkit, - alloc_shared_physical_cuda, - is_nvls_supported, ) import mscclpp.comm as mscclpp_comm -from mscclpp.utils import KernelBuilder, pack +from mscclpp.utils import KernelBuilder, GpuBuffer, pack import os import struct @@ -129,18 +127,6 @@ def dtype_to_mscclpp_dtype(dtype): raise ValueError(f"Unknown data type: {dtype}") -def allocate_buffer(nelems, dtype): - if is_nvls_supported(): - buffer_raw = alloc_shared_physical_cuda(nelems * cp.dtype(dtype).itemsize) - buffer_ptr = cp.cuda.MemoryPointer( - cp.cuda.UnownedMemory(buffer_raw.get_ptr(), buffer_raw.size(), buffer_raw), 0 - ) - buffer = cp.ndarray(nelems, dtype=dtype, memptr=buffer_ptr) - return buffer - else: - return cp.zeros(nelems, dtype=dtype) - - def build_bufs( collective: str, size: int, @@ -160,14 +146,14 @@ def build_bufs( nelems_input = nelems nelems_output = nelems - result_buf = allocate_buffer(nelems_output, dtype=dtype) + result_buf = GpuBuffer(nelems_output, dtype=dtype) if in_place: if "allgather" in collective: input_buf = cp.split(result_buf, num_ranks)[rank] else: input_buf = result_buf else: - input_buf = allocate_buffer(nelems_input, dtype=dtype) + input_buf = GpuBuffer(nelems_input, dtype=dtype) test_buf = cp.zeros(nelems_output, dtype=dtype) return input_buf, result_buf, test_buf diff --git a/python/test/test_mscclpp.py b/python/test/test_mscclpp.py index 929d975c6..976d74362 100644 --- a/python/test/test_mscclpp.py +++ b/python/test/test_mscclpp.py @@ -27,7 +27,7 @@ npkit, ) import mscclpp.comm as mscclpp_comm -from mscclpp.utils import KernelBuilder, pack +from mscclpp.utils import KernelBuilder, GpuBuffer, pack from ._cpp import _ext from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group @@ -156,12 +156,26 @@ def test_group_with_connections(mpi_group: MpiGroup, transport: str): create_group_and_connection(mpi_group, transport) +@parametrize_mpi_groups(1) +@pytest.mark.parametrize("nelem", [2**i for i in [0, 10, 15, 20]]) +@pytest.mark.parametrize("dtype", [cp.float32, cp.float16]) +def test_gpu_buffer(mpi_group: MpiGroup, nelem: int, dtype: cp.dtype): + memory = GpuBuffer(nelem, dtype=dtype) + assert memory.shape == (nelem,) + assert memory.dtype == dtype + assert memory.itemsize == cp.dtype(dtype).itemsize + assert memory.nbytes == nelem * cp.dtype(dtype).itemsize + assert memory.data.ptr != 0 + assert memory.data.mem.ptr != 0 + assert memory.data.mem.size >= nelem * cp.dtype(dtype).itemsize + + @parametrize_mpi_groups(2, 4, 8, 16) @pytest.mark.parametrize("transport", ["IB", "NVLink"]) @pytest.mark.parametrize("nelem", [2**i for i in [10, 15, 20]]) def test_connection_write(mpi_group: MpiGroup, transport: Transport, nelem: int): group, connections = create_group_and_connection(mpi_group, transport) - memory = cp.zeros(nelem, dtype=cp.int32) + memory = GpuBuffer(nelem, dtype=cp.int32) nelemPerRank = nelem // group.nranks sizePerRank = nelemPerRank * memory.itemsize memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 @@ -436,13 +450,12 @@ def test_d2d_semaphores(mpi_group: MpiGroup): def test_sm_channels(mpi_group: MpiGroup, nelem: int, use_packet: bool): group, connections = create_group_and_connection(mpi_group, "NVLink") - memory = cp.zeros(nelem, dtype=cp.int32) + memory = GpuBuffer(nelem, dtype=cp.int32) if use_packet: - scratch = cp.zeros(nelem * 2, dtype=cp.int32) + scratch = GpuBuffer(nelem * 2, dtype=cp.int32) else: scratch = None nelemPerRank = nelem // group.nranks - nelemPerRank * memory.itemsize memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 memory_expected = cp.zeros_like(memory) for rank in range(group.nranks): @@ -484,7 +497,7 @@ def test_fifo( def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): group, connections = create_group_and_connection(mpi_group, transport) - memory = cp.zeros(nelem, dtype=cp.int32) + memory = GpuBuffer(nelem, dtype=cp.int32) nelemPerRank = nelem // group.nranks nelemPerRank * memory.itemsize memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 @@ -534,11 +547,11 @@ def test_proxy(mpi_group: MpiGroup, nelem: int, transport: str): def test_proxy_channel(mpi_group: MpiGroup, nelem: int, transport: str, use_packet: bool): group, connections = create_group_and_connection(mpi_group, transport) - memory = cp.zeros(nelem, dtype=cp.int32) + memory = GpuBuffer(nelem, dtype=cp.int32) if use_packet: - scratch = cp.zeros(nelem * 2, dtype=cp.int32) + scratch = GpuBuffer(nelem * 2, dtype=cp.int32) else: - scratch = cp.zeros(1, dtype=cp.int32) # just so that we can pass a valid ptr + scratch = GpuBuffer(1, dtype=cp.int32) # just so that we can pass a valid ptr nelemPerRank = nelem // group.nranks nelemPerRank * memory.itemsize memory[(nelemPerRank * group.my_rank) : (nelemPerRank * (group.my_rank + 1))] = group.my_rank + 1 diff --git a/src/connection.cc b/src/connection.cc index 6a5b554d5..0781afe25 100644 --- a/src/connection.cc +++ b/src/connection.cc @@ -301,8 +301,8 @@ void EthernetConnection::write(RegisteredMemory dst, uint64_t dstOffset, Registe uint64_t dataSize = std::min(sendBufferSize_ - headerSize / sizeof(char), (size - sentDataSize) / sizeof(char)) * sizeof(char); uint64_t messageSize = dataSize + headerSize; - mscclpp::memcpyCuda(sendBuffer_.data() + headerSize / sizeof(char), - (char*)srcPtr + (sentDataSize / sizeof(char)), dataSize, cudaMemcpyDeviceToHost); + mscclpp::gpuMemcpy(sendBuffer_.data() + headerSize / sizeof(char), srcPtr + (sentDataSize / sizeof(char)), dataSize, + cudaMemcpyDeviceToHost); sendSocket_->send(sendBuffer_.data(), messageSize); sentDataSize += messageSize; headerSize = 0; @@ -402,8 +402,7 @@ void EthernetConnection::recvMessages() { received &= !closed; if (received) - mscclpp::memcpyCuda((char*)ptr + (recvSize / sizeof(char)), recvBuffer_.data(), messageSize, - cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(ptr + (recvSize / sizeof(char)), recvBuffer_.data(), messageSize, cudaMemcpyHostToDevice); recvSize += messageSize; } diff --git a/src/executor/executor.cc b/src/executor/executor.cc index d2e5ac7e2..944ddb254 100644 --- a/src/executor/executor.cc +++ b/src/executor/executor.cc @@ -155,10 +155,10 @@ struct Executor::Impl { plan.impl_->lightLoadExecutionPlan(inputMessageSize, outputMessageSize, constSrcOffset, constDstOffset); this->setupDeviceExecutionPlan(this->contexts[key], devicePlanKey, rank, plan); this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey] = - allocExtSharedCuda(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(), - (char*)devicePlans[devicePlanKey].data(), - devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); + GpuBuffer(devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)).memory(); + gpuMemcpy(this->contexts[key].deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)devicePlans[devicePlanKey].data(), + devicePlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); this->contexts[key].currentDevicePlan = devicePlanKey; return this->contexts[key]; } @@ -170,12 +170,7 @@ struct Executor::Impl { size_t maxScratchBufferSize = plan.impl_->getMaxScratchBufferSize(rank); size_t scratchBufferSize = std::min(plan.impl_->getScratchBufferSize(rank, sendMemRange, recvMemRange), maxScratchBufferSize); - std::shared_ptr scratchBuffer; - if (isNvlsSupported()) { - scratchBuffer = allocSharedPhysicalCuda(scratchBufferSize); - } else { - scratchBuffer = allocExtSharedCuda(scratchBufferSize); - } + std::shared_ptr scratchBuffer = GpuBuffer(scratchBufferSize).memory(); context.scratchBuffer = scratchBuffer; context.scratchBufferSize = scratchBufferSize; context.proxyService = std::make_shared(); @@ -186,11 +181,10 @@ struct Executor::Impl { this->setupNvlsChannels(context, sendbuff, recvbuff, sendMemRange, recvMemRange, rank, plan); this->setupDeviceExecutionPlan(context, devicePlanKey, rank, plan); context.deviceExecutionPlansBuffers[devicePlanKey] = - allocExtSharedCuda(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)); - memcpyCuda(context.deviceExecutionPlansBuffers[devicePlanKey].get(), - (char*)context.deviceExecutionPlans[devicePlanKey].data(), - context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), - cudaMemcpyHostToDevice); + GpuBuffer(context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan)).memory(); + gpuMemcpy(context.deviceExecutionPlansBuffers[devicePlanKey].get(), + (char*)context.deviceExecutionPlans[devicePlanKey].data(), + context.deviceExecutionPlans[devicePlanKey].size() * sizeof(DeviceExecutionPlan), cudaMemcpyHostToDevice); context.currentDevicePlan = devicePlanKey; context.proxyService->startProxy(); this->contexts.insert({key, context}); diff --git a/src/fifo.cc b/src/fifo.cc index 592bf7d00..43e73ca92 100644 --- a/src/fifo.cc +++ b/src/fifo.cc @@ -10,9 +10,9 @@ namespace mscclpp { struct Fifo::Impl { - UniqueCudaHostPtr triggers; - UniqueCudaPtr head; - UniqueCudaPtr tailReplica; + detail::UniqueGpuHostPtr triggers; + detail::UniqueGpuPtr head; + detail::UniqueGpuPtr tailReplica; const int size; // allocated on the host. Only accessed by the host. This is a copy of the @@ -28,9 +28,9 @@ struct Fifo::Impl { CudaStreamWithFlags stream; Impl(int size) - : triggers(makeUniqueCudaHost(size)), - head(allocUniqueCuda()), - tailReplica(allocUniqueCuda()), + : triggers(detail::gpuCallocHostUnique(size)), + head(detail::gpuCallocUnique()), + tailReplica(detail::gpuCallocUnique()), size(size), hostTail(0), stream(cudaStreamNonBlocking) {} diff --git a/src/gpu_utils.cc b/src/gpu_utils.cc new file mode 100644 index 000000000..c70cdcfa1 --- /dev/null +++ b/src/gpu_utils.cc @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#include +#include +#include + +namespace mscclpp { + +namespace detail { + +/// set memory access permission to read-write +/// @param base Base memory pointer. +/// @param size Size of the memory. +void setReadWriteMemoryAccess(void* base, size_t size) { + CUmemAccessDesc accessDesc = {}; + int deviceId; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.location.id = deviceId; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + MSCCLPP_CUTHROW(cuMemSetAccess((CUdeviceptr)base, size, &accessDesc, 1)); +} + +void* gpuCalloc(size_t bytes) { + AvoidCudaGraphCaptureGuard cgcGuard; + void* ptr; + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(cudaMalloc(&ptr, bytes)); + MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, 0, bytes, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + return ptr; +} + +void* gpuCallocHost(size_t bytes) { + AvoidCudaGraphCaptureGuard cgcGuard; + void* ptr; + MSCCLPP_CUDATHROW(cudaHostAlloc(&ptr, bytes, cudaHostAllocMapped | cudaHostAllocWriteCombined)); + ::memset(ptr, 0, bytes); + return ptr; +} + +#if defined(__HIP_PLATFORM_AMD__) +void* gpuCallocUncached(size_t bytes) { + AvoidCudaGraphCaptureGuard cgcGuard; + void* ptr; + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(hipExtMallocWithFlags((void**)&ptr, bytes, hipDeviceMallocUncached)); + MSCCLPP_CUDATHROW(cudaMemsetAsync(ptr, 0, bytes, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + return ptr; +} +#endif // defined(__HIP_PLATFORM_AMD__) + +#if (CUDA_NVLS_SUPPORTED) +size_t getMulticastGranularity(size_t size, CUmulticastGranularity_flags granFlag) { + size_t gran = 0; + int numDevices = 0; + MSCCLPP_CUDATHROW(cudaGetDeviceCount(&numDevices)); + + CUmulticastObjectProp prop = {}; + prop.size = size; + // This is a dummy value, it might affect the granularity in the future + prop.numDevices = numDevices; + prop.handleTypes = (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); + prop.flags = 0; + MSCCLPP_CUTHROW(cuMulticastGetGranularity(&gran, &prop, granFlag)); + return gran; +} + +void* gpuCallocPhysical(size_t bytes, size_t gran, size_t align) { + AvoidCudaGraphCaptureGuard cgcGuard; + int deviceId = -1; + CUdevice currentDevice; + MSCCLPP_CUDATHROW(cudaGetDevice(&deviceId)); + MSCCLPP_CUTHROW(cuDeviceGet(¤tDevice, deviceId)); + + CUmemAllocationProp prop = {}; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.requestedHandleTypes = + (CUmemAllocationHandleType)(CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR | CU_MEM_HANDLE_TYPE_FABRIC); + prop.location.id = currentDevice; + + if (gran == 0) { + gran = getMulticastGranularity(bytes, CU_MULTICAST_GRANULARITY_RECOMMENDED); + } + + // allocate physical memory + CUmemGenericAllocationHandle memHandle; + size_t nbytes = (bytes + gran - 1) / gran * gran; + MSCCLPP_CUTHROW(cuMemCreate(&memHandle, nbytes, &prop, 0 /*flags*/)); + + if (align == 0) { + align = getMulticastGranularity(nbytes, CU_MULTICAST_GRANULARITY_MINIMUM); + } + + void* devicePtr = nullptr; + MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&devicePtr, nbytes, align, 0U, 0)); + MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)devicePtr, nbytes, 0, memHandle, 0)); + setReadWriteMemoryAccess(devicePtr, nbytes); + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(cudaMemsetAsync(devicePtr, 0, nbytes, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); + + return devicePtr; +} +#endif // CUDA_NVLS_SUPPORTED + +void gpuFree(void* ptr) { + AvoidCudaGraphCaptureGuard cgcGuard; + MSCCLPP_CUDATHROW(cudaFree(ptr)); +} + +void gpuFreeHost(void* ptr) { + AvoidCudaGraphCaptureGuard cgcGuard; + MSCCLPP_CUDATHROW(cudaFreeHost(ptr)); +} + +#if (CUDA_NVLS_SUPPORTED) +void gpuFreePhysical(void* ptr) { + AvoidCudaGraphCaptureGuard cgcGuard; + CUmemGenericAllocationHandle handle; + size_t size = 0; + MSCCLPP_CUTHROW(cuMemRetainAllocationHandle(&handle, ptr)); + MSCCLPP_CUTHROW(cuMemGetAddressRange(NULL, &size, (CUdeviceptr)ptr)); + MSCCLPP_CUTHROW(cuMemUnmap((CUdeviceptr)ptr, size)); + MSCCLPP_CUTHROW(cuMemRelease(handle)); + MSCCLPP_CUTHROW(cuMemAddressFree((CUdeviceptr)ptr, size)); +} +#endif // CUDA_NVLS_SUPPORTED + +void gpuMemcpyAsync(void* dst, const void* src, size_t bytes, cudaStream_t stream, cudaMemcpyKind kind) { + AvoidCudaGraphCaptureGuard cgcGuard; + MSCCLPP_CUDATHROW(cudaMemcpyAsync(dst, src, bytes, kind, stream)); +} + +void gpuMemcpy(void* dst, const void* src, size_t bytes, cudaMemcpyKind kind) { + AvoidCudaGraphCaptureGuard cgcGuard; + CudaStreamWithFlags stream(cudaStreamNonBlocking); + MSCCLPP_CUDATHROW(cudaMemcpyAsync(dst, src, bytes, kind, stream)); + MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream)); +} + +} // namespace detail + +bool isNvlsSupported() { + [[maybe_unused]] static bool result = false; + [[maybe_unused]] static bool isChecked = false; +#if (CUDA_NVLS_SUPPORTED) + if (!isChecked) { + int isMulticastSupported; + int isFabricSupported; + CUdevice dev; + MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); + MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isMulticastSupported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); + MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isFabricSupported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev)); + result = (isMulticastSupported == 1 && isFabricSupported == 1); + } + return result; +#endif + return false; +} + +} // namespace mscclpp diff --git a/src/npkit/npkit.cc b/src/npkit/npkit.cc index 77db6b90c..30fc35c75 100644 --- a/src/npkit/npkit.cc +++ b/src/npkit/npkit.cc @@ -12,16 +12,16 @@ uint64_t NpKit::rank_ = 0; -std::vector> NpKit::gpu_event_buffers_; +std::vector> NpKit::gpu_event_buffers_; std::vector> NpKit::cpu_event_buffers_; -mscclpp::UniqueCudaPtr NpKit::gpu_collect_contexts_; +mscclpp::detail::UniqueGpuPtr NpKit::gpu_collect_contexts_; std::unique_ptr NpKit::cpu_collect_contexts_; #if defined(__HIP_PLATFORM_AMD__) -mscclpp::UniqueCudaHostPtr NpKit::cpu_timestamp_; +mscclpp::detail::UniqueGpuHostPtr NpKit::cpu_timestamp_; #else -mscclpp::UniqueCudaHostPtr NpKit::cpu_timestamp_; +mscclpp::detail::UniqueGpuHostPtr NpKit::cpu_timestamp_; #endif std::unique_ptr NpKit::cpu_timestamp_update_thread_; volatile bool NpKit::cpu_timestamp_update_thread_should_stop_ = false; @@ -53,11 +53,11 @@ void NpKit::Init(int rank) { rank_ = rank; // Init event data structures - gpu_collect_contexts_ = mscclpp::allocUniqueCuda(NpKit::kNumGpuEventBuffers); + gpu_collect_contexts_ = mscclpp::detail::gpuCallocUnique(NpKit::kNumGpuEventBuffers); for (i = 0; i < NpKit::kNumGpuEventBuffers; i++) { - gpu_event_buffers_.emplace_back(mscclpp::allocUniqueCuda(kMaxNumGpuEventsPerBuffer)); + gpu_event_buffers_.emplace_back(mscclpp::detail::gpuCallocUnique(kMaxNumGpuEventsPerBuffer)); ctx.event_buffer = gpu_event_buffers_[i].get(); - mscclpp::memcpyCuda(gpu_collect_contexts_.get() + i, &ctx, 1); + mscclpp::gpuMemcpy(gpu_collect_contexts_.get() + i, &ctx, 1); } cpu_collect_contexts_ = std::make_unique(NpKit::kNumCpuEventBuffers); @@ -69,15 +69,15 @@ void NpKit::Init(int rank) { #if defined(__HIP_PLATFORM_AMD__) // Init timestamp. Allocates MAXCHANNELS*128 bytes buffer for GPU - cpu_timestamp_ = mscclpp::makeUniqueCudaHost(NPKIT_MAX_NUM_GPU_THREADBLOCKS * - NPKIT_CPU_TIMESTAMP_SLOT_SIZE / sizeof(uint64_t)); + cpu_timestamp_ = mscclpp::detail::gpuCallocHostUnique(NPKIT_MAX_NUM_GPU_THREADBLOCKS * + NPKIT_CPU_TIMESTAMP_SLOT_SIZE / sizeof(uint64_t)); for (int i = 0; i < NPKIT_MAX_NUM_GPU_THREADBLOCKS; i++) { NPKIT_STORE_CPU_TIMESTAMP_PER_BLOCK(cpu_timestamp_.get(), std::chrono::system_clock::now().time_since_epoch().count(), i); } #else // Init timestamp - cpu_timestamp_ = mscclpp::makeUniqueCudaHost(); + cpu_timestamp_ = mscclpp::detail::gpuCallocHostUnique(); volatile uint64_t* volatile_cpu_timestamp = cpu_timestamp_.get(); *volatile_cpu_timestamp = std::chrono::system_clock::now().time_since_epoch().count(); #endif @@ -153,8 +153,8 @@ void NpKit::Dump(const std::string& dump_dir) { dump_file_path += std::to_string(rank_); dump_file_path += "_buf_"; dump_file_path += std::to_string(i); - mscclpp::memcpyCuda(cpu_event_buffers_[0].get(), gpu_event_buffers_[i].get(), kMaxNumGpuEventsPerBuffer); - mscclpp::memcpyCuda(cpu_collect_contexts_.get(), gpu_collect_contexts_.get() + i, 1); + mscclpp::gpuMemcpy(cpu_event_buffers_[0].get(), gpu_event_buffers_[i].get(), kMaxNumGpuEventsPerBuffer); + mscclpp::gpuMemcpy(cpu_collect_contexts_.get(), gpu_collect_contexts_.get() + i, 1); auto gpu_trace_file = std::fstream(dump_file_path, std::ios::out | std::ios::binary); gpu_trace_file.write(reinterpret_cast(cpu_event_buffers_[0].get()), cpu_collect_contexts_[0].event_buffer_head * sizeof(NpKitEvent)); diff --git a/src/nvls.cc b/src/nvls.cc index 3221e6e00..07620fa9f 100644 --- a/src/nvls.cc +++ b/src/nvls.cc @@ -58,7 +58,7 @@ NvlsConnection::Impl::Impl(size_t bufferSize, int numDevices) { mcProp_.handleTypes = CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR; MSCCLPP_CUTHROW(cuMulticastGetGranularity(&minMcGran_, &mcProp_, CU_MULTICAST_GRANULARITY_MINIMUM)); MSCCLPP_CUTHROW(cuMulticastGetGranularity(&mcGran_, &mcProp_, CU_MULTICAST_GRANULARITY_RECOMMENDED)); - mcProp_.size = ((mcProp_.size + minMcGran_ - 1) / minMcGran_) * minMcGran_; + mcProp_.size = ((mcProp_.size + mcGran_ - 1) / mcGran_) * mcGran_; bufferSize_ = mcProp_.size; MSCCLPP_CUTHROW(cuMulticastCreate(&mcHandle_, &mcProp_)); mcFileDesc_ = 0; @@ -200,7 +200,7 @@ std::shared_ptr NvlsConnection::Impl::bindMemory(CUdeviceptr devicePtr, si char* mcPtr; MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)(&mcPtr), devBuffSize, minMcGran_, 0U, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)(mcPtr), devBuffSize, 0, mcHandle_, 0)); - setReadWriteMemoryAccess(mcPtr, devBuffSize); + detail::setReadWriteMemoryAccess(mcPtr, devBuffSize); INFO(MSCCLPP_COLL, "NVLS connection bound memory at offset %ld, size %ld", offset, devBuffSize); auto deleter = [=, self = shared_from_this()](char* ptr) { @@ -240,8 +240,6 @@ class NvlsConnection::Impl { }; #endif // !(CUDA_NVLS_SUPPORTED) -const int NvlsConnection::DefaultNvlsBufferSize = (1 << 29); - NvlsConnection::NvlsConnection(size_t bufferSize, int numDevices) : pimpl_(std::make_shared(bufferSize, numDevices)) {} diff --git a/src/registered_memory.cc b/src/registered_memory.cc index bf40470de..84b0ccc47 100644 --- a/src/registered_memory.cc +++ b/src/registered_memory.cc @@ -244,7 +244,7 @@ RegisteredMemory::Impl::Impl(const std::vector& serialization) { size_t gran = getRecommendedGranularity(); MSCCLPP_CUTHROW(cuMemAddressReserve((CUdeviceptr*)&base, this->size, gran, 0, 0)); MSCCLPP_CUTHROW(cuMemMap((CUdeviceptr)base, this->size, 0, handle, 0)); - setReadWriteMemoryAccess(base, this->size); + detail::setReadWriteMemoryAccess(base, this->size); this->data = static_cast(base) + entry.offsetFromBase; } else { MSCCLPP_CUDATHROW(cudaIpcOpenMemHandle(&base, entry.cudaIpcBaseHandle, cudaIpcMemLazyEnablePeerAccess)); diff --git a/src/semaphore.cc b/src/semaphore.cc index 348f1cdb1..c6238b532 100644 --- a/src/semaphore.cc +++ b/src/semaphore.cc @@ -19,9 +19,17 @@ static NonblockingFuture setupInboundSemaphoreId(Communicator& return communicator.recvMemoryOnSetup(remoteRank, tag); } +static detail::UniqueGpuPtr createGpuSemaphoreId() { +#if defined(__HIP_PLATFORM_AMD__) + return detail::gpuCallocUncachedUnique(); +#else // !defined(__HIP_PLATFORM_AMD__) + return detail::gpuCallocUnique(); +#endif // !defined(__HIP_PLATFORM_AMD__) +} + MSCCLPP_API_CPP Host2DeviceSemaphore::Host2DeviceSemaphore(Communicator& communicator, std::shared_ptr connection) - : BaseSemaphore(allocExtUniqueCuda(), allocExtUniqueCuda(), std::make_unique()), + : BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), std::make_unique()), connection_(connection) { INFO(MSCCLPP_INIT, "Creating a Host2Device semaphore for %s transport from %d to %d", connection->getTransportName().c_str(), communicator.bootstrap()->getRank(), @@ -85,7 +93,7 @@ MSCCLPP_API_CPP void Host2HostSemaphore::wait(int64_t maxSpinCount) { MSCCLPP_API_CPP SmDevice2DeviceSemaphore::SmDevice2DeviceSemaphore(Communicator& communicator, std::shared_ptr connection) - : BaseSemaphore(allocExtUniqueCuda(), allocExtUniqueCuda(), allocExtUniqueCuda()) { + : BaseSemaphore(createGpuSemaphoreId(), createGpuSemaphoreId(), createGpuSemaphoreId()) { INFO(MSCCLPP_INIT, "Creating a Device2Device semaphore for %s transport from %d to %d", connection->getTransportName().c_str(), communicator.bootstrap()->getRank(), communicator.remoteRankOf(*connection)); diff --git a/src/utils.cc b/src/utils.cc index fb470a4ab..7153d55c5 100644 --- a/src/utils.cc +++ b/src/utils.cc @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include @@ -67,22 +66,4 @@ std::string getHostName(int maxlen, const char delim) { return hostname.substr(0, i); } -bool isNvlsSupported() { - [[maybe_unused]] static bool result = false; - [[maybe_unused]] static bool isChecked = false; -#if (CUDA_NVLS_SUPPORTED) - if (!isChecked) { - int isMulticastSupported; - int isFabricSupported; - CUdevice dev; - MSCCLPP_CUTHROW(cuCtxGetDevice(&dev)); - MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isMulticastSupported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, dev)); - MSCCLPP_CUTHROW(cuDeviceGetAttribute(&isFabricSupported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, dev)); - result = (isMulticastSupported == 1 && isFabricSupported == 1); - } - return result; -#endif - return false; -} - } // namespace mscclpp diff --git a/test/executor_test.cc b/test/executor_test.cc index 68e8bfa32..ff5ad9e36 100644 --- a/test/executor_test.cc +++ b/test/executor_test.cc @@ -129,12 +129,7 @@ int main(int argc, char* argv[]) { } mscclpp::ExecutionPlan plan(executionPlanPath); - std::shared_ptr sendbuff; - if (mscclpp::isNvlsSupported()) { - sendbuff = mscclpp::allocSharedPhysicalCuda(bufferSize); - } else { - sendbuff = mscclpp::allocExtSharedCuda(bufferSize); - } + std::shared_ptr sendbuff = mscclpp::GpuBuffer(bufferSize).memory(); std::vector dataHost(bufferSize / sizeof(int), rank); MSCCLPP_CUDATHROW(cudaMemcpy(sendbuff.get(), dataHost.data(), bufferSize, cudaMemcpyHostToDevice)); double deltaSec = benchTime(rank, bootstrap, executor, plan, sendbuff, bufferSize, niters, ngraphIters, packetType); diff --git a/test/mp_unit/communicator_tests.cu b/test/mp_unit/communicator_tests.cu index adb6b5df6..a73ca732f 100644 --- a/test/mp_unit/communicator_tests.cu +++ b/test/mp_unit/communicator_tests.cu @@ -113,7 +113,7 @@ void CommunicatorTest::SetUp() { } for (size_t n = 0; n < numBuffers; n++) { - devicePtr[n] = mscclpp::allocSharedCuda(deviceBufferSize / sizeof(int)); + devicePtr[n] = mscclpp::detail::gpuCallocShared(deviceBufferSize / sizeof(int)); registerMemoryPairs(devicePtr[n].get(), deviceBufferSize, mscclpp::Transport::CudaIpc | ibTransport, 0, remoteRanks, localMemory[n], remoteMemory[n]); } @@ -133,7 +133,7 @@ void CommunicatorTest::deviceBufferInit() { for (size_t i = 0; i < dataCount; i++) { hostBuffer[i] = gEnv->rank + n * gEnv->worldSize; } - mscclpp::memcpyCuda(devicePtr[n].get(), hostBuffer.data(), dataCount, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(devicePtr[n].get(), hostBuffer.data(), dataCount, cudaMemcpyHostToDevice); } } @@ -155,7 +155,7 @@ bool CommunicatorTest::testWriteCorrectness(bool skipLocal) { size_t dataCount = deviceBufferSize / sizeof(int); for (int n = 0; n < (int)devicePtr.size(); n++) { std::vector hostBuffer(dataCount, 0); - mscclpp::memcpyCuda(hostBuffer.data(), devicePtr[n].get(), dataCount, cudaMemcpyDeviceToHost); + mscclpp::gpuMemcpy(hostBuffer.data(), devicePtr[n].get(), dataCount, cudaMemcpyDeviceToHost); for (int i = 0; i < gEnv->worldSize; i++) { if (((i / gEnv->nRanksPerNode) == (gEnv->rank / gEnv->nRanksPerNode)) && skipLocal) { continue; @@ -214,12 +214,13 @@ TEST_F(CommunicatorTest, WriteWithDeviceSemaphores) { deviceBufferInit(); communicator->bootstrap()->barrier(); - auto deviceSemaphoreHandles = mscclpp::allocSharedCuda(gEnv->worldSize); + auto deviceSemaphoreHandles = + mscclpp::detail::gpuCallocShared(gEnv->worldSize); for (int i = 0; i < gEnv->worldSize; i++) { if (i != gEnv->rank) { mscclpp::Host2DeviceSemaphore::DeviceHandle deviceHandle = semaphores[i]->deviceHandle(); - mscclpp::memcpyCuda(deviceSemaphoreHandles.get() + i, &deviceHandle, - 1, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(deviceSemaphoreHandles.get() + i, &deviceHandle, + 1, cudaMemcpyHostToDevice); } } communicator->bootstrap()->barrier(); diff --git a/test/mp_unit/executor_tests.cc b/test/mp_unit/executor_tests.cc index 116470dd1..d5f87e6d7 100644 --- a/test/mp_unit/executor_tests.cc +++ b/test/mp_unit/executor_tests.cc @@ -57,7 +57,7 @@ TEST_F(ExecutorTest, TwoNodesAllreduce) { path.parent_path().parent_path().parent_path() / "test/execution-files/allreduce.json"; mscclpp::ExecutionPlan plan(executionFilesPath.string()); const int bufferSize = 1024 * 1024; - std::shared_ptr sendbuff = mscclpp::allocExtSharedCuda(bufferSize); + std::shared_ptr sendbuff = mscclpp::GpuBuffer(bufferSize).memory(); mscclpp::CudaStreamWithFlags stream(cudaStreamNonBlocking); executor->execute(gEnv->rank, sendbuff.get(), sendbuff.get(), bufferSize, bufferSize, mscclpp::DataType::FLOAT16, plan, stream); diff --git a/test/mp_unit/ib_tests.cu b/test/mp_unit/ib_tests.cu index 92ee287e2..9bb073520 100644 --- a/test/mp_unit/ib_tests.cu +++ b/test/mp_unit/ib_tests.cu @@ -82,7 +82,7 @@ TEST_F(IbPeerToPeerTest, SimpleSendRecv) { const int maxIter = 100000; const int nelem = 1; - auto data = mscclpp::allocUniqueCuda(nelem); + auto data = mscclpp::detail::gpuCallocUnique(nelem); registerBufferAndConnect(data.get(), sizeof(int) * nelem); @@ -196,7 +196,7 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { const uint64_t signalPeriod = 1024; const uint64_t maxIter = 10000; const uint64_t nelem = 65536 + 1; - auto data = mscclpp::allocUniqueCuda(nelem); + auto data = mscclpp::detail::gpuCallocUnique(nelem); registerBufferAndConnect(data.get(), sizeof(uint64_t) * nelem); @@ -205,12 +205,14 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { if (gEnv->rank == 0) { // Receiver - auto curIter = mscclpp::makeUniqueCudaHost(0); - auto result = mscclpp::makeUniqueCudaHost(0); + auto curIter = mscclpp::detail::gpuCallocHostUnique(); + auto result = mscclpp::detail::gpuCallocHostUnique(); volatile uint64_t* ptrCurIter = (volatile uint64_t*)curIter.get(); volatile int* ptrResult = (volatile int*)result.get(); + ASSERT_NE(ptrCurIter, nullptr); + ASSERT_NE(ptrResult, nullptr); ASSERT_EQ(*ptrCurIter, 0); ASSERT_EQ(*ptrResult, 0); @@ -246,7 +248,7 @@ TEST_F(IbPeerToPeerTest, MemoryConsistency) { for (uint64_t i = 0; i < nelem; i++) { hostBuffer[i] = iter; } - mscclpp::memcpyCuda(data.get(), hostBuffer.data(), nelem, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(data.get(), hostBuffer.data(), nelem, cudaMemcpyHostToDevice); // Need to signal from time to time to empty the IB send queue bool signaled = (iter % signalPeriod == 0); @@ -303,7 +305,7 @@ TEST_F(IbPeerToPeerTest, SimpleAtomicAdd) { const int maxIter = 100000; const int nelem = 1; - auto data = mscclpp::allocUniqueCuda(nelem); + auto data = mscclpp::detail::gpuCallocUnique(nelem); registerBufferAndConnect(data.get(), sizeof(int) * nelem); diff --git a/test/mp_unit/proxy_channel_tests.cu b/test/mp_unit/proxy_channel_tests.cu index 79ca9b656..192985b47 100644 --- a/test/mp_unit/proxy_channel_tests.cu +++ b/test/mp_unit/proxy_channel_tests.cu @@ -157,7 +157,7 @@ void ProxyChannelOneToOneTest::testPingPong(PingPongTestParams params) { const int nElem = 4 * 1024 * 1024; std::vector proxyChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int)); std::vector> proxyChannelHandles; @@ -169,7 +169,7 @@ void ProxyChannelOneToOneTest::testPingPong(PingPongTestParams params) { proxyService->startProxy(); - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + std::shared_ptr ret = mscclpp::detail::gpuCallocHostShared(); const int nTries = 1000; @@ -202,7 +202,7 @@ void ProxyChannelOneToOneTest::testPingPongPerf(PingPongTestParams params) { const int nElem = 4 * 1024 * 1024; std::vector proxyChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); setupMeshConnections(proxyChannels, params.useIPC, params.useIB, params.useEthernet, buff.get(), nElem * sizeof(int)); std::vector> proxyChannelHandles; @@ -214,7 +214,7 @@ void ProxyChannelOneToOneTest::testPingPongPerf(PingPongTestParams params) { proxyService->startProxy(); - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + std::shared_ptr ret = mscclpp::detail::gpuCallocHostShared(); auto* testInfo = ::testing::UnitTest::GetInstance()->current_test_info(); const std::string testName = std::string(testInfo->test_suite_name()) + "." + std::string(testInfo->name()); @@ -344,11 +344,11 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) { const int nElem = 4 * 1024 * 1024; std::vector proxyChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t); - auto putPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); - auto getPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); + auto putPacketBuffer = mscclpp::GpuBuffer(nPacket).memory(); + auto getPacketBuffer = mscclpp::GpuBuffer(nPacket).memory(); setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); @@ -368,7 +368,7 @@ void ProxyChannelOneToOneTest::testPacketPingPong(bool useIbOnly) { proxyService->startProxy(); - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + std::shared_ptr ret = mscclpp::detail::gpuCallocHostShared(); const int nTries = 1000; @@ -411,11 +411,11 @@ void ProxyChannelOneToOneTest::testPacketPingPongPerf(bool useIbOnly) { const int nElem = 4 * 1024 * 1024; std::vector proxyChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); const size_t nPacket = (nElem * sizeof(int) + sizeof(uint64_t) - 1) / sizeof(uint64_t); - auto putPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); - auto getPacketBuffer = mscclpp::allocExtSharedCuda(nPacket); + auto putPacketBuffer = mscclpp::GpuBuffer(nPacket).memory(); + auto getPacketBuffer = mscclpp::GpuBuffer(nPacket).memory(); setupMeshConnections(proxyChannels, !useIbOnly, true, false, putPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket), getPacketBuffer.get(), nPacket * sizeof(mscclpp::LLPacket)); diff --git a/test/mp_unit/sm_channel_tests.cu b/test/mp_unit/sm_channel_tests.cu index 45c5fa644..af4aa2985 100644 --- a/test/mp_unit/sm_channel_tests.cu +++ b/test/mp_unit/sm_channel_tests.cu @@ -77,8 +77,8 @@ void SmChannelOneToOneTest::packetPingPongTest(const std::string testName, Packe const int defaultNTries = 1000; std::vector smChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); - std::shared_ptr intermBuff = mscclpp::allocExtSharedCuda(nElem * 2); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); + std::shared_ptr intermBuff = mscclpp::GpuBuffer(nElem * 2).memory(); setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int), intermBuff.get(), nElem * 2 * sizeof(int)); std::vector> deviceHandles(smChannels.size()); std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(), @@ -88,7 +88,7 @@ void SmChannelOneToOneTest::packetPingPongTest(const std::string testName, Packe MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(), sizeof(DeviceHandle))); - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + std::shared_ptr ret = mscclpp::detail::gpuCallocHostShared(); // The least nelem is 2 for packet ping pong kernelWrapper(buff.get(), gEnv->rank, 2, ret.get(), defaultNTries); @@ -178,7 +178,7 @@ TEST_F(SmChannelOneToOneTest, PutPingPong) { const int nElem = 4 * 1024 * 1024; std::vector smChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int)); std::vector> deviceHandles(smChannels.size()); std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(), @@ -188,7 +188,7 @@ TEST_F(SmChannelOneToOneTest, PutPingPong) { MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(), sizeof(DeviceHandle))); - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + std::shared_ptr ret = mscclpp::detail::gpuCallocHostShared(); kernelSmPutPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); @@ -257,7 +257,7 @@ TEST_F(SmChannelOneToOneTest, GetPingPong) { const int nElem = 4 * 1024 * 1024; std::vector smChannels; - std::shared_ptr buff = mscclpp::allocExtSharedCuda(nElem); + std::shared_ptr buff = mscclpp::GpuBuffer(nElem).memory(); setupMeshConnections(smChannels, buff.get(), nElem * sizeof(int)); std::vector> deviceHandles(smChannels.size()); std::transform(smChannels.begin(), smChannels.end(), deviceHandles.begin(), @@ -267,7 +267,7 @@ TEST_F(SmChannelOneToOneTest, GetPingPong) { MSCCLPP_CUDATHROW(cudaMemcpyToSymbol(gChannelOneToOneTestConstSmChans, deviceHandles.data(), sizeof(DeviceHandle))); - std::shared_ptr ret = mscclpp::makeSharedCudaHost(0); + std::shared_ptr ret = mscclpp::detail::gpuCallocHostShared(); kernelSmGetPingPong<<<1, 1024>>>(buff.get(), gEnv->rank, 1, ret.get()); MSCCLPP_CUDATHROW(cudaDeviceSynchronize()); diff --git a/test/mscclpp-test/allgather_test.cu b/test/mscclpp-test/allgather_test.cu index 3fd65e3d2..27506f340 100644 --- a/test/mscclpp-test/allgather_test.cu +++ b/test/mscclpp-test/allgather_test.cu @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include #include @@ -711,13 +712,13 @@ class AllGatherTestEngine : public BaseTestEngine { AllGatherTestEngine::AllGatherTestEngine(const TestArgs& args) : BaseTestEngine(args, "allgather") {} void AllGatherTestEngine::allocateBuffer() { - sendBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); + sendBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); expectedBuff_ = std::shared_ptr(new int[args_.maxBytes / sizeof(int)]); if (args_.kernelNum == 7) { const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); // 2x for double-buffering, scratchBuff used to store original data and reduced results const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/; - scratchPacketBuff_ = mscclpp::allocExtSharedCuda(scratchBuffNelem); + scratchPacketBuff_ = mscclpp::GpuBuffer(scratchBuffNelem).memory(); } } diff --git a/test/mscclpp-test/allreduce_test.cu b/test/mscclpp-test/allreduce_test.cu index 6ba6ce3db..b7632a83d 100644 --- a/test/mscclpp-test/allreduce_test.cu +++ b/test/mscclpp-test/allreduce_test.cu @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include #include #include @@ -1272,30 +1273,30 @@ bool AllReduceTestEngine::isInPlace() const { } void AllReduceTestEngine::allocateBuffer() { - inputBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); - resultBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); + inputBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); + resultBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); inputBuff = inputBuff_.get(); resultBuff = resultBuff_.get(); if (args_.kernelNum == 0 || args_.kernelNum == 1 || args_.kernelNum == 3 || args_.kernelNum == 4) { - scratchBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); + scratchBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); scratchBuff = scratchBuff_.get(); } else if (args_.kernelNum == 2) { const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); // 2x for double-buffering const size_t scratchBuffNelem = nPacket * std::max(args_.nRanksPerNode - 1, 1) * 2; - scratchPacketBuff_ = mscclpp::allocExtSharedCuda(scratchBuffNelem); + scratchPacketBuff_ = mscclpp::GpuBuffer(scratchBuffNelem).memory(); scratchPacketBuff = scratchPacketBuff_.get(); const size_t packetBuffNelem = nPacket * 2; - putPacketBuff_ = mscclpp::allocExtSharedCuda(packetBuffNelem); - getPacketBuff_ = mscclpp::allocExtSharedCuda(packetBuffNelem); + putPacketBuff_ = mscclpp::GpuBuffer(packetBuffNelem).memory(); + getPacketBuff_ = mscclpp::GpuBuffer(packetBuffNelem).memory(); putPacketBuff = putPacketBuff_.get(); getPacketBuff = getPacketBuff_.get(); } else if (args_.kernelNum == 6 || args_.kernelNum == 7) { const size_t nPacket = (args_.maxBytes + sizeof(uint64_t) - 1) / sizeof(uint64_t); // 2x for double-buffering, scratchBuff used to store original data and reduced results const size_t scratchBuffNelem = nPacket * 2 /*original data & reduced result */ * 2 /* double buffering*/; - scratchPacketBuff_ = mscclpp::allocExtSharedCuda(scratchBuffNelem); + scratchPacketBuff_ = mscclpp::GpuBuffer(scratchBuffNelem).memory(); scratchPacketBuff = scratchPacketBuff_.get(); } diff --git a/test/mscclpp-test/alltoall_test.cu b/test/mscclpp-test/alltoall_test.cu index d3c8d891a..6d39e9f5f 100644 --- a/test/mscclpp-test/alltoall_test.cu +++ b/test/mscclpp-test/alltoall_test.cu @@ -2,6 +2,7 @@ // Licensed under the MIT license. #include +#include #include #include "common.hpp" @@ -139,8 +140,8 @@ class AllToAllTestEngine : public BaseTestEngine { AllToAllTestEngine::AllToAllTestEngine(const TestArgs& args) : BaseTestEngine(args, "alltoall") { inPlace_ = false; } void AllToAllTestEngine::allocateBuffer() { - sendBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); - recvBuff_ = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); + sendBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); + recvBuff_ = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); expectedBuff_ = std::shared_ptr(new int[args_.maxBytes / sizeof(int)]); localSendBuff = sendBuff_.get(); diff --git a/test/mscclpp-test/sendrecv_test.cu b/test/mscclpp-test/sendrecv_test.cu index b0f830a1a..0bd13e02c 100644 --- a/test/mscclpp-test/sendrecv_test.cu +++ b/test/mscclpp-test/sendrecv_test.cu @@ -137,8 +137,8 @@ class SendRecvTestEngine : public BaseTestEngine { SendRecvTestEngine::SendRecvTestEngine(const TestArgs& args) : BaseTestEngine(args, "sendrecv") { inPlace_ = false; } void SendRecvTestEngine::allocateBuffer() { - std::shared_ptr sendBuff = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); - std::shared_ptr recvBuff = mscclpp::allocExtSharedCuda(args_.maxBytes / sizeof(int)); + std::shared_ptr sendBuff = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); + std::shared_ptr recvBuff = mscclpp::GpuBuffer(args_.maxBytes / sizeof(int)).memory(); devicePtrs_.push_back(sendBuff); devicePtrs_.push_back(recvBuff); diff --git a/test/unit/cuda_utils_tests.cc b/test/unit/cuda_utils_tests.cc index c2f565967..c74f9d9cf 100644 --- a/test/unit/cuda_utils_tests.cc +++ b/test/unit/cuda_utils_tests.cc @@ -6,23 +6,23 @@ #include TEST(CudaUtilsTest, AllocShared) { - auto p1 = mscclpp::allocSharedCuda(); - auto p2 = mscclpp::allocSharedCuda(5); + auto p1 = mscclpp::detail::gpuCallocShared(); + auto p2 = mscclpp::detail::gpuCallocShared(5); } TEST(CudaUtilsTest, AllocUnique) { - auto p1 = mscclpp::allocUniqueCuda(); - auto p2 = mscclpp::allocUniqueCuda(5); + auto p1 = mscclpp::detail::gpuCallocUnique(); + auto p2 = mscclpp::detail::gpuCallocUnique(5); } TEST(CudaUtilsTest, MakeSharedHost) { - auto p1 = mscclpp::makeSharedCudaHost(); - auto p2 = mscclpp::makeSharedCudaHost(5); + auto p1 = mscclpp::detail::gpuCallocHostShared(); + auto p2 = mscclpp::detail::gpuCallocHostShared(5); } TEST(CudaUtilsTest, MakeUniqueHost) { - auto p1 = mscclpp::makeUniqueCudaHost(); - auto p2 = mscclpp::makeUniqueCudaHost(5); + auto p1 = mscclpp::detail::gpuCallocHostUnique(); + auto p2 = mscclpp::detail::gpuCallocHostUnique(5); } TEST(CudaUtilsTest, Memcpy) { @@ -32,9 +32,9 @@ TEST(CudaUtilsTest, Memcpy) { hostBuff[i] = i + 1; } std::vector hostBuffTmp(nElem, 0); - auto devBuff = mscclpp::allocSharedCuda(nElem); - mscclpp::memcpyCuda(devBuff.get(), hostBuff.data(), nElem, cudaMemcpyHostToDevice); - mscclpp::memcpyCuda(hostBuffTmp.data(), devBuff.get(), nElem, cudaMemcpyDeviceToHost); + auto devBuff = mscclpp::detail::gpuCallocShared(nElem); + mscclpp::gpuMemcpy(devBuff.get(), hostBuff.data(), nElem, cudaMemcpyHostToDevice); + mscclpp::gpuMemcpy(hostBuffTmp.data(), devBuff.get(), nElem, cudaMemcpyDeviceToHost); for (int i = 0; i < nElem; ++i) { EXPECT_EQ(hostBuff[i], hostBuffTmp[i]);