Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
|
Nice! |
…hannels_per_head] in order to make use of batched matmuls. fuse multiply into matmul. breaks bias, mask in exchange for massive speedup.
…ghts_calc_fn, calc_fn_data) and unused vars
…ul for SD 2.1. but remove value float32, having established that it works without.
…to prefer fast-path whenever unchunked attention would fit into memory. add kv_chunk_size_min to control the kv_chunk_size=None behaviour, so that sqrt(key_tokens) does not pick too small of a chunk size
…of chunk key size. improve separation of concerns.
…al kv_chunk_size: they can notice when no chunking would happen at all, and use fast-path. note: there's a question of whether that concern belongs *inside* the algorithm. but it'd feel weird for chunked attention to have a no-chunking-at-all branch.
… equivalent fast-path for 1 query chunk, 1 kv chunk is already supported inside
…ything in one chunk, to re-use an existing fast-path.
84bf1c0 to
0eafb95
Compare
…ose during the matmul
3c92600 to
9dc6822
Compare
| starts: List[int], | ||
| sizes: List[int], | ||
| ) -> Tensor: | ||
| slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] |
There was a problem hiding this comment.
this attempts to implement jax.lax.dynamic_slice(), but hey is this literally just torch.narrow()?
There was a problem hiding this comment.
Yeah that works also:
brkirch/stable-diffusion-webui@b119815
No notable performance difference that I observed, but it's probably slightly more efficient nonetheless.
| scale: float, | ||
| ) -> AttnChunk: | ||
| attn_weights = torch.baddbmm( | ||
| torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), |
There was a problem hiding this comment.
Shouldn't torch.zeros() be used here instead of torch.empty()?
There was a problem hiding this comment.
nope; it's actually an unused tensor (because beta=0), so we want whatever's the cheapest thing that passes the parameter validation. unfortunately PyTorch complains if you pass None. bad API design.
Implementation of:
Self-attention Does Not Need O(n^2) Memory:
https://arxiv.org/abs/2112.05682v2
Based on Amin Rezaei's implementation:
https://github.com/AminRezaei0x443/memory-efficient-attention
With:
[batch * num_heads, tokens, channels_per_head]format