Overview
PyTorch is the dominant framework for deep learning research and production.
Core Concepts
- Autograd (Automatic Differentiation)
- Tensors and GPU execution
Distributed Training
Data Parallelism
- DataParallel (DP): Single-process, multi-threaded. Each GPU gets a copy of the model, processes different microbatches, and averages gradients. Warning: Suffers from Python’s GIL contention and slow inter-GPU communication.
- Distributed Data Parallel (DDP): Multi-process (one per GPU). Uses the Ring All-Reduce algorithm to avoid central bottlenecks.
Synchronization Approaches
At the end of each minibatch, workers synchronize to avoid staleness:
- Bulk Synchronous Parallel (BSP): Workers synchronize at the end of every minibatch. Prevents stale weights but forces machines to halt and wait for each other.
- Asynchronous Parallel (ASP): Every GPU processes data asynchronously without waiting. Increases raw computation speed but can lead to stale weights being used, lowering statistical learning efficiency.
ZeRO (Zero Redundancy Optimizer) / FSDP
Instead of just replicating the entire model across all GPUs (which wastes massive memory), ZeRO shards the components:
- Stage 1 (Optimizer State Partitioning): 4x memory reduction. Splits Adam momentum and variance across GPUs.
- Stage 2 (Gradient Partitioning): 8x memory reduction.
- Stage 3 (Parameter Partitioning / FSDP): Slices the actual model weights across GPUs. Memory reduction is linear with the number of GPUs.
Communication Primitives
- All-Reduce: Everyone ends up with the sum of all data (used for syncing gradients).
- Ring All-Reduce: GPUs pass data in a ring. Overhead is
2 × (N-1) × X/N. - All-Gather: Each process gathers data chunks from everyone else so everyone has the complete picture.
- Reduce-Scatter: Each process sums a chunk of data across all processes and keeps only its own chunk.
Additional Resources: