Reproduce the inference time scaling exp

In this blog post, I share my reproduction of huggingface blogpost-scaling-test-time-compute. The goal is to show that with more generated tokens, the performance of a smaller model can approach that of a larger model.

1. Takeaways

2. Dataset and model

2.1. dataset

The dataset used in this experiment is HuggingFaceH4/MATH-500. It consists of 500 problems from the MATH benchmark, each containing:

problem: Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. Enter your answer in the form $(r,\theta),$ where $r > 0$ and $0 \le \theta < 2 \pi.$

solution: We have that $r = \sqrt{0^2 + 3^2} = 3.$ Also, if we draw the line connecting the origin and $(0,3),$ this line makes an angle of $\frac{\pi}{2}$ with the positive $x$-axis. [asy] unitsize(0.8 cm); draw((-0.5,0)--(3.5,0)); draw((0,-0.5)--(0,3.5)); draw(arc((0,0),3,0,90),red,Arrow(6)); dot((0,3), red); label("$(0,3)$", (0,3), W); dot((3,0), red); [/asy] Therefore, the polar coordinates are $\boxed{\left( 3, \frac{\pi}{2} \right)}.$

answer: \left( 3, \frac{\pi}{2} \right)

2.2. Large language models

I evaluate two models Llama and Qwen with different sizes:

2.3. Reward model

Llama3.1-8B-PRM-Deepseek-Data

The model is trained from meta-llama/Llama-3.1-8B-Instruct on RLHFlow/Deepseek-PRM-Data for 1 epochs. This model can be used for ORM and PRM. ORM evaluates the final solution, while PRM measures logical correctness at each computation step.

[
{"role": "user", "content": "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. To convert from rectangular coordinates $(x, y)$ to polar coordinates $(r, \\theta)$, we can use the formulas\n\\[r = \\sqrt{x^2 + y^2}\\]\n\\[\\theta = \\arctan \\frac{y}{x}\\]\n\nIn this case, the rectangular coordinates are $(0,3)$, so $x = 0$ and $y = 3$. \n\nFirst, we calculate $r$:\n\\[r = \\sqrt{0^2 + 3^2} = \\sqrt{9} = 3\\]\n\nNext, we calculate $\\theta$:\n\\[\\theta = \\arctan \\frac{3}{0}\\]\nSince the tangent function is not defined for $x = 0$, we need to use a special case. When $x = 0$, $\\theta = \\frac{\\pi}{2}$ if $y > 0$, and $\\theta = \\frac{3\\pi}{2}$ if $y < 0$. In this case, $y = 3 > 0$, so $\\theta = \\frac{\\pi}{2}$.\n\nSo, the polar coordinates equivalent to $(0,3)$ are $\\boxed{(3,\\frac{\\pi}{2})}$."},
{"role": "assistant", "content": "+"},
]
[
{"role": "user", "content": "Convert the point $(0,3)$ in rectangular coordinates to polar coordinates. To convert from rectangular coordinates $(x, y)$ to polar coordinates $(r, \\theta)$, we can use the formulas\n\\[r = \\sqrt{x^2 + y^2}\\]\n\\[\\theta = \\arctan \\frac{y}{x}\\]"},
{"role": "assistant", "content": "+"},
{"role": "user", "content": "In this case, the rectangular coordinates are $(0,3)$, so $x = 0$ and $y = 3$."},
{"role": "assistant", "content": "+"},
{"role": "user", "content": "In this case, $y = 3 > 0$, so $\\theta = \\frac{\\pi}{2}$."},
{"role": "assistant", "content": "+"},
{"role": "user", "content": "So, the polar coordinates equivalent to $(0,3)$ are $\\boxed{(3,\\frac{\\pi}{2})}$."},
{"role": "assistant", "content": "+"},
]

2.4. Test-time scaling strategies

3. Reproduce results

3.1. obersevations

4. Performance improvement in terms of flops?

A natural question: Does scaling the test-time compute yield consistent improvements if we measure actual FLOPs cost rather than just the number of generated tokens?

Different model sizes have different computational demands. Additionally, for inference, the FLOPs for prefill (the forward pass over the prompt) and decoding (token-by-token generation) are quite different. For the PRM approach, thereโ€™s an extra overhead of the reward model forward pass. For different size of models, the inference flops may not be liner to the model size. thus I want to see if the performance improvement in terms of flops is consistent with the number of generated tokens.

where ๐‘ is the number of samples generated.

4.1. LLM FLOPs estimation

I estimated the FLOPs of the forward pass for prefill and decoding stages as follows. The equation and the anylysis are based on this paper arXiv.

During the following analysis, I use the following notations:

For prefill stage, the equations and corresponding FLOPs are:

๐‘ธ๐‘ฒ๐‘ฝ=๐‘ฟ๐‘พ๐‘„๐พ๐‘‰ 6๐‘๐‘ โ„Ž2
๐‘ธ๐‘ฒ=ย RoPE(๐‘ธ๐‘ฒ) 6๐‘๐‘ โ„Ž
๐‘ถ=ย Attn(๐‘ธ๐‘ฒ๐‘ฝ) 4๐‘๐‘ 2โ„Ž+4๐‘๐‘ 2๐‘›
๐‘ฟ=๐‘ถ๐‘พ๐‘‚ 2๐‘๐‘ โ„Ž2
๐‘ฟ=ย Add&Norm(๐‘ฟ) 5๐‘๐‘ โ„Ž
๐‘ฎ๐‘ผ=๐‘ฟ[๐‘พ๐บ,๐‘พ๐‘ˆ] 4๐‘๐‘ โ„Žโ„Žโ€ฒ
๐‘ซ=ย Swish(๐‘ฎ)๐‘ผ 2๐‘๐‘ โ„Žโ€ฒ
๐‘ฟ=๐‘ซ๐‘พ๐ท 2๐‘๐‘ โ„Žโ„Žโ€ฒ
๐‘ฟ=ย Add&Norm(๐‘ฟ) 5๐‘๐‘ โ„Ž

For decoding stage, the equations and corresponding FLOPs are:

(๐‘ž,๐‘˜,๐‘ฃ)=๐‘ฅ๐‘พ๐‘„๐พ๐‘‰ 6๐‘โ„Ž2
(๐‘ž,๐‘˜)=ย RoPE(๐‘ž,๐‘˜) 6๐‘โ„Ž
(๐พ,๐‘‰)=ย Cache(๐‘˜,๐‘ฃ) โ€œ-โ€
๐‘œ=ย Attn(๐‘ž,๐พ,๐‘‰) 4๐‘๐‘ โ„Ž+4๐‘๐‘ ๐‘›
๐‘ฅ=๐‘œ๐‘พ๐‘‚ 2๐‘โ„Ž2
๐‘ฅ=ย Add&Norm(๐‘ฅ) 5๐‘โ„Ž
(๐‘”,๐‘ข)=๐‘ฅ[๐‘พ๐บ,๐‘พ๐‘ˆ] 4๐‘โ„Žโ„Žโ€ฒ
๐‘‘=ย Swish(๐‘”)๐‘ข 2๐‘โ„Žโ€ฒ
๐‘ฅ=๐‘‘๐‘พ๐ท 2๐‘โ„Žโ„Žโ€ฒ
๐‘ฅ=ย Add&Norm(๐‘ฅ) 5๐‘โ„Ž

For MATH-500 dataset, The FLOPs of the forward pass can be estimated as follows:

I compute the FLOPs of the forward pass for batch size is 1. Then

Thus I use the following formula to compute the total FLOPs:

FLOPsprefill(๐‘ )=8๐‘ โ„Ž2+16๐‘ โ„Ž+4๐‘ 2โ„Ž+4๐‘ 2๐‘›+6๐‘ โ„Žโ„Žโ€ฒ+2๐‘ โ„Žโ€ฒFLOPsdecode(๐‘ )=8โ„Ž2+16โ„Ž+4๐‘ โ„Ž+4๐‘ ๐‘›+6โ„Žโ„Žโ€ฒ+2โ„Žโ€ฒFLOPsย totalย =ย FLOPsย prefill(๐‘๐‘™)+โˆ‘๐‘–=0๐‘‘๐‘™โˆ’1ย FLOPsย decode(๐‘๐‘™+๐‘–)

where ๐‘๐‘™ is the length of the problem prompt, and ๐‘‘๐‘™ is the number of tokens we generate for the solution.

4.2. results

Below, we re-plot the same dataโ€”accuracy vs. total FLOPsโ€”for Qwen2.5 of various sizes. The left endpoint of each curve (for majority voting) corresponds to the minimal compute cost of a greedy decoding (๐‘=1). As the inference time move right, (ideally) smaller models with less flops can achieve similar performance to larger models with more flops.

The results are shown below:

4.3. obersevations

5. Summary

This reproduction reaffirms the main conclusion from the Hugging Face blog post: scaling test-time compute (by sampling multiple solutions and picking the best or majority) can improve accuracy, especially for smaller models. Yet, these improvements donโ€™t entirely overcome the fundamental quality gap between smaller and larger models.

We further demonstrate how analyzing FLOPs clarifies the computational trade-offs in test-time scaling. Itโ€™s not always free to sample or evaluate more solutions. Practitioners need to weigh the cost-to-benefit ratio carefully, particularly if they aim to deploy these methods at scale.