Skip to content
This repository was archived by the owner on Feb 24, 2026. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion bitblas/gpu/matmul_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,15 @@ def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1

def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV, target: Target) -> bool:
def analysis_tensorcore_tags(sch: tir.Schedule, block: BlockRV,
target: Target) -> Union[bool, Dict]:
tags: Dict[str, Union[List[int], int]] = {}
block_stmt = sch.get(block)

# Nvidia Only Support Tensor Core for
# devices greater than 70.
if check_sm_version(target.arch) < 70:
return False
# analysis tensorcore axis
# todo(lei): maybe we can remove this in the future
(write_buffer_region,) = block_stmt.writes
Expand Down Expand Up @@ -612,6 +617,11 @@ def check_last_trait(region: List[Range]):
in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
intrin_info["in_dtype"] = in_dtype
intrin_info["out_dtype"] = out_dtype

if 70 <= check_sm_version(target.arch) < 80 and out_dtype == "int32":
# INT32 Accum TensorCore only supports SM Version > 32.
return False

# if the last dimension is reduce axis, the B is transposed
intrin_info["trans_b"] = check_last_trait(block_stmt.reads[1].region)
if func.attrs is not None and "input_transform_kind" in func.attrs:
Expand Down Expand Up @@ -666,6 +676,7 @@ def check_last_trait(region: List[Range]):

block_stmt = sch.get(main_block)

# 16 for 16 bits tensor core while 32 for 8bits tensorcore.
minimal_tensorize_threshold = 16 if in_dtype in ["bfloat16", "float16"] else 32
# the batch dimension is not taken into consideration.
extent = block_stmt.iter_vars[1].dom.extent
Expand Down