Skip to content

Conversation

@feich-ms
Copy link

@feich-ms feich-ms commented Dec 22, 2025

We found a perf regression issue for long context which is described in this ticket #1910. This PR proposed a fix to avoid the switch from short factor to long factor on webgpu ep of phi model to mitigate above issues caused by recomputation of position IDs and KV cache when switching to long factor. The experimental results from benchmark_e2e.py shows that it can generate the response tokens without perf regression for short(< 1000), middle(> 1000 and < 4097) and long(> 4097) sequence length. The fix includes two pieces:

  1. In src/python/py/models/builders/base.py, for webgpu ep, we always use large cos/sin so that the model doesn't need the short-long switch.
    cos_cache = cos_cache_large
    sin_cache = sin_cache_large
  2. In src/generators.cpp, for webgpu ep, we will skip the workaround fix of recomputation of position IDs and KV cache for the short to long factor switch, and always use the long factor for the model which is converted by above base.py.

Pls note this is not final fix for the issue, just demonstrating a reasonable direction to discuss and move forward.

I have tested against this change with below command:

Convert model: python3 builder.py -m /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct -o /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -p int4 -e webgpu --extra_options int4_accuracy_level=4 int4_algo_config=k_quant_last

Run beachmark like below commands:

  • Long context (7936+6000): benchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 7936 -g 6000 --use_prompt_set -mo
  • Long context (4096+6000): benchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 4096 -g 6000 --use_prompt_set -mo
  • Long context (4096+2000): benchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 4096 -g 2000 --use_prompt_set -mo
  • Middle context: benchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 2048 -g 1000 --use_prompt_set -mo
  • Short context: benchmark/python/benchmark_e2e.py -i /Volumes/workspace/ort-web-perf/models/Phi-4-mini-instruct-onnx-optimized -l 256 -g 300 --use_prompt_set -mo

Perf and Correctness Comparsion(Metrics: tps - average tokens generated per second, Original Solution means original model+generator, Proposed Solution means the updated model+generator by this pr)

Length Original Solution Proposed Solution
Long Context (7936+6000) 22.62(Correctness: True) 24.49(Correctness: True)
Long Context (4096+6000) 26.26(Correctness: True) 29.83(Correctness: True)
Long Context (4096+2000) 26.84(Correctness: True) 34.1(Correctness: True)
Middle Context (2048+1000) 38.35(Correctness: True) 38.54(Correctness: True)
Short Context (256+300) 43.74(Correctness: True) 43.54(Correctness: True)

@feich-ms
Copy link
Author

CC @guschmue @qjia7 @gyagp

@qjia7
Copy link
Contributor

qjia7 commented Dec 23, 2025

@kunal-vaishnavi @baijumeswani Could you please advise whether always using cos_cache_large and sin_cache_large would be an acceptable approach?
I’m not entirely sure if there would be any quality impact when the total sequence length is < 4096, but adopting this approach would allow us to remove the fixing introduced in #1161 and resolve the performance issue described in #1910.
Please let us know if there’s anything we may have overlooked.

@feich-ms feich-ms changed the title Make phi model of webgpu ep always use long RoPE to improve tps performance Make phi model of webgpu ep always use long RoPE to improve tps performance and correctness for long context scenario Dec 23, 2025
@feich-ms feich-ms changed the title Make phi model of webgpu ep always use long RoPE to improve tps performance and correctness for long context scenario Make phi model of webgpu ep always use long RoPE to improve tps performance for long context scenario Dec 23, 2025
@kunal-vaishnavi
Copy link
Contributor

@kunal-vaishnavi @baijumeswani Could you please advise whether always using cos_cache_large and sin_cache_large would be an acceptable approach? I’m not entirely sure if there would be any quality impact when the total sequence length is < 4096, but adopting this approach would allow us to remove the fixing introduced in #1161 and resolve the performance issue described in #1910. Please let us know if there’s anything we may have overlooked.

The reasoning for why the KV caches need to be re-computed and the impact on quality can be found here. You should use both the small and large caches for the best output quality. Ideally, the KV cache re-computation should not be avoided here.

This was a similar issue with the DML EP and the initial fix was to use just the large caches. However, quality issues soon emerged and a new fix was made. The small and large caches are combined into one tensor and the position ids are updated to index accordingly. The KV cache re-computation is skipped here, however.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants