Skip to content

fxeqxmulfx/diff-gpt

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

105 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

diff-gpt

Autoregressive time-series forecasting with a decoder-only transformer trained on a derivative-based tokenization. The encoder differentiates the input, scales by its observed range, and quantizes with a MASH (multi-stage noise-shaping) $k$-th-order $\Sigma\Delta$ modulator; the decoder inverts this. The model is a standard GPT — the inductive bias sits entirely in the tokenization.

See papers/paper.pdf (English) and papers/paper_ru.pdf (Russian) for the full write-up with proofs and benchmarks.

Signal assumption

$$\exists m \ge 0 \quad \exists M \ge 0 \quad \forall n \ge 0 \quad \forall \mathbf{x} \in \mathbb{R}^n: \quad \|f^{(m)}(\mathbf{x})\|_{\infty} \le M$$

Tokenization contract

Two vocabulary-scaling laws hold at the tokenization level, independent of the model on top.

quantity per vocab doubling reference
ideal CE $\mathcal L^\star(V)$ $+\log 2 \approx 0.69$ nats Prop. 2
reconstruction $|\hat x - x|_\infty$ $\times \tfrac{1}{2}$ Thm. 1
ordinal smoothing $\sigma^\star$ $\times 2$ (hold value-space constant) §4

Theorem 1 (MASH reconstruction, any $k$). For any signal with $D = \max_t |\Delta^k x_t| > 0$ and $V > 2^k + 2$,

$$|\hat x - x|_\infty \le \frac{D}{V - 2 - 2^k} = \frac{\Delta}{2}$$

uniformly in horizon $T$. Vocabulary doubling halves reconstruction error at any order $k$.

Proposition 2 (achievable CE). In the fine-quantization limit, $\mathcal L^\star(V) = h(X_t \mid \text{ctx}) + \log V - \log(2D) + o(1)$. Ideal loss is linear in $\log V$, slope 1. Cross-vocab comparisons must subtract out $\log V$ to get the vocab-invariant differential-entropy residual $\tilde h(V) = \mathcal L(V) - \log V + \log(2D)$.

Key features

  • MASH $k$-th-order $\Sigma\Delta$: stable noise-shaping at arbitrary order $k$, tight $\Delta/2$ reconstruction bound uniformly in $T$. Falls back to the classical carry scheme when $V \le 2^k + 2$.
  • Gaussian ordinal soft-CE: replaces one-hot targets with a discrete Gaussian over bin indices, $p^\text{soft}_i \propto \exp(-(i - y)^2 / 2\sigma^2)$. Teaches the model that "off by one bin is almost right" — matches the ordinal structure of derivative-quantized tokens. On M4 Hourly: sMAPE $18.58 \to 17.49$ at $\sigma = 2$ with no architectural change.
  • Per-column derivative order: order_of_derivative can be an array of shape $(F,)$ giving each channel its own $k_c$. Mix $k = 0$ (raw) with $k = 1, 2, \ldots$ in a single model.
  • GPT architecture niceties: RoPE, RMSNorm (no learnable scale), SwiGLU with soft limit, tied embeddings, logit softcap, Block Attention Residuals (Chen et al. 2026), Schedule-Free AMSGrad optimizer.
  • Inference engine: KV-cache-accelerated generation, position-aware force_schedule for teacher-forced decoding (TimeXer-style exogenous covariates), best-val checkpoint auto-restore.
  • NUMA-aware training: single-process CPU-affinity pin via DiffGPT.train(pin_node=...), or one-worker-per-node DDP via DiffGPT.train_numa on multi-socket boxes.

Usage

import numpy as np, pandas as pd, torch
from diff_gpt.diff_gpt import DiffGPT
from diff_gpt.model.gpt import GPT
from diff_gpt.data_loader import DiffDataFrameDataLoader
from diff_gpt.encoder_decoder import get_domain_of_definition
from diff_gpt.sampler.temperature import TemperatureSampler

# Your data as DataFrames (one per series).
dfs: list[pd.DataFrame] = [...]

# Compute domain (per-channel max |k-th diff|) from training data.
order = 1  # scalar, OR np.array([0, 1, 2]) for per-column
all_data = np.concatenate([df.to_numpy(dtype=np.float64) for df in dfs], axis=0)
domain = get_domain_of_definition(
    inp=all_data, order_of_derivative=order, use_decimal=False,
)

# Construct GPT + DiffGPT wrapper.
base = GPT(
    vocab_size=256,
    n_embd=64,
    block_size=138,
    n_head=4,
    n_layer=2,
    label_smoothing_sigma=2.0,  # ordinal soft-CE
)
model = DiffGPT(
    model=base,
    order_of_derivative=order,
    domain_of_definition=domain,
    use_decimal=False,
)

loader = DiffDataFrameDataLoader(
    dfs=dfs,
    block_size=138,
    batch_size=32,
    vocab_size=256,
    order_of_derivative=order,
    domain_of_definition=domain,
    use_decimal=False,
    device="cuda",
    train_part=0.8,
)

model.train(loader=loader, max_iters=20_000, eval_interval=500)

# Point forecast (argmax).
context = dfs[0].iloc[-96:]
prediction = model.predict(
    df=context,
    max_new_points=48,
    sampler=TemperatureSampler(temperature=0.0),
)

Probabilistic forecasting

Every trained model is already a probabilistic forecaster — the token head outputs a calibrated distribution at each step (cross-entropy is a strictly proper scoring rule). Point argmax is one readout; num_samples independent temperature-1 trajectories give Monte-Carlo samples from the joint future distribution, which we can reduce to any set of quantiles.

from diff_gpt.sampler.temperature import TemperatureSampler
from diff_gpt.sampler.nucleus import NucleusSampler

# N Monte-Carlo trajectories, shape (N, H, F).
# Plain temperature-1 sampling — the most diverse readout.
samples = model.predict_samples(
    df=context,
    max_new_points=48,
    num_samples=100,
    sampler=TemperatureSampler(temperature=1.0),
)

# Nucleus top-p on top of temperature: drop the low-probability tail of
# each per-step distribution before sampling. Tightens forecast bands
# without any retraining (−25% MSIS on M4 Hourly at p=0.9).
nucleus = NucleusSampler(
    p=0.9, sampler=TemperatureSampler(temperature=1.0),
)
samples_nucleus = model.predict_samples(
    df=context, max_new_points=48, num_samples=100, sampler=nucleus,
)

# Or directly as quantile bands {0.1, 0.5, 0.9}: each a DataFrame of shape (H, F).
bands = model.predict_quantiles(
    df=context,
    max_new_points=48,
    quantiles=(0.1, 0.5, 0.9),
    num_samples=100,
    sampler=nucleus,
)
p10, p50, p90 = bands[0.1], bands[0.5], bands[0.9]

No architectural change and no auxiliary pinball-loss head: one model serves arbitrary quantile sets at inference time. Since the model predicts the derivative and the decoder integrates, confidence bands naturally widen with horizon as an integral of random-walk uncertainty.

First-pass M4 Hourly numbers (same model as the point-forecast row, no probabilistic-specific tuning; 414 series, horizon 48, N = 100 samples):

metric T = 1 T = 1, top-p = 0.9
sMAPE (median point) 24.18 22.02
MASE (median point) 2.94 2.41
CRPS 332.0 293.6
MSIS (α = 0.05) 41.1 30.8 (−25%)
coverage@80 (nominal 0.80) 0.946 0.919

Nucleus top-p = 0.9 at inference time (zero retraining) tightens every metric, MSIS by 25%. Remaining over-coverage reflects the fact that the model was tuned for argmax sMAPE; conformal calibration or CRPS-tuned σ / temperature closes most of the rest of the gap vs M4 probabilistic winners (MSIS ~13–15).

Anomaly detection

Cross-entropy training also gives a calibrated per-step conditional log-likelihood for free. Feeding an observed signal through a single forward pass and taking $-\log p(\text{token}_t \mid \text{ctx})$ at each position yields a per-(time, channel) anomaly score: low where the observation is typical, high where it surprises the model.

# df: observed multivariate signal, shape (T, F).
scores = model.anomaly_scores(df)
# scores: DataFrame, same shape and columns as df.
# First max_k rows are NaN (the derivative prefix); the first encoded
# token at column 0 is also NaN (no preceding context).

row_surprise = scores.sum(axis=1, skipna=True)     # per-step aggregate
anomalies = row_surprise[row_surprise > row_surprise.quantile(0.99)]

Regression-based forecasters have to approximate this with $|y_t - \hat y_t|$, which miscalibrates on heteroscedastic series; the categorical head's log-likelihood is proper under the learned distribution.

ETTh1 oil-temperature demo (benchmarks/anomaly_demo.py): train on 2880 clean hours, inject 5 synthetic ±5σ spikes into a 336-hour test window, rank positions by NLL. After only 2000 training iterations the top-10 contains 3 of 5 injected spikes (60% recall); the rest of the top-10 are genuine irregular transitions in the unsynthetic signal.

Compiled training (torch.compile)

PyTorch 2's torch.compile fuses the GPT forward/backward into a handful of kernels — a sizable step-time speedup at this model scale, once the one-time graph-tracing cost is amortized. Two rules:

  • Disable gradient checkpointing. use_checkpoint=True introduces graph breaks that defeat the fusion. Pass use_checkpoint=False to GPT when compiling.
  • Compile the inner GPT, not the DiffGPT wrapper. DiffGPT is a numpy-side encoder/decoder shell around an nn.Module; torch.compile only operates on the latter. Compile before constructing the wrapper.
base = GPT(
    vocab_size=256, n_embd=64, block_size=138, n_head=4, n_layer=2,
    label_smoothing_sigma=2.0,
    use_checkpoint=False,           # required alongside torch.compile
)
# GPU: max-autotune without cudagraphs — peak Triton/Inductor tuning,
# no CUDA-graph capture (which breaks under variable shapes, DDP, and
# KV-cache inference). CPU: max-autotune enables Inductor's CPU GEMM
# template / tile sweep; cudagraphs is a no-op on CPU either way.
compile_mode = "max-autotune-no-cudagraphs" if torch.cuda.is_available() else "max-autotune"
base = torch.compile(base, mode=compile_mode)
model = DiffGPT(
    model=base, order_of_derivative=order,
    domain_of_definition=domain, use_decimal=False,
)
model.train(loader=loader, max_iters=20_000, eval_interval=500)

Combined with DiffGPT.train_numa, compile inside the factory so every worker DDP-wraps an already-compiled module:

def build_gpt(rank: int) -> GPT:
    gpt = GPT(vocab_size=256, n_embd=64, block_size=138, n_head=4,
              n_layer=2, use_checkpoint=False)
    mode = "max-autotune-no-cudagraphs" if torch.cuda.is_available() else "max-autotune"
    return torch.compile(gpt, mode=mode)

The first 1–2 training iterations are dominated by tracing and kernel compilation; benchmark step time from iteration 10+ for a fair number. Inference (predict, predict_samples, anomaly_scores) still works on a compiled model — OptimizedModule delegates attribute access — but the KV-cache generation path grows T by one token per step, so the compiled graph is recompiled or specialized frequently and the speedup is smaller than at training time.

NUMA-aware training

On multi-socket CPU boxes, cross-socket memory traffic and thread migration dominate wall-clock for this size of model. Two entry points:

Single-process pin. Pass a NUMA node id to train() — pins CPU affinity and aligns torch.set_num_threads / OMP_NUM_THREADS to that node's core count before the hot loop:

model.train(loader=loader, max_iters=20_000, pin_node=0)

One process per NUMA node (DDP). For linear scaling across sockets, DiffGPT.train_numa spawns one worker per detected NUMA node, pins each to its CPUs, and wraps the inner GPT in DistributedDataParallel. Both factories run inside the pinned worker so allocations land on local memory; seed the loader by rank so workers draw different batches (otherwise DDP averages identical gradients).

from diff_gpt.diff_gpt import DiffGPT
from diff_gpt.model.gpt import GPT
from diff_gpt.data_loader import DiffDataFrameDataLoader
import torch

def build_gpt(rank: int) -> GPT:
    return GPT(vocab_size=256, n_embd=64, block_size=138,
               n_head=4, n_layer=2, label_smoothing_sigma=2.0)

def build_loader(rank: int, world: int) -> DiffDataFrameDataLoader:
    rng = torch.Generator().manual_seed(42 + rank)
    return DiffDataFrameDataLoader(
        dfs=dfs, block_size=138, batch_size=32, vocab_size=256,
        order_of_derivative=order, domain_of_definition=domain,
        use_decimal=False, device="cpu", train_part=0.8, rng=rng,
    )

val_loss, _ = DiffGPT.train_numa(
    model_factory=build_gpt,
    loader_factory=build_loader,
    train_kwargs={"max_iters": 20_000, "eval_interval": 500},
    checkpoint_path="best.pt",
)
# Rebuild DiffGPT around the trained weights:
gpt = build_gpt(0); gpt.load_state_dict(torch.load("best.pt"))
model = DiffGPT(model=gpt, order_of_derivative=order,
                domain_of_definition=domain, use_decimal=False)

Backend: NCCL on CUDA, Gloo on CPU. Single-node systems short-circuit to the in-process pinning path — no spawn, no DDP overhead.

Benchmarks

M4 Hourly (short-term, global)

414 series, horizon 48, per-series z-normalization, single global model, 20k iters, argmax inference.

config sMAPE MASE
baseline ($V=256$, plain CE) 18.58
+ soft-CE $\sigma = 1$ 17.55 1.663
+ soft-CE $\sigma = 2$ 17.49 1.659
+ soft-CE $\sigma = 3$ 17.35 1.949 (over-smoothed)

Vocab sweep (U-curve, optimum at $V = 256$):

$V$ $\sigma$ sMAPE
64 0.5 19.96
128 1.0 19.39
256 2.0 17.49
512 4.0 18.21

Higher-order tokenization: $k=1 \to$ 17.49, $k=2 \to 39.27$. Prediction error compounds as $O(\Delta \cdot T^{k-1})$ — MASH gives tight reconstruction but a noisy predictor integrated twice is much worse than once. $k = 1$ is the forecasting sweet spot.

ETTh1 (long-term, multivariate, iTransformer protocol)

7 channels, seq_len = 96, pred_len = 96, chronological 12/4/4-month split.

config MSE MAE
$V = 256$, plain CE 0.9861 0.6056
$V = 64$, plain CE 0.9787 0.6030
$V = 64$, soft-CE $\sigma = 1.0$ 0.9806 0.6034

ETTh1 is data-limited (~2 tokens/parameter), not loss-limited. Soft targets help most in data-starved regimes.

Install

pip install git+https://github.com/fxeqxmulfx/diff-gpt

Develop & test

uv sync
uv run --no-sync pytest

References

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors