Chapter 3: FSDP, ZeRO, NCCL
๋ถ์ฐ ํ์ต์ ๋ฉ๋ชจ๋ฆฌ ํจ์จํ ๊ธฐ์ (ZeRO, FSDP)๊ณผ GPU ๊ฐ ํต์ ์ ํต์ฌ์ธ NCCL์ ์ฌ์ธต์ ์ผ๋ก ์ดํดํฉ๋๋ค. ์ด ๊ธฐ์ ๋ค์ Checkpointless Training์ ๊ธฐ๋ฐ์ด ๋ฉ๋๋ค.
1. ZeRO ๊ฐ์
DDP์ ๋ฉ๋ชจ๋ฆฌ ์ค๋ณต ๋ฌธ์
ํ์ค Data Parallelism (DDP)์ ๊ฐ GPU๊ฐ ๋์ผํ ๋ชจ๋ธ ๋ณต์ฌ๋ณธ(Model Replica)์ ๊ฐ์ง๊ณ , ๋ฐ์ดํฐ๋ง ๋ถ์ฐํ์ฌ ์ฒ๋ฆฌํฉ๋๋ค. ๊ฐ GPU๋ Forward/Backward pass๋ฅผ ๋
๋ฆฝ์ ์ผ๋ก ์ํํ ๋ค ๊ทธ๋๋์ธํธ๋ฅผ All-Reduce๋ก ๋๊ธฐํํฉ๋๋ค.
ํ์ต ์ํ(Training State) ๊ตฌ์ฑ
Mixed Precision Training (BF16/FP16 + FP32 Master Weights)๊ณผ Adam ์ตํฐ๋ง์ด์ ๋ฅผ ์ฌ์ฉํ ๋, ๋จ์ผ ํ๋ผ๋ฏธํฐ๋น ํ์ํ ๋ฉ๋ชจ๋ฆฌ๋ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
| ๊ตฌ์ฑ ์์ | ๋ฐ์ดํฐ ํ์ | ํ๋ผ๋ฏธํฐ๋น ํฌ๊ธฐ | ์ค๋ช |
|---|---|---|---|
| Model Weights | BF16/FP16 | 2 Bytes | Forward/Backward ์ฐ์ฐ์ฉ ๊ฐ์ค์น |
| Gradients | BF16/FP16 | 2 Bytes | Backward pass์์ ๊ณ์ฐ๋ ๊ทธ๋๋์ธํธ |
| Master Weights | FP32 | 4 Bytes | ์ตํฐ๋ง์ด์ ์ ๋ฐ์ดํธ์ฉ ๊ณ ์ ๋ฐ ๊ฐ์ค์น |
| Momentum (1st moment) | FP32 | 4 Bytes | Adam์ 1์ฐจ ๋ชจ๋ฉํธ (ํ๊ท ) |
| Variance (2nd moment) | FP32 | 4 Bytes | Adam์ 2์ฐจ ๋ชจ๋ฉํธ (๋ถ์ฐ) |
| ํฉ๊ณ | - | 16 Bytes | ํ๋ผ๋ฏธํฐ 1๊ฐ๋น ์ด ๋ฉ๋ชจ๋ฆฌ |
์: 70B ๋ชจ๋ธ = 70,000,000,000 x 16 = 1,120 GB (1.12 TB)
์ด 1.12TB๋ฅผ ๋จ์ผ GPU(80GB VRAM)์ ์ฌ๋ฆฌ๋ ๊ฒ์ ๋ถ๊ฐ๋ฅํฉ๋๋ค. ZeRO๋ ์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํฉ๋๋ค.
2. ZeRO Stage 1: Optimizer State Partitioning
๋์ ์๋ฆฌ
ZeRO Stage 1 ($P_{os}$)์ ๊ฐ์ฅ ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋ง์ด ์ฐจ์งํ๋ ์ตํฐ๋ง์ด์ ์ํ(Optimizer States)๋ง GPU๋ค์ ๋ถ์ฐ(Sharding)ํฉ๋๋ค.
- ๊ฐ GPU๋ ์ ์ฒด ํ๋ผ๋ฏธํฐ ์ค ์์ ์ด ๋งก์ ๋ถ๋ถ์ ์ตํฐ๋ง์ด์ ์ํ๋ง ์ ์ง
- ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ์ ๊ทธ๋๋์ธํธ๋ ๋ชจ๋ GPU์ ๋ณต์ (DDP์ ๋์ผ)
- ์ตํฐ๋ง์ด์ ์คํ
ํ, ์
๋ฐ์ดํธ๋ ํ๋ผ๋ฏธํฐ๋ฅผ
All-Gather๋ก ๋๊ธฐํ
๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ ๊ณ์ฐ
N๊ฐ GPU ๋ถ์ฐ ์: 12/N Bytes/param
8 GPU ์์: 12/8 = 1.5 Bytes/param (๊ธฐ์กด 12B์์ 8๋ฐฐ ์ ๊ฐ)
DeepSpeed ZeRO Stage 1 Config JSON
{
"zero_optimization": {
"stage": 1,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"initial_scale_power": 16,
"hysteresis": 2,
"min_loss_scale": 1
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-4,
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01
}
}
}
3. ZeRO Stage 2: + Gradient Partitioning
๋์ ์๋ฆฌ
ZeRO Stage 2 ($P_{os+g}$)๋ Stage 1์ ์ถ๊ฐ๋ก ๊ทธ๋๋์ธํธ(Gradients)๋ ๋ถ์ฐํฉ๋๋ค.
- Backward pass ํ, ์ ์ฒด ๊ทธ๋๋์ธํธ๋ฅผ
Reduce-Scatter์ฐ์ฐ์ผ๋ก ํฉ์ฐ + ๋ถ๋ฐฐ - ๊ฐ GPU๋ ์์ ์ด ๋ด๋นํ๋ ํ๋ผ๋ฏธํฐ์ ๊ทธ๋๋์ธํธ๋ง ์ ์ง
- ์ตํฐ๋ง์ด์ ์คํ ์ ๊ฐ GPU๊ฐ ์์ ์ ํ๋ผ๋ฏธํฐ ์กฐ๊ฐ์ ๋ํด์๋ง ์ํ
Reduce-Scatter ์ฐ์ฐ
Reduce-Scatter๋ All-Reduce๋ฅผ ๋ ๋จ๊ณ๋ก ๋ถ๋ฆฌํ ๊ฒ ์ค ์ฒซ ๋ฒ์งธ์
๋๋ค:
Full Gradients
(ํฉ์ฐ)
(๋ถ๋ฐฐ)
1/N Gradient Shard
๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ
Stage 2: + Gradients ๋ถ์ฐ = 2/N Bytes/param ์ถ๊ฐ ์ ๊ฐ
์ด: (12 + 2)/N = 14/N Bytes/param (ํ๋ผ๋ฏธํฐ ๋ณต์ ์ ์ธ)
DeepSpeed ZeRO Stage 2 Config JSON
{
"zero_optimization": {
"stage": 2,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"allgather_bucket_size": 5e8
},
"fp16": {
"enabled": true,
"loss_scale": 0,
"initial_scale_power": 16
},
"gradient_clipping": 1.0,
"train_batch_size": 32,
"train_micro_batch_size_per_gpu": 4
}
4. ZeRO Stage 3: + Parameter Partitioning
๋์ ์๋ฆฌ
ZeRO Stage 3 ($P_{os+g+p}$)๋ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ(Parameters)๊น์ง ๋ชจ๋ ๋ถ์ฐํฉ๋๋ค. ์ด๊ฒ์ด PyTorch FSDP์ ๊ธฐ์ ์ ์ผ๋ก ๊ฑฐ์ ๋์ผํ ๊ตฌํ์ ๋๋ค.
- ๊ฐ GPU๋ ๋ชจ๋ธ์ 1/N ์กฐ๊ฐ๋ง ๋ฉ๋ชจ๋ฆฌ์ ์์ฃผ
- ์ฐ์ฐ์ด ํ์ํ ๋๋ง ๋ค๋ฅธ GPU๋ก๋ถํฐ ํ๋ผ๋ฏธํฐ๋ฅผ
All-Gather๋ก ๊ฐ์ ธ์ด - ์ฐ์ฐ ํ ์ฆ์ ํด์ ํ์ฌ ๋ฉ๋ชจ๋ฆฌ ํ๋ณด
All-Gather on Demand
Forward pass ์ ํน์ ๋ ์ด์ด ์ฐ์ฐ ์ง์ ์ ํด๋น ๋ ์ด์ด์ ํ๋ผ๋ฏธํฐ๋ฅผ ๋ชจ๋ GPU๋ก๋ถํฐ ์์งํฉ๋๋ค:
Shard 0
Shard 1
Shard 2
Shard 3
(์์)
๋ฉ๋ชจ๋ฆฌ ์ ๊ฐ (N-fold)
์ด ๋ฉ๋ชจ๋ฆฌ/GPU = (Parameters 2B + Gradients 2B + Optimizer 12B) / N = 16/N Bytes/param
256 GPU ์์: 70B ๋ชจ๋ธ
= 70B x 16 / 256 = 4.375 GB/GPU (๊ธฐ์กด 1.12TB์์ 256๋ฐฐ ์ ๊ฐ)
DeepSpeed ZeRO Stage 3 Config JSON
{
"zero_optimization": {
"stage": 3,
"overlap_comm": true,
"contiguous_gradients": true,
"reduce_bucket_size": 5e8,
"stage3_prefetch_bucket_size": 5e8,
"stage3_param_persistence_threshold": 1e6,
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"bf16": {
"enabled": true
},
"gradient_clipping": 1.0,
"train_batch_size": 256,
"train_micro_batch_size_per_gpu": 1,
"gradient_accumulation_steps": 32
}
5. ZeRO-Infinity: NVMe Offloading
๊ฐ๋
ZeRO-Infinity๋ ZeRO Stage 3์ NVMe SSD ์คํ๋ก๋ฉ์ ์ถ๊ฐํ์ฌ GPU ๋ฉ๋ชจ๋ฆฌ๋ฅผ ๋์ด ์์คํ ๋ฉ๋ชจ๋ฆฌ์ NVMe ์คํ ๋ฆฌ์ง๊น์ง ํ์ฉํฉ๋๋ค.
Memory Pool Hierarchy
(TB๊ธ)
(์๋ฐฑ GB)
(80 GB)
- Offload Optimizer: ์ตํฐ๋ง์ด์ ์ํ๋ฅผ CPU ๋ฉ๋ชจ๋ฆฌ๋ก ์คํ๋ก๋
- Offload Param: ํ๋ผ๋ฏธํฐ๊น์ง CPU/NVMe๋ก ์คํ๋ก๋
- NVMe Offload: CPU ๋ฉ๋ชจ๋ฆฌ๋ ๋ถ์กฑํ ๋ NVMe SSD ํ์ฉ
- ๋น๋๊ธฐ I/O(aio)๋ฅผ ํตํ prefetch๋ก ์ฑ๋ฅ ์ ํ ์ต์ํ
DeepSpeed ZeRO-Infinity Config JSON (aio ์ค์ ํฌํจ)
{
"zero_optimization": {
"stage": 3,
"offload_optimizer": {
"device": "nvme",
"nvme_path": "/local_nvme",
"pin_memory": true,
"buffer_count": 5,
"fast_init": false
},
"offload_param": {
"device": "nvme",
"nvme_path": "/local_nvme",
"pin_memory": true,
"buffer_count": 5,
"buffer_size": 1e8,
"max_in_cpu": 1e9
},
"overlap_comm": true,
"contiguous_gradients": true,
"sub_group_size": 1e9,
"reduce_bucket_size": "auto",
"stage3_prefetch_bucket_size": "auto",
"stage3_param_persistence_threshold": "auto",
"stage3_max_live_parameters": 1e9,
"stage3_max_reuse_distance": 1e9,
"stage3_gather_16bit_weights_on_model_save": true
},
"aio": {
"block_size": 1048576,
"queue_depth": 8,
"thread_count": 1,
"single_submit": false,
"overlap_events": true
},
"bf16": {
"enabled": true
},
"train_batch_size": 512,
"gradient_accumulation_steps": 64
}
6. FSDP ๋์ ์๋ฆฌ
FSDP๋?
FSDP (Fully Sharded Data Parallel)๋ PyTorch์ ZeRO Stage 3 ๋ค์ดํฐ๋ธ ๊ตฌํ์ ๋๋ค. ํ๋ผ๋ฏธํฐ, ๊ทธ๋๋์ธํธ, ์ตํฐ๋ง์ด์ ์ํ๋ฅผ ๋ชจ๋ GPU์ ๋ถ์ฐ(Shard)ํฉ๋๋ค.
Forward Pass
# FSDP Forward Pass ๋์ ๊ณผ์
# 1. At Rest: ๊ฐ GPU๋ 1/N ํ๋ผ๋ฏธํฐ๋ง ๋ณด์
GPU_0: [Shard_0] GPU_1: [Shard_1] GPU_2: [Shard_2] GPU_3: [Shard_3]
# 2. Before Forward: All-Gather๋ก ์ ์ฒด ํ๋ผ๋ฏธํฐ ์ฌ๊ตฌ์ฑ
All-Gather() โ ๋ชจ๋ GPU: [Full Parameters]
# 3. Forward Compute: ์ ์ฒด ํ๋ผ๋ฏธํฐ๋ก ์ฐ์ฐ ์ํ
output = layer(input) # with full parameters
# 4. After Forward: ์ฌ์ฉํ ํ๋ผ๋ฏธํฐ ํด์ (๋ฉ๋ชจ๋ฆฌ ํ๋ณด)
del full_parameters # keep only local shard
Backward Pass
# FSDP Backward Pass ๋์ ๊ณผ์
# 1. Before Backward: ๋ค์ All-Gather๋ก ํ๋ผ๋ฏธํฐ ์ฌ๊ตฌ์ฑ
All-Gather() โ ๋ชจ๋ GPU: [Full Parameters]
# 2. Backward Compute: ๊ทธ๋๋์ธํธ ๊ณ์ฐ
gradients = backward(loss)
# 3. After Backward: Reduce-Scatter๋ก ๊ทธ๋๋์ธํธ ํฉ์ฐ + ๋ถ๋ฐฐ
Reduce-Scatter(gradients) โ ๊ฐ GPU: [1/N Gradient Shard]
# 4. Optimizer Step: ๊ฐ GPU๊ฐ ์์ ์ ์ค๋๋ง ์
๋ฐ์ดํธ
optimizer.step(local_shard) # only 1/N of parameters
์ ์ฒด ํ๋ฆ ๋ค์ด์ด๊ทธ๋จ
Params
- Sharded State: ํ๋ผ๋ฏธํฐ๊ฐ DTensor ์กฐ๊ฐ์ผ๋ก ๋ถ์ฐ - ๋จ์ผ GPU์ ์ ์ฒด ๋ชจ๋ธ์ด ์์
- Reconstruction Required: "Full" checkpoint ์ ์ฅ ์ All-Gather ํ์
- Memory Spike: ์ ์ฒด ์ํ ์์ง ์ ์ผ์์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋ 2๋ฐฐ
- Coordination: ๋ชจ๋ ๋ญํฌ๊ฐ ์ฒดํฌํฌ์ธํธ ์ ์ฅ ์ ๋๊ธฐํ ํ์
7. FSDP Sharding Strategies
| Strategy | ๋์ | ๋ฉ๋ชจ๋ฆฌ | ์ฑ๋ฅ | ์ฌ์ฉ ์ฌ๋ก |
|---|---|---|---|---|
FULL_SHARD |
Forward ํ ํ๋ผ๋ฏธํฐ ํด์ | ์ต์ | ํต์ ๋ง์ | ๋ฉ๋ชจ๋ฆฌ ๊ทนํ ์ํฉ |
SHARD_GRAD_OP |
Forward ์ค ํ๋ผ๋ฏธํฐ ์ ์ง | ๋์ | ํต์ ์ ์ | ๋ฉ๋ชจ๋ฆฌ ์ฌ์ ์์ ๋ |
HYBRID_SHARD |
๋ ธ๋ ๋ด ์ค๋ฉ, ๋ ธ๋ ๊ฐ ๋ณต์ | ๊ท ํ | ์ต์ ํ๋จ | ๋ฉํฐ๋ ธ๋ ๋๊ท๋ชจ ํ์ต |
NO_SHARD |
์ค๋ฉ ์์ (DDP์ ๋์ผ) | ์ต๋ | ํต์ ์ต์ | ๋๋ฒ๊น , ์์ ๋ชจ๋ธ |
HYBRID_SHARD ์์ธ ์ค๋ช
HYBRID_SHARD๋ ๋คํธ์ํฌ ํ ํด๋ก์ง๋ฅผ ์ต์ ํํฉ๋๋ค:
- ๋ ธ๋ ๋ด๋ถ (Intra-node): NVLink๋ฅผ ํตํด ๋น ๋ฅธ All-Gather/Reduce-Scatter
- ๋ ธ๋ ๊ฐ (Inter-node): ๋ชจ๋ธ ๋ณต์ ๋ก ๋คํธ์ํฌ ํต์ ์ต์ํ
์: 8 GPU/node x 32 nodes = 256 GPU ํด๋ฌ์คํฐ์์, ๊ฐ ๋ ธ๋ ๋ด 8 GPU๋ FSDP๋ก ์ค๋ฉํ๊ณ , 32๊ฐ ๋ ธ๋ ๊ฐ์๋ DDP์ฒ๋ผ ๊ทธ๋๋์ธํธ๋ง ๋๊ธฐํํฉ๋๋ค.
8. FSDP2 vs FSDP1
์ฃผ์ ์ฐจ์ด์
| Feature | FSDP1 | FSDP2 |
|---|---|---|
| ๊ธฐ๋ฐ ๊ธฐ์ | FlatParameter | DTensor |
| API | FSDP(module) wrapper | fully_shard(module) ํจ์ |
| ์ ์ฐ์ฑ | ๋ชจ๋ ๋จ์ | ํ๋ผ๋ฏธํฐ ๋จ์ |
| ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ | ์๋ ์กฐ์ ํ์ | ์๋ ์ต์ ํ |
| ์ถ์ฒ ๋ฒ์ | PyTorch 1.x ~ 2.3 | PyTorch 2.4+ |
DTensor๋?
DTensor (Distributed Tensor)๋ PyTorch 2.0์์ ๋์ ๋ ๋ถ์ฐ ํ ์ ์ถ์ํ์ ๋๋ค. ํ ์๊ฐ ์ฌ๋ฌ ๋๋ฐ์ด์ค์ ์ด๋ป๊ฒ ๋ถ์ฐ๋์ด ์๋์ง๋ฅผ ๋ฉํ๋ฐ์ดํฐ๋ก ๊ด๋ฆฌํฉ๋๋ค.
FSDP2 ์ฝ๋ ์์
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed.device_mesh import init_device_mesh
# Device Mesh ์ด๊ธฐํ (2D: DP x TP)
mesh = init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("dp", "tp"))
# Mixed Precision ์ ์ฑ
mp_policy = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
# FSDP2: ๊ฐ๋ณ ๋ชจ๋์ fully_shard ์ ์ฉ
for layer in model.transformer.layers:
fully_shard(layer, mesh=mesh["dp"], mp_policy=mp_policy)
# ์ต์์ ๋ชจ๋์๋ ์ ์ฉ
fully_shard(model, mesh=mesh["dp"], mp_policy=mp_policy)
# ์ด์ model์ FSDP2๋ก ์ค๋ฉ๋จ
output = model(input_ids)
9. ๋ฉ๋ชจ๋ฆฌ ๊ณ์ฐ ์์
70B ๋ชจ๋ธ ์๋๋ฆฌ์ค๋ณ ๋ฉ๋ชจ๋ฆฌ
| ์๋๋ฆฌ์ค | GPU ์ | ๋ชจ๋ธ ์ํ ๋ฉ๋ชจ๋ฆฌ/GPU | ๊ฐ๋ฅ ์ฌ๋ถ |
|---|---|---|---|
| ๋จ์ผ GPU (DDP) | 1 | 70B x 16B = 1,120 GB | ๋ถ๊ฐ๋ฅ (80GB VRAM ์ด๊ณผ) |
| 8 GPU ZeRO-3 | 8 | 1,120 / 8 = 140 GB | ๋ถ๊ฐ๋ฅ |
| 32 GPU ZeRO-3 | 32 | 1,120 / 32 = 35 GB | ๊ฐ๋ฅ (+ Activation ๋ฉ๋ชจ๋ฆฌ ํ์) |
| 256 GPU ZeRO-3 | 256 | 1,120 / 256 = 4.375 GB | ์ฌ์ |
Activation Memory ๊ณต์
Forward pass ์ Backward๋ฅผ ์ํด ์ค๊ฐ ์ฐ์ฐ ๊ฒฐ๊ณผ(Activation)๋ฅผ ์ ์ฅํด์ผ ํฉ๋๋ค:
์: Llama 70B (hidden=8192, layers=80, BF16)
batch=1, seq=4096: 1 x 4096 x 8192 x 80 x 2 โ 5.4 GB
batch=4, seq=4096: 4 x 4096 x 8192 x 80 x 2 โ 21.5 GB
10. NCCL Collective Operations
NCCL์ด๋?
NCCL (NVIDIA Collective Communications Library)์ ๋ถ์ฐ GPU ํ์ต์์ GPU ๊ฐ ํต์ ์ ๋ด๋นํ๋ ๊ณ ์ฑ๋ฅ ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋๋ค. "๋คํธ์ํฌ ์คํ"์ฒ๋ผ GPU๋ค์ด ์๋ก ๋ํํ๋ ๋ฐฉ๋ฒ์ ์ ์ํฉ๋๋ค.
์ฃผ์ Collective Operations
All-Reduce
๋ชจ๋ GPU๊ฐ ๊ธฐ์ฌํ๊ณ , ๋ชจ๋ GPU๊ฐ ํฉ์ฐ ๊ฒฐ๊ณผ๋ฅผ ๋ฐ์ต๋๋ค. DDP์์ ๊ทธ๋๋์ธํธ ๋๊ธฐํ์ ํ์.
# All-Reduce: ๋ชจ๋ GPU์ ๊ทธ๋๋์ธํธ ํฉ์ฐ โ ๋ชจ๋ GPU์ ๋์ผํ ๊ฒฐ๊ณผ
GPU 0: [1, 2, 3] GPU 0: [10, 20, 30]
GPU 1: [2, 4, 6] โ GPU 1: [10, 20, 30] (sum)
GPU 2: [3, 6, 9] GPU 2: [10, 20, 30]
GPU 3: [4, 8, 12] GPU 3: [10, 20, 30]
All-Gather
๊ฐ GPU์ ์กฐ๊ฐ์ ๋ชจ์ ์ ์ฒด ํ ์๋ฅผ ๊ตฌ์ฑ, ๋ชจ๋ GPU๊ฐ ๋์ผํ ์ ์ฒด ํ ์๋ฅผ ๋ฐ์ต๋๋ค. FSDP์์ ํ๋ผ๋ฏธํฐ ์ฌ๊ตฌ์ฑ์ ์ฌ์ฉ.
# All-Gather: ๊ฐ GPU์ ์ค๋ ์์ง โ ์ ์ฒด ํ
์ ์ฌ๊ตฌ์ฑ
GPU 0: [A] GPU 0: [A, B, C, D]
GPU 1: [B] โ GPU 1: [A, B, C, D] (concatenate)
GPU 2: [C] GPU 2: [A, B, C, D]
GPU 3: [D] GPU 3: [A, B, C, D]
Reduce-Scatter
๋ชจ๋ GPU๊ฐ ๊ธฐ์ฌํ๊ณ , ๊ฒฐ๊ณผ๋ฅผ N๋ฑ๋ถํ์ฌ ๊ฐ GPU๊ฐ ๋ค๋ฅธ ์กฐ๊ฐ์ ๋ฐ์ต๋๋ค. FSDP Backward์์ ๊ทธ๋๋์ธํธ ๋ถ๋ฐฐ์ ์ฌ์ฉ.
# Reduce-Scatter: ํฉ์ฐ + ๋ถ๋ฐฐ
GPU 0: [1,2,3,4] GPU 0: [10] (chunk 0์ ํฉ)
GPU 1: [2,4,6,8] โ GPU 1: [20] (chunk 1์ ํฉ)
GPU 2: [3,6,9,12] GPU 2: [30] (chunk 2์ ํฉ)
GPU 3: [4,8,12,16] GPU 3: [40] (chunk 3์ ํฉ)
Broadcast
ํ๋์ GPU(root)๊ฐ ๋ฐ์ดํฐ๋ฅผ ๋ณด๋ด๊ณ , ๋ชจ๋ GPU๊ฐ ๋ฐ์ต๋๋ค. ์ด๊ธฐ ๊ฐ์ค์น ๋ถ๋ฐฐ์ ์ฌ์ฉ.
# Broadcast: Root GPU์ ๋ฐ์ดํฐ๋ฅผ ๋ชจ๋ GPU๋ก ๋ณต์ฌ
GPU 0: [W] (root) GPU 0: [W]
GPU 1: [?] โ GPU 1: [W] (copied from root)
GPU 2: [?] GPU 2: [W]
GPU 3: [?] GPU 3: [W]
All-to-All
๊ฐ GPU๊ฐ ๋ค๋ฅธ ๋ชจ๋ GPU์๊ฒ ์๋ก ๋ค๋ฅธ ๋ฐ์ดํฐ๋ฅผ ๋ณด๋ ๋๋ค. MoE์์ Expert Parallelism์ ์ฌ์ฉ.
# All-to-All: ๊ฐ GPU๊ฐ ๋ค๋ฅธ GPU๋ค์๊ฒ ๊ฐ๊ฐ ๋ค๋ฅธ ๋ฐ์ดํฐ ์ ์ก
GPU 0: [A0,A1,A2,A3] GPU 0: [A0,B0,C0,D0]
GPU 1: [B0,B1,B2,B3] โ GPU 1: [A1,B1,C1,D1]
GPU 2: [C0,C1,C2,C3] GPU 2: [A2,B2,C2,D2]
GPU 3: [D0,D1,D2,D3] GPU 3: [A3,B3,C3,D3]
11. Ring vs Tree Algorithm
Ring Algorithm
GPU๋ค์ ๋ ผ๋ฆฌ์ ์ธ ๋ง(Ring) ํํ๋ก ์ฐ๊ฒฐํ์ฌ ๋ฐ์ดํฐ๋ฅผ ์ํ์ํต๋๋ค.
# Ring All-Reduce ๋์ (4 GPU ์์)
GPU 0 โโ GPU 1
โ โ
GPU 3 โโ GPU 2
# ๋จ๊ณ 1: ๊ฐ GPU๊ฐ ์ด์์๊ฒ ์ฒญํฌ ์ ์ก
# ๋จ๊ณ 2: ๋ฐ์ ๋ฐ์ดํฐ์ ๋ก์ปฌ ๋ฐ์ดํฐ ํฉ์ฐ
# ๋จ๊ณ 3: N-1๋ฒ ๋ฐ๋ณตํ๋ฉด ๋ชจ๋ GPU๊ฐ ์ ์ฒด ํฉ์ฐ ๊ฒฐ๊ณผ ๋ณด์
# ๋ณต์ก๋
๋์ญํญ: O(๋ฐ์ดํฐ ํฌ๊ธฐ) # ๋ฐ์ดํฐ ํฌ๊ธฐ์ ๋น๋ก, GPU ์ ๋ฌด๊ด
์ง์ฐ์๊ฐ: O(N) # GPU ์์ ๋น๋ก (๋จ์ )
Tree Algorithm (Double Binary Tree)
GPU๋ค์ ํธ๋ฆฌ(Tree) ๊ตฌ์กฐ๋ก ์ฐ๊ฒฐํ์ฌ ๊ณ์ธต์ ์ผ๋ก ๋ฐ์ดํฐ๋ฅผ ์ง๊ณํฉ๋๋ค.
# Tree All-Reduce ๋์
Root (GPU 0)
/ \
GPU 1 GPU 2
/ \ / \
GPU 3 GPU 4 GPU 5 GPU 6
# Reduce Phase: ๋ฆฌํ โ ๋ฃจํธ (ํฉ์ฐ)
# Broadcast Phase: ๋ฃจํธ โ ๋ฆฌํ (๋ถ๋ฐฐ)
# ๋ณต์ก๋
๋์ญํญ: O(๋ฐ์ดํฐ ํฌ๊ธฐ) # Ring๊ณผ ๋์ผ
์ง์ฐ์๊ฐ: O(log N) # GPU ์์ ๋ก๊ทธ์ ๋น๋ก (์ฅ์ )
๋น๊ตํ
| ํน์ฑ | Ring Algorithm | Tree Algorithm |
|---|---|---|
| ๋์ญํญ ํจ์จ | ์ต์ | ์ต์ |
| ์ง์ฐ์๊ฐ | O(N) - ๋์ | O(log N) - ๋ฎ์ |
| ์๊ท๋ชจ ํด๋ฌ์คํฐ | ์ฐ์ | ๋ณดํต |
| ๋๊ท๋ชจ ํด๋ฌ์คํฐ | ์ง์ฐ ์ฆ๊ฐ | ์ฐ์ |
12. NCCL ์ด๊ธฐํ
TCPStore Rendezvous ๊ณผ์
๋ถ์ฐ ํ์ต์ ์์ํ๋ ค๋ฉด ์๋ฐฑ~์์ฒ ๊ฐ์ ํ๋ก์ธ์ค(GPU)๊ฐ ์๋ก์ ์กด์ฌ์ ์์น๋ฅผ ์์์ผ ํฉ๋๋ค.
# ์ ํต์ ์ธ NCCL ์ด๊ธฐํ ๊ณผ์
# 1. Master Node (Rank 0)๊ฐ TCPStore ์๋ฒ ์์
Rank 0: TCPStore ์๋ฒ ์คํ (IP:PORT)
# 2. ๋ชจ๋ Worker๊ฐ Master์ ์ฐ๊ฒฐํ์ฌ ์์ ์ ์ ๋ณด ๋ฑ๋ก
Rank 1 โ Rank 0: "๋ด ์ฃผ์๋ 192.168.1.2:29501"
Rank 2 โ Rank 0: "๋ด ์ฃผ์๋ 192.168.1.3:29501"
...
Rank N โ Rank 0: "๋ด ์ฃผ์๋ ..."
# 3. ๋ชจ๋ Worker๊ฐ ๋ฑ๋ก ์๋ฃ๋๋ฉด NCCL Unique ID ์์ฑ
Rank 0: NCCL Unique ID ์์ฑ ๋ฐ ๋ธ๋ก๋์บ์คํธ
# 4. ๊ฐ Rank๊ฐ Communicator ํ์ฑ
ncclCommInitRank(comm, nranks, uniqueId, rank)
init_process_group ์ฝ๋
import torch.distributed as dist
import os
# ํ๊ฒฝ๋ณ์์์ ๋ถ์ฐ ์ค์ ์ฝ๊ธฐ
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
master_addr = os.environ['MASTER_ADDR']
master_port = os.environ['MASTER_PORT']
# Process Group ์ด๊ธฐํ (TCPStore ๊ธฐ๋ฐ rendezvous)
dist.init_process_group(
backend='nccl', # GPU ํต์ ์ฉ
init_method=f'tcp://{master_addr}:{master_port}',
rank=rank,
world_size=world_size,
)
# ์ด์ collective operations ์ฌ์ฉ ๊ฐ๋ฅ
tensor = torch.ones(10).cuda()
dist.all_reduce(tensor) # ๋ชจ๋ GPU์ tensor ํฉ์ฐ
13. Topology Discovery
NCCL์ ์๋ ํ ํด๋ก์ง ํ์
NCCL์ ์์ ์ hwloc ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ์ฌ ํ๋์จ์ด ๊ตฌ์ฑ์ ์๋์ผ๋ก ํ์ํฉ๋๋ค.
์ฐ๊ฒฐ ๊ณ์ธต (P2P Level)
| ์ฐ๊ฒฐ ํ์ | ๋์ญํญ | ์ง์ฐ์๊ฐ | ์ฌ์ฉ ์์น |
|---|---|---|---|
| NVLink | 600-900 GB/s | ๋งค์ฐ ๋ฎ์ | ๋ ธ๋ ๋ด GPU ๊ฐ |
| NVSwitch | 7.2 TB/s (์ดํฉ) | ๋งค์ฐ ๋ฎ์ | ๋ ธ๋ ๋ด All-to-All |
| PCIe Gen5 | 64 GB/s | ๋ฎ์ | GPU-CPU, NVLink ์์ ๋ |
| InfiniBand HDR | 200 Gbps | ~1 us | ๋ ธ๋ ๊ฐ |
| InfiniBand NDR | 400 Gbps | ~1 us | ๋ ธ๋ ๊ฐ (์ต์ ) |
| EFA (AWS) | 3200 Gbps | ๋ฎ์ | AWS ๋ ธ๋ ๊ฐ |
NCCL_TOPO_FILE
๋ณต์กํ ํ ํด๋ก์ง์์๋ XML ํ์ผ๋ก ์ง์ ํ ํด๋ก์ง๋ฅผ ์ง์ ํ ์ ์์ต๋๋ค:
# ํ ํด๋ก์ง ํ์ผ ์ง์
export NCCL_TOPO_FILE=/path/to/topology.xml
# ํ ํด๋ก์ง ํ์ ๊ฒฐ๊ณผ ๋คํ
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.txt
14. NCCL ํ๊ฒฝ๋ณ์
์ค์ ํ๊ฒฝ๋ณ์ ํ ์ด๋ธ
| ํ๊ฒฝ๋ณ์ | ์ค๋ช | ๊ธฐ๋ณธ๊ฐ | ์ถ์ฒ๊ฐ |
|---|---|---|---|
NCCL_DEBUG |
๋๋ฒ๊ทธ ๋ก๊ทธ ๋ ๋ฒจ | WARN | INFO (๋ฌธ์ ํด๊ฒฐ ์ TRACE) |
NCCL_DEBUG_SUBSYS |
ํน์ ์๋ธ์์คํ ๋ง ๋ก๊น | ALL | INIT,COLL (์ด๊ธฐํ/์ง๋จํต์ ) |
NCCL_ALGO |
์๊ณ ๋ฆฌ์ฆ ๊ฐ์ ์ง์ | ์๋ | Ring, Tree, CollnetDirect |
NCCL_PROTO |
ํ๋กํ ์ฝ ์ง์ | ์๋ | Simple, LL, LL128 |
NCCL_BUFFSIZE |
ํต์ ๋ฒํผ ํฌ๊ธฐ | 4MB | 8388608 (8MB, ๋๊ท๋ชจ ์) |
NCCL_NTHREADS |
์ปค๋ ์ค๋ ๋ ์ | ์๋ | 512 (๋๊ท๋ชจ ์) |
NCCL_IB_TIMEOUT |
InfiniBand ํ์์์ | 18 | 22-23 (๋๊ท๋ชจ ํด๋ฌ์คํฐ) |
NCCL_IB_RETRY_CNT |
IB ์ฌ์๋ ํ์ | 7 | 13 (์์ ์ฑ ํฅ์) |
NCCL_IB_GID_INDEX |
IB GID ์ธ๋ฑ์ค | 0 | RoCE v2: 3 |
NCCL_SOCKET_IFNAME |
๋คํธ์ํฌ ์ธํฐํ์ด์ค | ์๋ | eth0, ens5 (AWS EFA) |
NCCL_P2P_LEVEL |
P2P ํต์ ์ ํ | 5 | NVL (NVLink๋ง ํ์ฉ) |
NCCL_SHM_DISABLE |
๊ณต์ ๋ฉ๋ชจ๋ฆฌ ๋นํ์ฑํ | 0 | 1 (๋๋ฒ๊น ์) |
AWS EFA ์ต์ ํ ํ๊ฒฝ๋ณ์
# AWS EFA (Elastic Fabric Adapter) ์ต์ ํ ์ค์
# ๊ธฐ๋ณธ EFA ์ค์
export FI_PROVIDER=efa
export FI_EFA_USE_DEVICE_RDMA=1
export FI_EFA_FORK_SAFE=1
# NCCL EFA ํ๋ฌ๊ทธ์ธ
export NCCL_NET=aws-ofi-nccl
export NCCL_DEBUG=INFO
# P5 ์ธ์คํด์ค (H100 x8) ์ต์ ํ
export NCCL_NVLS_ENABLE=1 # NVLink SHARP
export NCCL_IB_TIMEOUT=22
export NCCL_MIN_NCHANNELS=4
# ๋์ญํญ ์ต์ ํ
export NCCL_BUFFSIZE=8388608 # 8MB
export NCCL_P2P_NET_CHUNKSIZE=524288 # 512KB
NCCL ๋๋ฒ๊น ํ
# ๋ฌธ์ ๋ฐ์ ์ ์์ธ ๋ก๊น
export NCCL_DEBUG=TRACE
export NCCL_DEBUG_SUBSYS=INIT,COLL,P2P,NET
export NCCL_DEBUG_FILE=/tmp/nccl_debug_%h_%p.log
# ํ ํด๋ก์ง ํ์ธ
export NCCL_TOPO_DUMP_FILE=/tmp/nccl_topo.xml
# Hang ๊ฐ์ง (30์ด ํ์์์)
export NCCL_TIMEOUT=30
export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
# ๋๋ฒ๊ทธ ์ ๋ณด ์ถ๋ ฅ ์์
[Rank 0] NCCL INFO Bootstrap: Using eth0:192.168.1.10<6379>
[Rank 0] NCCL INFO Trees [0] -1/-1/-1->0->1 [1] -1/-1/-1->0->1
[Rank 0] NCCL INFO Channel 00 : 0 1 2 3
์์ฝ
- ZeRO: Stage 1(Optimizer) โ Stage 2(+Gradients) โ Stage 3(+Parameters) ์์ผ๋ก ๋ฉ๋ชจ๋ฆฌ ํจ์จ ๊ทน๋ํ
- FSDP: PyTorch์ ZeRO-3 ๋ค์ดํฐ๋ธ ๊ตฌํ, All-Gather/Reduce-Scatter๋ก ๋์
- FSDP2: DTensor ๊ธฐ๋ฐ,
fully_shard()API๋ก ๋ ์ ์ฐํ ์ค๋ฉ - NCCL: All-Reduce, All-Gather, Reduce-Scatter ๋ฑ ์ง๋จ ํต์ ๋ด๋น
- Ring vs Tree: ์๊ท๋ชจ๋ Ring, ๋๊ท๋ชจ๋ Tree๊ฐ ์ ๋ฆฌ, NCCL์ด ์๋ ์ ํ
- TCPStore: ๊ธฐ์กด NCCL ์ด๊ธฐํ์ ๋ณ๋ชฉ์ โ Checkpointless์์ Rootless๋ก ํด๊ฒฐ