diff --git a/examples/mps/matrix-multiplication/main.rs b/examples/mps/matrix-multiplication/main.rs index ed54c46..4e8d0c5 100644 --- a/examples/mps/matrix-multiplication/main.rs +++ b/examples/mps/matrix-multiplication/main.rs @@ -1,7 +1,7 @@ use metal::mps::*; use metal::*; -fn generate_matrix() -> Matrix +fn generate_matrix() -> Matrix where T: MPSDataType, MatMulInput: Valid, @@ -17,9 +17,9 @@ fn main() { type A = Float32; type B = Float32; type C = Float32; - const M: usize = 1; - const N: usize = 1; - const K: usize = 5; + const M: u64 = 2; + const N: u64 = 2; + const K: u64 = 2; let transpose_left = false; let transpose_right = false; @@ -32,6 +32,27 @@ fn main() { println!("{left:?}"); println!("{right:?}"); - let result = matrix_multiplication(transpose_left, transpose_right, &left, &right, alpha, beta); + let device = Device::system_default().expect("No device found"); + let command_queue = device.new_command_queue(); + let command_buffer = command_queue.new_command_buffer(); + + // Add matrix multiplication to command buffer and get result buffer + let result_buffer = apply_gemm( + &device, + command_buffer, + transpose_left, + transpose_right, + &left, + &right, + alpha, + beta, + ); + + // Run multiplication + command_buffer.commit(); + command_buffer.wait_until_completed(); + + // Read result buffer + let result = result_buffer.contents(); println!("{result:?}"); } diff --git a/src/mps.rs b/src/mps.rs index c0c8e00..b1f70f1 100644 --- a/src/mps.rs +++ b/src/mps.rs @@ -1138,20 +1138,25 @@ impl MatrixBuffer { unsafe { std::slice::from_raw_parts(contents, (self.rows * self.columns) as usize) }; sl.to_vec() } + pub fn read_to_vec(&self) -> Vec { + read_buffer_to_vec(&self.buffer, (self.rows * self.columns) as usize) + } } pub fn read_buffer_to_vec(buffer: &BufferRef, len: usize) -> Vec { Vec::from(unsafe { std::slice::from_raw_parts(buffer.contents() as *const T, len) }) } -pub fn matrix_multiplication( +pub fn apply_gemm( + device: &DeviceRef, + command_buffer: &CommandBufferRef, transpose_left: bool, transpose_right: bool, left: &Matrix, right: &Matrix, alpha: f64, beta: f64, -) -> Matrix +) -> MatrixBuffer where A: MPSDataType, B: MPSDataType, @@ -1161,14 +1166,23 @@ where MatMulResult: Valid, MatMulSpecification: Valid, { - validate_matrix_multiplication(transpose_left, transpose_right, left, right); - - let device = Device::system_default().expect("No device found"); - let command_queue = device.new_command_queue(); + let M = if transpose_left { + left.columns + } else { + left.rows + }; + let N = if transpose_right { + right.rows + } else { + right.columns + }; + let K = if transpose_left { + left.rows + } else { + left.columns + }; - let M = left.rows; - let N = right.columns; - let K = left.columns; + validate_matrix_multiplication(left, right, M, N, K); // Create descriptors for the matrices. let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::CODE); @@ -1185,7 +1199,8 @@ where device.new_buffer_with_data(left.entries.as_ptr().cast(), M * left_row_bytes, options); let right_buffer = device.new_buffer_with_data(right.entries.as_ptr().cast(), K * right_row_bytes, options); - let result_buffer = device.new_buffer(M * result_row_bytes, options); + + let result_buffer = MatrixBuffer::new(device, M, N, M * result_row_bytes, options); // Create matrix objects let left_matrix = @@ -1193,7 +1208,8 @@ where let right_matrix = MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap(); let result_matrix = - MatrixObject::init_with_buffer_descriptor(&result_buffer, &result_descriptor).unwrap(); + MatrixObject::init_with_buffer_descriptor(&result_buffer.buffer, &result_descriptor) + .unwrap(); // Create kernel let matrix_multiplication = MatrixMultiplication::init( @@ -1208,8 +1224,6 @@ where ) .unwrap(); - let command_buffer = command_queue.new_command_buffer(); - // Encode kernel to command buffer matrix_multiplication.encode_to_command_buffer( &command_buffer, @@ -1217,24 +1231,17 @@ where &right_matrix, &result_matrix, ); - // Run multiplication - command_buffer.commit(); - command_buffer.wait_until_completed(); - // Get result from buffer - let entries = read_buffer_to_vec::(&result_buffer, (M * N) as usize); - Matrix { - entries, - rows: M, - columns: N, - } + // Return result buffer + result_buffer } fn validate_matrix_multiplication( - transpose_left: bool, - transpose_right: bool, left: &Matrix, right: &Matrix, + M: NSUInteger, + N: NSUInteger, + K: NSUInteger, ) where A: MPSDataType, B: MPSDataType,