Standard reinforcement learning methods for LLMs are inefficient, wasting compute by generating many independent reasoning paths from scratch. We propose TreePO, a new method that organizes generation into a tree structure. This allows for sharing common reasoning steps (like a shared trunk and branches), which makes the process much faster and more stable. Our method saves up to 43% of GPU time while achieving state-of-the-art performance on reasoning benchmarks.
Demonstration of the Validation Performance Curves along Training based on Qwen2.5-7B (Left, Mid) and Demonstration of TreePO Sampling (Right).
When using reinforcement learning to teach LLMs how to reason, a common strategy is to generate multiple possible solutions for a single problem and reward the good ones. However, this is like asking 16 different people to solve a math problem, and finding out they all started by writing down the same first few steps. It's redundant and inefficient. Each of these independent solutions (or "trajectories") re-computes the same initial reasoning steps, wasting valuable GPU time and memory.
Multiple sampled trajectories from the same prompt, with shared reasoning segments highlighted in matching colors. Despite stochastic generation, key problem-solving steps are consistently reproduced.
TreePO introduces a smarter way to explore the solution space. Instead of independent paths, we grow a single tree of possibilities. The model generates a segment of text, and then can "branch" off, creating multiple continuations. This has two major benefits:
Demonstration of the TreePO Advantage Estimation, which calculates rewards based on sub-groups of trajectories within the tree.
Our experiments show that TreePO is not just a theoretical improvement. It delivers concrete benefits.
TreePO achieves top results on several challenging math and reasoning benchmarks.
Model | AIME | AMC | MATH | Overall |
---|---|---|---|---|
GRPO (Baseline) | 17.13% | 44.42% | 72.89% | 46.63% |
TreePO (Ours) | 27.83% | 55.53% | 85.34% | 58.21% |
This is where TreePO truly shines. By avoiding redundant computations, we see significant reductions in the GPU hours required for training, making the whole process more scalable and accessible.
Model | Sampling Method | Overall Accuracy ↑ | GPU Hours ↓ |
---|---|---|---|
TreePO w/ More Init Divergence | Sequential | 58.21% | 6.40 |
Tree-based (b=2) | 54.67% | 3.65 (↓43%) |
We benchmarked the throughput of TreePO against conventional sampling. On average, our tree-based sampling yields a +40% increase in trajectories per second and a +30% increase in tokens per second across three different models. The optimal configuration, however, depends on the model and the task.
Efficiency peaks at an intermediate trade-off between tree depth and segment length. Shorter segments allow for more branching and parallelism, but increase computational overhead. Longer segments are better for prefilling, but limit the tree's depth. The ideal balance is model-specific.
Performance comparison across different tree depths.
For instruction-tuned models, throughput increases almost linearly as we generate more rollouts (solutions) per prompt. For base models, throughput peaks and then declines, as the increased diversity of solutions leads to less sharing of the computational path.
Performance comparison across different numbers of rollouts.
We explored several ways to calculate the reward signal (the "advantage"). The plots below show that a simple averaging of rewards across different subgroups in the tree works best. More complex weighting schemes or filtering out certain subgroups can actually hurt performance by creating a biased signal.
How should we slice up the generation? Should the tree be deep with short text segments, or shallow with long segments? We tested various combinations and found a sweet spot. A moderately deep tree with 512-token segments (14x512) gave the best results, while very long segments underperformed. This suggests that giving the model more frequent opportunities to branch is beneficial.
Can we guide the search using the model's own confidence? We tested a heuristic where we gave more computational budget to less likely paths to encourage exploration. The results show this is a bad idea. Forcing the model down low-probability paths leads to longer, less coherent answers and worse performance. A balanced exploration strategy is more effective.
If you find this work useful, please consider citing the paper: