diff --git a/3rdparty/tvm b/3rdparty/tvm index 07648907e..240802497 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 07648907e1678ec2b84d8ec579b2ec8f4925d218 +Subproject commit 2408024972b9199683491871329829d98b59dc5c diff --git a/bitblas/builder/lib_generator/__init__.py b/bitblas/builder/lib_generator/__init__.py index fd877c679..64eaee9e8 100644 --- a/bitblas/builder/lib_generator/__init__.py +++ b/bitblas/builder/lib_generator/__init__.py @@ -43,7 +43,8 @@ def compile_lib(self, timeout: float = None): "--shared", src.name, "-lcuda", - f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", + "-gencode", + f"arch=compute_{compute_version},code=sm_{compute_version}", "-o", libpath, ] diff --git a/bitblas/gpu/element_wise.py b/bitblas/gpu/element_wise.py index 07ea3a27e..3d67937e8 100644 --- a/bitblas/gpu/element_wise.py +++ b/bitblas/gpu/element_wise.py @@ -8,6 +8,7 @@ from tvm import tir from ..base import ScheduleRule, normalize_prim_func, try_inline +from ..base.analysis import get_coalesced_veclen class ElementWise(ScheduleRule): @@ -39,6 +40,11 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring dom_kind = block.dom_kind() block = block.block_rv + # set vector factors + vec_len = get_coalesced_veclen(sch.get(block)) + vector_factors = [1] * len(block_factors) + vector_factors[-1] = vec_len + if ( any( [ @@ -93,5 +99,10 @@ def apply_config( # pylint: disable=too-many-locals,missing-docstring for i, ax in enumerate(vthread_loops): sch.bind(ax, "vthread" + [".x", ".y", ".z"][i]) - + + # vectorize the last axis + ax = inner_loops[-1] + if sch.get(ax).extent.value > 1: + sch.vectorize(ax) + return sch diff --git a/bitblas/gpu/intrin/lop3.py b/bitblas/gpu/intrin/lop3.py index 184d74211..a6a7011a0 100644 --- a/bitblas/gpu/intrin/lop3.py +++ b/bitblas/gpu/intrin/lop3.py @@ -1495,6 +1495,7 @@ def fast_decode_impl( (1, "int8", "float16", 8, "local", "uint", True, True, "rescale"), (4, "int8", "int8", 8, "local", "uint", False, False, "original"), (4, "int8", "int8", 16, "local", "uint", False, False, "original"), + (4, "int8", "int8", 16, "local", "int", False, False, "original"), (2, "int8", "int8", 16, "local", "uint", False, False, "original"), (2, "int8", "int8", 16, "local", "int", False, False, "original"), (1, "int8", "int8", 16, "local", "uint", False, False, "original"), @@ -1523,6 +1524,7 @@ def fast_decode_impl( (1, "int8", "float16", 8, "warp", "uint", True, True, "rescale"), (4, "int8", "int8", 8, "warp", "uint", False, False, "original"), (4, "int8", "int8", 16, "warp", "uint", False, False, "original"), + (4, "int8", "int8", 16, "warp", "int", False, False, "original"), (2, "int8", "int8", 16, "warp", "uint", False, False, "original"), (2, "int8", "int8", 16, "warp", "int", False, False, "original"), (1, "int8", "int8", 16, "warp", "uint", False, False, "original"), diff --git a/bitblas/gpu/matmul_analysis.py b/bitblas/gpu/matmul_analysis.py index 1d0889fa3..36cba1969 100644 --- a/bitblas/gpu/matmul_analysis.py +++ b/bitblas/gpu/matmul_analysis.py @@ -623,7 +623,8 @@ def check_last_trait(region: List[Range]): # Currently, we only support block reduction depth 2 for small M # When the func is a dequantize like ops, we should consider the M require_block_reduce = False - if hasattr(func.attrs, "dequantize_info"): + # And we only support float16 for now + if hasattr(func.attrs, "dequantize_info") and in_dtype == "float16": for arg in func.params: inp_shape = func.buffer_map[arg].shape M = inp_shape[0] diff --git a/bitblas/ops/general_matmul/__init__.py b/bitblas/ops/general_matmul/__init__.py index 16908dd41..dea4042e1 100644 --- a/bitblas/ops/general_matmul/__init__.py +++ b/bitblas/ops/general_matmul/__init__.py @@ -85,6 +85,8 @@ class MatmulConfig(OperatorConfig): None # propagate_b is a flag to control the ladder permutation ) + # TODO: This is a temporary solution to legalize the dynamic symbolic. + # Maybe we should remove this in the future. # optimize strategy, default is SingleBatchDecodeOnly optimize_stratety: Union[int, OptimizeStrategy] = OptimizeStrategy.SingleBatchDecodeOnly diff --git a/bitblas/wrapper/general.py b/bitblas/wrapper/general.py index aa76f6158..4aaebc64e 100644 --- a/bitblas/wrapper/general.py +++ b/bitblas/wrapper/general.py @@ -160,7 +160,8 @@ def compile_lib(self, timeout: float = None): "--shared", src.name, "-lcuda", - f"-gencode=arch=compute_{compute_version},code=compute_{compute_version}", + "-gencode", + f"arch=compute_{compute_version},code=sm_{compute_version}", "-o", libpath, ] diff --git a/integration/BitNet/eval_correctness.py b/integration/BitNet/eval_correctness.py index 4017a6c17..6bd787535 100644 --- a/integration/BitNet/eval_correctness.py +++ b/integration/BitNet/eval_correctness.py @@ -72,18 +72,19 @@ def get_runtime(num_repeats=1): def main(): model = BitnetForCausalLM.from_pretrained( model_path, - use_flash_attention_2=True, + use_flash_attention_2=False, torch_dtype=torch.float16, ).cuda().half() - with torch.no_grad(): - model._post_process_weights() tokenizer = BitnetTokenizer.from_pretrained(model_path, use_fast=False) input_id = tokenizer("Hello")['input_ids'] input_id = torch.tensor(input_id).unsqueeze(0).cuda() - output = model(input_id) - print(output) + print("original model generated text:") + print(generate_text(model, tokenizer, "Hello", max_length=100)) + + model.quantize() + print("quantized model generated text:") print(generate_text(model, tokenizer, "Hello", max_length=100)) diff --git a/integration/BitNet/maint/create_bitblas_ckpt.py b/integration/BitNet/maint/create_bitblas_ckpt.py index 0bf603e0d..4f0555430 100644 --- a/integration/BitNet/maint/create_bitblas_ckpt.py +++ b/integration/BitNet/maint/create_bitblas_ckpt.py @@ -68,7 +68,7 @@ def main(): model = ( BitnetForCausalLM.from_pretrained( model_name_or_path, - use_flash_attention_2=True, + use_flash_attention_2=False, torch_dtype=torch.float16, ).cuda().half()) tokenizer = BitnetTokenizer.from_pretrained(model_name_or_path, use_fast=False) @@ -80,7 +80,7 @@ def main(): output = model(input_ids) print("original model output:", output) - model.quantize() + model.quantize(fuse_qkv=True, fuse_gateup=True) print("original model generated text:") print(generate_text(model, tokenizer, "Hi, ", max_length=100)) @@ -93,6 +93,8 @@ def main(): print("quant config:") print(quant_config) quant_config["checkpoint_format"] = "bitblas" + quant_config["fuse_qkv"] = True + quant_config["fuse_gateup"] = True # save quant config quant_config_path = os.path.join(saved_model_path, "quantize_config.json") diff --git a/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh index 3ace58031..e265658ac 100755 --- a/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh +++ b/integration/BitNet/maint/generate_bitnet_model_bitblas_format.sh @@ -18,6 +18,9 @@ fi if [ -z "$SAVED_MODEL_DIR" ]; then python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR else + if [ ! -d "$SAVED_MODEL_DIR" ]; then + mkdir -p $SAVED_MODEL_DIR + fi python ./maint/create_bitblas_ckpt.py --model_name_or_path $MODEL_DIR --saved_model_path $SAVED_MODEL_DIR fi diff --git a/integration/BitNet/modeling_bitnet.py b/integration/BitNet/modeling_bitnet.py index e4e1d88ea..22a985ce0 100644 --- a/integration/BitNet/modeling_bitnet.py +++ b/integration/BitNet/modeling_bitnet.py @@ -244,6 +244,49 @@ def forward(self, x): return x +class BitnetMLPFuseGateUp(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_proj = BitLinear( + self.hidden_size, + self.intermediate_size * 2, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.down_proj = BitLinear( + self.intermediate_size, + self.hidden_size, + bias=False, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.act_fn = ACT2FN[config.hidden_act] + self.ffn_layernorm = BitnetRMSNorm(self.intermediate_size, eps=config.rms_norm_eps) + + @classmethod + def from_bit_mlp(cls, bit_mlp: BitnetMLP): + module = cls(bit_mlp.config) + # assign the weights + module.gate_up_proj.weight = nn.Parameter( + torch.cat([bit_mlp.gate_proj.weight, bit_mlp.up_proj.weight], dim=0)) + module.down_proj = bit_mlp.down_proj + module.ffn_layernorm = bit_mlp.ffn_layernorm + return module + + def forward(self, x): + gate_up = self.gate_up_proj(x) + gate, up = torch.chunk(gate_up, chunks=2, dim=-1) + x = self.act_fn(gate) * up + x = self.ffn_layernorm(x) + x = self.down_proj(x) + return x + + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, @@ -394,6 +437,153 @@ def forward( return attn_output, attn_weights, past_key_value +class BitnetAttentionQKVFused(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: BitnetConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class.") + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads}).") + + self.qkv_proj = BitLinear( + self.hidden_size, + self.num_heads * self.head_dim + (self.num_key_value_heads * self.head_dim) * 2, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self.o_proj = BitLinear( + self.hidden_size, + self.hidden_size, + bias=config.attention_bias, + weight_bits=config.weight_bits, + input_bits=config.input_bits, + ) + self._init_rope() + self.inner_attn_ln = BitnetRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + def _init_rope(self): + if self.config.rope_scaling is None: + self.rotary_emb = BitnetRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + raise NotImplementedError + + @classmethod + def from_bit_attention(cls, bit_attention: BitnetAttention): + module = cls(bit_attention.config, bit_attention.layer_idx) + # assign the weights + module.qkv_proj.weight = nn.Parameter( + torch.cat([ + bit_attention.q_proj.weight, bit_attention.k_proj.weight, + bit_attention.v_proj.weight + ], + dim=0)) + if bit_attention.q_proj.bias is not None and bit_attention.k_proj.bias is not None and bit_attention.v_proj.bias is not None: + module.qkv_proj.bias = nn.Parameter( + torch.cat([ + bit_attention.q_proj.bias, bit_attention.k_proj.bias, bit_attention.v_proj.bias + ], + dim=0)) + module.o_proj = bit_attention.o_proj + module.inner_attn_ln = bit_attention.inner_attn_ln + if bit_attention.config.rope_scaling is None: + module.rotary_emb = bit_attention.rotary_emb + return module + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split( + qkv_states, [ + self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim, + self.num_key_value_heads * self.head_dim + ], + dim=-1) + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, + self.head_dim).transpose(1, 2) + + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, + self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( + self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, :key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}") + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.inner_attn_ln(attn_output) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + class BitnetFlashAttention2(BitnetAttention): """ Bitnet flash attention module. This module inherits from `BitnetAttention` as the weights of the module stays @@ -1240,13 +1430,30 @@ def recursive_set(model, name, attr): obj = getattr(obj, n) setattr(obj, names[-1], attr) - def quantize(self): + def quantize(self, fuse_qkv=True, fuse_gateup=True): + for name, module in self.model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + self.recursive_set(self.model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + self.recursive_set(self.model, name, bitnet_mlp_fused) for name, module in self.model.named_modules(): # if is bitnet layer if isinstance(module, BitLinear): # create quantized version of the layer print("Quantizing module", name) - bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) + if name.endswith(".qkv_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=3) + elif name.endswith(".gate_up_proj"): + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module, weight_group=2) + else: + bitblas_linear = BitLinearBitBLAS.from_bit_linear(module) print("Replacing module", name, "with a quantized version") self.recursive_set(self.model, name, bitblas_linear) self.quantized = True @@ -1300,20 +1507,34 @@ def from_quantized( trust_remote_code=trust_remote_code, **cached_file_kwargs, ) - # only load from remote instead of local - # TODO(lei): add local support + # load quantize config quantize_file = cached_file(model_name_or_path, "quantize_config.json") assert quantize_file is not None, "quantize config file not found" import json + # get quantize format with open(quantize_file, "r") as f: quant_config = json.load(f) checkpoint_format = quant_config["checkpoint_format"] assert checkpoint_format in ["bitblas"], "quantize format not supported" + fuse_qkv = quant_config.get("fuse_qkv", True) + fuse_gateup = quant_config.get("fuse_gateup", True) import accelerate if checkpoint_format == "bitblas": model = cls(config) + for name, module in model.named_modules(): + # if is bitnet layer + if fuse_qkv and isinstance(module, BitnetAttention): + # create quantized version of the layer + print("Replacing BitnetAttention", name) + bitnet_attenion_qkv_fused = BitnetAttentionQKVFused.from_bit_attention(module) + model.recursive_set(model, name, bitnet_attenion_qkv_fused) + if fuse_gateup and isinstance(module, BitnetMLP): + # create quantized version of the layer + print("Replacing BitnetMLP", name) + bitnet_mlp_fused = BitnetMLPFuseGateUp.from_bit_mlp(module) + model.recursive_set(model, name, bitnet_mlp_fused) for name, module in model.named_modules(): if isinstance(module, BitLinear): # create quantized version of the layer diff --git a/integration/BitNet/requirements.txt b/integration/BitNet/requirements.txt index 7d4b14956..45952b615 100644 --- a/integration/BitNet/requirements.txt +++ b/integration/BitNet/requirements.txt @@ -1,2 +1,3 @@ lm_eval==0.3.0 flash_attn +transformers==4.40 \ No newline at end of file diff --git a/integration/BitNet/utils_quant.py b/integration/BitNet/utils_quant.py index cb0c0f50b..3da74c213 100644 --- a/integration/BitNet/utils_quant.py +++ b/integration/BitNet/utils_quant.py @@ -101,10 +101,10 @@ def replace_weight_param_with_qweight(self): self.format = "bitblas" @classmethod - def from_bit_linear(cls, bitlinear): + def from_bit_linear(cls, bitlinear, weight_group=1): bitblas_linear = cls( bitlinear.in_features, bitlinear.out_features, weight_bits=1, input_bits=8) - sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight) + sw, qweight = bitblas_linear.create_bitblas_weights(bitlinear.weight, weight_group) bitblas_linear.register_buffer("qweight", qweight) bitblas_linear.register_buffer("sw", sw) if bitlinear.bias is not None: @@ -113,11 +113,31 @@ def from_bit_linear(cls, bitlinear): bitblas_linear.bias = None return bitblas_linear - def create_bitblas_weights(self, weight): - sw = 1 / weight.abs().mean().clamp(min=1e-5) - quant_weight = self.weight_quant(weight).detach() - quant_weight = self.bitblas_matmul.transform_weight(quant_weight) - qweight = nn.Parameter(quant_weight, requires_grad=False) + def create_bitblas_weights(self, weight, weight_group=1): + if weight_group: + hidden_size = weight.size(0) + group_size = hidden_size // weight_group + + sw_list = [] + qweight_list = [] + + for i in range(weight_group): + start_idx = i * group_size + end_idx = (i + 1) * group_size + + sw = 1 / weight[start_idx:end_idx].abs().mean().clamp(min=1e-5) + sw_list.append(sw.repeat(group_size)) + + qweight = self.weight_quant(weight[start_idx:end_idx]).detach() + qweight_list.append(qweight) + + sw = torch.cat(sw_list, dim=0) + qweight = torch.cat(qweight_list, dim=0) + else: + sw = 1 / weight.abs().mean().clamp(min=1e-5) + qweight = self.weight_quant(weight).detach() + qweight = self.bitblas_matmul.transform_weight(qweight) + qweight = nn.Parameter(qweight, requires_grad=False) return sw, qweight def post_process_weights(self):