From 7e330f53caae24d88a5e3a5a605b52f97342b550 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Thu, 25 Apr 2024 09:07:31 -0400 Subject: [PATCH 1/5] Add Str Parse library to requirements.txt and requirements-dev.txt --- requirements-dev.txt | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements-dev.txt b/requirements-dev.txt index 2c9828847..4fd416900 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -26,3 +26,4 @@ psutil scipy tornado torch +thefuzz diff --git a/requirements.txt b/requirements.txt index 935da1857..e8257a571 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,3 +19,4 @@ psutil scipy tornado torch +thefuzz From 702c49ac538ba468da7c3f02047c756a6ae06b71 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Thu, 25 Apr 2024 09:16:46 -0400 Subject: [PATCH 2/5] Support quantized zero types for uint2. --- python/bitblas/gpu/intrin/lop3.py | 57 +++++++++++++++++++ .../lop3_type_conversion/fast_decoding.hpp | 39 +++++++++++++ .../lowprecision_to_float16.cu | 56 ++++++++++++++++++ 3 files changed, 152 insertions(+) diff --git a/python/bitblas/gpu/intrin/lop3.py b/python/bitblas/gpu/intrin/lop3.py index 70819362a..7ea0f93f4 100644 --- a/python/bitblas/gpu/intrin/lop3.py +++ b/python/bitblas/gpu/intrin/lop3.py @@ -366,6 +366,47 @@ } """ +decode_i2_to_f16_scale_zeros_quantized = """ +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} +""" + decode_i1_to_f16 = """ template __device__ void decode_i1u_to_f16(T1 *_i1s, T2 *B_local_decode, const int N = 8) @@ -1359,6 +1400,21 @@ def fast_decode_impl( ), ) +LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN = ( + "lop3_fast_decode_u2_to_int8_to_f16_l8_scale_zeros_quantized_") +TensorIntrin.register( + LOP3_FAST_DECODE_UINT2_TO_INT8_TO_FP16_L8_SCALE_ZEROS_QUANTIZED_INTRIN, + *get_fast_decode_intrin( + source_bit=2, + storage_dtype="int8", + target_dtype="float16", + loops_extent=8, + with_scale=True, + with_zeros=True, + zeros_mode="quantized", + ), +) + LOP3_FAST_DECODE_UINT1_TO_INT8_TO_FP16_L8_SCALE_INTRIN = ( "lop3_fast_decode_u1_to_int8_to_f16_l8_scale_") TensorIntrin.register( @@ -1561,6 +1617,7 @@ def get_lop3_intrin_group( "i2_to_f16_scale_zeros_rescale": decode_i2_to_f16_scale_zeros_rescale, "i1_to_f16_scale_zeros_rescale": decode_i1_to_f16_scale_zeros_rescale, "i4_to_f16_scale_zeros_quantized": decode_i4_to_f16_scale_zeros_quantized, + "i2_to_f16_scale_zeros_quantized": decode_i2_to_f16_scale_zeros_quantized, "i1_to_i8": decode_i1s_to_i8s, "i2_to_i8": decode_i2s_to_i8s, "i4_to_i8": decode_i4s_to_i8s, diff --git a/testing/cpp/lop3_type_conversion/fast_decoding.hpp b/testing/cpp/lop3_type_conversion/fast_decoding.hpp index 184dfa243..6d5b6335a 100644 --- a/testing/cpp/lop3_type_conversion/fast_decoding.hpp +++ b/testing/cpp/lop3_type_conversion/fast_decoding.hpp @@ -381,6 +381,45 @@ __device__ void decode_i2u_to_f16_scale_zeros_rescale(T1 *_i2u, T2 *B_local_deco decode_i2b_to_f16(_i2u, B_local_decode, N, scale, zeros); } +template +__device__ void decode_i2b_to_f16_scale_zeros_quantized(T1 *_i2s, T2 *B_local_decode, const int N = 8, T3 *scale = nullptr, T4 *zeros = nullptr) +{ + uint *h = reinterpret_cast(B_local_decode); + + static constexpr uint immLut = (0xf0 & 0xcc) | 0xaa; + static constexpr uint BOTTOM_MASK = 0x00030003; + static constexpr uint FP16_TOP_MAGIC_NUM = 0x64006400; + static constexpr uint MEDIAN_NUM = isSigned ? 0x64016401 : 0x64006400; + int16_t const i2s_i16 = *reinterpret_cast(_i2s); + T3 const scale_r = *scale; + uint const packed_scales = __pack_half2(scale_r, scale_r); + T4 const zero_r = *zeros; + uint median_num = ((0xe400 | zero_r) << 16) | (0xe400 | zero_r); + + // decode 2 elems at one time. + // interleave {e15,e13,e11,e9,e7,e5,e3,e1,e14,e12,e10,e8,e6,e4,e2,e0} + // only decode for {x,x,x,x,e7,e5,e3,e1,x,x,x,x,e6,e4,e2,e0} + // otherwise the pointer of _i2s should be moved to + int i2s = (i2s_i16 & 0x00ff); + i2s |= ((i2s_i16 & 0xff00) << 8); + +#pragma unroll + for (int i = 0; i < (N / 2); i++) + { + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(h[i]) + : "r"(i2s >> (2 * i)), "n"(BOTTOM_MASK), "n"(FP16_TOP_MAGIC_NUM), "n"(immLut)); + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(h[i]) : "r"(h[i]), "r"(median_num)); + + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(h[i]) : "r"(h[i]), "r"(packed_scales), "r"(0)); + } +} +template +__device__ void decode_i2u_to_f16_scale_zeros_quantized(T1 *_i2u, T2 *B_local_decode, T3 *scale = nullptr, T4 *zeros = nullptr, const int N = 8) +{ + decode_i2b_to_f16_scale_zeros_quantized(_i2u, B_local_decode, N, scale, zeros); +} + /* Kind 0: original Kind 1: rescale diff --git a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu index 0d0ebf7d2..7307ad1fe 100644 --- a/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu +++ b/testing/cpp/lop3_type_conversion/lowprecision_to_float16.cu @@ -46,6 +46,7 @@ REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_rescale, dec REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_rescale, decode_i2u_to_f16_scale_zeros_rescale) REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i1u_to_f16_scale_zeros_rescale, decode_i1u_to_f16_scale_zeros_rescale) REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i4u_to_f16_scale_zeros_quantized, decode_i4u_to_f16_scale_zeros_quantized) +REGISTER_GLOBAL_DEVICE_INVOKER(kernelWrapper_i2u_to_f16_scale_zeros_quantized, decode_i2u_to_f16_scale_zeros_quantized) TEST(DecodeTest, DecodeInt4ToFloat16) { @@ -1076,4 +1077,59 @@ TEST(DecodeTest, DecodeUInt4ToFloat16WithScalingWithZerosQuantized) free(ins); free(interleaved); free(decoded); +} + +TEST(DecodeTest, DecodeUInt2toFloat16WithScalingWithZerosQuantized) +{ + constexpr int nbits = 2; + constexpr int N = 32 / nbits; + constexpr int QN = N / 8 * nbits; + constexpr bool isSigned = false; + + // create four int8_t values + int8_t in_data[N] = { + 0}; + half scale[1] = {__float2half(1.2)}; + uint qzeros[1] = {(1 << (nbits - 1)) - 1}; + // breed seed + srand(0); + + // random initializations with nbits range + for (int i = 0; i < N; i++) + { + in_data[i] = (rand() % (1 << nbits)); + } + + int8_t *ins = new int8_t[QN]; + general_compress(in_data, ins, nbits, N, isSigned); + + int8_t *interleaved = new int8_t[QN]; + general_interleave_fp16(ins, interleaved, nbits, QN * sizeof(int8_t), false); + half *decoded = new half[N]; + int8_t *ins_gpu; + half *decoded_gpu, *scale_gpu; + uint *qzeros_gpu; + + cudaCheckLastError(cudaMalloc((void **)&ins_gpu, QN * sizeof(int8_t))); + cudaCheckLastError(cudaMalloc((void **)&decoded_gpu, N * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&scale_gpu, 1 * sizeof(half))); + cudaCheckLastError(cudaMalloc((void **)&qzeros_gpu, 1 * sizeof(uint))); + cudaCheckLastError(cudaMemcpy(ins_gpu, interleaved, QN * sizeof(int8_t), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(decoded_gpu, decoded, N * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(scale_gpu, scale, 1 * sizeof(half), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaMemcpy(qzeros_gpu, qzeros, 1 * sizeof(uint), cudaMemcpyHostToDevice)); + cudaCheckLastError(cudaDeviceSynchronize()); + kernelWrapper_i2u_to_f16_scale_zeros_quantized<<>>(ins_gpu, decoded_gpu, scale_gpu, qzeros_gpu); + kernelWrapper_i2u_to_f16_scale_zeros_quantized<<>>(ins_gpu + QN / 2, decoded_gpu + N / 2, scale_gpu, qzeros_gpu); + cudaCheckLastError(cudaDeviceSynchronize()); + cudaCheckLastError(cudaMemcpy(decoded, decoded_gpu, N * sizeof(half), cudaMemcpyDeviceToHost)); + cudaCheckLastError(cudaFree(ins_gpu)); + cudaCheckLastError(cudaFree(decoded_gpu)); + for (int i = 0; i < N; i++) + { + EXPECT_NEAR(((int)in_data[i] - (int)qzeros[0]) * float(scale[0]), float(decoded[i]), 1e-2); + } + free(ins); + free(interleaved); + free(decoded); } \ No newline at end of file From 44948ef97f1ce27c68d13d069552dec499887350 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 11:33:32 +0000 Subject: [PATCH 3/5] Support FP8 Codegen --- 3rdparty/tvm | 2 +- python/bitblas/gpu/matmul_analysis.py | 144 ++++++++++++------ python/bitblas/gpu/matmul_mma.py | 6 +- python/bitblas/gpu/matmul_mma_dequantize.py | 9 +- .../bitblas/ops/impl/ladder_permutate_impl.py | 4 +- .../ops/impl/matmul_dequantize_impl.py | 4 +- python/bitblas/ops/impl/matmul_impl.py | 6 +- .../bitblas/ops/impl/param_permutate_impl.py | 2 +- .../relax/transform/weight_only_propagate.py | 3 +- 9 files changed, 122 insertions(+), 58 deletions(-) diff --git a/3rdparty/tvm b/3rdparty/tvm index 780b83017..4afdf9f22 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 780b83017be0b5b12d123834adb07546bc4e6082 +Subproject commit 4afdf9f2274c5c59edf45655fc9c1654db2b0d8b diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index df50d283c..e01943d30 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -53,7 +53,9 @@ def auto_inline_producers( inlined_cnt = 0 producers = _collect_producers(sch, block) for producer in producers: - if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): + if any( + sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks + ): continue try: sch.compute_inline(producer) @@ -123,7 +125,9 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): # find the block that required to be reindex and scope. -def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: +def find_last_producer_from_buffer( + sch, main_block, buffer: tir.Buffer +) -> Optional[BlockRV]: # block that most near to the arguments block = main_block buffer = buffer @@ -140,14 +144,17 @@ def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optio for write in sch.get(producer).writes: if write.buffer == buffer: block = producer - buffer = find_first_similar_buffer(sch.get(producer).reads, last_buffer) + buffer = find_first_similar_buffer( + sch.get(producer).reads, last_buffer + ) if buffer == last_buffer: break return block -def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, - buffer: tir.Buffer) -> int: +def find_arg_idx_from_buffer_chain( + sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer +) -> int: """traverse to find the arg index from the buffer""" producers = sch.get_producers(main_block) @@ -216,7 +223,8 @@ def make_iter_fusion_index_map( fused_iters[trait.kind] = v_i final_indices: List[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order + fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) + for kind in kind_order ] return tir.IndexMap(input_iters, final_indices, None) @@ -289,15 +297,22 @@ def get_access_axes(region: List[Range]) -> Set[Var]: if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: return None - A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] - B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] - C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] + A_traits = [ + traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes + ] + B_traits = [ + traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes + ] + C_traits = [ + traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes + ] block_traits = [traits[i.var] for i in block.iter_vars] return A_traits, B_traits, C_traits, block_traits -def get_index_map(block: tir.Block, - layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: +def get_index_map( + block: tir.Block, layout: Optional[List[str]] = None +) -> Optional[Tuple[tir.IndexMap, ...]]: """Get index maps for the block Parameters @@ -369,17 +384,23 @@ def infer_layout(layout: str, region: List[Range], kind: str = "A"): if kind == "C": return [IterKind.kIter_S, primary_iter, secondary_iter] else: - return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) - else [IterKind.kIter_S, reduction_iter, spatial_iter]) + return ( + [IterKind.kIter_S, spatial_iter, reduction_iter] + if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter] + ) else: raise ValueError(f"Unknown layout {layout}") A_index_map = make_iter_fusion_index_map( - A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) + A_traits, infer_layout(layout[0], block.reads[0].region, kind="A") + ) B_index_map = make_iter_fusion_index_map( - B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) + B_traits, infer_layout(layout[1], block.reads[1].region, kind="B") + ) C_index_map = make_iter_fusion_index_map( - C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) + C_traits, infer_layout(layout[2], block.writes[0].region, kind="C") + ) matmul_index_map = make_iter_fusion_index_map( block_traits, @@ -411,10 +432,14 @@ def is_dequantize(block: BlockRV) -> bool: block_stmt = sch.get(block) if len(block_stmt.reads) < 2: return False - has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) + has_uint_input = any( + "uint" in str(region.buffer.dtype) for region in block_stmt.reads + ) if not has_uint_input: return False - if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): + if len(block_stmt.writes) != 1 or "float" not in str( + block_stmt.writes[0].buffer.dtype + ): return False return True @@ -439,7 +464,10 @@ def get_access_vars(region: List[Range]) -> List[Var]: axes.extend(undefined_vars(r.min)) # remove trivial axis trivial_vars = set( - iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) + iter_var.var + for iter_var in block_stmt.iter_vars + if _is_one(iter_var.dom.extent) + ) axes = [axis for axis in axes if axis not in trivial_vars] # remove duplicate axis axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] @@ -448,8 +476,9 @@ def get_access_vars(region: List[Range]) -> List[Var]: lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] is_identity = list(lhs_access_vars) == list(rhs_access_vars) - is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( - rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set( + lhs_access_vars + ) == set(rhs_access_vars) return is_identity, is_transpose @@ -477,9 +506,9 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] return result_blocks -def normalize_to_matmul(sch: tir.Schedule, - main_block: BlockRV, - layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: +def normalize_to_matmul( + sch: tir.Schedule, main_block: BlockRV, layout: Optional[List[str]] = None +) -> Optional[tir.Schedule]: if layout is None: layout = ["n", "t", "n"] block_stmt = sch.get(main_block) @@ -512,7 +541,9 @@ def get_tensorized_func_and_tags( allow_gemv: bool = False, ) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group,) + get_mma_intrin_group, + ) + """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ @@ -533,8 +564,12 @@ def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: conditions.append(len(block_stmt.writes) == 1) conditions.append( len( - collect_block_iter_vars_used_in_access_region(block_stmt, - block_stmt.writes[0].region)) > 0) + collect_block_iter_vars_used_in_access_region( + block_stmt, block_stmt.writes[0].region + ) + ) + > 0 + ) if not all(conditions): return False return True @@ -544,7 +579,9 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: + def analysis_tensorcore_tags( + sch: tir.Schedule, block: BlockRV, target: Target + ) -> bool: tags: Dict[str, Union[List[int], int]] = {} block_stmt = sch.get(block) @@ -607,14 +644,18 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: + # TODO(lei): we should consider the dtype of the input a and b + # instead of assuming both a and b share the same dtype. + # As the tensorcore may supports e4m3_float8 * e5m2_float8 in_dtype, out_dtype = get_in_out_dtypes(block_stmt) try: - _ = get_wmma_intrin_group( - in_dtype=in_dtype, + _ = get_mma_intrin_group( + a_dtype=in_dtype, + b_dtype=in_dtype, out_dtype=out_dtype, ) except Exception: - logger.debug("Cannot find the corresponding wmma intrin group") + logger.debug("Cannot find the corresponding mma intrin group") return func, None # reindex and transform functions @@ -631,13 +672,16 @@ def check_last_trait(region: List[Range]): minimal_tensorize_threshold = 16 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent - if isinstance(extent, - tir.expr.IntImm) and (extent.value < - (1 if allow_gemv else minimal_tensorize_threshold)): + if isinstance(extent, tir.expr.IntImm) and ( + extent.value < (1 if allow_gemv else minimal_tensorize_threshold) + ): return func, None for item_var in block_stmt.iter_vars[2:]: extent = item_var.dom.extent - if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): + if ( + isinstance(extent, tir.expr.IntImm) + and extent.value < minimal_tensorize_threshold + ): return func, None tags = analysis_tensorcore_tags(sch, main_block, target) return sch.mod["main"], tags @@ -645,17 +689,26 @@ def check_last_trait(region: List[Range]): return func, None -def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): +def get_propagate_map( + trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32" +): from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x8_to_shared_16x16_layout, + ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, + ldmatrix_32x16_to_shared_16x32_layout_b, ) - assert dtype in ["float16", "int8"], "Only support float16 for now" + assert dtype in [ + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", + ], "Only support float16, int8, e4m3_float8, e5m2_float8" if dtype == "float16": ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout - elif dtype == "int8": + elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: # int8 mma only support 32x16 to 16x32 layout if matrix_name == "A" and trans is False: ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a @@ -683,7 +736,9 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): if dtype == "float16": ldmatrix_index_map = ( ldmatrix_trans_permutation_16x16_32x8_16x16 - if trans else ldmatrix_permutation_16x16_32x8_16x16) + if trans + else ldmatrix_permutation_16x16_32x8_16x16 + ) else: ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 @@ -726,7 +781,9 @@ def layout_propagate_chain( read_indices = [r.min for r in read.region] # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout tmp_index_map = IndexMap(write_indices, read_indices, None) - tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0] + tmp_index_map = tmp_index_map.non_surjective_inverse( + write.buffer.shape + )[0] # if dequantize like ops are used, the scaling factor should be considered # to be applied to the final indices @@ -734,7 +791,8 @@ def layout_propagate_chain( for i, j in zip(write.buffer.shape, read.buffer.shape): scaling_factor *= i // j final_indices = list( - index_map.map_indices(tmp_index_map.map_indices(write_indices))) + index_map.map_indices(tmp_index_map.map_indices(write_indices)) + ) final_indices[-1] = final_indices[-1] // scaling_factor index_map = IndexMap( write_indices, diff --git a/python/bitblas/gpu/matmul_mma.py b/python/bitblas/gpu/matmul_mma.py index 5043501d5..a20359e11 100644 --- a/python/bitblas/gpu/matmul_mma.py +++ b/python/bitblas/gpu/matmul_mma.py @@ -303,7 +303,8 @@ def store_output(block_outer, write_buffer_idx): intrin_group = get_mma_intrin_group( load_scope="shared.dyn", store_scope="shared.dyn", - in_dtype=str(dtype_a), + a_dtype=str(dtype_a), + b_dtype=str(dtype_b), out_dtype=str(dtype_c), trans_a=is_transpose_a, trans_b=is_transpose_b, @@ -396,7 +397,8 @@ def check_has_dynamic(func: tir.PrimFunc): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py index 79a5eeb2f..e4c5a272a 100644 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -167,7 +167,8 @@ def check_weight_decode_info(weight_decode_info): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, @@ -654,7 +655,8 @@ def check_weight_decode_info(weight_decode_info): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, @@ -1143,7 +1145,8 @@ def check_weight_decode_info(weight_decode_info): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, diff --git a/python/bitblas/ops/impl/ladder_permutate_impl.py b/python/bitblas/ops/impl/ladder_permutate_impl.py index 5ac4a5334..8086bf584 100644 --- a/python/bitblas/ops/impl/ladder_permutate_impl.py +++ b/python/bitblas/ops/impl/ladder_permutate_impl.py @@ -9,7 +9,7 @@ def select_implementation( M: int, N: int, - datatype: Literal["float16", "int8"] = "float16", + datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", dequantize_bits: int = -1, storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", propagate_kind: Literal["A", "B"] = "B", @@ -23,7 +23,7 @@ def select_implementation( # This is trick to get the basic tile size for the current datatype # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 l = r = 16 # noqa: E741 - if datatype == "int8": + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 intra_index_map, _ = get_propagate_map( transpose_matrix, dtype=datatype, matrix_name=propagate_kind) diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 28e9ae42b..2aabcbc90 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -187,7 +187,7 @@ def matmul_nt_dequantize_b_propagate_b( M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") @@ -358,7 +358,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) diff --git a/python/bitblas/ops/impl/matmul_impl.py b/python/bitblas/ops/impl/matmul_impl.py index 26b748a88..69b426354 100644 --- a/python/bitblas/ops/impl/matmul_impl.py +++ b/python/bitblas/ops/impl/matmul_impl.py @@ -111,7 +111,7 @@ def matmul_nt_propagate_a( if not isinstance(M, int): M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") @@ -171,7 +171,7 @@ def matmul_nt_propagate_b( if not isinstance(M, int): M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") @@ -232,7 +232,7 @@ def matmul_nt_propagate_a_propagate_b( if not isinstance(M, int): M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) diff --git a/python/bitblas/ops/impl/param_permutate_impl.py b/python/bitblas/ops/impl/param_permutate_impl.py index 620212eb6..4ecb17709 100644 --- a/python/bitblas/ops/impl/param_permutate_impl.py +++ b/python/bitblas/ops/impl/param_permutate_impl.py @@ -24,7 +24,7 @@ def select_implementation( # This is trick to get the basic tile size for the current datatype # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 l = r = 16 # noqa: E741 - if datatype == "int8": + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 if group_size == -1: group_size = N diff --git a/python/bitblas/relax/transform/weight_only_propagate.py b/python/bitblas/relax/transform/weight_only_propagate.py index 8240a0fd8..709e02085 100644 --- a/python/bitblas/relax/transform/weight_only_propagate.py +++ b/python/bitblas/relax/transform/weight_only_propagate.py @@ -130,7 +130,8 @@ def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): intrin_group = get_mma_intrin_group( load_scope="shared", store_scope="shared", - in_dtype=intrin_info["in_dtype"], + a_dtype=intrin_info["in_dtype"], + b_dtype=intrin_info["in_dtype"], out_dtype=intrin_info["out_dtype"], trans_a=False, trans_b=intrin_info["trans_b"], From 3b758fbc837f83b5777bb26697e7d08b47294565 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 19:12:04 +0000 Subject: [PATCH 4/5] Add support for e4m3_float8 and e5m2_float8 types in CUDA wrapper --- python/bitblas/wrapper/general.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py index f97405716..651030c26 100644 --- a/python/bitblas/wrapper/general.py +++ b/python/bitblas/wrapper/general.py @@ -21,6 +21,8 @@ "float32": "float", "float16": "half", "bfloat16": "__nv_bfloat162", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", "float64": "double", "int64": "int64_t", "int32": "int", @@ -247,7 +249,7 @@ def update_lib_code(self, code: str): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=0", "type": "cudaStream_t"},) + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -285,7 +287,7 @@ def legalize_c(p): # Determine the shared memory size, defaulting to 0 if not specified smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf # Format the CUDA kernel launch string - call_str = "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + call_str = "{}<<<{}, {}, {}, 0>>>({});".format(function_name, grid_str, block_str, smem_str, call_args) # Create the host function wrapper for the CUDA kernel host_func = """ @@ -352,7 +354,7 @@ def create_dispatch_func(self, code, function_informations): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=0", "type": "cudaStream_t"},) + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) # Format the argument definitions for function declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) From a62682f67bd46ffc6b5fee04697d1de0080bd33a Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Mon, 29 Apr 2024 19:21:33 +0000 Subject: [PATCH 5/5] Support FP8 --- 3rdparty/tvm | 2 +- README.md | 13 ++ python/bitblas/base/utils.py | 17 ++- python/bitblas/gpu/gemv.py | 7 +- python/bitblas/gpu/matmul_analysis.py | 125 ++++++------------ python/bitblas/ops/general_matmul.py | 36 ++++- .../ops/impl/matmul_dequantize_impl.py | 2 - python/bitblas/ops/operator.py | 14 +- python/bitblas/wrapper/general.py | 4 +- .../operators/test_general_matmul_fp8.py | 77 +++++++++++ 10 files changed, 193 insertions(+), 104 deletions(-) create mode 100644 testing/python/operators/test_general_matmul_fp8.py diff --git a/3rdparty/tvm b/3rdparty/tvm index 4afdf9f22..a9b770a85 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 4afdf9f2274c5c59edf45655fc9c1654db2b0d8b +Subproject commit a9b770a85d2b856424a2b4c71d870e3f1af90396 diff --git a/README.md b/README.md index 73796371b..fcfcc2956 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,17 @@ Some of the key features of BitBLAS include: - BitBLAS first implemented $W_{INT2}A_{INT8}$ GEMV/GEMM in [BitNet-b1.58](https://arxiv.org/abs/2402.17764) with 8x/2x speedup over cuBLAS $W_{FP16}A_{FP16}$ on A100, please checkout [op_benchmark_a100_int2_scaling](https://github.com/microsoft/BitBLAS/blob/main/images/figures/op_benchmark_a100_int2_scaling.png) for detailed benchmark results. Please checkout [BitNet-b1.58 integration](https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet) for the integration with the 3rdparty reproduced BitNet-b1.58 model. - Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script). +## Latest News + +- 2024.04.19: BitBLAS is now open source! We are excited to announce that BitBLAS, a high-performance library for mixed-precision DNN model deployment, is now available to the public. +- 2024.04.30: BitBLAS now support + +## Integration Example of FasterTransformer with BitBLAS +![FasterTransformer Integration](images/gif/FasterTransformer.gif) + +## Benchmark Summary + + ## Integration Example of FasterTransformer with BitBLAS ![FasterTransformer Integration](images/gif/FasterTransformer.gif) @@ -63,6 +74,8 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and | INT8 | UINT4/INT4 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | INT8 | UINT2/INT2 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | +| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR. diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index da7b66ad8..0e51ef57b 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -135,16 +135,27 @@ def var_wrapper(v): else: raise ValueError("Not supported type: ", type(func)) + def map_numpy_type(intype): + typemap = { + 'e4m3_float8': 'float8_e4m3fn', + 'e5m2_float8': 'float8_e5m2', + } + if intype in typemap: + return typemap[intype] + else: + return intype + + numpy_dtype = map_numpy_type(arg.dtype) if distribution == "uniform": profile_tensors.append( tvm.nd.array( - np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(arg.dtype), + np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) elif distribution == "onefill": profile_tensors.append( tvm.nd.array( - np.ones([var_wrapper(i) for i in arg.shape]).astype(arg.dtype), + np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) else: @@ -245,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _): try: latency = cpresult.profile() except Exception as e_mesg: - logger.debug("Evaluation with config failed: ", e_mesg) + logger.debug(f"Evaluation with config failed {e_mesg}") continue logger.info("Evaluation with config {}".format(config)) logger.info("Time cost of this config: {:.3f} ms".format(latency)) diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py index 33388bffe..7a2880ed1 100644 --- a/python/bitblas/gpu/gemv.py +++ b/python/bitblas/gpu/gemv.py @@ -64,10 +64,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): def get_bytes(dtype: Union[DataType, str]) -> int: - num = re.findall(r"\d+", dtype) - if len(num) != 1: - raise ValueError(f"Cannot get bytes from {dtype}") - return int(num[0]) // 8 + if isinstance(dtype, str): + dtype = DataType(dtype) + return int(dtype.bits) // 8 def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index e01943d30..2fa9c16a4 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -53,9 +53,7 @@ def auto_inline_producers( inlined_cnt = 0 producers = _collect_producers(sch, block) for producer in producers: - if any( - sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks - ): + if any(sch.get(producer) == sch.get(skip_block) for skip_block in skip_blocks): continue try: sch.compute_inline(producer) @@ -125,9 +123,7 @@ def find_first_similar_buffer(regions: List[BufferRegion], buffer: tir.Buffer): # find the block that required to be reindex and scope. -def find_last_producer_from_buffer( - sch, main_block, buffer: tir.Buffer -) -> Optional[BlockRV]: +def find_last_producer_from_buffer(sch, main_block, buffer: tir.Buffer) -> Optional[BlockRV]: # block that most near to the arguments block = main_block buffer = buffer @@ -144,17 +140,14 @@ def find_last_producer_from_buffer( for write in sch.get(producer).writes: if write.buffer == buffer: block = producer - buffer = find_first_similar_buffer( - sch.get(producer).reads, last_buffer - ) + buffer = find_first_similar_buffer(sch.get(producer).reads, last_buffer) if buffer == last_buffer: break return block -def find_arg_idx_from_buffer_chain( - sch: tir.Schedule, main_block: tir.schedule.BlockRV, buffer: tir.Buffer -) -> int: +def find_arg_idx_from_buffer_chain(sch: tir.Schedule, main_block: tir.schedule.BlockRV, + buffer: tir.Buffer) -> int: """traverse to find the arg index from the buffer""" producers = sch.get_producers(main_block) @@ -223,8 +216,7 @@ def make_iter_fusion_index_map( fused_iters[trait.kind] = v_i final_indices: List[tir.PrimExpr] = [ - fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) - for kind in kind_order + fused_iters.get(kind, tir.IntImm(traits[0].extent.dtype, 0)) for kind in kind_order ] return tir.IndexMap(input_iters, final_indices, None) @@ -297,22 +289,15 @@ def get_access_axes(region: List[Range]) -> Set[Var]: if {x.kind for x in traits.values()}.intersection(gemm_traits) != gemm_traits: return None - A_traits = [ - traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes - ] - B_traits = [ - traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes - ] - C_traits = [ - traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes - ] + A_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in A_axes] + B_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in B_axes] + C_traits = [traits[iter_var.var] for iter_var in block.iter_vars if iter_var.var in C_axes] block_traits = [traits[i.var] for i in block.iter_vars] return A_traits, B_traits, C_traits, block_traits -def get_index_map( - block: tir.Block, layout: Optional[List[str]] = None -) -> Optional[Tuple[tir.IndexMap, ...]]: +def get_index_map(block: tir.Block, + layout: Optional[List[str]] = None) -> Optional[Tuple[tir.IndexMap, ...]]: """Get index maps for the block Parameters @@ -384,23 +369,17 @@ def infer_layout(layout: str, region: List[Range], kind: str = "A"): if kind == "C": return [IterKind.kIter_S, primary_iter, secondary_iter] else: - return ( - [IterKind.kIter_S, spatial_iter, reduction_iter] - if check_last_trait(region) - else [IterKind.kIter_S, reduction_iter, spatial_iter] - ) + return ([IterKind.kIter_S, spatial_iter, reduction_iter] if check_last_trait(region) + else [IterKind.kIter_S, reduction_iter, spatial_iter]) else: raise ValueError(f"Unknown layout {layout}") A_index_map = make_iter_fusion_index_map( - A_traits, infer_layout(layout[0], block.reads[0].region, kind="A") - ) + A_traits, infer_layout(layout[0], block.reads[0].region, kind="A")) B_index_map = make_iter_fusion_index_map( - B_traits, infer_layout(layout[1], block.reads[1].region, kind="B") - ) + B_traits, infer_layout(layout[1], block.reads[1].region, kind="B")) C_index_map = make_iter_fusion_index_map( - C_traits, infer_layout(layout[2], block.writes[0].region, kind="C") - ) + C_traits, infer_layout(layout[2], block.writes[0].region, kind="C")) matmul_index_map = make_iter_fusion_index_map( block_traits, @@ -432,14 +411,10 @@ def is_dequantize(block: BlockRV) -> bool: block_stmt = sch.get(block) if len(block_stmt.reads) < 2: return False - has_uint_input = any( - "uint" in str(region.buffer.dtype) for region in block_stmt.reads - ) + has_uint_input = any("uint" in str(region.buffer.dtype) for region in block_stmt.reads) if not has_uint_input: return False - if len(block_stmt.writes) != 1 or "float" not in str( - block_stmt.writes[0].buffer.dtype - ): + if len(block_stmt.writes) != 1 or "float" not in str(block_stmt.writes[0].buffer.dtype): return False return True @@ -464,10 +439,7 @@ def get_access_vars(region: List[Range]) -> List[Var]: axes.extend(undefined_vars(r.min)) # remove trivial axis trivial_vars = set( - iter_var.var - for iter_var in block_stmt.iter_vars - if _is_one(iter_var.dom.extent) - ) + iter_var.var for iter_var in block_stmt.iter_vars if _is_one(iter_var.dom.extent)) axes = [axis for axis in axes if axis not in trivial_vars] # remove duplicate axis axes = [var for i, var in enumerate(axes) if i == 0 or var != axes[i - 1]] @@ -476,9 +448,8 @@ def get_access_vars(region: List[Range]) -> List[Var]: lhs_access_vars = get_access_vars(block_stmt.reads[0].region)[-2:] rhs_access_vars = get_access_vars(block_stmt.writes[0].region)[-2:] is_identity = list(lhs_access_vars) == list(rhs_access_vars) - is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set( - lhs_access_vars - ) == set(rhs_access_vars) + is_transpose = list(lhs_access_vars) != list(rhs_access_vars) and set(lhs_access_vars) == set( + rhs_access_vars) return is_identity, is_transpose @@ -506,9 +477,9 @@ def inline_transpose_block(sch: tir.Schedule, blocks: List[tir.schedule.BlockRV] return result_blocks -def normalize_to_matmul( - sch: tir.Schedule, main_block: BlockRV, layout: Optional[List[str]] = None -) -> Optional[tir.Schedule]: +def normalize_to_matmul(sch: tir.Schedule, + main_block: BlockRV, + layout: Optional[List[str]] = None) -> Optional[tir.Schedule]: if layout is None: layout = ["n", "t", "n"] block_stmt = sch.get(main_block) @@ -541,9 +512,7 @@ def get_tensorized_func_and_tags( allow_gemv: bool = False, ) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_mma_intrin_group, - ) - + get_mma_intrin_group,) """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ @@ -564,12 +533,8 @@ def _can_be_tensorized(sch: tir.Schedule, block: BlockRV) -> bool: conditions.append(len(block_stmt.writes) == 1) conditions.append( len( - collect_block_iter_vars_used_in_access_region( - block_stmt, block_stmt.writes[0].region - ) - ) - > 0 - ) + collect_block_iter_vars_used_in_access_region(block_stmt, + block_stmt.writes[0].region)) > 0) if not all(conditions): return False return True @@ -579,9 +544,7 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def analysis_tensorcore_tags( - sch: tir.Schedule, block: BlockRV, target: Target - ) -> bool: + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: tags: Dict[str, Union[List[int], int]] = {} block_stmt = sch.get(block) @@ -672,16 +635,13 @@ def check_last_trait(region: List[Range]): minimal_tensorize_threshold = 16 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent - if isinstance(extent, tir.expr.IntImm) and ( - extent.value < (1 if allow_gemv else minimal_tensorize_threshold) - ): + if isinstance(extent, + tir.expr.IntImm) and (extent.value < + (1 if allow_gemv else minimal_tensorize_threshold)): return func, None for item_var in block_stmt.iter_vars[2:]: extent = item_var.dom.extent - if ( - isinstance(extent, tir.expr.IntImm) - and extent.value < minimal_tensorize_threshold - ): + if (isinstance(extent, tir.expr.IntImm) and extent.value < minimal_tensorize_threshold): return func, None tags = analysis_tensorcore_tags(sch, main_block, target) return sch.mod["main"], tags @@ -689,14 +649,10 @@ def check_last_trait(region: List[Range]): return func, None -def get_propagate_map( - trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32" -): +def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", index_dtype="int32"): from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - ldmatrix_32x8_to_shared_16x16_layout, - ldmatrix_trans_32x8_to_shared_16x16_layout, - ldmatrix_32x16_to_shared_16x32_layout_a, - ldmatrix_32x16_to_shared_16x32_layout_b, + ldmatrix_32x8_to_shared_16x16_layout, ldmatrix_trans_32x8_to_shared_16x16_layout, + ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, ) assert dtype in [ @@ -736,9 +692,7 @@ def ldmatrix_permutation_16x32_32x16_32x16(kernel_i, kernel_j): if dtype == "float16": ldmatrix_index_map = ( ldmatrix_trans_permutation_16x16_32x8_16x16 - if trans - else ldmatrix_permutation_16x16_32x8_16x16 - ) + if trans else ldmatrix_permutation_16x16_32x8_16x16) else: ldmatrix_index_map = ldmatrix_permutation_16x32_32x16_32x16 @@ -781,9 +735,7 @@ def layout_propagate_chain( read_indices = [r.min for r in read.region] # reverse index map from [vi // x] -> [vi * x] to match the inconsistent layout tmp_index_map = IndexMap(write_indices, read_indices, None) - tmp_index_map = tmp_index_map.non_surjective_inverse( - write.buffer.shape - )[0] + tmp_index_map = tmp_index_map.non_surjective_inverse(write.buffer.shape)[0] # if dequantize like ops are used, the scaling factor should be considered # to be applied to the final indices @@ -791,8 +743,7 @@ def layout_propagate_chain( for i, j in zip(write.buffer.shape, read.buffer.shape): scaling_factor *= i // j final_indices = list( - index_map.map_indices(tmp_index_map.map_indices(write_indices)) - ) + index_map.map_indices(tmp_index_map.map_indices(write_indices))) final_indices[-1] = final_indices[-1] // scaling_factor index_map = IndexMap( write_indices, diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index aed148df8..4a48fb901 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -23,6 +23,24 @@ WORKSPACE_SIZE = 1024 * 1024 * 256 +# TODO(lei): This should be improved into a general +# Method to get the consistent compute patterns. +NATIVE_COMPUTE_PATTERNS = [ + # A_dtype, W_dtype + ("float64", "float64"), + ("float32", "float32"), + ("float16", "float16"), + ("int8", "int8"), + ("e4m3_float8", "e4m3_float8"), + ("e4m3_float8", "e5m2_float8"), + ("e5m2_float8", "e4m3_float8"), + ("e5m2_float8", "e5m2_float8"), +] + + +def is_native_compute(A_dtype, W_dtype) -> bool: + return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS + class OPExecutorCPU: @@ -150,8 +168,15 @@ def __post_init__(self): if self.with_zeros is None: object.__setattr__(self, "with_zeros", False) - if self.A_dtype == self.W_dtype and self.W_dtype in ["float16", "int8"]: + if self.A_dtype == self.W_dtype and self.W_dtype in [ + "float16", "int8", "e4m3_float8", "e5m2_float8" + ]: object.__setattr__(self, "storage_dtype", self.W_dtype) + # TODO(lei): This is a limitation arose by pytorch + # Should be removed in the future. + if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) class Matmul(Operator): @@ -176,6 +201,8 @@ class Matmul(Operator): "nf4": ("nf", 4), "fp8_e5m2": ("fp", 8), "fp4_e2m1": ("fp", 4), + "e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn" + "e5m2_float8": ("fp", 8), } def __init__( @@ -316,7 +343,7 @@ def _build_default_module(self, target: Target): self._build_runtime_module(target) def _select_implementation(self): - if self.A_dtype == self.W_dtype: + if is_native_compute(self.A_dtype, self.W_dtype): return consistent_implementation( M=self.M, N=self.N, @@ -446,8 +473,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(bias) args.append(output) - m = reduce(operator.mul, A.shape[:-1], 1) - args.append(m) + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) if self.lib is None: self._forward_from_torch_func(*args) diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 2aabcbc90..b7c7c64ee 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -97,8 +97,6 @@ def decode_func(n, k): else: raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: return w diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index 0290e2e28..224726b6d 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -220,15 +220,27 @@ def var_warpper(v): else: raise RuntimeError("Not supported type: ", type(v)) + def map_numpy_type(intype): + typemap = { + 'e4m3_float8': 'float8_e4m3fn', + 'e5m2_float8': 'float8_e5m2', + } + if intype in typemap: + return typemap[intype] + else: + return intype + profile_tensors = [] for param in func.params: if param not in func.buffer_map: # in case of dynamic symbolic may in params continue arg = func.buffer_map[param] + numpy_dtype = map_numpy_type(arg.dtype) profile_tensors.append( tvm.nd.array( - np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype(arg.dtype), + np.random.uniform(0, 1, + [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) self.profile_tensors = profile_tensors diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py index 651030c26..fa2e7405c 100644 --- a/python/bitblas/wrapper/general.py +++ b/python/bitblas/wrapper/general.py @@ -287,8 +287,8 @@ def legalize_c(p): # Determine the shared memory size, defaulting to 0 if not specified smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf # Format the CUDA kernel launch string - call_str = "{}<<<{}, {}, {}, 0>>>({});".format(function_name, grid_str, block_str, - smem_str, call_args) + call_str = "{}<<<{}, {}, {}, 0>>>({});".format(function_name, grid_str, block_str, smem_str, + call_args) # Create the host function wrapper for the CUDA kernel host_func = """ extern "C" void call({}) {{ diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py new file mode 100644 index 000000000..720e8fb32 --- /dev/null +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, + None), + (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, + None), + (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, + None), + (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, + None), + ], +) +def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=True) + + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + + def map_torch_type(intype): + + typemap = { + 'e4m3_float8': torch.float8_e4m3fn, + 'e5m2_float8': torch.float8_e5m2, + } + if intype in typemap: + return typemap[intype] + else: + return getattr(torch, intype) + + numpytype_a = map_torch_type(A_dtype) + numpytype_b = map_torch_type(W_dtype) + numpytype_c = map_torch_type(out_dtype) + + torch_a = torch.rand(M*K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N*K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() + ref_out = torch.matmul(torch_a.to(torch.float32), torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul(torch_a.to(torch.float32), torch_b.to(torch.float32)) + ref_out = ref_out.to(numpytype_c) + + print("torch_ref_out", ref_out) + new_torch_b = matmul.transform_weight(torch_b) + bitblas_out = matmul(torch_a, new_torch_b) + print("bitblas_out", bitblas_out) + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main()