Skip to content

Commit

Permalink
Use "apply" scheme for gemm function. Buffer reading bug needs fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
ivarflakstad committed Oct 21, 2023
1 parent 4f4df06 commit 530e2c1
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 30 deletions.
31 changes: 26 additions & 5 deletions examples/mps/matrix-multiplication/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use metal::mps::*;
use metal::*;

fn generate_matrix<T, const ROWS: usize, const COLS: usize>() -> Matrix<T>
fn generate_matrix<T, const ROWS: u64, const COLS: u64>() -> Matrix<T>
where
T: MPSDataType,
MatMulInput<T>: Valid,
Expand All @@ -17,9 +17,9 @@ fn main() {
type A = Float32;
type B = Float32;
type C = Float32;
const M: usize = 1;
const N: usize = 1;
const K: usize = 5;
const M: u64 = 2;
const N: u64 = 2;
const K: u64 = 2;

let transpose_left = false;
let transpose_right = false;
Expand All @@ -32,6 +32,27 @@ fn main() {
println!("{left:?}");
println!("{right:?}");

let result = matrix_multiplication(transpose_left, transpose_right, &left, &right, alpha, beta);
let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();
let command_buffer = command_queue.new_command_buffer();

// Add matrix multiplication to command buffer and get result buffer
let result_buffer = apply_gemm(
&device,
command_buffer,
transpose_left,
transpose_right,
&left,
&right,
alpha,
beta,
);

// Run multiplication
command_buffer.commit();
command_buffer.wait_until_completed();

// Read result buffer
let result = result_buffer.contents();
println!("{result:?}");
}
57 changes: 32 additions & 25 deletions src/mps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1138,20 +1138,25 @@ impl<T: MPSDataType> MatrixBuffer<T> {
unsafe { std::slice::from_raw_parts(contents, (self.rows * self.columns) as usize) };
sl.to_vec()
}
pub fn read_to_vec(&self) -> Vec<T::Type> {
read_buffer_to_vec(&self.buffer, (self.rows * self.columns) as usize)
}
}

pub fn read_buffer_to_vec<T: Clone>(buffer: &BufferRef, len: usize) -> Vec<T> {
Vec::from(unsafe { std::slice::from_raw_parts(buffer.contents() as *const T, len) })
}

pub fn matrix_multiplication<A, B, C>(
pub fn apply_gemm<A, B, C>(
device: &DeviceRef,
command_buffer: &CommandBufferRef,
transpose_left: bool,
transpose_right: bool,
left: &Matrix<A>,
right: &Matrix<B>,
alpha: f64,
beta: f64,
) -> Matrix<C>
) -> MatrixBuffer<C>
where
A: MPSDataType,
B: MPSDataType,
Expand All @@ -1161,14 +1166,23 @@ where
MatMulResult<C>: Valid,
MatMulSpecification<A, B, C>: Valid,
{
validate_matrix_multiplication(transpose_left, transpose_right, left, right);

let device = Device::system_default().expect("No device found");
let command_queue = device.new_command_queue();
let M = if transpose_left {
left.columns
} else {
left.rows
};
let N = if transpose_right {
right.rows
} else {
right.columns
};
let K = if transpose_left {
left.rows
} else {
left.columns
};

let M = left.rows;
let N = right.columns;
let K = left.columns;
validate_matrix_multiplication(left, right, M, N, K);

// Create descriptors for the matrices.
let left_row_bytes = MatrixDescriptor::row_bytes_for_columns(K, A::CODE);
Expand All @@ -1185,15 +1199,17 @@ where
device.new_buffer_with_data(left.entries.as_ptr().cast(), M * left_row_bytes, options);
let right_buffer =
device.new_buffer_with_data(right.entries.as_ptr().cast(), K * right_row_bytes, options);
let result_buffer = device.new_buffer(M * result_row_bytes, options);

let result_buffer = MatrixBuffer::new(device, M, N, M * result_row_bytes, options);

// Create matrix objects
let left_matrix =
MatrixObject::init_with_buffer_descriptor(&left_buffer, &left_descriptor).unwrap();
let right_matrix =
MatrixObject::init_with_buffer_descriptor(&right_buffer, &right_descriptor).unwrap();
let result_matrix =
MatrixObject::init_with_buffer_descriptor(&result_buffer, &result_descriptor).unwrap();
MatrixObject::init_with_buffer_descriptor(&result_buffer.buffer, &result_descriptor)
.unwrap();

// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
Expand All @@ -1208,33 +1224,24 @@ where
)
.unwrap();

let command_buffer = command_queue.new_command_buffer();

// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
// Run multiplication
command_buffer.commit();
command_buffer.wait_until_completed();

// Get result from buffer
let entries = read_buffer_to_vec::<C::Type>(&result_buffer, (M * N) as usize);
Matrix {
entries,
rows: M,
columns: N,
}
// Return result buffer
result_buffer
}

fn validate_matrix_multiplication<A, B, C>(
transpose_left: bool,
transpose_right: bool,
left: &Matrix<A>,
right: &Matrix<B>,
M: NSUInteger,
N: NSUInteger,
K: NSUInteger,
) where
A: MPSDataType,
B: MPSDataType,
Expand Down

0 comments on commit 530e2c1

Please sign in to comment.