diff --git a/3rdparty/tvm b/3rdparty/tvm index 780b83017..a9b770a85 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 780b83017be0b5b12d123834adb07546bc4e6082 +Subproject commit a9b770a85d2b856424a2b4c71d870e3f1af90396 diff --git a/README.md b/README.md index 73796371b..fcfcc2956 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,17 @@ Some of the key features of BitBLAS include: - BitBLAS first implemented $W_{INT2}A_{INT8}$ GEMV/GEMM in [BitNet-b1.58](https://arxiv.org/abs/2402.17764) with 8x/2x speedup over cuBLAS $W_{FP16}A_{FP16}$ on A100, please checkout [op_benchmark_a100_int2_scaling](https://github.com/microsoft/BitBLAS/blob/main/images/figures/op_benchmark_a100_int2_scaling.png) for detailed benchmark results. Please checkout [BitNet-b1.58 integration](https://github.com/microsoft/BitBLAS/blob/main/integration/BitNet) for the integration with the 3rdparty reproduced BitNet-b1.58 model. - Support customizing mixed-precision DNN operations for your specific scenarios via the flexible DSL (TIR Script). +## Latest News + +- 2024.04.19: BitBLAS is now open source! We are excited to announce that BitBLAS, a high-performance library for mixed-precision DNN model deployment, is now available to the public. +- 2024.04.30: BitBLAS now support + +## Integration Example of FasterTransformer with BitBLAS +![FasterTransformer Integration](images/gif/FasterTransformer.gif) + +## Benchmark Summary + + ## Integration Example of FasterTransformer with BitBLAS ![FasterTransformer Integration](images/gif/FasterTransformer.gif) @@ -63,6 +74,8 @@ For more detailed information on benchmark sets with other formats (NF4/FP4) and | INT8 | UINT4/INT4 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | INT8 | UINT2/INT2 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | | INT8 | UINT1 | INT32 | FP32/INT32/FP16/INT8 | **√** | V100(SM_70)/A100(SM_80)/A6000(SM_86)/RTX 4090(SM_89) | +| FP8_E4M3 | FP8_E4M3 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | +| FP8_E5M2 | FP8_E5M2 | FP32 | FP32/FP16 | **√** | RTX 4090(SM_89) | We are continuously expanding the support matrix. If you have any specific requirements, please feel free to open an issue or PR. diff --git a/python/bitblas/base/utils.py b/python/bitblas/base/utils.py index da7b66ad8..0e51ef57b 100644 --- a/python/bitblas/base/utils.py +++ b/python/bitblas/base/utils.py @@ -135,16 +135,27 @@ def var_wrapper(v): else: raise ValueError("Not supported type: ", type(func)) + def map_numpy_type(intype): + typemap = { + 'e4m3_float8': 'float8_e4m3fn', + 'e5m2_float8': 'float8_e5m2', + } + if intype in typemap: + return typemap[intype] + else: + return intype + + numpy_dtype = map_numpy_type(arg.dtype) if distribution == "uniform": profile_tensors.append( tvm.nd.array( - np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(arg.dtype), + np.random.rand(*[var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) elif distribution == "onefill": profile_tensors.append( tvm.nd.array( - np.ones([var_wrapper(i) for i in arg.shape]).astype(arg.dtype), + np.ones([var_wrapper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) else: @@ -245,7 +256,7 @@ def tvm_callback_cuda_postproc(code, _): try: latency = cpresult.profile() except Exception as e_mesg: - logger.debug("Evaluation with config failed: ", e_mesg) + logger.debug(f"Evaluation with config failed {e_mesg}") continue logger.info("Evaluation with config {}".format(config)) logger.info("Time cost of this config: {:.3f} ms".format(latency)) diff --git a/python/bitblas/gpu/gemv.py b/python/bitblas/gpu/gemv.py index 33388bffe..7a2880ed1 100644 --- a/python/bitblas/gpu/gemv.py +++ b/python/bitblas/gpu/gemv.py @@ -64,10 +64,9 @@ def get_extent(sch: tir.Schedule, loop_rv: tir.schedule.LoopRV): def get_bytes(dtype: Union[DataType, str]) -> int: - num = re.findall(r"\d+", dtype) - if len(num) != 1: - raise ValueError(f"Cannot get bytes from {dtype}") - return int(num[0]) // 8 + if isinstance(dtype, str): + dtype = DataType(dtype) + return int(dtype.bits) // 8 def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffer]]: diff --git a/python/bitblas/gpu/matmul_analysis.py b/python/bitblas/gpu/matmul_analysis.py index df50d283c..2fa9c16a4 100644 --- a/python/bitblas/gpu/matmul_analysis.py +++ b/python/bitblas/gpu/matmul_analysis.py @@ -512,7 +512,7 @@ def get_tensorized_func_and_tags( allow_gemv: bool = False, ) -> Tuple[tir.PrimFunc, Dict[str, Union[List[int], int]]]: from tvm.tir.tensor_intrin.cuda import ( # pylint: disable=import-outside-toplevel - get_wmma_intrin_group,) + get_mma_intrin_group,) """ transform function to matmul if necessary (e.g. transform conv2d with im2col) """ @@ -607,14 +607,18 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) if target.kind.name == "cuda" and check_sm_version(target.arch) >= 70: + # TODO(lei): we should consider the dtype of the input a and b + # instead of assuming both a and b share the same dtype. + # As the tensorcore may supports e4m3_float8 * e5m2_float8 in_dtype, out_dtype = get_in_out_dtypes(block_stmt) try: - _ = get_wmma_intrin_group( - in_dtype=in_dtype, + _ = get_mma_intrin_group( + a_dtype=in_dtype, + b_dtype=in_dtype, out_dtype=out_dtype, ) except Exception: - logger.debug("Cannot find the corresponding wmma intrin group") + logger.debug("Cannot find the corresponding mma intrin group") return func, None # reindex and transform functions @@ -651,11 +655,16 @@ def get_propagate_map(trans: bool = True, dtype="float16", matrix_name="A", inde ldmatrix_32x16_to_shared_16x32_layout_a, ldmatrix_32x16_to_shared_16x32_layout_b, ) - assert dtype in ["float16", "int8"], "Only support float16 for now" + assert dtype in [ + "float16", + "int8", + "e4m3_float8", + "e5m2_float8", + ], "Only support float16, int8, e4m3_float8, e5m2_float8" if dtype == "float16": ldmatrix_layout = ldmatrix_32x8_to_shared_16x16_layout ldmatrix_layout_trans = ldmatrix_trans_32x8_to_shared_16x16_layout - elif dtype == "int8": + elif dtype in ["int8", "e4m3_float8", "e5m2_float8"]: # int8 mma only support 32x16 to 16x32 layout if matrix_name == "A" and trans is False: ldmatrix_layout = ldmatrix_32x16_to_shared_16x32_layout_a diff --git a/python/bitblas/gpu/matmul_mma.py b/python/bitblas/gpu/matmul_mma.py index 5043501d5..a20359e11 100644 --- a/python/bitblas/gpu/matmul_mma.py +++ b/python/bitblas/gpu/matmul_mma.py @@ -303,7 +303,8 @@ def store_output(block_outer, write_buffer_idx): intrin_group = get_mma_intrin_group( load_scope="shared.dyn", store_scope="shared.dyn", - in_dtype=str(dtype_a), + a_dtype=str(dtype_a), + b_dtype=str(dtype_b), out_dtype=str(dtype_c), trans_a=is_transpose_a, trans_b=is_transpose_b, @@ -396,7 +397,8 @@ def check_has_dynamic(func: tir.PrimFunc): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, diff --git a/python/bitblas/gpu/matmul_mma_dequantize.py b/python/bitblas/gpu/matmul_mma_dequantize.py index 79a5eeb2f..e4c5a272a 100644 --- a/python/bitblas/gpu/matmul_mma_dequantize.py +++ b/python/bitblas/gpu/matmul_mma_dequantize.py @@ -167,7 +167,8 @@ def check_weight_decode_info(weight_decode_info): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, @@ -654,7 +655,8 @@ def check_weight_decode_info(weight_decode_info): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, @@ -1143,7 +1145,8 @@ def check_weight_decode_info(weight_decode_info): intrin_group = get_mma_intrin_group( load_scope=shared_scope, store_scope=shared_scope if cache_write_required else "global", - in_dtype=intrin_info.in_dtype, + a_dtype=intrin_info.in_dtype, + b_dtype=intrin_info.in_dtype, out_dtype=intrin_info.out_dtype, trans_a=intrin_info.trans_a, trans_b=intrin_info.trans_b, diff --git a/python/bitblas/ops/general_matmul.py b/python/bitblas/ops/general_matmul.py index aed148df8..4a48fb901 100644 --- a/python/bitblas/ops/general_matmul.py +++ b/python/bitblas/ops/general_matmul.py @@ -23,6 +23,24 @@ WORKSPACE_SIZE = 1024 * 1024 * 256 +# TODO(lei): This should be improved into a general +# Method to get the consistent compute patterns. +NATIVE_COMPUTE_PATTERNS = [ + # A_dtype, W_dtype + ("float64", "float64"), + ("float32", "float32"), + ("float16", "float16"), + ("int8", "int8"), + ("e4m3_float8", "e4m3_float8"), + ("e4m3_float8", "e5m2_float8"), + ("e5m2_float8", "e4m3_float8"), + ("e5m2_float8", "e5m2_float8"), +] + + +def is_native_compute(A_dtype, W_dtype) -> bool: + return (A_dtype, W_dtype) in NATIVE_COMPUTE_PATTERNS + class OPExecutorCPU: @@ -150,8 +168,15 @@ def __post_init__(self): if self.with_zeros is None: object.__setattr__(self, "with_zeros", False) - if self.A_dtype == self.W_dtype and self.W_dtype in ["float16", "int8"]: + if self.A_dtype == self.W_dtype and self.W_dtype in [ + "float16", "int8", "e4m3_float8", "e5m2_float8" + ]: object.__setattr__(self, "storage_dtype", self.W_dtype) + # TODO(lei): This is a limitation arose by pytorch + # Should be removed in the future. + if self.A_dtype in ["e4m3_float8", "e5m2_float8"]: + object.__setattr__(self, "propagate_a", TransformKind.NonTransform) + object.__setattr__(self, "propagate_b", TransformKind.NonTransform) class Matmul(Operator): @@ -176,6 +201,8 @@ class Matmul(Operator): "nf4": ("nf", 4), "fp8_e5m2": ("fp", 8), "fp4_e2m1": ("fp", 4), + "e4m3_float8": ("fp", 8), # "e4m3_float8" is a trick for "float8_e4m3fn" + "e5m2_float8": ("fp", 8), } def __init__( @@ -316,7 +343,7 @@ def _build_default_module(self, target: Target): self._build_runtime_module(target) def _select_implementation(self): - if self.A_dtype == self.W_dtype: + if is_native_compute(self.A_dtype, self.W_dtype): return consistent_implementation( M=self.M, N=self.N, @@ -446,8 +473,9 @@ def forward(self, A, W, scale=None, zeros=None, bias=None, output=None) -> Any: args.append(bias) args.append(output) - m = reduce(operator.mul, A.shape[:-1], 1) - args.append(m) + if self.dynamic_range is not None: + m = reduce(operator.mul, A.shape[:-1], 1) + args.append(m) if self.lib is None: self._forward_from_torch_func(*args) diff --git a/python/bitblas/ops/impl/ladder_permutate_impl.py b/python/bitblas/ops/impl/ladder_permutate_impl.py index 5ac4a5334..8086bf584 100644 --- a/python/bitblas/ops/impl/ladder_permutate_impl.py +++ b/python/bitblas/ops/impl/ladder_permutate_impl.py @@ -9,7 +9,7 @@ def select_implementation( M: int, N: int, - datatype: Literal["float16", "int8"] = "float16", + datatype: Literal["float16", "int8", "e4m3_float8", "e5m2_float8"] = "float16", dequantize_bits: int = -1, storage_dtype: Literal["float16", "int8", "uint8", "int32", "uint32"] = "float16", propagate_kind: Literal["A", "B"] = "B", @@ -23,7 +23,7 @@ def select_implementation( # This is trick to get the basic tile size for the current datatype # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 l = r = 16 # noqa: E741 - if datatype == "int8": + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 intra_index_map, _ = get_propagate_map( transpose_matrix, dtype=datatype, matrix_name=propagate_kind) diff --git a/python/bitblas/ops/impl/matmul_dequantize_impl.py b/python/bitblas/ops/impl/matmul_dequantize_impl.py index 28e9ae42b..b7c7c64ee 100644 --- a/python/bitblas/ops/impl/matmul_dequantize_impl.py +++ b/python/bitblas/ops/impl/matmul_dequantize_impl.py @@ -97,8 +97,6 @@ def decode_func(n, k): else: raise ValueError("Unsupported source_format: {}".format(source_format)) - - if not with_scaling: return w @@ -187,7 +185,7 @@ def matmul_nt_dequantize_b_propagate_b( M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inverse_indexmap = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") @@ -358,7 +356,7 @@ def matmul_nt_dequantize_b_propagate_a_propagate_b( M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) diff --git a/python/bitblas/ops/impl/matmul_impl.py b/python/bitblas/ops/impl/matmul_impl.py index 26b748a88..69b426354 100644 --- a/python/bitblas/ops/impl/matmul_impl.py +++ b/python/bitblas/ops/impl/matmul_impl.py @@ -111,7 +111,7 @@ def matmul_nt_propagate_a( if not isinstance(M, int): M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=False, dtype=in_dtype, matrix_name="A") @@ -171,7 +171,7 @@ def matmul_nt_propagate_b( if not isinstance(M, int): M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 _, inversed_index_map = get_propagate_map(trans=True, dtype=in_dtype, matrix_name="B") @@ -232,7 +232,7 @@ def matmul_nt_propagate_a_propagate_b( if not isinstance(M, int): M = tvm.te.var("m") l = r = 16 # noqa: E741 - if in_dtype == "int8": + if in_dtype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 A = te.placeholder((M // l, K // r, l, r), name="A", dtype=in_dtype) diff --git a/python/bitblas/ops/impl/param_permutate_impl.py b/python/bitblas/ops/impl/param_permutate_impl.py index 620212eb6..4ecb17709 100644 --- a/python/bitblas/ops/impl/param_permutate_impl.py +++ b/python/bitblas/ops/impl/param_permutate_impl.py @@ -24,7 +24,7 @@ def select_implementation( # This is trick to get the basic tile size for the current datatype # as for nvidia tensorcore instruction, the basic tile size is 16x16/16x32 for float16/int8 l = r = 16 # noqa: E741 - if datatype == "int8": + if datatype in ["int8", "e4m3_float8", "e5m2_float8"]: l, r = 16, 32 # noqa: E741 if group_size == -1: group_size = N diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index 0290e2e28..224726b6d 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -220,15 +220,27 @@ def var_warpper(v): else: raise RuntimeError("Not supported type: ", type(v)) + def map_numpy_type(intype): + typemap = { + 'e4m3_float8': 'float8_e4m3fn', + 'e5m2_float8': 'float8_e5m2', + } + if intype in typemap: + return typemap[intype] + else: + return intype + profile_tensors = [] for param in func.params: if param not in func.buffer_map: # in case of dynamic symbolic may in params continue arg = func.buffer_map[param] + numpy_dtype = map_numpy_type(arg.dtype) profile_tensors.append( tvm.nd.array( - np.random.uniform(0, 1, [var_warpper(i) for i in arg.shape]).astype(arg.dtype), + np.random.uniform(0, 1, + [var_warpper(i) for i in arg.shape]).astype(numpy_dtype), device=device, )) self.profile_tensors = profile_tensors diff --git a/python/bitblas/relax/transform/weight_only_propagate.py b/python/bitblas/relax/transform/weight_only_propagate.py index 8240a0fd8..709e02085 100644 --- a/python/bitblas/relax/transform/weight_only_propagate.py +++ b/python/bitblas/relax/transform/weight_only_propagate.py @@ -130,7 +130,8 @@ def transform_matmul(self, g_var: GlobalVar, func: tir.PrimFunc, intrin_info): intrin_group = get_mma_intrin_group( load_scope="shared", store_scope="shared", - in_dtype=intrin_info["in_dtype"], + a_dtype=intrin_info["in_dtype"], + b_dtype=intrin_info["in_dtype"], out_dtype=intrin_info["out_dtype"], trans_a=False, trans_b=intrin_info["trans_b"], diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py index f97405716..fa2e7405c 100644 --- a/python/bitblas/wrapper/general.py +++ b/python/bitblas/wrapper/general.py @@ -21,6 +21,8 @@ "float32": "float", "float16": "half", "bfloat16": "__nv_bfloat162", + "e4m3_float8": "__nv_fp8_e4m3", + "e5m2_float8": "__nv_fp8_e5m2", "float64": "double", "int64": "int64_t", "int32": "int", @@ -247,7 +249,7 @@ def update_lib_code(self, code: str): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=0", "type": "cudaStream_t"},) + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -285,8 +287,8 @@ def legalize_c(p): # Determine the shared memory size, defaulting to 0 if not specified smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf # Format the CUDA kernel launch string - call_str = "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, - smem_str, call_args) + call_str = "{}<<<{}, {}, {}, 0>>>({});".format(function_name, grid_str, block_str, smem_str, + call_args) # Create the host function wrapper for the CUDA kernel host_func = """ extern "C" void call({}) {{ @@ -352,7 +354,7 @@ def create_dispatch_func(self, code, function_informations): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) - function_args.append({"name": "stream=0", "type": "cudaStream_t"},) + function_args.append({"name": "stream=cudaStreamDefault", "type": "cudaStream_t"},) # Format the argument definitions for function declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) diff --git a/testing/python/operators/test_general_matmul_fp8.py b/testing/python/operators/test_general_matmul_fp8.py new file mode 100644 index 000000000..720e8fb32 --- /dev/null +++ b/testing/python/operators/test_general_matmul_fp8.py @@ -0,0 +1,77 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pytest +import bitblas +from bitblas import MatmulConfig, Matmul +import logging +from bitblas import set_log_level + +set_log_level(logging.DEBUG) + + +@pytest.mark.parametrize( + "M,N,K,A_dtype,W_dtype,accum_dtype,out_dtype,layout,with_bias,group_size,with_scaling,with_zeros,zeros_mode", + [ + (1, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, + None), + (1024, 1024, 1024, "e4m3_float8", "e4m3_float8", "float32", "float32", "nt", None, None, None, None, + None), + (1, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, + None), + (1024, 1024, 1024, "e5m2_float8", "e5m2_float8", "float32", "float32", "nt", None, None, None, None, + None), + ], +) +def test_matmul_torch_forward(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + import torch + torch.random.manual_seed(0) + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + ) + matmul = Matmul(config=matmul_config, enable_tuning=True) + + input_shape = (M, K) + weight_shape = (N, K) if layout == "nt" else (K, N) + + def map_torch_type(intype): + + typemap = { + 'e4m3_float8': torch.float8_e4m3fn, + 'e5m2_float8': torch.float8_e5m2, + } + if intype in typemap: + return typemap[intype] + else: + return getattr(torch, intype) + + numpytype_a = map_torch_type(A_dtype) + numpytype_b = map_torch_type(W_dtype) + numpytype_c = map_torch_type(out_dtype) + + torch_a = torch.rand(M*K).uniform_(-5, 5).reshape(input_shape).type(numpytype_a).cuda() + torch_b = torch.rand(N*K).uniform_(-5, 5).reshape(weight_shape).type(numpytype_b).cuda() + ref_out = torch.matmul(torch_a.to(torch.float32), torch_b.t().to(torch.float32)) if layout == "nt" else torch.matmul(torch_a.to(torch.float32), torch_b.to(torch.float32)) + ref_out = ref_out.to(numpytype_c) + + print("torch_ref_out", ref_out) + new_torch_b = matmul.transform_weight(torch_b) + bitblas_out = matmul(torch_a, new_torch_b) + print("bitblas_out", bitblas_out) + +# fmt: on +if __name__ == "__main__": + bitblas.testing.main()