diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ed6b3404..1ef14100d 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -214,6 +214,326 @@ def emit(self): ) return tvm.IRModule.from_expr(func) +# TODO: The following code should be refactored. +class MatMulNTDequantizeEmitter: + + def __init__( + self, + M, + N, + K, + in_dtype="float16", + out_dtype="float16", + accum_dtype="float16", + bit=4, + storage_dtype="int8", + source_format="uint", + with_scaling=False, + with_zeros=False, + group_size=-1, + fast_decoding=False, + with_bias=False, + zeros_mode="original", + propagate_a: TransformKind = TransformKind.NonTransform, + propagate_b: TransformKind = TransformKind.NonTransform, + ): + self.M = self._validate_dimension(M, "M") + self.N = N + self.K = K + self.in_dtype = in_dtype + self.out_dtype = out_dtype + self.accum_dtype = accum_dtype + self.bit = bit + self.storage_dtype = storage_dtype + self.source_format = source_format + self.with_scaling = with_scaling + self.with_zeros = with_zeros + self.group_size = group_size if group_size != -1 else K + self.fast_decoding = fast_decoding + self.with_bias = with_bias + self.zeros_mode = zeros_mode + self.propagate_a = self._legalize_transform_kind(propagate_a) + self.propagate_b = self._legalize_transform_kind(propagate_b) + + self._validate_bit() + self._validate_layout() + + @staticmethod + def _validate_dimension(dim, name): + if not isinstance(dim, int): + return tvm.te.var(name.lower()) + return dim + + def _validate_bit(self): + if self.bit not in [1, 2, 4, 8]: + raise ValueError(f"Unsupported bit: {self.bit}") + + def _validate_layout(self): + # TODO: extend the dequantize operators into General Layout + pass + + def _legalize_group_size(self): + if self.group_size == -1: + self.group_size = self.K + + def _legalize_transform_kind(self, propagate): + if propagate is None: + return TransformKind.NonTransform + if isinstance(propagate, bool): + return (TransformKind.IntraWarpTransform if propagate else TransformKind.NonTransform) + elif isinstance(propagate, int): + return TransformKind(propagate) + + def _create_placeholders(self): + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + in_dtype = self.in_dtype + bit = self.bit + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + + A = te.placeholder((self.M, self.K), name="A", dtype=in_dtype) + B = te.placeholder((self.N, self.K // storage_nbit * bit), name="B", dtype=storage_dtype) + if self.propagate_a: + A = te.placeholder((self.M // l, self.K // r, l, r), name="A", dtype=in_dtype) + if self.propagate_b: + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + qr = r * bit // storage_nbit + B = te.placeholder((self.N // l, (self.K // scaling_factor) // qr, l, qr), + name="B", + dtype=storage_dtype) + + LUT = te.placeholder((1 << bit,), name="LUT", dtype=in_dtype) + Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=in_dtype) + Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=in_dtype) + QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * bit), + name="QZeros", + dtype=self.storage_dtype) + Bias = te.placeholder((self.N,), name="Bias", dtype=in_dtype) + return A, B, LUT, Scale, Zeros, QZeros, Bias + + def _propagate_input(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="A"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=False, dtype=in_dtype, matrix_name=matrix_name) + + def fcompute(i, j): + warp_i, warp_j = i % l, j % r + spatial_args = i // l, j // r + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.M, self.K), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _propagage_weight(self, tensor, transform_kind=TransformKind.NonTransform, matrix_name="B"): + if transform_kind == TransformKind.NonTransform: + return tensor + in_dtype = self.in_dtype + bit = self.bit + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) + + l = r = 16 # noqa: E741 + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: + l, r = 16, 32 # noqa: E741 + _, inversed_index_map = get_propagate_map( + trans=True, dtype=in_dtype, matrix_name=matrix_name) + target_dtype = DataType(in_dtype) + scaling_factor = 1 + if bit > 0 and bit < target_dtype.bits: + scaling_factor = ((target_dtype.bits // bit) * DataType(storage_dtype).bits // + target_dtype.bits) + initial_indices = inversed_index_map.initial_indices + scaling_final_indices = inversed_index_map.map_indices( + initial_indices[:-1] + [initial_indices[-1] * scaling_factor]) + scaling_final_indices = scaling_final_indices[:-1] + [ + scaling_final_indices[-1] // scaling_factor + ] + inversed_index_map = IndexMap( + initial_indices, + scaling_final_indices, + None, + ) + + qr = r * bit // storage_nbit + + def fcompute(i, j): + warp_i, warp_j = i % l, j % qr + spatial_args = i // l, j // qr + if transform_kind >= TransformKind.IntraWarpTransform: + warp_i, warp_j = inversed_index_map.map_indices([warp_i, warp_j]) + new_index = (*spatial_args, warp_i, warp_j) + return tensor[new_index] + + return te.compute( + (self.N, self.K // storage_nbit * bit), + fcompute, + name=f"{matrix_name}_reindex", + ) + + def _decode_func(self, B, LUT, Scale, Zeros, QZeros): + bit = self.bit + in_dtype = self.in_dtype + storage_dtype = self.storage_dtype + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + storage_type = str("".join(c for c in storage_dtype if not c.isdigit())) + n_float_per_elem = storage_nbit // bit + + # TODO: Move the decode function into a more general place + def decode(n, k): + w = None + if self.with_zeros and self.zeros_mode == "quantized": + qzeros_dequantize = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + QZeros[k, n // n_float_per_elem], + n % n_float_per_elem, + dtype=self.storage_dtype, + ) + w = _tir_packed_to_unsigned_convert_with_zeros(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + qzeros_dequantize, + dtype=in_dtype, + ) + elif self.source_format == "uint": + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "int": + if bit == 1: + w = _tir_packed_int_to_int_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + if bit == 8: + w = B[n, k].astype(in_dtype) + w = _tir_packed_to_signed_convert(storage_type, storage_nbit)( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "fp": + w = _tir_u32_to_f4_to_f16( + bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=in_dtype) + elif self.source_format == "fp_e4m3": + w = _tir_u8_to_f8_e4m3_to_f16(bit, B[n, k], dtype=in_dtype) + elif self.source_format == "nf": + index = _tir_packed_to_unsigned_convert(storage_type, storage_nbit)( + bit, + B[n, k // n_float_per_elem], + k % n_float_per_elem, + dtype="int32", + ) + w = LUT[index] + else: + raise ValueError(f"Unsupported source_format: {self.source_format}") + + assert w is not None, "w is None" + + group_size = self.group_size + zeros_mode = self.zeros_mode + + if not self.with_scaling: + return w + + if not self.with_zeros: + return w * Scale[n, k // group_size] + + if zeros_mode == "original": + w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] + elif zeros_mode == "rescale": + w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] + elif zeros_mode == "quantized": + w = w * Scale[n, k // group_size] + else: + raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) + + return w + + return te.compute((self.N, self.K), decode, name="B_decode") + + def _compute_matmul(self, A, B_decode): + k = te.reduce_axis((0, self.K), name="k") + C = te.compute( + (self.M, self.N), + lambda i, j: te.sum( + A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), + name="C", + ) + return C + + def _convert_dtype(self, tensor): + if self.accum_dtype != self.out_dtype: + return te.compute((self.M, self.N), + lambda i, j: tensor[i, j].astype(self.out_dtype), + name="D") + return tensor + + def _apply_bias(self, tensor, Bias): + if self.with_bias: + return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") + return tensor + + def emit(self): + A, B, LUT, Scale, Zeros, QZeros, Bias = self._create_placeholders() + A_reindex = self._propagate_input(A, self.propagate_a, "A") + B_reindex = self._propagage_weight(B, self.propagate_b, "B") + + B_decode = self._decode_func(B_reindex, LUT, Scale, Zeros, QZeros) + C = self._compute_matmul(A_reindex, B_decode) + D = self._convert_dtype(C) + last_output = self._apply_bias(D, Bias) + + args = [A, B] + if self.source_format == "nf": + args.append(LUT) + if self.with_scaling: + args.append(Scale) + if self.with_zeros: + args.append(QZeros if self.zeros_mode == "quantized" else Zeros) + if self.with_bias: + args.append(Bias) + args.append(last_output) + + func = te.create_prim_func(args).with_attr( + "dequantize_info", + { + "B_decode": { + "decode_block": "B_decode", + "fast_decoding": self.fast_decoding, + "source_format": { + "bits": self.bit, + "format": self.source_format, + }, + "storage_dtype": self.storage_dtype, + "target_format": self.in_dtype, + "with_zeros": self.with_zeros, + "zeros_mode": self.zeros_mode, + "with_scaling": self.with_scaling, + "group_size": self.group_size, + } + }, + ) + if self.propagate_a: + func = func.with_attr("input_transform_kind", self.propagate_a.value) + if self.propagate_b: + func = func.with_attr("weight_transform_kind", self.propagate_b.value) + return tvm.IRModule.from_expr(func) + + def matmul_nt_dequantize_b( M, N, @@ -335,9 +655,12 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -517,9 +840,11 @@ def decode_func(n, k): A[i, k].astype(accum_dtype) * B_decode[j, k].astype(accum_dtype), axis=k), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: @@ -715,9 +1040,11 @@ def decode_func(n, k): ), name="C", ) - D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = C + if accum_dtype != out_dtype: + D = te.compute((M, N), lambda i, j: C[i, j].astype(out_dtype), name="D") + last_output = D args = [A, B] - last_output = D if source_format == "nf": args.append(LUT) if with_scaling: diff --git a/testing/python/operators/test_tir_script_emitter.py b/testing/python/operators/test_tir_script_emitter.py index cec56b473..b2c7a8d4f 100644 --- a/testing/python/operators/test_tir_script_emitter.py +++ b/testing/python/operators/test_tir_script_emitter.py @@ -1,18 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from bitblas.ops.impl.matmul_dequantize_impl import ( - MatMulNTDequantizeEmitter, - matmul_nt_dequantize_b, - matmul_nt_dequantize_b_propagate_b, - matmul_nt_dequantize_b_propagate_a_propagate_b, -) from bitblas import tvm import logging from bitblas import set_log_level set_log_level(logging.DEBUG) -def compare_tir_scripts_and_emitter( + +def check_eual_ref_scripts_with_emitter( M, N, K, @@ -28,8 +23,26 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a, + propagate_b, ): - tir_script_func = matmul_nt_dequantize_b( + from bitblas.ops.impl.matmul_dequantize_impl import ( + MatMulNTDequantizeEmitter, + matmul_nt_dequantize_b, + matmul_nt_dequantize_b_propagate_b, + matmul_nt_dequantize_b_propagate_a_propagate_b, + ) + func = None + if propagate_a and propagate_b: + func = matmul_nt_dequantize_b_propagate_a_propagate_b + elif propagate_b: + func = matmul_nt_dequantize_b_propagate_b + else: + func = matmul_nt_dequantize_b + + assert func is not None, "No function found for the given configuration" + + ref_func = func( M, N, K, @@ -46,8 +59,8 @@ def compare_tir_scripts_and_emitter( with_bias, zeros_mode, ) - - emitter_func = MatMulNTDequantizeEmitter( + + emit_func = MatMulNTDequantizeEmitter( M, N, K, @@ -63,6 +76,36 @@ def compare_tir_scripts_and_emitter( fast_decoding, with_bias, zeros_mode, + propagate_a=propagate_a, + propagate_b=propagate_b, ).emit() - - tvm.ir.assert_structural_equal(tir_script_func, emitter_func) + + tvm.ir.assert_structural_equal(ref_func, emit_func) + + +def test_check_eual_ref_scripts_with_emitter(): + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "nf", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(16384, 16384, 16384, "float16", "float16", "float16", 4, + "int8", "nf", True, False, -1, False, False, "original", + False, False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + False) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1, 16384, 16384, "float16", "float16", "float16", 4, "int8", + "uint", True, False, -1, False, False, "original", False, + True) + check_eual_ref_scripts_with_emitter(1024, 1024, 1024, "float16", "float16", "float16", 4, + "int8", "uint", True, False, -1, False, False, "original", + True, True) + + +if __name__ == "__main__": + test_check_eual_ref_scripts_with_emitter()