| import contextlib |
| import functools |
| import inspect |
| from enum import Enum |
| from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union |
|
|
| import torch |
|
|
| |
| |
| |
| import torch.autograd |
| from diffusers.utils.import_utils import OptionalDependencyNotAvailable |
| from torch.nn.functional import scaled_dot_product_attention as native_sdpa |
|
|
| from finetrainers.constants import FINETRAINERS_ATTN_CHECKS, FINETRAINERS_ATTN_PROVIDER |
| from finetrainers.logging import get_logger |
| from finetrainers.utils.import_utils import ( |
| is_flash_attn_available, |
| is_flash_attn_version, |
| is_sageattention_available, |
| is_sageattention_version, |
| is_torch_version, |
| is_xformers_available, |
| is_xformers_version, |
| ) |
|
|
|
|
| if is_flash_attn_available(): |
| if is_flash_attn_version("<", "2.6.3"): |
| raise OptionalDependencyNotAvailable( |
| "The `flash-attn` library version is too old. Please update it to at least 2.6.3." |
| ) |
|
|
| from flash_attn import flash_attn_func, flash_attn_varlen_func |
| from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward |
| else: |
| flash_attn_func = None |
| flash_attn_varlen_func = None |
| _flash_attn_forward = None |
| _flash_attn_backward = None |
|
|
|
|
| if is_sageattention_available(): |
| if is_sageattention_version("<", "2.1.1"): |
| raise OptionalDependencyNotAvailable( |
| "The `sageattention` library version is too old. Please update it to at least 2.1.1." |
| ) |
|
|
| from sageattention import ( |
| sageattn, |
| sageattn_qk_int8_pv_fp8_cuda, |
| sageattn_qk_int8_pv_fp8_cuda_sm90, |
| sageattn_qk_int8_pv_fp16_cuda, |
| sageattn_qk_int8_pv_fp16_triton, |
| sageattn_varlen, |
| ) |
| else: |
| sageattn = None |
| sageattn_qk_int8_pv_fp16_cuda = None |
| sageattn_qk_int8_pv_fp16_triton = None |
| sageattn_qk_int8_pv_fp8_cuda = None |
| sageattn_qk_int8_pv_fp8_cuda_sm90 = None |
| sageattn_varlen = None |
|
|
|
|
| if is_torch_version(">=", "2.5.0"): |
| import torch.nn.attention.flex_attention as flex_attention |
|
|
|
|
| if is_torch_version(">=", "2.6.0"): |
| from torch.distributed.tensor.experimental._attention import ( |
| _AttentionOp, |
| _cp_options, |
| _templated_ring_attention, |
| _templated_ring_attention_backward, |
| set_rotate_method, |
| ) |
| else: |
| _cp_options = None |
| _templated_ring_attention = None |
| set_rotate_method = None |
|
|
| class _AttentionOp: |
| def __init__(self, *args, **kwargs): |
| raise OptionalDependencyNotAvailable( |
| "The `torch.distributed.tensor.experimental._attention` module is not available. Please update PyTorch to at least 2.6.0." |
| ) |
|
|
|
|
| if is_xformers_available(): |
| if is_xformers_version("<", "0.0.29"): |
| raise OptionalDependencyNotAvailable( |
| "The `xformers` library version is too old. Please update it to at least 0.0.29." |
| ) |
|
|
| import xformers.ops as xops |
| else: |
| xops = None |
|
|
|
|
| logger = get_logger() |
|
|
| _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"] |
| _SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"] |
| _SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"] |
|
|
|
|
| |
|
|
|
|
| def _finetrainers_scaled_dot_product_efficient_attention_forward( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_bias: Optional[torch.Tensor] = None, |
| compute_log_sumexp: bool = False, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| |
| |
| seqlen_q = query.shape[-2] |
| out, lse, philox_seed, philox_offset = torch.ops.aten._scaled_dot_product_efficient_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_bias, |
| compute_log_sumexp=compute_log_sumexp, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| ) |
|
|
| |
| |
| if compute_log_sumexp: |
| assert lse.ndim == 3 |
| lse = lse[:, :, :seqlen_q] |
|
|
| return out, lse, philox_seed, philox_offset |
|
|
|
|
| |
| def _finetrainers_scaled_dot_product_efficient_attention_backward( |
| grad_out_: torch.Tensor, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_bias: torch.Tensor, |
| out: torch.Tensor, |
| logsumexp: torch.Tensor, |
| philox_seed: torch.Tensor, |
| philox_offset: torch.Tensor, |
| dropout_p: float, |
| grad_input_mask: List[bool], |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| assert len(grad_input_mask) == 4 |
| |
| kAlignLSE = 32 |
|
|
| logsumexp = torch.nn.functional.pad( |
| logsumexp, (0, kAlignLSE - (logsumexp.shape[-1] % kAlignLSE)), value=float("inf") |
| ) |
|
|
| grad_query, grad_key, grad_value, grad_attn_bias = torch.ops.aten._scaled_dot_product_efficient_attention_backward( |
| grad_out_=grad_out_, |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_bias, |
| out=out, |
| logsumexp=logsumexp, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| dropout_p=dropout_p, |
| grad_input_mask=grad_input_mask, |
| is_causal=is_causal, |
| scale=scale, |
| ) |
|
|
| return grad_query, grad_key, grad_value, grad_attn_bias |
|
|
|
|
| |
| def _finetrainers_flash_attn_forward( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| dropout_p: float = 0.0, |
| scale: Optional[float] = None, |
| is_causal: bool = False, |
| window_size: Tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| return_softmax: bool = False, |
| ): |
| query, key, value = ( |
| x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value) |
| ) |
| out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( |
| query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, return_softmax |
| ) |
| out = out.permute(0, 2, 1, 3).contiguous() |
| return out, softmax_lse, q, k, v, out_padded, S_dmask, rng_state |
|
|
|
|
| |
| def _finetrainers_flash_attn_backward( |
| grad_out: torch.Tensor, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| out: torch.Tensor, |
| logsumexp: torch.Tensor, |
| dropout_p: float, |
| scale: Optional[float] = None, |
| is_causal: bool = False, |
| window_size: Tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| deterministic: bool = False, |
| rng_state: Optional[torch.Tensor] = None, |
| _permute_outputs: bool = True, |
| ): |
| dq, dk, dv = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value) |
| grad_out = grad_out.permute(0, 2, 1, 3).contiguous() |
|
|
| dq, dk, dv, softmax_d = _flash_attn_backward( |
| grad_out, |
| query, |
| key, |
| value, |
| out, |
| logsumexp, |
| dq, |
| dk, |
| dv, |
| dropout_p, |
| scale, |
| is_causal, |
| window_size, |
| softcap, |
| alibi_slopes, |
| deterministic, |
| rng_state, |
| ) |
|
|
| |
| dq = dq[..., : grad_out.shape[-1]] |
| dk = dk[..., : grad_out.shape[-1]] |
| dv = dv[..., : grad_out.shape[-1]] |
|
|
| if _permute_outputs: |
| dq, dk, dv = (x.permute(0, 2, 1, 3).contiguous() for x in (dq, dk, dv)) |
| return dq, dk, dv |
|
|
|
|
| |
|
|
|
|
| class AttentionProvider(str, Enum): |
| |
|
|
| |
| FLASH = "flash" |
| FLASH_VARLEN = "flash_varlen" |
|
|
| |
| FLEX = "flex" |
| NATIVE = "native" |
| _NATIVE_CUDNN = "_native_cudnn" |
| _NATIVE_EFFICIENT = "_native_efficient" |
| _NATIVE_FLASH = "_native_flash" |
| _NATIVE_MATH = "_native_math" |
|
|
| |
| SAGE = "sage" |
| SAGE_VARLEN = "sage_varlen" |
| _SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda" |
| _SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90" |
| _SAGE_QK_INT8_PV_FP16_CUDA = "_sage_qk_int8_pv_fp16_cuda" |
| _SAGE_QK_INT8_PV_FP16_TRITON = "_sage_qk_int8_pv_fp16_triton" |
| |
| |
| |
|
|
| |
| XFORMERS = "xformers" |
|
|
|
|
| class _AttentionProviderRegistry: |
| _providers = {} |
| _constraints = {} |
| _supports_cp = {} |
| _supported_arg_names = {} |
|
|
| _active_provider = AttentionProvider(FINETRAINERS_ATTN_PROVIDER) |
| _checks_enabled = FINETRAINERS_ATTN_CHECKS |
|
|
| |
| _mesh: torch.distributed.device_mesh.DeviceMesh = None |
| _convert_to_fp32: bool = None |
| _rotate_method: Literal["allgather", "alltoall"] = None |
|
|
| @classmethod |
| def register( |
| cls, provider: AttentionProvider, constraints: Optional[List[Callable]] = None, supports_cp: bool = False |
| ): |
| logger.debug(f"Registering attention provider: {provider}") |
|
|
| def decorator(func): |
| cls._providers[provider] = func |
| cls._constraints[provider] = constraints or [] |
| cls._supports_cp[provider] = supports_cp |
| cls._supported_arg_names[provider] = set(inspect.signature(func).parameters.keys()) |
| return func |
|
|
| return decorator |
|
|
| @classmethod |
| def get_active_provider(cls): |
| return cls._active_provider, cls._providers[cls._active_provider] |
|
|
| @classmethod |
| def list_providers(cls): |
| return list(cls._providers.keys()) |
|
|
| @classmethod |
| def supports_context_parallel(cls, provider: AttentionProvider): |
| if provider not in cls._providers: |
| raise ValueError(f"Provider {provider} is not registered.") |
| return cls._supports_cp.get(provider, False) |
|
|
| @classmethod |
| def context_parallel_enabled(cls): |
| return cls._mesh is not None |
|
|
| @classmethod |
| def _set_context_parallel( |
| cls, |
| mesh: torch.distributed.device_mesh.DeviceMesh = None, |
| convert_to_fp32: bool = None, |
| rotate_method: str = None, |
| *, |
| reset: bool = False, |
| ): |
| if reset: |
| mesh = convert_to_fp32 = rotate_method = None |
| cls._mesh = mesh |
| cls._convert_to_fp32 = convert_to_fp32 |
| cls._rotate_method = rotate_method |
|
|
| @classmethod |
| def _raise_cp_error_if_mesh_not_set(cls): |
| if cls._mesh is None: |
| raise ValueError( |
| "`_AttentionProviderRegistry._mesh` is None. It must be set before calling context parallel attention methods." |
| ) |
|
|
|
|
| @contextlib.contextmanager |
| def attention_provider( |
| provider: AttentionProvider = AttentionProvider.NATIVE, |
| *, |
| mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None, |
| convert_to_fp32: bool = True, |
| rotate_method: str = "allgather", |
| ): |
| """Context manager to set the active attention provider and possibly enable context parallelism.""" |
|
|
| if provider not in _AttentionProviderRegistry._providers: |
| raise ValueError(f"Provider {provider} is not registered.") |
| if mesh is not None and not _AttentionProviderRegistry.supports_context_parallel(provider): |
| raise ValueError(f"Provider {provider} does not support context parallelism.") |
|
|
| old_provider = _AttentionProviderRegistry._active_provider |
| _AttentionProviderRegistry._active_provider = provider |
|
|
| _AttentionProviderRegistry._mesh = mesh |
| _AttentionProviderRegistry._convert_to_fp32 = convert_to_fp32 |
| _AttentionProviderRegistry._rotate_method = rotate_method |
| if mesh is not None: |
| _convert_to_f32 = _cp_options.convert_to_f32 |
| _enable_load_balance = _cp_options.enable_load_balance |
| _rotate_method = _cp_options.rotate_method |
|
|
| try: |
| yield |
| finally: |
| _AttentionProviderRegistry._active_provider = old_provider |
|
|
| _AttentionProviderRegistry._mesh = None |
| _AttentionProviderRegistry._convert_to_fp32 = None |
| _AttentionProviderRegistry._rotate_method = None |
| if mesh is not None: |
| _cp_options.convert_to_f32 = _convert_to_f32 |
| _cp_options.enable_load_balance = _enable_load_balance |
| _cp_options.rotate_method = _rotate_method |
|
|
|
|
| def attention_dispatch( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| enable_gqa: bool = False, |
| attention_kwargs: Optional[Dict[str, Any]] = None, |
| ) -> torch.Tensor: |
| attention_kwargs = attention_kwargs or {} |
| provider_name, provider_fn = _AttentionProviderRegistry.get_active_provider() |
| kwargs = { |
| "query": query, |
| "key": key, |
| "value": value, |
| "attn_mask": attn_mask, |
| "dropout_p": dropout_p, |
| "is_causal": is_causal, |
| "scale": scale, |
| "enable_gqa": enable_gqa, |
| **attention_kwargs, |
| } |
|
|
| if _AttentionProviderRegistry._checks_enabled: |
| removed_kwargs = set(kwargs) - set(_AttentionProviderRegistry._supported_arg_names[provider_name]) |
| if removed_kwargs: |
| log_freq = 512 |
| msg = ( |
| f"Removing unsupported arguments for attention provider {provider_name}: {removed_kwargs}. This " |
| f"message will be logged every {log_freq} calls." |
| ) |
| logger.log_freq("WARNING", "REMOVING_ATTN_UNSUPPORTED_KWARGS", msg, log_freq) |
| for check in _AttentionProviderRegistry._constraints.get(provider_name): |
| check(**kwargs) |
|
|
| kwargs = {k: v for k, v in kwargs.items() if k in _AttentionProviderRegistry._supported_arg_names[provider_name]} |
|
|
| if _AttentionProviderRegistry.context_parallel_enabled(): |
| _set_context_parallel_options(**kwargs) |
|
|
| return provider_fn(**kwargs) |
|
|
|
|
| |
|
|
|
|
| |
| def _set_context_parallel_options(is_causal: bool, **kwargs): |
| _cp_options.enable_load_balance = is_causal |
| _cp_options.convert_to_f32 = _AttentionProviderRegistry._convert_to_fp32 |
| set_rotate_method(_AttentionProviderRegistry._rotate_method) |
|
|
|
|
| def _check_attn_mask_is_none(attn_mask: Optional[torch.Tensor], **kwargs) -> None: |
| if attn_mask is not None: |
| raise ValueError("Attention mask must be None for this provider.") |
|
|
|
|
| def _check_attn_mask_or_causal(attn_mask: Optional[torch.Tensor], is_causal: bool, **kwargs) -> None: |
| if attn_mask is not None and is_causal: |
| raise ValueError("`is_causal` cannot be True when `attn_mask` is not None.") |
|
|
|
|
| def _check_device(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| if query.device != key.device or query.device != value.device: |
| raise ValueError("Query, key, and value must be on the same device.") |
| if query.dtype != key.dtype or query.dtype != value.dtype: |
| raise ValueError("Query, key, and value must have the same dtype.") |
|
|
|
|
| def _check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| _check_device(query, key, value) |
| if query.device.type != "cuda": |
| raise ValueError("Query, key, and value must be on a CUDA device.") |
|
|
|
|
| def _check_device_cuda_atleast_smXY(major: int, minor: int) -> Callable: |
| def check_device_cuda(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| _check_device_cuda(query, key, value) |
| if torch.cuda.get_device_capability(query.device) < (major, minor): |
| raise ValueError( |
| f"Query, key, and value must be on a CUDA device with compute capability >= {major}.{minor}." |
| ) |
|
|
| return check_device_cuda |
|
|
|
|
| def _check_qkv_dtype_match(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| if query.dtype != key.dtype: |
| raise ValueError("Query and key must have the same dtype.") |
| if query.dtype != value.dtype: |
| raise ValueError("Query and value must have the same dtype.") |
|
|
|
|
| def _check_qkv_dtype_bf16_or_fp16(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs) -> None: |
| _check_qkv_dtype_match(query, key, value) |
| if query.dtype not in (torch.bfloat16, torch.float16): |
| raise ValueError("Query, key, and value must be either bfloat16 or float16.") |
|
|
|
|
| def _check_shape( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> None: |
| if query.shape[-1] != key.shape[-1]: |
| raise ValueError("Query and key must have the same last dimension.") |
| if query.shape[-2] != value.shape[-2]: |
| raise ValueError("Query and value must have the same second to last dimension.") |
| if attn_mask is not None and attn_mask.shape[-1] != key.shape[-2]: |
| raise ValueError("Attention mask must match the key's second to last dimension.") |
|
|
|
|
| def _prepare_for_flash_attn_or_sage_varlen( |
| batch_size: int, |
| seq_len_q: int, |
| seq_len_kv: int, |
| attn_mask: Optional[torch.Tensor] = None, |
| device: Optional[torch.device] = None, |
| ) -> None: |
| seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device) |
| if attn_mask is None: |
| seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device) |
| else: |
| seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32) |
| cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
| cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device) |
| cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) |
| cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0) |
| max_seqlen_q = seqlens_q.max().item() |
| max_seqlen_k = seqlens_k.max().item() |
| return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) |
|
|
|
|
| def _normalize_attn_mask(attn_mask: torch.Tensor, batch_size: int, seq_len_k: int) -> torch.Tensor: |
| """ |
| Normalize an attention mask to shape [batch_size, seq_len_k] (bool) suitable for inferring seqlens_k in |
| FlashAttention/Sage varlen. |
| |
| Supports 1D to 4D shapes and common broadcasting patterns. |
| """ |
| if attn_mask.dtype != torch.bool: |
| raise ValueError(f"Attention mask must be of type bool, got {attn_mask.dtype}.") |
|
|
| if attn_mask.ndim == 1: |
| |
| attn_mask = attn_mask.unsqueeze(0).expand(batch_size, seq_len_k) |
|
|
| elif attn_mask.ndim == 2: |
| |
| if attn_mask.size(0) not in [1, batch_size]: |
| raise ValueError( |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 2D attention mask." |
| ) |
| attn_mask = attn_mask.expand(batch_size, seq_len_k) |
|
|
| elif attn_mask.ndim == 3: |
| |
| if attn_mask.size(0) not in [1, batch_size]: |
| raise ValueError( |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 3D attention mask." |
| ) |
| attn_mask = attn_mask.any(dim=1) |
| attn_mask = attn_mask.expand(batch_size, seq_len_k) |
|
|
| elif attn_mask.ndim == 4: |
| |
| if attn_mask.size(0) not in [1, batch_size]: |
| raise ValueError( |
| f"attn_mask.shape[0] ({attn_mask.shape[0]}) must be 1 or {batch_size} for 4D attention mask." |
| ) |
| attn_mask = attn_mask.expand(batch_size, -1, -1, seq_len_k) |
| attn_mask = attn_mask.any(dim=(1, 2)) |
|
|
| else: |
| raise ValueError(f"Unsupported attention mask shape: {attn_mask.shape}") |
|
|
| if attn_mask.shape != (batch_size, seq_len_k): |
| raise ValueError( |
| f"Normalized attention mask shape mismatch: got {attn_mask.shape}, expected ({batch_size}, {seq_len_k})" |
| ) |
|
|
| return attn_mask |
|
|
|
|
| def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
| return q_idx >= kv_idx |
|
|
|
|
| |
|
|
|
|
| |
| class _flash_attn_flash_attention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| dropout_p: float = 0.0, |
| softmax_scale: Optional[float] = None, |
| causal: bool = False, |
| window_size: Tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| deterministic: bool = False, |
| return_softmax: bool = False, |
| ): |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
|
|
| ctx.dropout_p = dropout_p |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| ctx.window_size = window_size |
| ctx.softcap = softcap |
| ctx.alibi_slopes = alibi_slopes |
| ctx.deterministic = deterministic |
|
|
| out, lse, q, k, v, out_padded, S_dmask, rng_state = _finetrainers_flash_attn_forward( |
| query=q, |
| key=k, |
| value=v, |
| dropout_p=dropout_p, |
| scale=softmax_scale, |
| is_causal=causal, |
| window_size=window_size, |
| softcap=softcap, |
| alibi_slopes=alibi_slopes, |
| return_softmax=return_softmax, |
| ) |
|
|
| ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) |
|
|
| return (out, lse) if return_softmax else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| q, k, v, out, lse, rng_state = ctx.saved_tensors |
|
|
| grad_query, grad_key, grad_value = _finetrainers_flash_attn_backward( |
| grad_out=grad_out, |
| query=q, |
| key=k, |
| value=v, |
| out=out, |
| logsumexp=lse, |
| dropout_p=ctx.dropout_p, |
| scale=ctx.softmax_scale, |
| is_causal=ctx.causal, |
| window_size=ctx.window_size, |
| softcap=ctx.softcap, |
| alibi_slopes=ctx.alibi_slopes, |
| deterministic=ctx.deterministic, |
| rng_state=rng_state, |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None |
|
|
|
|
| |
| class _native_ring_flash_attn_flash_attention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| dropout_p: float = 0.0, |
| softmax_scale: Optional[float] = None, |
| causal: bool = False, |
| window_size: Tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| deterministic: bool = False, |
| return_softmax: bool = False, |
| ): |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
|
|
| |
| dropout_p = dropout_p if dropout_p > 0 else 1e-30 |
|
|
| ctx.dropout_p = dropout_p |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| ctx.window_size = window_size |
| ctx.softcap = softcap |
| ctx.alibi_slopes = alibi_slopes |
| ctx.deterministic = deterministic |
|
|
| out, lse, q, k, v, out_padded, S_dmask, rng_state = _templated_ring_attention( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=_finetrainers_flash_attn_forward, |
| query=q, |
| key=k, |
| value=v, |
| dropout_p=dropout_p, |
| scale=softmax_scale, |
| is_causal=causal, |
| window_size=window_size, |
| softcap=softcap, |
| alibi_slopes=alibi_slopes, |
| return_softmax=True, |
| ) |
|
|
| ctx.save_for_backward(q, k, v, out_padded, lse, rng_state) |
|
|
| return (out, lse) if return_softmax else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| q, k, v, out, lse, rng_state = ctx.saved_tensors |
| lse = lse.permute(0, 2, 1).contiguous() |
|
|
| grad_query, grad_key, grad_value = _templated_ring_attention_backward( |
| mesh=_AttentionProviderRegistry._mesh, |
| |
| |
| |
| |
| seq_dim=1, |
| op=functools.partial(_finetrainers_flash_attn_backward, _permute_outputs=False), |
| grad_out=grad_out, |
| grad_out_name="grad_out", |
| query=q, |
| key=k, |
| value=v, |
| out=out, |
| logsumexp=lse, |
| dropout_p=ctx.dropout_p, |
| scale=ctx.softmax_scale, |
| is_causal=ctx.causal, |
| window_size=ctx.window_size, |
| softcap=ctx.softcap, |
| alibi_slopes=ctx.alibi_slopes, |
| deterministic=ctx.deterministic, |
| rng_state=rng_state, |
| ) |
| grad_query, grad_key, grad_value = ( |
| x.permute(0, 2, 1, 3).contiguous() for x in (grad_query, grad_key, grad_value) |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.FLASH, |
| constraints=[_check_attn_mask_is_none, _check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_cp=True, |
| ) |
| def flash_attn_flash_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| dropout_p: float = 0.0, |
| scale: Optional[float] = None, |
| is_causal: bool = False, |
| window_size: Tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| deterministic: bool = False, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| dispatch_fn = ( |
| _native_ring_flash_attn_flash_attention |
| if _AttentionProviderRegistry.context_parallel_enabled() |
| else _flash_attn_flash_attention |
| ) |
| return dispatch_fn.apply( |
| query, key, value, dropout_p, scale, is_causal, window_size, softcap, alibi_slopes, deterministic, return_lse |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.FLASH_VARLEN, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_cp=False, |
| ) |
| def _flash_varlen_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| cu_seqlens_q: Optional[torch.Tensor] = None, |
| cu_seqlens_k: Optional[torch.Tensor] = None, |
| max_seqlen_q: Optional[int] = None, |
| max_seqlen_k: Optional[int] = None, |
| dropout_p: float = 0.0, |
| scale: Optional[float] = None, |
| is_causal: bool = False, |
| window_size: Tuple[int, int] = (-1, -1), |
| softcap: float = 0.0, |
| alibi_slopes: Optional[torch.Tensor] = None, |
| deterministic: bool = False, |
| return_attn_probs: bool = False, |
| attn_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| batch_size, _, seq_len_q, _ = query.shape |
| _, _, seq_len_kv, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): |
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
| else: |
| seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) |
| cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) |
| cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) |
|
|
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| if _AttentionProviderRegistry.context_parallel_enabled(): |
| return_attn_probs = True |
|
|
| out = flash_attn_varlen_func( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| dropout_p=dropout_p, |
| softmax_scale=scale, |
| causal=is_causal, |
| window_size=window_size, |
| softcap=softcap, |
| alibi_slopes=alibi_slopes, |
| deterministic=deterministic, |
| return_attn_probs=return_attn_probs, |
| ) |
|
|
| rest = None |
| if return_attn_probs: |
| out, *rest = out |
| out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) |
| if return_attn_probs: |
| return out, *rest[:1] |
| return out |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.FLEX, |
| constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
| supports_cp=False, |
| ) |
| def _native_flex_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[Union[torch.Tensor, "flex_attention.BlockMask"]] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| enable_gqa: bool = False, |
| return_lse: bool = False, |
| kernel_options: Optional[Dict[str, Any]] = None, |
| ) -> torch.Tensor: |
| |
| score_mod = None |
| block_mask = None |
| batch_size, num_heads, seq_len_q, _ = query.shape |
| _, _, seq_len_kv, _ = key.shape |
|
|
| if attn_mask is None or isinstance(attn_mask, flex_attention.BlockMask): |
| block_mask = attn_mask |
| elif is_causal: |
| block_mask = flex_attention.create_block_mask( |
| _flex_attention_causal_mask_mod, None, None, seq_len_q, seq_len_kv, query.device |
| ) |
| elif torch.is_tensor(attn_mask): |
| if attn_mask.ndim == 2: |
| attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) |
|
|
| attn_mask = attn_mask.expand(batch_size, num_heads, seq_len_q, seq_len_kv) |
|
|
| if attn_mask.dtype == torch.bool: |
| |
| def mask_mod(batch_idx, head_idx, q_idx, kv_idx): |
| return attn_mask[batch_idx, head_idx, q_idx, kv_idx] |
|
|
| block_mask = flex_attention.create_block_mask( |
| mask_mod, batch_size, None, seq_len_q, seq_len_kv, query.device |
| ) |
| else: |
|
|
| def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): |
| return score + attn_mask[batch_idx, head_idx, q_idx, kv_idx] |
| else: |
| raise ValueError("Attention mask must be either None, a BlockMask, or a 2D/4D tensor.") |
|
|
| return flex_attention.flex_attention( |
| query=query, |
| key=key, |
| value=value, |
| score_mod=score_mod, |
| block_mask=block_mask, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| return_lse=return_lse, |
| kernel_options=None, |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.NATIVE, |
| constraints=[_check_device, _check_shape], |
| supports_cp=False, |
| ) |
| def _native_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| enable_gqa: bool = False, |
| ) -> torch.Tensor: |
| return native_sdpa( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
|
|
|
|
| class _native_cudnn_attention(torch.autograd.Function): |
| |
| |
| |
| |
| |
|
|
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ): |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.attn_mask = attn_mask |
|
|
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
| torch.ops.aten._scaled_dot_product_cudnn_attention( |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_mask, |
| compute_log_sumexp=True, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| return_debug_mask=False, |
| scale=scale, |
| ) |
| ) |
|
|
| ctx.max_q = max_q |
| ctx.max_k = max_k |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward( |
| grad_out=grad_out, |
| query=query, |
| key=key, |
| value=value, |
| out=out, |
| logsumexp=lse, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| attn_bias=ctx.attn_mask, |
| cum_seq_q=cum_seq_q, |
| cum_seq_k=cum_seq_k, |
| max_q=ctx.max_q, |
| max_k=ctx.max_k, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
| class _native_ring_native_cudnn_attention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ): |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.attn_mask = attn_mask |
|
|
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
| _templated_ring_attention( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=torch.ops.aten._scaled_dot_product_cudnn_attention, |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_mask, |
| compute_log_sumexp=True, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| return_debug_mask=False, |
| scale=scale, |
| ) |
| ) |
|
|
| ctx.max_q = max_q |
| ctx.max_k = max_k |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| grad_query, grad_key, grad_value = _templated_ring_attention_backward( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=torch.ops.aten._scaled_dot_product_cudnn_attention_backward, |
| grad_out=grad_out, |
| grad_out_name="grad_out", |
| query=query, |
| key=key, |
| value=value, |
| out=out, |
| logsumexp=lse, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| attn_bias=ctx.attn_mask, |
| cum_seq_q=cum_seq_q, |
| cum_seq_k=cum_seq_k, |
| max_q=ctx.max_q, |
| max_k=ctx.max_k, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._NATIVE_CUDNN, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_cp=True, |
| ) |
| def native_cudnn_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| dispatch_fn = ( |
| _native_ring_native_cudnn_attention |
| if _AttentionProviderRegistry.context_parallel_enabled() |
| else _native_cudnn_attention |
| ) |
| return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, return_lse) |
|
|
|
|
| class _native_efficient_attention(torch.autograd.Function): |
| |
| |
| |
| |
| |
|
|
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ): |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.attn_mask = attn_mask |
|
|
| |
| out, lse, philox_seed, philox_offset = _finetrainers_scaled_dot_product_efficient_attention_forward( |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_mask, |
| compute_log_sumexp=True, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| ) |
|
|
| ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| |
| grad_query, grad_key, grad_value, grad_attn_bias = ( |
| _finetrainers_scaled_dot_product_efficient_attention_backward( |
| grad_out_=grad_out, |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=ctx.attn_mask, |
| out=out, |
| logsumexp=lse, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| dropout_p=ctx.dropout_p, |
| grad_input_mask=[True, True, True, False], |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| ) |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
| class _native_ring_native_efficient_attention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ): |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
| ctx.attn_mask = attn_mask |
|
|
| |
| out, lse, philox_seed, philox_offset = _templated_ring_attention( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=_finetrainers_scaled_dot_product_efficient_attention_forward, |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=attn_mask, |
| compute_log_sumexp=True, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| ) |
|
|
| ctx.save_for_backward(query, key, value, out, lse, philox_seed, philox_offset) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
| query, key, value, out, lse, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| |
| grad_query, grad_key, grad_value, grad_attn_bias = _templated_ring_attention_backward( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=_finetrainers_scaled_dot_product_efficient_attention_backward, |
| grad_out=grad_out, |
| grad_out_name="grad_out_", |
| query=query, |
| key=key, |
| value=value, |
| attn_bias=ctx.attn_mask, |
| out=out, |
| logsumexp=lse, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| dropout_p=ctx.dropout_p, |
| grad_input_mask=[True, True, True, False], |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None, None |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._NATIVE_EFFICIENT, |
| constraints=[_check_device, _check_shape], |
| supports_cp=True, |
| ) |
| def native_efficient_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| ) -> torch.Tensor: |
| dispatch_fn = ( |
| _native_ring_native_efficient_attention |
| if _AttentionProviderRegistry.context_parallel_enabled() |
| else _native_efficient_attention |
| ) |
| return dispatch_fn.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) |
|
|
|
|
| class _native_flash_attention(torch.autograd.Function): |
| |
| |
| |
| |
| |
|
|
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ): |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
|
|
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
| torch.ops.aten._scaled_dot_product_flash_attention( |
| query=query, |
| key=key, |
| value=value, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| return_debug_mask=False, |
| scale=scale, |
| ) |
| ) |
|
|
| ctx.max_q = max_q |
| ctx.max_k = max_k |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_flash_attention_backward( |
| grad_out=grad_out, |
| query=query, |
| key=key, |
| value=value, |
| out=out, |
| logsumexp=lse, |
| cum_seq_q=cum_seq_q, |
| cum_seq_k=cum_seq_k, |
| max_q=ctx.max_q, |
| max_k=ctx.max_k, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| scale=ctx.scale, |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None |
|
|
|
|
| class _native_ring_native_flash_attention(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx: torch.autograd.function.FunctionCtx, |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ): |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
| ctx.dropout_p = dropout_p |
| ctx.is_causal = is_causal |
| ctx.scale = scale |
|
|
| out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = ( |
| _templated_ring_attention( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=torch.ops.aten._scaled_dot_product_flash_attention, |
| query=query, |
| key=key, |
| value=value, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| ) |
| ) |
|
|
| ctx.max_q = max_q |
| ctx.max_k = max_k |
| ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset) |
|
|
| return (out, lse) if return_lse else out |
|
|
| @staticmethod |
| def backward( |
| ctx: torch.autograd.function.FunctionCtx, |
| grad_out: torch.Tensor, |
| *args: torch.Tensor, |
| ): |
| _AttentionProviderRegistry._raise_cp_error_if_mesh_not_set() |
| query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors |
|
|
| grad_query, grad_key, grad_value, *_ = _templated_ring_attention_backward( |
| mesh=_AttentionProviderRegistry._mesh, |
| seq_dim=2, |
| op=torch.ops.aten._scaled_dot_product_flash_attention_backward, |
| grad_out=grad_out, |
| grad_out_name="grad_out", |
| query=query, |
| key=key, |
| value=value, |
| out=out, |
| logsumexp=lse, |
| dropout_p=ctx.dropout_p, |
| is_causal=ctx.is_causal, |
| scale=ctx.scale, |
| cum_seq_q=cum_seq_q, |
| cum_seq_k=cum_seq_k, |
| max_q=ctx.max_q, |
| max_k=ctx.max_k, |
| philox_seed=philox_seed, |
| philox_offset=philox_offset, |
| ) |
|
|
| return grad_query, grad_key, grad_value, None, None, None, None |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._NATIVE_FLASH, |
| constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_cp=True, |
| ) |
| def native_flash_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| dispatch_fn = ( |
| _native_ring_native_flash_attention |
| if _AttentionProviderRegistry.context_parallel_enabled() |
| else _native_flash_attention |
| ) |
| return dispatch_fn.apply(query, key, value, dropout_p, is_causal, scale, return_lse) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._NATIVE_MATH, |
| constraints=[_check_device, _check_shape], |
| supports_cp=False, |
| ) |
| def native_math_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| enable_gqa: bool = False, |
| ) -> torch.Tensor: |
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): |
| return native_sdpa( |
| query=query, |
| key=key, |
| value=value, |
| attn_mask=attn_mask, |
| dropout_p=dropout_p, |
| is_causal=is_causal, |
| scale=scale, |
| enable_gqa=enable_gqa, |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.SAGE, |
| constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| supports_cp=False, |
| ) |
| def _sage_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| if _AttentionProviderRegistry.context_parallel_enabled(): |
| return_lse = True |
|
|
| kwargs = { |
| "q": query, |
| "k": key, |
| "v": value, |
| "tensor_layout": "HND", |
| "is_causal": is_causal, |
| "sm_scale": scale, |
| "return_lse": return_lse, |
| } |
| out = sageattn(**kwargs) |
|
|
| rest = None |
| if return_lse: |
| out, *rest = out |
| if return_lse: |
| return out, *rest[:1] |
| return out |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.SAGE_VARLEN, |
| constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape], |
| ) |
| def _sage_varlen_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| cu_seqlens_q: Optional[torch.Tensor] = None, |
| cu_seqlens_k: Optional[torch.Tensor] = None, |
| max_seqlen_q: Optional[int] = None, |
| max_seqlen_k: Optional[int] = None, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| smooth_k: bool = True, |
| attn_mask: Optional[torch.Tensor] = None, |
| enable_gqa: bool = False, |
| ) -> torch.Tensor: |
| batch_size, _, seq_len_q, _ = query.shape |
| _, _, seq_len_kv, _ = key.shape |
|
|
| if attn_mask is not None: |
| attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) |
|
|
| if enable_gqa: |
| |
| pass |
|
|
| if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)): |
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( |
| _prepare_for_flash_attn_or_sage_varlen( |
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device |
| ) |
| ) |
| else: |
| seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device) |
| cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device) |
| cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device) |
|
|
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
|
|
| key_valid, value_valid = [], [] |
| for b in range(batch_size): |
| valid_len = seqlens_k[b] |
| key_valid.append(key[b, :valid_len]) |
| value_valid.append(value[b, :valid_len]) |
|
|
| query_packed = query.flatten(0, 1) |
| key_packed = torch.cat(key_valid, dim=0) |
| value_packed = torch.cat(value_valid, dim=0) |
|
|
| out = sageattn_varlen( |
| q=query_packed, |
| k=key_packed, |
| v=value_packed, |
| cu_seqlens_q=cu_seqlens_q, |
| cu_seqlens_k=cu_seqlens_k, |
| max_seqlen_q=max_seqlen_q, |
| max_seqlen_k=max_seqlen_k, |
| is_causal=is_causal, |
| sm_scale=scale, |
| smooth_k=smooth_k, |
| ) |
| out = out.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3) |
|
|
| return out |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA, |
| constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], |
| supports_cp=False, |
| ) |
| def _sage_qk_int8_pv_fp8_cuda_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", |
| pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", |
| smooth_k: bool = True, |
| smooth_v: bool = False, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| return sageattn_qk_int8_pv_fp8_cuda( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="HND", |
| is_causal=is_causal, |
| qk_quant_gran=qk_quant_gran, |
| sm_scale=scale, |
| pv_accum_dtype=pv_accum_dtype, |
| smooth_k=smooth_k, |
| smooth_v=smooth_v, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._SAGE_QK_INT8_PV_FP8_CUDA_SM90, |
| constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape], |
| supports_cp=False, |
| ) |
| def _sage_qk_int8_pv_fp8_cuda_sm90_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", |
| pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", |
| smooth_k: bool = True, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| return sageattn_qk_int8_pv_fp8_cuda_sm90( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="HND", |
| is_causal=is_causal, |
| qk_quant_gran=qk_quant_gran, |
| sm_scale=scale, |
| pv_accum_dtype=pv_accum_dtype, |
| smooth_k=smooth_k, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._SAGE_QK_INT8_PV_FP16_CUDA, |
| constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], |
| supports_cp=False, |
| ) |
| def _sage_qk_int8_pv_fp16_cuda_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| qk_quant_gran: _SAGE_ATTENTION_QK_QUANT_GRAN = "per_thread", |
| pv_accum_dtype: _SAGE_ATTENTION_PV_ACCUM_DTYPE = "fp32+fp32", |
| smooth_k: bool = True, |
| smooth_v: bool = False, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| return sageattn_qk_int8_pv_fp16_cuda( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="HND", |
| is_causal=is_causal, |
| qk_quant_gran=qk_quant_gran, |
| sm_scale=scale, |
| pv_accum_dtype=pv_accum_dtype, |
| smooth_k=smooth_k, |
| smooth_v=smooth_v, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider._SAGE_QK_INT8_PV_FP16_TRITON, |
| constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape], |
| supports_cp=False, |
| ) |
| def _sage_qk_int8_pv_fp16_triton_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| quantization_backend: _SAGE_ATTENTION_QUANTIZATION_BACKEND = "triton", |
| smooth_k: bool = True, |
| return_lse: bool = False, |
| ) -> torch.Tensor: |
| return sageattn_qk_int8_pv_fp16_triton( |
| q=query, |
| k=key, |
| v=value, |
| tensor_layout="HND", |
| quantization_backend=quantization_backend, |
| is_causal=is_causal, |
| sm_scale=scale, |
| smooth_k=smooth_k, |
| return_lse=return_lse, |
| ) |
|
|
|
|
| @_AttentionProviderRegistry.register( |
| AttentionProvider.XFORMERS, |
| constraints=[_check_attn_mask_or_causal, _check_device, _check_shape], |
| ) |
| def _xformers_attention( |
| query: torch.Tensor, |
| key: torch.Tensor, |
| value: torch.Tensor, |
| attn_mask: Optional[torch.Tensor] = None, |
| dropout_p: float = 0.0, |
| is_causal: bool = False, |
| scale: Optional[float] = None, |
| enable_gqa: bool = False, |
| ) -> torch.Tensor: |
| batch_size, num_heads_q, seq_len_q, _ = query.shape |
| _, num_heads_kv, seq_len_kv, _ = key.shape |
|
|
| |
| if is_causal: |
| attn_mask = xops.LowerTriangularMask() |
| elif attn_mask is not None: |
| if attn_mask.ndim == 2: |
| attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1) |
| elif attn_mask.ndim != 4: |
| raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.") |
| attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query) |
|
|
| |
| |
| query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) |
|
|
| if enable_gqa: |
| if num_heads_q % num_heads_kv != 0: |
| raise ValueError("Number of heads in query must be divisible by number of heads in key/value.") |
| num_heads_per_group = num_heads_q // num_heads_kv |
| query = query.unflatten(2, (num_heads_kv, -1)) |
| key = key.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) |
| value = value.unflatten(2, (num_heads_kv, -1)).expand(-1, -1, -1, num_heads_per_group, -1) |
|
|
| out = xops.memory_efficient_attention(query, key, value, attn_mask, dropout_p, scale) |
| if enable_gqa: |
| out = out.flatten(2, 3) |
|
|
| out = out.permute(0, 2, 1, 3) |
| return out |
|
|