A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors. The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Parameters:
| Name | Type | Description | Default |
fn | Callable[..., Tensor] | The function to be decorated. It should take tensor inputs and return tensor outputs. | required |
Returns:
| Type | Description |
Callable[..., Tensor] | Callable[..., torch.Tensor]: A wrapped version of the input function with single-entry caching. |
Source code in vllm/model_executor/layers/fla/ops/utils.py
| def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
"""
A decorator that caches the most recent results of a function with tensor inputs.
This decorator will store the output of the decorated function for the most recent set of input tensors.
The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
Args:
fn (Callable[..., torch.Tensor]):
The function to be decorated. It should take tensor inputs and return tensor outputs.
Returns:
Callable[..., torch.Tensor]:
A wrapped version of the input function with single-entry caching.
"""
cache_entries: tuple[tuple | None, dict | None, Any] = []
cache_size = 8
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
nonlocal cache_entries, cache_size
for i, entry in enumerate(cache_entries):
last_args, last_kwargs, last_result = entry
if (
len(args) == len(last_args)
and len(kwargs) == len(last_kwargs)
and all(a is b for a, b in zip(args, last_args))
and all(
k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()
)
):
cache_entries = (
cache_entries[:i]
+ cache_entries[i + 1 :]
+ [(args, kwargs, last_result)]
)
return last_result
result = fn(*args, **kwargs)
if len(cache_entries) >= cache_size:
cache_entries = cache_entries[1:]
cache_entries.append((args, kwargs, result))
return result
return wrapper
|