add dynamic fa3#1334
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for dynamic Multi-Token Prediction (MTP) verification within the FlashAttention-3 (fa3) decode state, including the implementation of specialized Triton kernels for parameter building and compaction, along with corresponding unit tests and test scripts. Additionally, it refactors the cache tensor manager's recycling logic using a use-count bias. The code review highlights two important issues: first, a potential data corruption bug in cache_tensor_manager.py where a buffer node could be appended to the free list multiple times, which can be resolved by checking for prior existence in the list; second, a portability issue in the test script dynamic_fa3.sh due to hardcoded GPU indices, which should instead default to standard indices with environment overrides.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) | ||
| if self.use_count(ptr) <= free_use_count: | ||
| self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) |
There was a problem hiding this comment.
If a temporary view of a cached tensor is deleted after the parent tensor has already been returned to the free list, it will trigger custom_del and add the pointer to changed_ptr again. Since the use count is still <= free_use_count, the same BufNode will be appended to free_shape_dtype_to_bufs multiple times. This leads to duplicate references in the free list, causing severe data corruption when the same buffer is allocated to multiple active tensors simultaneously. Adding a check to ensure the node is not already in the free list prevents this duplicate appending.
| free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) | |
| if self.use_count(ptr) <= free_use_count: | |
| self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) | |
| free_use_count = t_buf_node.free_use_count_bias + 1 + len(t_buf_node.shape_to_tensor) | |
| if self.use_count(ptr) <= free_use_count and t_buf_node not in self.free_shape_dtype_to_bufs[t_buf_node.shape_key]: | |
| self.free_shape_dtype_to_bufs[t_buf_node.shape_key].append(t_buf_node) |
| MAX_TOTAL_TOKEN_NUM="" | ||
| MAX_REQ_TOTAL_LEN="" | ||
| BATCH_MAX_TOKENS="" | ||
| export CUDA_VISIBLE_DEVICES=4,5 |
There was a problem hiding this comment.
Hardcoding specific GPU indices like 4,5 makes the script non-portable and prone to failing on environments with fewer GPUs or different GPU allocations. It is better to default to standard indices (e.g., 0,1) while allowing overrides from the environment.
| export CUDA_VISIBLE_DEVICES=4,5 | |
| export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0,1} |
No description provided.