← back to home

Notes on inference scaling laws and GPU server optimizations

by Soham Govande (@sohamgovande)

September 18, 2024 reads

The bitter lesson in the last decade of AI has been that doing the simple, stupid thing of scaling up compute actually delivers better results than coming up with novel architectures and algorithms. OpenAI's recent release of o1 extended this lesson from training compute into inference compute. The core ideas of o1 — reinforcement learning and chain of thought — are not new. The scale at which they were applied, however, is.

In the past few years of reasoning research, most papers using a best-of-N framework used small values of N, like 5, 10, or 32. What happens when N is 10,000, 100,000, or 1,000,000? Generating such a large number of completions requires orders of magnitude more compute, compute that most people don't have. This effectively entrenches the advantage of incumbents and large players. As a student researcher at the Stanford AI Lab, I'm bummed about this; while certain areas of LLM research, like pretraining, have always been compute-heavy, inference scaling is far more interesting and seems to be where the field is headed. But what if small players could still participate, albeit to a lesser extent, just by optimizing our existing systems?

Model performance scales log-linearly with test-time compute (Source: OpenAI)Model performance scales log-linearly with test-time compute (Source: OpenAI)

Throughput-optimized inference

Compute constraints will require creative solutions to extract more tokens per GPU hour — and current LLM solutions have a long way to go. Inference providers, both local ones like vLLM and hosted ones like Together, are already well-optimized but focused on a different objective: latency. Latency is valuable for chat applications, but when we only care about getting as many tokens as cheaply as possible, we shouldn't optimize for it. Instead, we should optimize for total token throughput — how many tokens we can get out of a box over a given hour, across an entire batch of requests.

At first glance, optimizing for both of these parameters should be the same, but it's actually not. Agrawal et al. explain this non-obvious quite well.

  1. Latency is minimized in small batch sizes, whereas throughput is maximized in large ones.
  2. When batch sizes are large, output lengths are quite variable. Early response termination and subsequent prompt prefills create inefficiencies, in both prefill-prioritized schedules (i.e., those used by vLLM/Orca) vs. decode-prioritized ones (e.g., FasterTransformer).

Optimizing for throughput instead of latency can result in steep cost savings, both when you own the GPUs and when you use other people's. There are a few reasons why the costs are much lower here.

  1. Older GPUs, like 4090s, actually have more FLOPS than A100s, albeit lower memory bandwidth. They're much cheaper, too — nearly ten times cheaper. While you won't get any bells and whistles like NVLink or a Tensor Memory Accelerator, each and every GPU hour will go ten times as far, which more than makes up for that difference. Also, maybe you can write kernels optimized for 4090's to eke out even more compute?
  2. Every hardware providers has a fixed fleet of GPUs, a fraction of which is not only underutilized but just unused. What if you could queue in requests during this excess capacity? These requests would be serviced for much, much cheaper. OpenAI's async batch API and inference.net both take advantage of these economics — they sell compute in large chunks to be serviced from excess capacity from global GPU clusters. In the words of inference.net:

Like a stock option that is about to expire, unused compute becomes less valuable as it approaches its expiration date. Few customers need just a few minutes of compute time, making these fragments challenging to sell conventionally.

While delayed inference is interesting, it does restrict the class of problems you can approach. Delayed responses make it particularly difficult to have back-and-forth roundtrips with the inference server, which can hurt certain 'interactive' applications (e.g., when an agent gets feedback from an environment, like coding). Using such inference providers will therefore restrict you to certain subset of problem solving that does not require an external environment, like basic math.

Looking Forward

Even with both of these strategies, I doubt we can reduce inference costs by more than a single order of magnitude. Yes, inference costs are falling dramatically, but they're decreasing much slower than you think. To get 2-3 OOMs of improvement, we'll need custom silicon. In the next 12 months, we'll have transformer-specific ASICs like Etched and Groq that reduce inference costs to 1/100th of GPUs. I'm really excited for this.

Looking forward, it's clear that a greater proportion of training budgets will be devoted to inference. Already, it's well known that synthetic data drives the majority of post-training compute, but the scale of this is underestimated. Problems will require millions of attempts to solve, and we'll have to identify the one correct chain of thought to distill from. Inference scaling is a new paradigm of compute, and I'm excited to see where it leads.