Skip to content

Commit

Permalink
clean up a bit
Browse files Browse the repository at this point in the history
Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu committed Jan 8, 2025
1 parent 7e99d50 commit f36ec96
Show file tree
Hide file tree
Showing 6 changed files with 7 additions and 67 deletions.
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ Return Value:
this->ConvSymU8S8Dispatch = &MlasConvSymDispatchAvx512Core;
this->FpQ4GemmDispatch = &MlasFpQ4GemmDispatchAvx512;
this->QNBitGemmDispatch = &MlasSQNBitGemmDispatchAvx512;

//
// Check if the processor supports AVX512VNNI.
//
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ MlasIsQNBitGemmAvailable(
}

const auto Variant = GetQNBitGemmVariant(BlkBitWidth, BlkLen, ComputeType);

switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
Expand Down
38 changes: 0 additions & 38 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -555,11 +555,6 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2(
const float* QuantBBlkSum
)
{
//if (BlkLen >= 32 && CountM == 1) {
// SQ4BitGemmM1Kernel_CompInt8_avx2<false>(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias);
// return CountM;
//}

SQ4BitGemmKernel_CompInt8_avx2<false>(
BlkLen,
QuantA,
Expand All @@ -576,20 +571,6 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2(
ABlockSum,
QuantBBlkSum
);
//float* c_blk = C;
//const float* b_blk_sum = QuantBBlkSum;

//size_t RowsRemaining = CountM;
//const float* a_blksum_row = ABlockSum;
//while (RowsRemaining > 0) {
// auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
// a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false
// );

// c_blk += ldc * RowsHandled;
// a_blksum_row += BlockCountK * RowsHandled;
// RowsRemaining -= RowsHandled;
//}
return CountM;
}

Expand All @@ -611,11 +592,6 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni(
const float* QuantBBlkSum
)
{
//if (BlkLen >= 32 && CountM == 1) {
// SQ4BitGemmM1Kernel_CompInt8_avx2<true>(BlkLen, QuantA, QuantAScale, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, CountK, BlockCountK, Bias);
// return CountM;
//}

SQ4BitGemmKernel_CompInt8_avx2<true>(
BlkLen,
QuantA,
Expand All @@ -632,20 +608,6 @@ SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni(
ABlockSum,
QuantBBlkSum
);
//float* c_blk = C;
//const float* b_blk_sum = QuantBBlkSum;

//size_t RowsRemaining = CountM;
//const float* a_blksum_row = ABlockSum;
//while (RowsRemaining > 0) {
// auto RowsHandled = GetMlasPlatform().GemmFloatKernel(
// a_blksum_row, b_blk_sum, c_blk, BlockCountK, RowsRemaining, CountN, BlockCountK, ldc, 1.f, false
// );

// c_blk += ldc * RowsHandled;
// a_blksum_row += BlockCountK * RowsHandled;
// RowsRemaining -= RowsHandled;
//}
return CountM;
}

Expand Down
11 changes: 3 additions & 8 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,15 +327,10 @@ ComputePackBlkSum(
*(BlockSumBegin + dst_offset) = -QuantBScale * zp;
return;
}
}
}

if (!is_avx512 || (is_avx512 && BlkLen != 32)) {
const size_t dst_offset = n * BlockCountK + k_blk;
*(BlockSumBegin + dst_offset) = -QuantBScale * zp;
} else {
const size_t dst_offset = ((n / 16) * BlockCountK + k_blk) * 16 + n % 16;
*(BlockSumBegin + dst_offset) = -QuantBScale * zp;
}
const size_t dst_offset = n * BlockCountK + k_blk;
*(BlockSumBegin + dst_offset) = -QuantBScale * zp;

if (BlkLen == 16) {
} else if (BlkLen >= SubBlkLen) {
Expand Down
7 changes: 2 additions & 5 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,6 @@ TEST(MatMulNBits, Float32_Accuracy4) {
TestMatMulNBitsTyped<float, 100, 288, 93, 32, 4>();
TestMatMulNBitsTyped<float, 100, 288, 93, 128, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1234, 16, 4>();
TestMatMulNBitsTyped<float, 2, 4, 128, 32, 4>();
TestMatMulNBitsTyped<float, 1, 1, 128, 32, 4>();
TestMatMulNBitsTyped<float, 100, 288, 1234, 32, 4>();
}

#if defined(MLAS_TARGET_AMD64_IX86) || defined(MLAS_TARGET_ARM64)
Expand Down Expand Up @@ -547,7 +544,7 @@ void LongTestMatMulNBitsTyped() {
}
}

TEST(MatMulNBits, LongTestFloat32) {
TEST(MatMulNBits, DISABLED_LongTestFloat32) {
// onnxruntime::profiling::Profiler::Profiler::Instance().StartProfiling<char>("profile.json");
LongTestMatMulNBitsTyped<float>();
}
Expand All @@ -556,7 +553,7 @@ TEST(MatMulNBits, LongTestFloat32) {
#if !defined(USE_DML)
// Actual and expected difference is over 0.01 with DmlExecutionProvider.
// Skip the tests instead of raising the tolerance to make is pass.
TEST(MatMulNBits, LongTestFloat16) {
TEST(MatMulNBits, DISABLED_LongTestFloat16) {
LongTestMatMulNBitsTyped<MLFloat16>();
}
#endif
Expand Down
16 changes: 0 additions & 16 deletions onnxruntime/test/mlas/unittest/test_sqnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,22 +412,6 @@ class SQNBitGemmShortExecuteTest : public MlasTestFixture<MlasSQNBitGemmTest<Blk
tests_registered += RegisterSingleTest(1, 527, 2131, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(11, 527, 2131, ComputeType, WithThreadpool, Symmetric, true);
// tests_registered += RegisterSingleTest(1001, 1027, 1031, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(1, 1, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(1, 4, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(2, 1, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(2, 4, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(2, 4, 128, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(3, 4, 128, ComputeType, WithThreadpool, Symmetric, false);
tests_registered += RegisterSingleTest(3, 4, 128, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(1, 1, 33, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(1, 4, 33, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(1, 32, 33, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(1, 32, 128, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(8, 1, 1, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(8, 4, 1, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(8, 6, 1, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(8, 8, 1, ComputeType, WithThreadpool, Symmetric, true);
tests_registered += RegisterSingleTest(8, 16, 1, ComputeType, WithThreadpool, Symmetric, true);
}
}
}
Expand Down

0 comments on commit f36ec96

Please sign in to comment.