Best Practices¶
This guide covers optimization strategies for getting the most performance from fast-axolotl.
Data Format Selection¶
Use Parquet for Best Performance¶
Parquet provides the best streaming performance due to:
- Columnar storage (only read needed columns)
- Efficient compression (ZSTD recommended)
- Row group organization for batch reading
# Convert to Parquet with optimal settings
import pyarrow as pa
import pyarrow.parquet as pq
table = pa.Table.from_pandas(df)
pq.write_table(
table,
"data.parquet",
compression="zstd",
row_group_size=10000 # Tune based on batch_size
)
Format Decision Tree¶
Streaming Optimization¶
Batch Size Tuning¶
| Memory Available | Recommended Batch Size |
|---|---|
| < 8 GB | 100-500 |
| 8-16 GB | 500-2000 |
| 16-32 GB | 1000-5000 |
| > 32 GB | 2000-10000 |
# Start conservative, increase if memory allows
for batch_size in [500, 1000, 2000, 5000]:
try:
for batch in streaming_dataset_reader(path, batch_size=batch_size):
process(batch)
print(f"batch_size={batch_size} works")
break
except MemoryError:
print(f"batch_size={batch_size} too large")
Column Selection¶
Always specify only the columns you need:
# Good - only loads needed columns
reader = streaming_dataset_reader(
"data.parquet",
columns=["input_ids", "attention_mask", "labels"]
)
# Bad - loads all columns including unused ones
reader = streaming_dataset_reader("data.parquet")
File Organization¶
For large datasets, split into multiple files:
Benefits:
- Parallel file processing
- Better memory management
- Easier data versioning
Token Packing Optimization¶
When to Use Packing¶
Use packing when:
- Average sequence length < 50% of max length
- High variance in sequence lengths
- Training on concatenated documents
Skip packing when:
- Sequences are already near max length
- Uniform sequence lengths
- Very small batch sizes
Packing Strategy¶
from fast_axolotl import pack_sequences
def efficient_packing(sequences, max_length, pad_token_id):
# Sort by length for better packing efficiency
sorted_seqs = sorted(sequences, key=len)
# Pack
packed = pack_sequences(sorted_seqs, max_length, pad_token_id)
return packed
Memory-Efficient Packing¶
For large datasets, pack in chunks:
def chunked_packing(sequences, max_length, pad_token_id, chunk_size=10000):
all_packed = []
for i in range(0, len(sequences), chunk_size):
chunk = sequences[i:i+chunk_size]
packed = pack_sequences(chunk, max_length, pad_token_id)
all_packed.append(packed)
return torch.cat(all_packed, dim=0)
Parallel Hashing Optimization¶
Optimal Row Size¶
Parallel hashing works best with:
- Row sizes between 100 bytes and 10 KB
- Large number of rows (1000+)
# For very small rows, batch them
def batch_hash(items, items_per_row=10):
batched = [
b"".join(items[i:i+items_per_row])
for i in range(0, len(items), items_per_row)
]
return parallel_hash_rows(batched)
Streaming Deduplication¶
For datasets too large for memory:
def streaming_dedupe(path, output_path, chunk_size=100000):
seen_hashes = set()
with open(output_path, "w") as out:
for batch in streaming_dataset_reader(path, batch_size=chunk_size):
rows = [str(r).encode() for r in batch["text"]]
hashes = parallel_hash_rows(rows)
for i, h in enumerate(hashes):
if h not in seen_hashes:
seen_hashes.add(h)
out.write(batch["text"][i] + "\n")
Batch Padding Optimization¶
Dynamic vs Static Batching¶
Dynamic batching (pad to longest in batch):
- Less wasted compute
- Variable memory usage
- Slightly more overhead
Static batching (pad to fixed length):
- Consistent memory usage
- Better for caching
- May waste compute on short sequences
Length-Sorted Batching¶
Minimize padding by sorting sequences by length:
def create_sorted_batches(sequences, batch_size, pad_value):
# Sort by length
sorted_idx = sorted(range(len(sequences)), key=lambda i: len(sequences[i]))
sorted_seqs = [sequences[i] for i in sorted_idx]
# Create batches
batches = []
for i in range(0, len(sorted_seqs), batch_size):
batch = sorted_seqs[i:i+batch_size]
padded = pad_sequences(batch, pad_value=pad_value)
batches.append(padded)
return batches
Memory Management¶
Reduce Peak Memory¶
import gc
def process_large_dataset(path):
for batch in streaming_dataset_reader(path, batch_size=1000):
result = process(batch)
save(result)
# Explicitly free batch memory
del batch
gc.collect()
Monitor Memory Usage¶
import psutil
def memory_efficient_processing(path):
process = psutil.Process()
for batch in streaming_dataset_reader(path):
mem_before = process.memory_info().rss / 1024 / 1024
result = process(batch)
mem_after = process.memory_info().rss / 1024 / 1024
if mem_after > mem_before + 100: # More than 100MB growth
gc.collect()
Integration Patterns¶
With PyTorch DataLoader¶
from torch.utils.data import DataLoader
from fast_axolotl import create_rust_streaming_dataset
# Optimal DataLoader settings for fast-axolotl
loader = DataLoader(
create_rust_streaming_dataset("data.parquet", batch_size=32),
batch_size=None, # Dataset handles batching
num_workers=0, # Rust handles parallelism
pin_memory=True, # Fast GPU transfer
prefetch_factor=None # Not needed with streaming
)
With HuggingFace Trainer¶
import fast_axolotl
from transformers import Trainer
# Enable shimming before creating trainer
fast_axolotl.install()
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
# fast-axolotl accelerates internal operations
)
Profiling¶
Identify Bottlenecks¶
import time
def profile_pipeline(path):
timings = {}
# Profile streaming
start = time.time()
data = list(streaming_dataset_reader(path, batch_size=1000))
timings["streaming"] = time.time() - start
# Profile packing
start = time.time()
packed = pack_sequences(data, max_length=2048, pad_token_id=0)
timings["packing"] = time.time() - start
# Profile padding
start = time.time()
padded = pad_sequences(data, pad_value=0)
timings["padding"] = time.time() - start
return timings
Compare With and Without fast-axolotl¶
def compare_performance():
import fast_axolotl
# With fast-axolotl
fast_axolotl.install()
start = time.time()
run_pipeline()
with_fa = time.time() - start
# Without fast-axolotl
fast_axolotl.uninstall()
start = time.time()
run_pipeline()
without_fa = time.time() - start
print(f"Speedup: {without_fa / with_fa:.1f}x")
Common Pitfalls¶
1. Not Using Shimming¶
# Wrong - import axolotl before install()
import axolotl
import fast_axolotl
fast_axolotl.install() # Too late!
# Right - install before importing axolotl
import fast_axolotl
fast_axolotl.install()
import axolotl
2. Loading All Columns¶
# Wrong - loads unused columns
reader = streaming_dataset_reader("data.parquet")
# Right - specify needed columns
reader = streaming_dataset_reader("data.parquet", columns=["input_ids", "labels"])
3. Small Batch Sizes¶
# Wrong - too small, high overhead per batch
reader = streaming_dataset_reader("data.parquet", batch_size=10)
# Right - larger batches for better throughput
reader = streaming_dataset_reader("data.parquet", batch_size=1000)
See Also¶
- Benchmarks - Performance comparisons
- Streaming Guide - Detailed streaming usage
- API Reference - Complete API documentation