Batch Padding¶
fast-axolotl provides optimized batch padding for efficient model training, automatically handling variable-length sequences.
Overview¶
Batch padding ensures all sequences in a batch have the same length, which is required for efficient tensor operations. fast-axolotl's Rust implementation is optimized for:
- Fast memory allocation
- Vectorized operations
- Minimal Python overhead
Basic Usage¶
Padding Sequences¶
from fast_axolotl import pad_sequences
sequences = [
[1, 2, 3],
[4, 5],
[6, 7, 8, 9, 10],
]
# Pad to longest sequence
padded = pad_sequences(sequences, pad_value=0)
print(padded)
# [[1, 2, 3, 0, 0],
# [4, 5, 0, 0, 0],
# [6, 7, 8, 9, 10]]
Padding Side¶
Control whether padding is added to the left or right:
# Right padding (default) - for causal/decoder models
padded_right = pad_sequences(sequences, pad_value=0, padding_side="right")
# [[1, 2, 3, 0, 0], ...]
# Left padding - for encoder models or specific use cases
padded_left = pad_sequences(sequences, pad_value=0, padding_side="left")
# [[0, 0, 1, 2, 3], ...]
Creating Attention Masks¶
Basic Mask Creation¶
from fast_axolotl import create_padding_mask
sequences = [
[1, 2, 3],
[4, 5],
]
# Pad sequences
padded = pad_sequences(sequences, pad_value=0)
# Create attention mask (1 for real tokens, 0 for padding)
mask = create_padding_mask(padded, pad_value=0)
print(mask)
# [[1, 1, 1, 0, 0],
# [1, 1, 0, 0, 0]]
Position IDs¶
Generate position IDs that account for padding:
from fast_axolotl import create_padding_mask
def create_position_ids(attention_mask):
"""Create position IDs from attention mask."""
position_ids = attention_mask.cumsum(dim=-1) - 1
position_ids = position_ids.masked_fill(attention_mask == 0, 0)
return position_ids
mask = create_padding_mask(padded, pad_value=0)
positions = create_position_ids(mask)
Padding for Training¶
Complete Batch Preparation¶
from fast_axolotl import pad_sequences, create_padding_mask
import torch
def prepare_batch(input_ids_list, label_ids_list, tokenizer):
"""Prepare a batch for training."""
# Pad inputs
input_ids = pad_sequences(
input_ids_list,
pad_value=tokenizer.pad_token_id,
padding_side="right"
)
# Pad labels (use -100 to ignore in loss)
labels = pad_sequences(
label_ids_list,
pad_value=-100,
padding_side="right"
)
# Create attention mask
attention_mask = create_padding_mask(input_ids, pad_value=tokenizer.pad_token_id)
return {
"input_ids": torch.tensor(input_ids),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(labels),
}
With Max Length¶
Truncate and pad to a specific length:
def pad_to_max_length(sequences, max_length, pad_value):
"""Pad or truncate sequences to exact length."""
# Truncate if needed
truncated = [seq[:max_length] for seq in sequences]
# Pad to max_length
padded = pad_sequences(
truncated,
pad_value=pad_value,
max_length=max_length
)
return padded
# Ensure all sequences are exactly 2048 tokens
padded = pad_to_max_length(sequences, max_length=2048, pad_value=0)
Dynamic vs Static Batching¶
Dynamic Batching (Default)¶
Pad to the longest sequence in each batch:
# Batch 1: max length might be 256
batch1 = pad_sequences(sequences_batch1, pad_value=0)
# Batch 2: max length might be 512
batch2 = pad_sequences(sequences_batch2, pad_value=0)
Static Batching¶
Pad all batches to the same length for consistent memory usage:
MAX_LENGTH = 2048
batch1 = pad_sequences(sequences_batch1, pad_value=0, max_length=MAX_LENGTH)
batch2 = pad_sequences(sequences_batch2, pad_value=0, max_length=MAX_LENGTH)
Integration with Data Loaders¶
Custom Collate Function¶
from torch.utils.data import DataLoader
from fast_axolotl import pad_sequences, create_padding_mask
import torch
def collate_fn(batch):
"""Custom collate with fast-axolotl padding."""
input_ids = [item["input_ids"] for item in batch]
labels = [item["labels"] for item in batch]
# Fast padding
padded_inputs = pad_sequences(input_ids, pad_value=0)
padded_labels = pad_sequences(labels, pad_value=-100)
attention_mask = create_padding_mask(padded_inputs, pad_value=0)
return {
"input_ids": torch.tensor(padded_inputs),
"attention_mask": torch.tensor(attention_mask),
"labels": torch.tensor(padded_labels),
}
# Use in DataLoader
loader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
With HuggingFace Trainer¶
from transformers import Trainer, TrainingArguments
from fast_axolotl import pad_sequences
class FastCollator:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, features):
input_ids = [f["input_ids"] for f in features]
labels = [f["labels"] for f in features]
padded_inputs = pad_sequences(input_ids, pad_value=self.tokenizer.pad_token_id)
padded_labels = pad_sequences(labels, pad_value=-100)
return {
"input_ids": torch.tensor(padded_inputs),
"labels": torch.tensor(padded_labels),
}
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=FastCollator(tokenizer),
)
Performance Tips¶
1. Batch Similar Lengths Together¶
Sorting by length reduces padding waste:
def create_length_sorted_batches(sequences, batch_size):
"""Create batches sorted by length to minimize padding."""
# Sort by length
sorted_idx = sorted(range(len(sequences)), key=lambda i: len(sequences[i]))
sorted_sequences = [sequences[i] for i in sorted_idx]
# Create batches
batches = []
for i in range(0, len(sorted_sequences), batch_size):
batch = sorted_sequences[i:i+batch_size]
padded = pad_sequences(batch, pad_value=0)
batches.append(padded)
return batches
2. Use Appropriate Batch Sizes¶
| Sequence Length | Recommended Batch Size |
|---|---|
| < 512 | 32-64 |
| 512-2048 | 8-32 |
| 2048-4096 | 4-16 |
| > 4096 | 1-8 |
3. Avoid Excessive Padding¶
Monitor your padding ratio:
def padding_ratio(padded_batch, pad_value):
"""Calculate percentage of padding in batch."""
total = padded_batch.numel()
padding = (padded_batch == pad_value).sum().item()
return padding / total
ratio = padding_ratio(batch["input_ids"], pad_value=0)
if ratio > 0.5:
print(f"Warning: {ratio:.1%} padding - consider smaller max_length")
Next Steps¶
- Token Packing - Reduce padding with sequence packing
- API Reference - Complete API docs
- Best Practices - Optimization strategies