Token Packing¶
Token packing efficiently combines multiple short sequences into longer ones, maximizing GPU utilization during training.
Why Token Packing?¶
Without packing, short sequences waste compute:
Sequence 1: [token, token, token, PAD, PAD, PAD, PAD, PAD] # 62% padding
Sequence 2: [token, token, PAD, PAD, PAD, PAD, PAD, PAD] # 75% padding
With packing:
Basic Usage¶
from fast_axolotl import pack_sequences
import torch
sequences = [
torch.tensor([1, 2, 3]),
torch.tensor([4, 5]),
torch.tensor([6, 7, 8, 9]),
torch.tensor([10, 11, 12]),
]
packed = pack_sequences(
sequences,
max_length=8,
pad_token_id=0
)
print(packed)
# tensor([[ 1, 2, 3, 4, 5, 0, 0, 0],
# [ 6, 7, 8, 9, 10, 11, 12, 0]])
Configuration Options¶
Max Length¶
Set the target sequence length:
# For models with 2048 context
packed = pack_sequences(sequences, max_length=2048, pad_token_id=0)
# For models with 4096 context
packed = pack_sequences(sequences, max_length=4096, pad_token_id=0)
Pad Token ID¶
Specify the padding token for your tokenizer:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b")
packed = pack_sequences(
sequences,
max_length=2048,
pad_token_id=tokenizer.pad_token_id
)
Advanced: Concatenate and Pack¶
For more control, use concatenate_and_pack with separate inputs, labels, and masks:
from fast_axolotl import concatenate_and_pack
# Separate input_ids and labels
input_sequences = [
[1, 2, 3],
[4, 5, 6, 7],
]
label_sequences = [
[-100, 2, 3], # -100 = ignore in loss
[-100, -100, 6, 7],
]
attention_masks = [
[1, 1, 1],
[1, 1, 1, 1],
]
packed_inputs, packed_labels, packed_masks = concatenate_and_pack(
input_sequences,
label_sequences,
attention_masks,
max_length=8,
pad_token_id=0,
label_pad_id=-100
)
Packing Strategies¶
Greedy First-Fit¶
The default strategy packs sequences greedily:
# Sequences are packed in order, fitting as many as possible
packed = pack_sequences(sequences, max_length=2048, pad_token_id=0)
With Sequence Boundaries¶
To preserve sequence boundaries for causal attention:
from fast_axolotl import pack_sequences
packed, boundaries = pack_sequences(
sequences,
max_length=2048,
pad_token_id=0,
return_boundaries=True
)
# boundaries contains start/end indices for each original sequence
Integration with Training¶
With Streaming Data¶
from fast_axolotl import streaming_dataset_reader, pack_sequences
def create_packed_batches(data_path, max_length, batch_size):
buffer = []
for batch in streaming_dataset_reader(data_path, batch_size=100):
buffer.extend(batch["input_ids"])
while len(buffer) >= batch_size:
to_pack = buffer[:batch_size]
buffer = buffer[batch_size:]
packed = pack_sequences(
to_pack,
max_length=max_length,
pad_token_id=0
)
yield packed
Complete Training Loop¶
import torch
from fast_axolotl import pack_sequences
def train_with_packing(model, sequences, max_length=2048):
optimizer = torch.optim.AdamW(model.parameters())
# Pack all sequences
packed = pack_sequences(
sequences,
max_length=max_length,
pad_token_id=model.config.pad_token_id
)
# Training loop
for i in range(0, len(packed), batch_size):
batch = packed[i:i+batch_size]
outputs = model(batch, labels=batch)
loss = outputs.loss
loss.backward()
optimizer.step()
optimizer.zero_grad()
Performance Considerations¶
When Packing Helps¶
- Short sequences (< 25% of max length)
- Variable-length datasets
- High padding ratios
When Packing May Not Help¶
- Already long sequences
- Uniform sequence lengths
- Very small batch sizes
Benchmarks¶
| Scenario | Without Packing | With Packing | Improvement |
|---|---|---|---|
| Avg length 256, max 2048 | 12.5% utilization | 80%+ utilization | 6.4x |
| Avg length 512, max 2048 | 25% utilization | 85%+ utilization | 3.4x |
| Avg length 1024, max 2048 | 50% utilization | 90%+ utilization | 1.8x |
Common Patterns¶
Packing with Labels¶
# Pack both inputs and labels together
packed_inputs = pack_sequences(input_ids, max_length=2048, pad_token_id=0)
packed_labels = pack_sequences(labels, max_length=2048, pad_token_id=-100)
Packing for Different Model Sizes¶
# Adjust max_length based on model
model_configs = {
"7b": 4096,
"13b": 4096,
"70b": 8192,
}
max_length = model_configs[model_size]
packed = pack_sequences(sequences, max_length=max_length, pad_token_id=0)
Troubleshooting¶
Sequences Longer Than Max Length¶
# Filter or truncate long sequences first
sequences = [s[:max_length] for s in sequences if len(s) > 0]
packed = pack_sequences(sequences, max_length=max_length, pad_token_id=0)
Memory Issues¶
# Process in smaller chunks
chunk_size = 10000
all_packed = []
for i in range(0, len(sequences), chunk_size):
chunk = sequences[i:i+chunk_size]
packed = pack_sequences(chunk, max_length=2048, pad_token_id=0)
all_packed.append(packed)
final = torch.cat(all_packed, dim=0)
Next Steps¶
- Batch Padding - Efficient batch preprocessing
- API Reference - Complete API docs
- Best Practices - Optimization tips