diff --git a/bitblas/ops/impl/matmul_dequantize_impl.py b/bitblas/ops/impl/matmul_dequantize_impl.py index 1ef14100d..55d672097 100644 --- a/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/bitblas/ops/impl/matmul_dequantize_impl.py @@ -15,204 +15,6 @@ _tir_packed_to_unsigned_convert_with_zeros, ) -# TODO: The following code should be refactored. -class MatMulNTDequantizeEmitter: - def __init__( - self, - M, - N, - K, - in_dtype="float16", - out_dtype="float16", - accum_dtype="float16", - bit=4, - storage_dtype="int8", - source_format="uint", - with_scaling=False, - with_zeros=False, - group_size=-1, - fast_decoding=False, - with_bias=False, - zeros_mode="original", - propagate_a: TransformKind = TransformKind.NonTransform, - propagate_b: TransformKind = TransformKind.NonTransform, - ): - self.M = self._validate_dimension(M, "M") - self.N = N - self.K = K - self.in_dtype = in_dtype - self.out_dtype = out_dtype - self.accum_dtype = accum_dtype - self.bit = bit - self.storage_dtype = storage_dtype - self.source_format = source_format - self.with_scaling = with_scaling - self.with_zeros = with_zeros - self.group_size = group_size if group_size != -1 else K - self.fast_decoding = fast_decoding - self.with_bias = with_bias - self.zeros_mode = zeros_mode - self.propagate_a = propagate_a - self.propagate_b = propagate_b - - self._validate_bit() - self._validate_layout() - - @staticmethod - def _validate_dimension(dim, name): - if not isinstance(dim, int): - return tvm.te.var(name.lower()) - return dim - - def _validate_bit(self): - if self.bit not in [1, 2, 4, 8]: - raise ValueError(f"Unsupported bit: {self.bit}") - - def _validate_layout(self): - if self.layout not in ["nt"]: - raise ValueError(f"Unsupported layout: {self.layout}") - - def _create_placeholders(self): - storage_nbit = int("".join(c for c in self.storage_dtype if c.isdigit())) - n_float_per_elem = storage_nbit // self.bit - - A = te.placeholder((self.M, self.K), name="A", dtype=self.in_dtype) - B = te.placeholder((self.N, self.K // storage_nbit * self.bit), name="B", dtype=self.storage_dtype) - LUT = te.placeholder((1 << self.bit,), name="LUT", dtype=self.in_dtype) - Scale = te.placeholder((self.N, self.K // self.group_size), name="Scale", dtype=self.in_dtype) - Zeros = te.placeholder((self.N, self.K // self.group_size), name="Zeros", dtype=self.in_dtype) - QZeros = te.placeholder(((self.K // self.group_size), self.N // storage_nbit * self.bit), - name="QZeros", - dtype=self.storage_dtype) - Bias = te.placeholder((self.N,), name="Bias", dtype=self.in_dtype) - return A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem - - def _decode_func(self, B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem): - w = None - def decode(n, k): - if self.with_zeros and self.zeros_mode == "quantized": - qzeros_dequantize = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - QZeros[k, n // n_float_per_elem], - n % n_float_per_elem, - dtype=self.storage_dtype, - ) - w = _tir_packed_to_unsigned_convert_with_zeros(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - qzeros_dequantize, - dtype=self.in_dtype, - ) - elif self.source_format == "uint": - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "int": - if self.bit == 1: - w = _tir_packed_int_to_int_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - if self.bit == 8: - w = B[n, k].astype(self.in_dtype) - w = _tir_packed_to_signed_convert(self.storage_dtype, storage_nbit)( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp": - w = _tir_u32_to_f4_to_f16( - self.bit, B[n, k // n_float_per_elem], k % n_float_per_elem, dtype=self.in_dtype) - elif self.source_format == "fp_e4m3": - w = _tir_u8_to_f8_e4m3_to_f16(self.bit, B[n, k], dtype=self.in_dtype) - elif self.source_format == "nf": - index = _tir_packed_to_unsigned_convert(self.storage_dtype, storage_nbit)( - self.bit, - B[n, k // n_float_per_elem], - k % n_float_per_elem, - dtype="int32", - ) - w = LUT[index] - else: - raise ValueError(f"Unsupported source_format: {self.source_format}") - - group_size = self.group_size - zeros_mode = self.zeros_mode - - if not self.with_scaling: - return w - - if not self.with_zeros: - return w * Scale[n, k // group_size] - - if zeros_mode == "original": - w = (w - Zeros[n, k // group_size]) * Scale[n, k // group_size] - elif zeros_mode == "rescale": - w = w * Scale[n, k // group_size] - Zeros[n, k // group_size] - elif zeros_mode == "quantized": - w = w * Scale[n, k // group_size] - else: - raise ValueError("Unsupported zeros_mode: {}".format(zeros_mode)) - - return w - - return te.compute((self.N, self.K), decode, name="B_decode") - - def _compute_matmul(self, A, B_decode): - k = te.reduce_axis((0, self.K), name="k") - C = te.compute( - (self.M, self.N), - lambda i, j: te.sum( - A[i, k].astype(self.accum_dtype) * B_decode[j, k].astype(self.accum_dtype), axis=k), - name="C", - ) - return C - - def _convert_dtype(self, tensor): - if self.accum_dtype != self.out_dtype: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j].astype(self.out_dtype), name="D") - return tensor - - def _apply_bias(self, tensor, Bias): - if self.with_bias: - return te.compute((self.M, self.N), lambda i, j: tensor[i, j] + Bias[j], name="E") - return tensor - - def emit(self): - A, B, LUT, Scale, Zeros, QZeros, Bias, storage_nbit, n_float_per_elem = self._create_placeholders() - B_decode = self._decode_func(B, LUT, Scale, Zeros, QZeros, storage_nbit, n_float_per_elem) - C = self._compute_matmul(A, B_decode) - D = self._convert_dtype(C) - last_output = self._apply_bias(D, Bias) - - args = [A, B] - if self.source_format == "nf": - args.append(LUT) - if self.with_scaling: - args.append(Scale) - if self.with_zeros: - args.append(QZeros if self.zeros_mode == "quantized" else Zeros) - if self.with_bias: - args.append(Bias) - args.append(last_output) - - func = te.create_prim_func(args).with_attr( - "dequantize_info", - { - "B_decode": { - "decode_block": "B_decode", - "fast_decoding": self.fast_decoding, - "source_format": { - "bits": self.bit, - "format": self.source_format, - }, - "storage_dtype": self.storage_dtype, - "target_format": self.in_dtype, - "with_zeros": self.with_zeros, - "zeros_mode": self.zeros_mode, - "with_scaling": self.with_scaling, - "group_size": self.group_size, - } - }, - ) - return tvm.IRModule.from_expr(func) # TODO: The following code should be refactored. class MatMulNTDequantizeEmitter: @@ -671,8 +473,7 @@ def decode_func(n, k): else: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) @@ -852,8 +653,7 @@ def decode_func(n, k): if with_zeros: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) @@ -1052,8 +852,7 @@ def decode_func(n, k): if with_zeros: args.append(Zeros) if with_bias: - E = te.compute((M, N), lambda i, j: D[i, j] + Bias[j], name="E") - last_output = E + last_output = te.compute((M, N), lambda i, j: last_output[i, j] + Bias[j], name="E") args.append(Bias) args.append(last_output) diff --git a/testing/python/module/test_bitblas_linear.py b/testing/python/module/test_bitblas_linear.py index eeaf90475..eee08c93c 100644 --- a/testing/python/module/test_bitblas_linear.py +++ b/testing/python/module/test_bitblas_linear.py @@ -11,16 +11,7 @@ torch.manual_seed(0) bitblas.set_log_level("DEBUG") -@pytest.mark.parametrize( - "m, in_features, out_features, bias", - [ - (1, 1024, 1024, False), - (1, 1024, 1024, True), - (1024, 1024, 1024, True), - ([1, 1024], 1024, 1024, True), - ], -) -def test_correctness_consistent(m, in_features, out_features, bias): +def correctness_consistent(m, in_features, out_features, bias): linear_torch = (nn.Linear(in_features, out_features, bias=bias).to(torch.float16).cuda()) linear_bitblas = BitBLASLinear( in_features, @@ -48,19 +39,13 @@ def test_correctness_consistent(m, in_features, out_features, bias): torch.testing.assert_close(output_torch, output_bitblas, rtol=1e-1, atol=1e-2) -@pytest.mark.parametrize( - "m, in_features, out_features, bias, W_dtype, group_size, with_scaling, with_zeros, zeros_mode", - [ - (1, 1024, 1024, False, "uint4", -1, False, False, None), - (1, 1024, 1024, False, "uint4", -1, False, False, None), - (1024, 1024, 1024, True, "uint4", -1, False, False, None), - (1, 1024, 1024, True, "uint2", -1, True, False, None), - (1, 1024, 1024, True, "uint2", 128, True, True, "original"), - (1024, 1024, 1024, True, "uint2", 128, True, True, "original"), - (1, 1024, 1024, True, "uint2", 128, True, True, "rescale"), - ], -) -def test_correctness_weight_only_dequantize( +def test_correctness_consistent(): + correctness_consistent(1, 1024, 1024, False) + correctness_consistent(1, 1024, 1024, True) + correctness_consistent(1024, 1024, 1024, True) + correctness_consistent([1, 1024], 1024, 1024, True) + +def correctness_weight_only_dequantize( m, in_features, out_features, @@ -169,6 +154,16 @@ def test_correctness_weight_only_dequantize( torch.testing.assert_close(output_bitblas, ref_result, rtol=1e0, atol=1e0) +def test_correctness_weight_only_dequantize(): + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, False, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint4", -1, False, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", -1, True, False, None) + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1024, 1024, 1024, True, "uint2", 128, True, True, "original") + correctness_weight_only_dequantize(1, 1024, 1024, True, "uint2", 128, True, True, "rescale") + + def profile(model, input_data): model = model.cuda() model.eval() diff --git a/testing/python/operators/test_general_matmul_ops.py b/testing/python/operators/test_general_matmul_ops.py index 05e0a45f4..62808e2a7 100644 --- a/testing/python/operators/test_general_matmul_ops.py +++ b/testing/python/operators/test_general_matmul_ops.py @@ -195,7 +195,7 @@ def matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layo if with_bias: permuted_inputs.append(bias) permuted_inputs.append(inputs[2]) - matmul(*permuted_inputs[:2], output=permuted_inputs[-1]) + matmul(*permuted_inputs[:-1], output=permuted_inputs[-1]) if zeros_mode == "rescale": torch.testing.assert_close(permuted_inputs[-1], ref_result, rtol=1e2, atol=1e0) else: