Skip to content

[BUG]: Hang forever on B200 but not GB10 #71

@SchrodingerZhu

Description

@SchrodingerZhu

Version

9af1b63

Version

13,1

Which installation method(s) does this occur on?

Source

Describe the bug.

On B200, the following code:

# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

import argparse
import cuda.tile as ct
try:
    import cuda.tile_experimental as ct_experimental
except ImportError:
    ct_experimental = None
import torch
import math
import sys

from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import sdpa_kernel, SDPBackend
from utils.benchmark import report_benchmark
from types import SimpleNamespace


import numpy as np
from cuda.tile import RoundingMode as RMd


INV_LOG_2 = 1.0 / math.log(2)
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
TILE_X = 2

@ct.kernel
def fmha_kernel(Q, K, V, Out,
                qk_scale: float,
                input_pos: int, 
                TILE_D: ConstInt,  # TILE_D = hidden_size
                H: ConstInt,
                TILE_M: ConstInt,
                TILE_N: ConstInt,
                QUERY_GROUP_SIZE: ConstInt,
                CAUSAL: ConstBool,
                EVEN_K: ConstBool,
                TILE_X: ConstInt):
    """
    cuTile kernel for Fused Multi-Head Attention (FMHA).
    Computes attention output for a specific batch item and head, using tiling and online softmax.
    """
    # Map block IDs to batch and head indices
    bid_start = ct.bid(0) * TILE_X
    bid_y = ct.bid(1)
    batch_idx = bid_y // H
    head_idx = bid_y % H
    off_kv_h = head_idx // QUERY_GROUP_SIZE

    # Adjust qk_scale for exp2
    qk_scale = qk_scale * INV_LOG_2
    for i in range(0, TILE_X):
        bid_x = bid_start + i
        # Initialize offsets for current query tile (M-dimension)
        offs_m = bid_x * TILE_M + ct.arange(TILE_M, dtype=np.int32)  # [TILE_M]
        offs_m += input_pos
        offs_m = offs_m[:, None]  # [TILE_M, 1]

        # Initialize local offsets for key/value tile (N-dimension)
        offs_n_tile = ct.arange(TILE_N, dtype=np.int32)  # [TILE_N]
        offs_n_tile = offs_n_tile[None, :]  # [1, TILE_N]

        # Initialize online softmax accumulators in float32 for stability
        m_i = ct.full((TILE_M, 1), -np.inf, dtype=np.float32)
        l_i = ct.full((TILE_M, 1), 0.0, dtype=np.float32)
        acc = ct.full((TILE_M, TILE_D), 0.0, dtype=np.float32)

        # Load query tile for this batch, head, and M-chunk
        q = ct.load(
            Q, index=(batch_idx, head_idx, bid_x, 0), shape=(1, 1, TILE_M, TILE_D)
        ).reshape((TILE_M, TILE_D))  # [TILE_M, TILE_D]

        # loop over k, v and update accumulator
        m_end = input_pos + (bid_x + 1) * TILE_M
        k_seqlen = K.shape[2]
        if CAUSAL:
            # when kv pos could exceed q pos
            mask_start = (input_pos + bid_x * TILE_M) // TILE_N
            # when kv pos could exceed k_seqlen
            mask_start = min(mask_start, k_seqlen // TILE_N)
            Tc = ct.cdiv(min(m_end, k_seqlen), TILE_N)
        else:
            Tc = ct.cdiv(k_seqlen, TILE_N)
            mask_start = k_seqlen // TILE_N

        # Loop over K, V blocks (N-dimension chunks)
        for j in range(0, Tc):
            # --- Compute QK product ---
            k = ct.load(
                K, index=(batch_idx, off_kv_h, 0, j), shape=(1, 1, TILE_D, TILE_N),
                order=(0, 1, 3, 2),
                latency=2,
            )
            k = k.reshape((TILE_D, TILE_N))  # [TILE_D, TILE_N]
            qk = ct.full((TILE_M, TILE_N), 0., dtype=np.float32)
            qk = ct.mma(q, k, qk)  # [TILE_M, TILE_N]

            # --- Apply Causal Masking ---
            if (CAUSAL or not EVEN_K) and j >= mask_start:
                offs_n = j * TILE_N + offs_n_tile
                mask = ct.full((TILE_M, TILE_N), True, dtype=np.bool)
                # out of bound mask
                if not EVEN_K:
                    mask = mask & (offs_n < k_seqlen)
                # causal mask
                if CAUSAL:
                    mask = mask & (offs_m >= offs_n)  # [TILE_M, TILE_N]
                mask = ct.where(mask, 0.0, -np.inf)  # [TILE_M, TILE_N]
                qk += mask

            # --- Online Softmax Update ---
            # Moving qk_scale multiplication after reduce_max is to improve performance.
            m_ij = max(m_i, ct.max(qk, axis=-1, keepdims=True) * qk_scale)
            qk = qk * qk_scale - m_ij  # [TILE_M, TILE_N]

            # attention weights
            p = ct.exp2(qk, flush_to_zero=True)  # [TILE_M, TILE_N]
            l_ij = ct.sum(p, axis=-1, keepdims=True)  # [TILE_M, 1]
            alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)  # [TILE_M, 1]
            # update m_i and l_i
            l_i = l_i * alpha + l_ij  # [TILE_M, 1]
            # scale acc
            acc = acc * alpha  # [TILE_M, TILE_N]

            # --- Compute PV product ---
            v = ct.load(
                V, index=(batch_idx, off_kv_h, j, 0), shape=(1, 1, TILE_N, TILE_D),
                latency=4,
            ).reshape((TILE_N, TILE_D))  # [TILE_N, TILE_D]
            p = p.astype(Q.dtype)
            acc = ct.mma(p, v, acc)  # [TILE_M, TILE_N]
            m_i = m_ij  # [TILE_M, 1]

        # --- Final Normalization and Store ---
        acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
        acc = acc.reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
        ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)


# --- Wrapper function to launch the FMHA kernel ---
def cutile_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                qk_scale: float | None = None,
                input_pos: int = 0,
                tile_m: int = 128,
                tile_n: int = 128,
                query_group_size: int = 1,
                causal: bool = False) -> torch.Tensor:
    """
    Performs Fused Multi-Head Attention (FMHA) using a cuTile kernel.

    Args:
        Q (torch.Tensor): Query tensor (Batch, Heads, SeqLen_Q, D_k).
        K (torch.Tensor): Key tensor (Batch, KV_Heads, SeqLen_KV, D_k).
        V (torch.Tensor): Value tensor (Batch, KV_Heads, SeqLen_KV, D_v).
        qk_scale (float, optional): Scaling factor for QK dot product. Defaults to 1/sqrt(D_k).
        input_pos (int, optional): Global start pos for queries (causal masking). Defaults to 0.
        tile_m (int): Tile size for Query sequence length (M dimension).
        tile_n (int): Tile size for Key/Value sequence length (N dimension).
        query_group_size (int): Number of query heads per key/value head.
        causal (bool): If True, applies causal masking.

    Returns:
        torch.Tensor: Output tensor (Batch, Heads, SeqLen_Q, D_v).
    """
    # --- Input Validation ---
    if Q.ndim != 4 or K.ndim != 4 or V.ndim != 4:
        raise ValueError("Input tensors Q, K, V must be 4D (Batch, Heads, SeqLen, Dim).")
    if Q.shape[0] != K.shape[0] or Q.shape[0] != V.shape[0]:
        raise ValueError("Batch dimensions must match for Q, K, V.")
    if Q.shape[1] % query_group_size != 0:
        raise ValueError("Number of query heads must be divisible by query_group_size.")
    if K.shape[1] * query_group_size != Q.shape[1]:
        raise ValueError("K_Heads * query_group_size must equal Q_Heads.")
    if Q.shape[3] != K.shape[3]:
        raise ValueError("D_k (last dim of Q and K) must match.")
    if K.shape[2] != V.shape[2]:
        raise ValueError("SeqLen_KV (dim 2 of K and V) must match.")
    if Q.device != K.device or Q.device != V.device or not Q.is_cuda:
        raise ValueError("All input tensors must be on the same CUDA device.")
    if Q.dtype != K.dtype or Q.dtype != V.dtype:
        raise ValueError("All input tensors must have the same data type.")

    Batch, Heads, SeqLen_Q, D_k = Q.shape
    _, KV_Heads, SeqLen_KV, D_v = V.shape
    even_k = (SeqLen_KV % tile_n) == 0

    if qk_scale is None:
        qk_scale = 1.0 / math.sqrt(D_k)

    # --- Create Output Tensor ---
    Out = torch.empty((Batch, Heads, SeqLen_Q, D_v), dtype=Q.dtype, device=Q.device)

    # --- Calculate Grid Dimensions ---
    grid_x = math.ceil(math.ceil(SeqLen_Q / tile_m)/TILE_X) # we manually tile x by 2
    grid_y = Batch * Heads
    grid = (grid_x, grid_y, 1)

    # --- Launch the FMHA Kernel ---
    ct.launch(torch.cuda.current_stream(), grid, fmha_kernel, (
        Q, K, V, Out,
        qk_scale,
        input_pos,
        D_k,
        Heads,
        tile_m,
        tile_n,
        query_group_size,
        causal,
        even_k,
        TILE_X,
    ))

    return Out


# --- Wrapper function to launch the FMHA kernel with autotuning ---
def cutile_autotune_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                         qk_scale: float,
                         input_pos: int = 0,
                         query_group_size: int = 1,
                         causal: bool = False) -> tuple[torch.Tensor, dict[str, int]]:
    """
    Performs Fused Multi-Head Attention (FMHA) using a cuTile kernel with autotuning.

    Args:
        Q (torch.Tensor): Query tensor (Batch, Heads, SeqLen_Q, D_k).
        K (torch.Tensor): Key tensor (Batch, KV_Heads, SeqLen_KV, D_k).
        V (torch.Tensor): Value tensor (Batch, KV_Heads, SeqLen_KV, D_v).
        qk_scale (float, optional): Scaling factor for QK dot product. Defaults to 1/sqrt(D_k).
        input_pos (int, optional): Global start pos for queries (causal masking). Defaults to 0.
        query_group_size (int): Number of query heads per key/value head.
        causal (bool): If True, applies causal masking.
        autotuner (Autotuner | None): Autotuner object that was injected by the autotune decorator.

    Returns:
        torch.Tensor: Output tensor (Batch, Heads, SeqLen_Q, D_v).
        dict[str, int]: The best configuration found by the autotuner.
    """
    Batch, Heads, SeqLen_Q, D_k = Q.shape
    _, KV_Heads, SeqLen_KV, D_v = V.shape

    # --- Create Output Tensor ---
    Out = torch.empty((Batch, Heads, SeqLen_Q, D_v), dtype=Q.dtype, device=Q.device)

    # --- Tune/Get the best configuration for the FMHA Kernel ---
    tuned_result = ct_experimental.autotune_launch(
        torch.cuda.current_stream(),
        grid_fn=lambda cfg: (math.ceil(SeqLen_Q / cfg.TILE_M), Batch * Heads, 1),
        kernel=fmha_kernel,
        args_fn=lambda cfg: (
            Q, K, V, Out,
            qk_scale, input_pos, D_k, Heads,
            cfg.TILE_M, cfg.TILE_N, query_group_size, causal, (SeqLen_KV % cfg.TILE_N) == 0
        ),
        hints_fn=lambda cfg: {
            "num_ctas": cfg.num_ctas,
            "occupancy": cfg.occupancy,
        },
        search_space=[
            SimpleNamespace(TILE_M=256, TILE_N=128, num_ctas=1, occupancy=2),
            SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=2, occupancy=2),
            SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=2),
            SimpleNamespace(TILE_M=128, TILE_N=128, num_ctas=1, occupancy=1),
            SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=1, occupancy=4),
            SimpleNamespace(TILE_M=64, TILE_N=64, num_ctas=2, occupancy=1),
            SimpleNamespace(TILE_M=64, TILE_N=32, num_ctas=1, occupancy=2),
            SimpleNamespace(TILE_M=256, TILE_N=32, num_ctas=2, occupancy=2),
            SimpleNamespace(TILE_M=32, TILE_N=32, num_ctas=1, occupancy=1),
        ],
    )

    return Out, tuned_result.tuned_config


def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
               is_causal: bool, enable_gqa: bool) -> torch.Tensor:
    backend = SDPBackend.CUDNN_ATTENTION \
            if (Q.shape[2] == K.shape[2]) \
            else SDPBackend.FLASH_ATTENTION
    with sdpa_kernel(backend):
        ret = scaled_dot_product_attention(Q, K, V,
                                           is_causal=is_causal,
                                           enable_gqa=enable_gqa)
    return ret

with entrypoint (variant=tile),

# SPDX-FileCopyrightText: Copyright (c) <2025> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""
Quantized Attention FMHA Entry Point

This script provides a unified entry point for running FMHA (Fused Multi-Head Attention)
with different quantization levels:
  - fp16: Standard float16 (baseline)
  - fp8_e4m3: FP8 E4M3 format (better precision)
  - fp8_e5m2: FP8 E5M2 format (larger dynamic range)

Supported variants: default, tile, tile_alt

Usage:
  python AttentionFMHAEntryPoint.py --variant default --quant fp16
  python AttentionFMHAEntryPoint.py --variant tile --quant fp8_e4m3
  python AttentionFMHAEntryPoint.py --variant tile_alt --quant fp8_e5m2 --correctness-check
"""

import argparse
import torch
import math
import sys
import os

# Add parent directory to path for imports
# sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from torch.nn.attention import sdpa_kernel, SDPBackend
from torch.nn.functional import scaled_dot_product_attention


def torch_fmha(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
               is_causal: bool, enable_gqa: bool) -> torch.Tensor:
    """Reference PyTorch FMHA implementation for correctness checking."""
    # Convert to float16 for reference computation if using FP8
    if Q.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
        Q = Q.to(torch.float16)
        K = K.to(torch.float16)
        V = V.to(torch.float16)
    
    backend = SDPBackend.CUDNN_ATTENTION \
            if (Q.shape[2] == K.shape[2]) \
            else SDPBackend.FLASH_ATTENTION
    with sdpa_kernel(backend):
        ret = scaled_dot_product_attention(Q, K, V,
                                           is_causal=is_causal,
                                           enable_gqa=enable_gqa)
    return ret


def quantize_inputs(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor,
                    quant_mode: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.dtype]:
    """
    Quantize Q, K, V inputs based on the selected quantization mode.
    
    Args:
        Q, K, V: Input tensors in float16
        quant_mode: One of 'fp16', 'fp8_e4m3', 'fp8_e5m2'
    
    Returns:
        Quantized Q, K, V tensors and the quantization dtype
    """
    if quant_mode == 'fp16':
        return Q, K, V, torch.float16
    elif quant_mode == 'fp8_e4m3':
        dtype = torch.float8_e4m3fn
        return Q.to(dtype), K.to(dtype), V.to(dtype), dtype
    elif quant_mode == 'fp8_e5m2':
        dtype = torch.float8_e5m2
        return Q.to(dtype), K.to(dtype), V.to(dtype), dtype
    else:
        raise ValueError(f"Unknown quantization mode: {quant_mode}")


def run_benchmark(cutile_fmha_fn, variant_name: str, quant_mode: str,
                  correctness_check: bool = False, tile_size: int = 64, causal: bool = False):
    """Run FMHA benchmark with the specified variant and quantization mode."""
    print(f"--- Running cuTile FMHA: variant={variant_name}, quant={quant_mode}, causal={causal} ---")

    # --- User Configuration ---
    BATCH_SIZE = 1
    NUM_HEADS = 1
    SEQ_LEN_Q = 256 * 1024
    SEQ_LEN_KV = 256 * 1024
    D_K = 64
    D_V = 64
    CAUSAL = causal
    QUERY_GROUP_SIZE = 1

    # Generate inputs in float16 first
    Q_fp16 = torch.randn(BATCH_SIZE, NUM_HEADS, SEQ_LEN_Q, D_K,
                         dtype=torch.float16, device='cuda')
    K_fp16 = torch.randn(BATCH_SIZE, NUM_HEADS // QUERY_GROUP_SIZE, SEQ_LEN_KV, D_K,
                         dtype=torch.float16, device='cuda')
    V_fp16 = torch.randn(BATCH_SIZE, NUM_HEADS // QUERY_GROUP_SIZE, SEQ_LEN_KV, D_V,
                         dtype=torch.float16, device='cuda')

    # Quantize inputs
    Q_input, K_input, V_input, quant_dtype = quantize_inputs(Q_fp16, K_fp16, V_fp16, quant_mode)

    print("  Configuration:")
    print(f"    Batch Size: {BATCH_SIZE}")
    print(f"    Number of Heads: {NUM_HEADS}")
    print(f"    Query Sequence Length: {SEQ_LEN_Q}")
    print(f"    KV Sequence Length: {SEQ_LEN_KV}")
    print(f"    Head Dimension (D_k): {D_K}")
    print(f"    Value Dimension (D_v): {D_V}")
    print(f"    Quantization: {quant_mode} ({quant_dtype})")
    print(f"  Input Q shape: {Q_input.shape}, dtype: {Q_input.dtype}")
    print(f"  Input K shape: {K_input.shape}, dtype: {K_input.dtype}")
    print(f"  Input V shape: {V_input.shape}, dtype: {V_input.dtype}")

    # Calculate estimated FLOPs
    flops = 2 * BATCH_SIZE * NUM_HEADS * SEQ_LEN_Q * SEQ_LEN_KV * (D_K + D_V)
    if CAUSAL:
        flops *= 0.5
    print(f"  Estimated FLOPs: {flops}")

    # Run FMHA
    print(f"\n--- Causal = {CAUSAL} ---")
    output_fmha_cutile = cutile_fmha_fn(
        Q=Q_input, K=K_input, V=V_input,
        tile_m=tile_size, tile_n=tile_size,
        causal=CAUSAL,
        query_group_size=QUERY_GROUP_SIZE
    )
    print(f"  cuTile FMHA Output shape: {output_fmha_cutile.shape}, dtype: {output_fmha_cutile.dtype}")

    # Benchmarking
    iterations = 3
    warmup = 0
    
    print(f"  Benchmarking with {iterations} iterations...")
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    for _ in range(warmup):
        cutile_fmha_fn(
            Q=Q_input, K=K_input, V=V_input,
            tile_m=tile_size, tile_n=tile_size,
            causal=CAUSAL,
            query_group_size=QUERY_GROUP_SIZE
        )
        
    start_event.record()
    for _ in range(iterations):
        cutile_fmha_fn(
            Q=Q_input, K=K_input, V=V_input,
            tile_m=tile_size, tile_n=tile_size,
            causal=CAUSAL,
            query_group_size=QUERY_GROUP_SIZE
        )
    end_event.record()
    torch.cuda.synchronize()
    
    elapsed_time_ms = start_event.elapsed_time(end_event) / iterations
    elapsed_time_sec = elapsed_time_ms / 1000
    tflops_per_sec = (flops / 1e12) / elapsed_time_sec
    print(f"  Average execution time: {elapsed_time_ms:.3f} ms")
    print(f"  Estimated TFlops/sec: {tflops_per_sec:.2f}")

    if correctness_check:
        if quant_mode == 'fp16':
            # For FP16, use standard tolerance
            ref_fmha = torch_fmha(Q_fp16, K_fp16, V_fp16, is_causal=CAUSAL, enable_gqa=False)
            torch.testing.assert_close(output_fmha_cutile, ref_fmha, atol=1e-3, rtol=1e-3)
            print("  Correctness check passed")
        else:
            # For FP8, compute reference but use relaxed tolerance
            ref_fmha = torch_fmha(Q_fp16, K_fp16, V_fp16, is_causal=CAUSAL, enable_gqa=False)
            # FP8 has lower precision, so we just verify output is reasonable
            output_fp16 = output_fmha_cutile.to(torch.float16) if output_fmha_cutile.dtype != torch.float16 else output_fmha_cutile
            try:
                torch.testing.assert_close(output_fp16, ref_fmha, atol=0.1, rtol=0.1)
                print("  Correctness check passed (relaxed tolerance for FP8)")
            except AssertionError as e:
                print(f"  Correctness check: FP8 outputs differ from FP16 reference (expected)")
                print(f"    Max diff: {(output_fp16 - ref_fmha).abs().max().item():.4f}")
                print("  Correctness check passed (execution verified)")
    else:
        print("  Correctness check disabled")


def main():
    parser = argparse.ArgumentParser(
        description="Run quantized AttentionFMHA variants via a unified entry point."
    )
    parser.add_argument(
        "--variant",
        choices=['default', 'tile', 'tile_alt', 'fully_static', 'fully_static_alt'],
        default='default',
        help="Choose the AttentionFMHA implementation variant."
    )
    parser.add_argument(
        "--quant",
        choices=['fp16', 'fp8_e4m3', 'fp8_e5m2'],
        default='fp16',
        help="Choose the quantization level for Q, K, V inputs."
    )
    parser.add_argument(
        "--correctness-check",
        action="store_true",
        help="Check the correctness of the results against PyTorch SDPA."
    )
    parser.add_argument(
        "--tile-size",
        type=int,
        default=64,
        help="Tile size for both tile_m and tile_n (default: 64)"
    )
    parser.add_argument(
        "--causal",
        action="store_true",
        help="Enable causal masking."
    )
    args = parser.parse_args()

    # Dynamic import based on variant
    if args.variant == 'default':
        from AttentionFMHA import cutile_fmha
    elif args.variant == 'tile':
        from AttentionFMHATile import cutile_fmha
    elif args.variant == 'tile_alt':
        from AttentionFMHATileAlt import cutile_fmha
    elif args.variant == 'fully_static':
        from AttentionFMHAFullyStatic import cutile_fmha
    elif args.variant == 'fully_static_alt':
        from AttentionFMHAFullyStaticAlt import cutile_fmha
    else:
        raise ValueError(f"Unknown variant: {args.variant}")

    run_benchmark(
        cutile_fmha,
        args.variant,
        args.quant,
        correctness_check=args.correctness_check,
        tile_size=args.tile_size,
        causal=args.causal
    )


if __name__ == "__main__":
    main()

will hang forever if stride TILE_X >= 2. This is not an issue when running on GB10.

Minimum reproducible example

Relevant log output

Full env printout

Other/Misc.

No response

Contributing Guidelines

  • I agree to follow cuTile Python's contributing guidelines
  • I have searched the open bugs and have found no duplicates for this bug report

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions