diff --git a/3rdparty/tvm b/3rdparty/tvm index d0c06c764..1fa647dbf 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit d0c06c7641956a3bd9ab1174ed05a1aa2a624d2a +Subproject commit 1fa647dbff6a273cbdf2a6f0a64b3478ba553223 diff --git a/bitblas/base/utils.py b/bitblas/base/utils.py index 90fab86d0..2b887ba2d 100644 --- a/bitblas/base/utils.py +++ b/bitblas/base/utils.py @@ -193,7 +193,7 @@ def _apply_schedule(f, c): sch = None return sch - with ThreadPoolExecutor(max_workers=4) as scheduler: + with ThreadPoolExecutor(max_workers=max_workers) as scheduler: futures = {scheduler.submit(_apply_schedule, func, config) for config in configs} for future in as_completed(futures, timeout=timeout): _sched.append(future.result()) diff --git a/bitblas/ops/base_scheduler.py b/bitblas/ops/base_scheduler.py index 72a52937b..72ee1b29c 100644 --- a/bitblas/ops/base_scheduler.py +++ b/bitblas/ops/base_scheduler.py @@ -4,6 +4,7 @@ from dataclasses import dataclass, field from tvm.tir.transform import Simplify from abc import ABC, abstractmethod +from bitblas.base.arch import TileDevice @dataclass @@ -20,6 +21,10 @@ def Simplify(stmt: Union[PrimFunc, IRModule]): else: raise ValueError(f"Unsupported type: {type(stmt)}") + def get_hardware_aware_configs(self, arch: TileDevice = None): + raise NotImplementedError( + f"{self.__class__.__name__} does not support hardware-aware tuning for {arch}") + def activate_simplify(self): self._enable_simplify = True return self diff --git a/bitblas/ops/general_matmul/tilelang/dense/__init__.py b/bitblas/ops/general_matmul/tilelang/dense/__init__.py index 2a929355c..9ab9b6990 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/__init__.py +++ b/bitblas/ops/general_matmul/tilelang/dense/__init__.py @@ -1,13 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .matmul import ( +from .matmul_simt import ( + MatmulFineGrainSIMTScheduler, # noqa: F401 +) + +from .matmul_tensorcore import ( matmul_blocked, # noqa: F401 matmul_macro_tensorcore, # noqa: F401 matmul_macro_tensorcore_weight_propagation_level_ldmatrix # noqa: F401 ) -from .matmul import ( +from .matmul_tensorcore import ( MatmulScheduler, # noqa: F401 MatmulFineGrainScheduler, # noqa: F401 MatmulWeightPropagationScheduler, # noqa: F401 diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py new file mode 100644 index 000000000..bc091f910 --- /dev/null +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_simt.py @@ -0,0 +1,62 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from bitblas import tvm as tvm +from tvm import DataType +import tvm.tl.language as T +from typing import Optional +from bitblas.tl.utils import ( + get_mma_micro_size, + make_swizzle_layout, +) + +from bitblas.ops.base_scheduler import BaseScheduler + +from dataclasses import dataclass + + +@dataclass +class MatmulFineGrainSIMTScheduler(BaseScheduler): + # Fine-grained matrix multiplication scheduler + # Allows for more detailed configuration. + + # Operation Configuration + M: Optional[int] = None + N: Optional[int] = None + K: Optional[int] = None + in_dtype: str = "float16" + out_dtype: str = "float16" + trans_A: bool = False + trans_B: bool = True + accum_dtype: str = "float16" + + # Tensor Core Warp Configuration + block_row_warps: int = 2 + block_col_warps: int = 2 + warp_row_tiles: int = 32 + warp_col_tiles: int = 32 + chunk: int = 32 # Usually determines the K-dimension split size + + # Tiling and Other Optimization Parameters + num_stages: int = 2 + enable_rasterization: bool = False + + def with_default_config(self): + raise NotImplementedError + + def apply_config( + self, + ): + + # M, N, K = self.M, self.N, self.K + # trans_A, trans_B = self.trans_A, self.trans_B + # in_dtype, out_dtype, accum_dtype = self.in_dtype, self.out_dtype, self.accum_dtype + + raise NotImplementedError + + + def __post_init__(self): + # Validate the matrix transpose settings + assert self.trans_A is False, "Currently only support Matrix A not transposed" + assert self.trans_B is True, "Currently only support Matrix B transposed" + + return diff --git a/bitblas/ops/general_matmul/tilelang/dense/matmul.py b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py similarity index 97% rename from bitblas/ops/general_matmul/tilelang/dense/matmul.py rename to bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py index 1c28ff695..35a200527 100644 --- a/bitblas/ops/general_matmul/tilelang/dense/matmul.py +++ b/bitblas/ops/general_matmul/tilelang/dense/matmul_tensorcore.py @@ -1,5 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import itertools from bitblas import tvm as tvm from tvm import DataType import tvm.tl.language as T @@ -15,7 +16,7 @@ ) from bitblas.ops.common import TransformKind from bitblas.ops.base_scheduler import BaseScheduler - +from bitblas.base.arch import CUDA from dataclasses import dataclass @@ -40,6 +41,22 @@ class MatmulScheduler(BaseScheduler): threads: int = 128 enable_rasterization: bool = False # Enhance L2 Locality + def get_configs_sm80(self): + num_stages = 2 + configs = [ + {'block_M': 128, 'block_N': 256, 'block_K': 32, 'threads': 128}, + {'block_M': 256, 'block_N': 128, 'block_K': 32, 'threads': 128}, + {'block_M': 128, 'block_N': 128, 'block_K': 32, 'threads': 128}, + ] + configs = [{**c, 'num_stages': num_stages} for c in configs] + return configs + + def get_hardware_aware_configs(self, arch: CUDA = None): + # TODO(lei): implement only for SM80 Currently + sm_version: int = int(arch.sm_partition) + assert sm_version is not None, "Please provide a valid CUDA Arch" + return self.get_configs_sm80() + def with_default_config(self): block_M = getattr(self, "block_M", 64) block_N = getattr(self, "block_N", 64) diff --git a/bitblas/ops/operator.py b/bitblas/ops/operator.py index b723eabf8..eb173352f 100644 --- a/bitblas/ops/operator.py +++ b/bitblas/ops/operator.py @@ -10,9 +10,10 @@ from tvm.contrib.dlpack import to_pytorch_func import bitblas import ctypes -from typing import (List, Dict, Any, Optional, Tuple, Literal, Callable) +from typing import List, Dict, Any, Optional, Tuple, Literal, Callable import numpy as np from bitblas.base import fast_tune, fast_tune_with_dynamic_range +from bitblas.tl.tuner import apply_and_build as tl_apply_and_build from copy import deepcopy from bitblas.ops.base_scheduler import BaseScheduler from bitblas.base.arch import get_arch, TileDevice @@ -38,6 +39,7 @@ @dataclass(frozen=True) class OperatorConfig: """Base class for operator configurations. Used for typing.""" + pass @@ -55,7 +57,7 @@ def is_valid_config(self, config: OperatorConfig): @abstractmethod def generate(self, hint: Hint = None) -> str: - '''Generate the kernel name based on the config and hint''' + """Generate the kernel name based on the config and hint""" pass @@ -73,18 +75,20 @@ def generate(self, hint: Hint = None) -> str: return self.DEFAULT_PREFIX def is_valid_config(self, config: OperatorConfig) -> bool: - # hint is not used + # config is not used assert config is not None return True class Operator(object): - def __init__(self, - name, - config: OperatorConfig, - target: Target = None, - backend: Literal["tir", "tl"] = "tir"): + def __init__( + self, + name, + config: OperatorConfig, + target: Target = None, + backend: Literal["tir", "tl"] = "tir", + ): if isinstance(target, str): target = Target(target) self.name = name @@ -169,7 +173,7 @@ def tvm_callback_cuda_postproc(code, _): config={ "tir.use_async_copy": True, "tir.disable_cse_tir": True, - **(self.pass_context if self.pass_context else {}) + **(self.pass_context if self.pass_context else {}), }): if self.is_tir_backend(): rt_mod = tvm.build(self.scheduled_ir_module, target=target) @@ -183,9 +187,12 @@ def tvm_callback_cuda_postproc(code, _): raise ValueError(f"Unsupported backend: {self.backend}") except Exception: # noqa: F841 logger.debug( - BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format(self.__class__.__name__, target, - "optimized", - "Failed to build optimized module")) + BUILD_RUNTIME_LIBRARY_FAILED_MESSAGE.format( + self.__class__.__name__, + target, + "optimized", + "Failed to build optimized module", + )) else: # For non-CUDA platforms or when no optimized function is available, build with the primary function rt_mod = tvm.build(self.prim_func, target=target, name=self.name) @@ -248,10 +255,12 @@ def _build_default_module(self, target: Target): scheduled_mod = self.apply_default_schedule(self.ir_module, target) elif self.is_tilelang_backend(): scheduled_mod = self.scheduler_with_default(self.scheduler) - assert len(scheduled_mod.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") - assert "main" in scheduled_mod, ( - "The optimized module should have a function named 'main' for default schedule.") + assert ( + len(scheduled_mod.get_global_vars()) == 1 + ), "The optimized module should only have one global variable for default schedule." + assert ( + "main" in scheduled_mod + ), "The optimized module should have a function named 'main' for default schedule." default_kernal_name = self.kernel_name_generator.generate() func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) @@ -267,54 +276,77 @@ def _build_default_module(self, target: Target): def post_process(self, code: str) -> str: return code - def apply_fast_tuning(self, - func: PrimFunc, - target: Target, - topk: int = 20, - parallel_build=True) -> Tuple[IRModule, Hint]: - _, best = fast_tune(func, target, topk=topk, parallel_build=parallel_build) - # annotate the best pass context - # TODO(lei): actually we should remove this by enable pass through - # annotation in the func's attribute. - self.pass_context = best.config.pass_context - return ((best.sch.mod, best.config) if best is not None else (None, None)) + def get_tl_tuning_config(self): + assert self.is_tilelang_backend(), "Only support tilelang backend" + return self.scheduler.get_hardware_aware_configs(self.arch) + + def apply_fast_tuning( + self, + func_or_scheduler: PrimFunc, + target: Target, + topk: int = 20, + parallel_build=True, + ) -> Tuple[IRModule, Hint]: + if self.is_tir_backend(): + _, best = fast_tune(func_or_scheduler, target, topk=topk, parallel_build=parallel_build) + # annotate the best pass context + # TODO(lei): actually we should remove this by enable pass through + # annotation in the func's attribute. + self.pass_context = best.config.pass_context + return (best.sch.mod, best.config) if best is not None else (None, None) + elif self.is_tilelang_backend(): + # Finetune the schedule + tuning_configs = self.get_tl_tuning_config() + _, best = tl_apply_and_build( + func_or_scheduler, tuning_configs, arch=self.arch, parallel_build=False) + # Return the best Config as Hint + return (best.sch.mod, best.config) if best is not None else (None, None) def apply_fast_tuning_with_dynamic_range( self, - func: PrimFunc, + func_or_scheduler: PrimFunc, target: Target, topk: int = 20, dynamic_range: Dict[str, List[int]] = None, ): scheduled_ir_module = fast_tune_with_dynamic_range( - func, + func_or_scheduler, target, topk=topk, parallel_build=True, dynamic_range=dynamic_range, - kernel_name_generator=self.kernel_name_generator) + kernel_name_generator=self.kernel_name_generator, + ) if scheduled_ir_module is not None: return scheduled_ir_module return None - def hardware_aware_finetune(self, - topk: int = 20, - target: Optional[tvm.target.Target] = None, - parallel_build=True): + def hardware_aware_finetune( + self, + topk: int = 20, + target: Optional[tvm.target.Target] = None, + parallel_build=True, + ): if target is None: target = self.target dynamic_range = self.dynamic_range - func = self.prim_func if dynamic_range is not None: - self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( - func, target, topk, dynamic_range) + if self.is_tir_backend(): + func = self.prim_func + self.scheduled_ir_module = self.apply_fast_tuning_with_dynamic_range( + func, target, topk, dynamic_range) + elif self.is_tilelang_backend(): + raise NotImplementedError("Not support dynamic range for tilelang backend") else: + func_or_scheduler = (self.prim_func if self.is_tir_backend() else self.scheduler) scheduled_mod, best_hint = self.apply_fast_tuning( - func, target, topk, parallel_build=parallel_build) - assert len(scheduled_mod.get_global_vars()) == 1, ( - "The optimized module should only have one global variable for default schedule.") - assert "main" in scheduled_mod, ( - "The optimized module should have a function named 'main' for default schedule.") + func_or_scheduler, target, topk, parallel_build=parallel_build) + assert ( + len(scheduled_mod.get_global_vars()) == 1 + ), "The optimized module should only have one global variable for default schedule." + assert ( + "main" in scheduled_mod + ), "The optimized module should have a function named 'main' for default schedule." default_kernal_name = self.kernel_name_generator.generate(best_hint) func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name) scheduled_ir_module = tvm.IRModule({default_kernal_name: func}) @@ -341,8 +373,9 @@ def var_warpper(v): for i in func.attrs["opt_shapes"][v.name]: avg_shape += i.value avg_shape = avg_shape // len(func.attrs["opt_shapes"][v.name]) - _info_message = f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, "\ - f"use average shape {avg_shape}" + _info_message = ( + f"Doesn't provide dynamic symbolic constrains for {v.name} when do benchmarking, " + f"use average shape {avg_shape}") logger.info(_info_message) return avg_shape else: diff --git a/bitblas/tl/tuner.py b/bitblas/tl/tuner.py new file mode 100644 index 000000000..8f9ab4f84 --- /dev/null +++ b/bitblas/tl/tuner.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from bitblas import tvm +import os +from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind +from concurrent.futures import ThreadPoolExecutor, as_completed +import numpy as np +from typing import List, Tuple, Optional, Dict, Union, Literal, Callable +from tvm import tir, IRModule +from tvm.runtime import Module +from tvm.tir import Schedule +from tvm.relax.expr import Function +import tvm.tl as tl +import bitblas +from bitblas.ops.base_scheduler import BaseScheduler +from bitblas.base.arch import CUDA +from bitblas.base import Hint +from bitblas.base.utils import get_dummy_input_arrays +from bitblas.base.roller.policy import TensorCorePolicy, DefaultPolicy +from bitblas.gpu.matmul_analysis import get_tensorized_func_and_tags +import tempfile +import itertools +from tvm.ir.supply import GlobalVarSupply +from bitblas.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2 +from bitblas.utils.tensor_adapter import ( + np_float2np_bf16,) +import logging + +logger = logging.getLogger(__name__) + + +def get_rasterization_code(pannel_width: int = 8) -> str: + return f""" + const int MAX_BLOCK_N = {pannel_width}; + const auto baseBlockIdx = blockIdx.x + gridDim.x *blockIdx.y; + const auto totalPanel = (gridDim.x * gridDim.y +MAX_BLOCK_N * gridDim.x - 1) / (MAX_BLOCK_N * gridDim.x); + const auto totalBlock = gridDim.x * gridDim.y; + const auto panelIdx = baseBlockIdx / (MAX_BLOCK_N *gridDim.x); + const auto strideLd = panelIdx + 1 < totalPanel ?MAX_BLOCK_N : (totalBlock - panelIdx * (MAX_BLOCK_N *gridDim.x)) / gridDim.x; + const auto bx = (panelIdx & 1) ? gridDim.x -(baseBlockIdx - panelIdx * MAX_BLOCK_N * gridDim.x) /strideLd - 1 : (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) / strideLd; + const auto by = (baseBlockIdx - panelIdx * MAX_BLOCK_N *gridDim.x) % strideLd + panelIdx * MAX_BLOCK_N; + const auto bz = blockIdx.z; + const dim3 blockIdx(bx, by, bz); + """ + + +class CompileResult: + """ + Class to store the result of compilation + """ + + def __init__(self, config, sch, mod: Module): + self.config = config + self.sch = sch + self.mod = mod + self.code = mod.imported_modules[0].get_source() if mod else None + self.latency = 1e9 + self.time_evaluator = None + + def profile(self, data_distribution="uniform"): + func = self.sch.mod["main"] + device = self.config.arch.device + profile_tensors = get_dummy_input_arrays(func, device, distribution=data_distribution) + latency = self.time_evaluator(*profile_tensors).mean * 1e3 + return latency + + +def _apply_config( + scheduler: BaseScheduler, + config: Dict = None, +) -> Optional[IRModule]: + """ + find rules: + case 1. if the main block has no reduce op, then use the Elementwise rule. + case 2. if the config enabled tensorcore, then use the TensorCore rule. + case 3. if any([t > 1 for t in config.reduce_thread]), we should use the InnerThread Reduction Rule. + case 4. else we should use general reduction rule. + """ + logger.debug("Scheduler Apply config {}".format(config)) + scheduled_func = scheduler.apply_config(**config) + if scheduled_func is None: + return None + else: + return tvm.IRModule.from_expr(scheduled_func) + + +def apply_and_build_parallel(scheduler, + configs, + arch, + num_repeats=3, + max_workers=10, + timeout=30, + data_distribution="uniform") -> CompileResult: + cpresults = [] + + max_workers = min(len(configs), os.cpu_count(), max_workers) + + # apply config in thread parallel + _scheduled_ir_modules: List[Schedule] = [] + + def _submit_config(f, c): + try: + scheduled_ir_module = _apply_config(f, c) + except Exception as apply_schedule_error: + logger.debug("Apply schedule failed: {}".format(apply_schedule_error)) + scheduled_ir_module = None + return scheduled_ir_module + + with ThreadPoolExecutor(max_workers=max_workers) as _scheduler: + futures = {_scheduler.submit(_submit_config, scheduler, config) for config in configs} + for future in as_completed(futures, timeout=timeout): + _scheduled_ir_modules.append(future.result()) + + builder = PopenPoolExecutor(max_workers=max_workers, timeout=timeout) + + # build in process parallel + def _build(context) -> str: + idx, mod, arch = context + if mod is None: + return idx, None, None + + config = configs[idx] + + @tvm.register_func(func_name="tvm_callback_cuda_postproc", override=True) + def tvm_callback_cuda_postproc(code, _): + code = tensor_replace_dp4a(code) + code = tensor_remove_make_int4(code) + code = tensor_remove_make_int2(code) + return code + # check only have one function in the module + if len(mod.functions) > 1: + raise ValueError("Only support one function in the module") + tl_prim_func = list(mod.functions.values())[0] + with tvm.transform.PassContext(config={ + "tir.use_async_copy": True, + "tir.disable_cse_tir": True, + }): + rt_mod = tl.lower(tl_prim_func, arch.target, runtime_only=True) + + from tvm.contrib.tar import tar # pylint: disable=import-outside-toplevel + + artifact_path = os.path.join(tempfile.mkdtemp(), "tvm_tmp_mod." + tar.output_format) + code = rt_mod.imported_modules[0].get_source() + rt_mod.export_library(artifact_path, fcompile=tar) + return idx, code, artifact_path + + _mods = [mod for mod in _scheduled_ir_modules] + + for map_result in builder.map_with_error_catching( + _build, + [(i, mod, arch) for i, mod in enumerate(_mods)], + ): + if map_result.status == StatusKind.TIMEOUT: + logger.debug("LocalBuilder: Timeout") + elif map_result.status == StatusKind.EXCEPTION: + # TODO(lei): redirect the exception to file if needed + logger.debug("LocalBuilder: An exception occurred {}".format(map_result.value)) + continue + elif map_result.status == StatusKind.COMPLETE: + idx, code, artifact_path = map_result.value + ir_module = _scheduled_ir_modules[idx] + sch = tvm.tir.Schedule(ir_module) + config = configs[idx] + if artifact_path is None: + ARTIFACT_NOT_FOUND = f"Apply config {config} failed, artifact path is None" + logger.debug(ARTIFACT_NOT_FOUND) + continue + rt_mod = tvm.runtime.load_module(artifact_path) + # Transform Tuning Config to Hint + hint = Hint.from_dict( + { + **{"arch": arch}, + **config, + } + ) + cpresult = CompileResult(hint, sch, rt_mod) + timer_cuda_mod = rt_mod.time_evaluator( + rt_mod.entry_name, arch.device, number=num_repeats) + cpresult.time_evaluator = timer_cuda_mod + cpresult.code = code + cpresults.append(cpresult) + else: + raise ValueError(f"Unreachable: unexpected result: {map_result}") + + del builder + + best = None + best_latency = 1e9 + for cpresult in cpresults: + config = cpresult.config + try: + latency = cpresult.profile(data_distribution=data_distribution) + except Exception as e_mesg: + logger.debug(f"Evaluation with config failed {e_mesg}") + continue + logger.info("Evaluation with config {}".format(config)) + logger.info("Time cost of this config: {:.3f} ms".format(latency)) + + cpresult.latency = latency + if latency < best_latency: + best_latency = latency + best = cpresult + + return cpresults, best + + +def apply_and_build( + scheduler, + configs, + arch, + parallel_build=False, + data_distribution="uniform", +) -> Tuple[List[CompileResult], CompileResult]: + max_workers = 10 if parallel_build else 1 + return apply_and_build_parallel( + scheduler, configs, arch, max_workers=max_workers, data_distribution=data_distribution) + + +def fast_tune( + func: tir.PrimFunc, + target: tvm.target.Target, + topk: int = 10, + parallel_build: bool = True, + data_distribution: Literal["uniform", "onefill"] = "uniform", +): + # check the function is a primfunc + if not isinstance(func, tir.PrimFunc): + raise ValueError("Only support func is PrimFunc") # pragma: no cover + + if target.kind.name != "cuda": + logger.error("Only support CUDA target") + return None, None + + specilized_func = func + if func.attrs is not None and "opt_shapes" in func.attrs: + opt_shapes = func.attrs["opt_shapes"] + # should be int value + if not all([isinstance(v.value, int) for v in opt_shapes.values()]): + logger.error("The opt_shapes should be int value") + return None, None + # currently only support one dynamic range + if len(opt_shapes) > 1: + logger.error("Currently only support one dynamic range") + return None, None + + for buffer in func.buffer_map.values(): + for axis in buffer.shape: + if isinstance(axis, tvm.tir.Var) and axis.name not in opt_shapes: + raise NotImplementedError( + "Currently do not support fast tune with none-dynamic range set") + if opt_shapes: + for name, shape in opt_shapes.items(): + var = find_var_from_func(func, name) + specilized_func = func.specialize({ + var: shape.astype(var.dtype) + }).with_attr("is_specialized") + + arch = CUDA(target) + + policy = DefaultPolicy(func=func, arch=arch) + try: + specilized_func, tags = get_tensorized_func_and_tags(specilized_func, arch.target) + except Exception as e_msg: + logger.debug("Get tensorized func and tags failed: ", e_msg) + tags = None + if tags: + policy = TensorCorePolicy(func=specilized_func, arch=arch, tags=tags) + + configs = policy.emit_config(topk) + + if len(configs) == 0: + raise ValueError("No valid config generated") + + cpresults, best = apply_and_build( + func, + configs, + arch, + parallel_build=parallel_build, + data_distribution=data_distribution, + ) + + return cpresults, best + diff --git a/testing/python/operators/test_general_matmul_ops_backend_tl.py b/testing/python/operators/test_general_matmul_ops_backend_tl.py index 90ed00c6e..eccb8ebb3 100644 --- a/testing/python/operators/test_general_matmul_ops_backend_tl.py +++ b/testing/python/operators/test_general_matmul_ops_backend_tl.py @@ -38,11 +38,47 @@ def matmul_codegen_default(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, la assert get_codegen_result(matmul) +def matmul_finetune(M, N, K, A_dtype, W_dtype, accum_dtype, out_dtype, layout, with_bias, + group_size, with_scaling, with_zeros, zeros_mode): + + matmul_config = MatmulConfig( + M=M, + N=N, + K=K, + A_dtype=A_dtype, + W_dtype=W_dtype, + accum_dtype=accum_dtype, + out_dtype=out_dtype, + layout=layout, + with_bias=with_bias, + group_size=group_size, + with_scaling=with_scaling, + with_zeros=with_zeros, + zeros_mode=zeros_mode, + propagate_a=False, + propagate_b=False, + ) + matmul = Matmul(config=matmul_config, enable_tuning=False, backend="tl") + matmul.hardware_aware_finetune(topk=10) + assert get_codegen_result(matmul) + + def test_matmul_codegen_default(): matmul_codegen_default(1, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), matmul_codegen_default(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, False, False, None), + # FP32 Accum + matmul_codegen_default(768, 768, 768, "float16", "float16", "float32", "float16", "nt", False, + -1, False, False, None), + # INT32 Accum + matmul_codegen_default(768, 768, 768, "int8", "int8", "int32", "int8", "nt", False, -1, False, + False, None), + + +def test_matmul_finetune(): + matmul_finetune(768, 768, 768, "float16", "float16", "float16", "float16", "nt", False, -1, + False, False, None), # fmt: on