From 4183ae1786ad5cde630153690a77eacf77855e3e Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Wed, 18 Sep 2024 06:27:25 +0000 Subject: [PATCH] fix for int8 gemm --- bitblas/gpu/matmul_analysis.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 4a0ef532f..9ddc4500b 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -561,10 +561,15 @@ def check_sm_version(arch: str) -> int: sm_version = arch.replace("sm_", "") return int(sm_version) if sm_version.isdigit() else -1 - def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool: + def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, + target: Target) -> Union[bool, Dict]: tags: Dict[str, Union[List[int], int]] = {} block_stmt = sch.get(block) + # Nvidia Only Support Tensor Core for + # devices greater than 70. + if check_sm_version(target.arch) < 70: + return False # analysis tensorcore axis # todo(lei): maybe we can remove this in the future (write_buffer_region,) = block_stmt.writes @@ -612,6 +617,11 @@ def check_last_trait(region: List[Range]): in_dtype, out_dtype = get_in_out_dtypes(block_stmt) intrin_info["in_dtype"] = in_dtype intrin_info["out_dtype"] = out_dtype + + if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32": + # INT32 Accum TensorCore only supports SM Version > 32. + return False + # if the last dimension is reduce axis, the B is transposed intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region) if func.attrs is not None and "input_transform_kind" in func.attrs: @@ -666,6 +676,7 @@ def check_last_trait(region: List[Range]): block_stmt = sch.get(main_block) + # 16 for 16 bits tensor core while 32 for 8bits tensorcore. minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32 # the batch dimension is not taken into consideration. extent = block_stmt.iter_vars[1].dom.extent