diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index 70819362a..7ea0f93f4 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -366,6 +366,47 @@ } """ +decode_i2_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} +""" + decode_i1_to_f16 = """ template __device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) @@ -1359,6 +1400,21 @@ def fast_decode_impl( ), ) +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="quantized", + ), +) + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") TensorIntrin.register( @@ -1561,6 +1617,7 @@ def get_lop3_intrin_group( "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, + "i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized, "i1_to_i8": decode_i1s_to_i8s, "i2_to_i8": decode_i2s_to_i8s, "i4_to_i8": decode_i4s_to_i8s, diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c9828847..4fd416900 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -26,3 +26,4 @@ psutil scipy tornado torch +thefuzz diff --git a/requirements.txt b/requirements.txt index 935da1857..e8257a571 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ psutil scipy tornado torch +thefuzz diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index 184dfa243..6d5b6335a 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -381,6 +381,45 @@ __device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_deco decode_i2b_to_f16(_i2u, B_local_decode, N, scale, zeros); } +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} + /* Kind 0: original Kind 1: rescale diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu index 0d0ebf7d2..7307ad1fe 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -46,6 +46,7 @@ REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_rescale, dec REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_rescale, decode_i2u_to_f16_scale_zeros_rescale) REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16_scale_zeros_rescale, decode_i1u_to_f16_scale_zeros_rescale) REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_quantized, decode_i4u_to_f16_scale_zeros_quantized) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_quantized, decode_i2u_to_f16_scale_zeros_quantized) TEST(DecodeTest, DecodeInt4ToFloat16) { @@ -1076,4 +1077,59 @@ TEST(DecodeTest, DecodeUInt4ToFloat16WithScalingWithZerosQuantized) free(ins); free(interleaved); free(decoded); +} + +TEST(DecodeTest, DecodeUInt2toFloat16WithScalingWithZerosQuantized) +{ + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + + // create four int8_t values + int8_t in_data[N] = { + 0}; + half scale[1] = {__float2half(1.2)}; + uint qzeros[1] = {(1 << (nbits - 1)) - 1}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)); + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + uint *qzeros_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&qzeros_gpu, 1 * sizeof(uint))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(qzeros_gpu, qzeros, 1 * sizeof(uint), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + kernelWrapper_i2u_to_f16_scale_zeros_quantized<<>>(ins_gpu, decoded_gpu, scale_gpu, qzeros_gpu); + kernelWrapper_i2u_to_f16_scale_zeros_quantized<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu, qzeros_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(((int)in_data[i] - (int)qzeros[0]) * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); } \ No newline at end of file