Distributed Training
Horovod, TensorFlow Distributed Training, PyTorch Distributed Training
Distributed Training Goals
๋ถ์ฐํ์ต ํ๋ ์์ํฌ๋ ๊ธฐ์กด์ ML๋ชจ๋ธ ํ์ต ์ฝ๋๋ฅผ ์กฐ๊ธ๋ง ๋ฐ๊พธ์ด๋ ๋ถ์ฐํ์ต์ ์ง์ํ๋ ๊ฒ์ ๋ชฉํ๋ก ํ๋ค. ๋๊ท๋ชจ ML๋ชจ๋ธ๋ ํ์ต ๊ฐ๋ฅํด์ผ ํ๋ฉฐ(Model Parallelism) ML๋ชจ๋ธ ํ์ต ์๊ฐ๋ ์ค์ผ ์ ์์ด์ผ ํ๋ค(Data Parallelism).
Model Parallelism

Data Parallelism

Distributed Training Mechanism
Synchronous training
๋๊ธฐ์ ํ์ต์ ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ์ชผ๊ฐ์ Worker๋ค์ด Gradient๋ฅผ ๊ณ์ฐํ๊ณ , ๊ณ์ฐํ Gradient๋ฅผ ์ง๊ณํ ํ ์๋ก์ด ML๋ชจ๋ธ์ ์์ฑํ๋ค. ์์ฑํ ML๋ชจ๋ธ์ Worker์ ์ ์กํ๊ณ Gradient๋ฅผ ๊ณ์ฐํ๋ ๊ณผ์ ์ ๋ฐ๋ณตํ์ฌ ํ์ต์ ์ํํ๋ค.
Asynchronous training
๋น๋์ ํ์ต์ Worker๋ค์ด ๊ฐ๊ฐ ์ ์ฒด ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํด Gradient๋ฅผ ๊ณ์ฐํ๊ณ , ๊ฐ๊ฐ ๋น๋๊ธฐ์ ์ผ๋ก Gradient๋ฅผ ์ ๋ฐ์ดํธํ๋ ๋ฐฉ์์ด๋ค. ์ผ๋ฐ์ ์ผ๋ก ๋๊ธฐ์ ํ์ต์ all-reduce ๋ฐฉ์์ผ๋ก ๊ตฌํํ๊ณ , ๋น๋๊ธฐ์ ํ์ต์ Parameter Server ๋ฐฉ์์ ์ฌ์ฉํ๋ค.

Parameter Server Training
Worker๊ฐ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ์ฌ Gradient๋ฅผ ๊ณ์ฐํ ํ, Parameter Server๋ก ์ ์กํ๊ณ ํ๊ท ์ ๊ณ์ฐํด ๋ค์ Worker๋ก ์ ์กํ๋ ๋ฐฉ์์ด๋ค. Worker๋ ๋์ญํญ ์ ์ฒด๋ฅผ ์ฌ์ฉํ์ง ์์ง๋ง, Parameter Server๋ ๋์ญํญ ๋ณ๋ชฉํ์์ด ๋ฐ์ํ ์ ์๋ค. Parameter Server๋ฅผ ์ฌ๋ฌ ๊ฐ ๋ ๊ฒฝ์ฐ, ๋คํธ์ํฌ ์ํธ ์ฐ๊ฒฐ์ด ๋ณต์กํด ์ง ์๋ ์๋ค.

All Reduce-based Distributed Training
Worker๊ฐ ์๋ก Gradient๋ฅผ ์ฃผ๊ณ ๋ฐ์ผ๋ฉด์ Reducingํ๋ ๋ฐฉ์์ผ๋ก ๋์ํ๋ค. Local ReduceScatter, Remote AllReduce, Local Gather ์์ผ๋ก Reducing ์ ์งํํ๋ค.


TensorFlow Distributed Training
TensorFlow ํด๋ฌ์คํฐ ๋ด์์ ๊ฐ์ง ์ ์๋ ์ญํ ์ Chief, PS, Worker, Evaluator ์ค ํ๋์ด๋ฉฐ, PS ์ญํ ์ Parameter Server Training์์๋ง ์ฌ์ฉํ๋ค. Chief๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ์ ์์ ์ ์ํํ๋ค. PS๋ ๋ชจ๋ธ ํ๋ผ๋ฏธํฐ ์๋ฒ ์ญํ ์ ์ํํ๋ค. Worker๋ Gradient ๊ตฌํ๋ ์ญํ ์ ํ๋ฉฐ, Chief ์ค์ ์ ํ์ง ์์๋ค๋ฉด 0๋ฒ Worker๊ฐ Chief๊ฐ ๋๋ค. Evaluator๋ ํ๊ฐ ์งํ ๊ณ์ฐํ๋ ์ญํ ์ ํ๋ค.

TensorFlow ๋ถ์ฐํ๊ฒฝ ํด๋ฌ์คํฐ ๊ตฌ์ฑ TensorFlow ๋ tf.distribute.Strategy
ํจํค์ง๋ฅผ ์ฌ์ฉํ์ฌ ๋ถ์ฐํ์ต์ ์ํํ ์ ์๋ค. ๋ถ์ฐํ์ต ์ ์ํํ ํด๋ฌ์คํฐ ํ๊ฒฝ ๊ตฌ์ฑ์ tf.train.ClusterSpec
์ ์ง์ ์ค์ ํ๊ฑฐ๋, TF_CONFIG
ํ๊ฒฝ ๋ณ์๋ฅผ์ ์ค์ ํ ์ ์๋ค.
tf.train.ClusterSpec
tf.train.ClusterSpec
cluster = tf.train.ClusterSpec({"worker": ["worker0.example.com:2222",
"worker1.example.com:2222",
"worker2.example.com:2222"],
"ps": ["ps0.example.com:2222",
"ps1.example.com:2222"]})
TF_CONFIG
Environment variable
TF_CONFIG
Environment variableTF_CONFIG
ํ๊ฒฝ๋ณ์๋ JSON ํฌ๋งท์ผ๋ก , Cluster๋ฅผ ๊ตฌ์ฑํ Host์ ์ญํ ์ ์ง์ ํ ์ ์๋ค.
tf.distribute.Strategy
๋ถ์ฐ ํจํค์ง
tf.distribute.Strategy
๋ถ์ฐ ํจํค์ง tf.distribute.Strategy
๋ ๊ธฐ์กด ML๋ชจ๋ธ ํ์ต ์ฝ๋๋ฅผ ์กฐ๊ธ๋ง ์์ ํด๋ ๋ถ์ฐ ํ์ต์ด ๊ฐ๋ฅํ๋ฉฐ, Multi-GPU ์ธ์ง, Multi-Node ์ธ์ง์ ๋ฐ๋ผ ๋ถ์ฐํ์ต ๋ฐฉ๋ฒ์ด ๊ตฌ๋ถ๋๋ค.
MirroredStrategy
์ฅ๋น ํ๋์์ ๋ค์คGPU(Multi-GPU)๋ฅผ ์ด์ฉํ ๋๊ธฐ์ ๋ถ์ฐํ์ต ๋ฐฉ๋ฒ์ด๋ค.
MultiWorkerMirroredStrategy
Multi-Worker ๋ฅผ ์ด์ฉํ ๋๊ธฐ์ ๋ถ์ฐํ์ต์ผ๋ก ๊ฐ Worker๋ Multi-GPU๋ฅผ ์ฌ์ฉํ ์ ์๋ค. Multi-Worker๋ค ์ฌ์ด์๋ All Reduce๋ฐฉ์์ ์ฌ์ฉํ๋ค.
ParameterServerStrategy
Parameter Server ๋ฐฉ์์ ๋น๋๊ธฐ์ ๋ถ์ฐํ์ต์ ์ ๊ณตํ๋ค.
PyTorch Distributed Training
PyTorch๋ torch.distributed
ํจํค์ง์์ ๋ถ์ฐํ์ต์ ์ ๊ณตํ๋ฉฐ, Parameter Server ๋ถ์ฐ ๋ฐฉ์์ผ๋ก DataParallel
๊ณผ All Reduce ๋ถ์ฐํ์ต ๋ฐฉ์์ผ๋ก DistributedDataParallel
์ ์ง์ํ๋ค.
DataParallel (DP)
DP๋ Master-Worker ๊ตฌ์กฐ๋ก Master๋ Cluster Coordinator์ญํ ์ ์ํํ๊ณ Worker๋ ๋ฐ์ดํฐ๋ฅผ ํ์ตํ๋ค.
Master๋ ๋ชจ๋ธ Weight๋ฅผ ๋ณต์ ํ๊ณ Worker์ Broadcastํ๋ค. Master๋ ํ์ต๋ฐ์ดํฐ๋ฅผ ์ชผ๊ฐ์ ๊ฐ๊ฐ Worker์ ์ ์กํ๋ค. Worker์์ Gradient ์ ๊ณ์ฐํ ํ, ๊ฐ GPU์์ Local Gradient๋ฅผ ์์งํ๋ค. Master๋ ์์งํ Gradient๋ฅผ ์ง๊ณํ๊ณ ๋ชจ๋ธ ์ ๋ฐ์ดํธ๋ฅผ ์ํํ๋ค.

DistributedDataParallel (DDP)
DDP๋ All Reduce ๋ฐฉ์์ผ๋ก Worker๋ง์ผ๋ก ๋ถ์ฐํ์ต์ ์ํํ๋ค.

PyTorch DDP Example
# ํ๋ก์ธ์ค๊ฐ ์๋ก ํต์ ํ ์ ์๋๋ก ์ด๊ธฐํ๋ฅผ ์ํํ๋ค.
torch.distributed.init_process_group(backend='mpi')
# ์ด๊ธฐํ๊ฐ ์๋ฃ๋๋ฉด, ๊ฐ ํ๋ก์ธ์ค์ GPU ์ฅ์น๋ฅผ ๋งคํํ๋ค.
local_rank = int(os.environ['LOCAL_RANK'])
device = torch.device("cuda:{}".format(local_rank))
torch.cuda.set_device(local_rank)
# ๋ฐ์ดํฐ์
์ ๋ถ์ฐํ๊ธฐ ์ํด DistributedSampler๋ฅผ ์์ฑํ๋ค.
torch.utils.data.distributed.DistributedSampler(trainset)
# ํ์ต ํน์ ํ๊ฐ์์ ์ฌ์ฉํ ๋ฐ์ดํฐ๋ฅผ Dataloaders์์ ๊ฐ์ ธ์จ๋ค.
train_loader = torch.utils.data.DataLoader(trainset,
batch_size=batch_size,
shuffle=(train_sample is None),
num_workers=workers,
pin_memory=False,
sampler=train_sampler)
# ๋ถ์ฐํ์ต ๋ฐฉ์์ ์ ์ํ๊ณ ๋ชจ๋ธ์ GPU ์ฅ์น์ ๋งคํํ๋ค.
model = Net().to(device)
Distributor = nn.parallel.DistributedDataParallel
model = Distributor(model)
# GPU ์ฅ์น์ ๋ฐ์ดํฐ๋ฅผ ๋งคํํ๋ค.
for data, target in train_loader:
data, target = data.to(device), target.to(device)
PyTorch
Environment variable
PyTorch
Environment variableWORLD_SIZE
๋ ํด๋ฌ์คํฐ ์ด ๋
ธ๋ ์์ด๊ณ ,
RANK๋ ๊ฐ ๋
ธ๋์ ๊ณ ์ ์๋ณ์์ด๋ค. RANK 0 ~WORLD_SIZE โ 1๊น์ง ์ธ๋ฑ์ค๋ฅผ ์ฌ์ฉํ๋ฉฐ, ๊ฐ ๋
ธ๋์ ์ธ๋ฑ์ค๋ฅผ ๋ถ์ฌํ๋ค. RANK๊ฐ 0์ด๋ฉด Master์ด๋ค. MASTER_ADDR
, MASTER_PORT
๋ Master ์ ๋ณด๋ก ๋ชจ๋ ๋
ธ๋์ ์ค์ ํ๋ค.
Horovod
Horovod๋ Tensorflow, Keras, PyTorch, MXNet ์์ MPI๊ธฐ๋ฐ์ผ๋ก Multi-GPU๋ฅผ ํ์ฉํ์ฌ Distributed Training์ ์ง์ํ๋ ํ๋ ์์ํฌ์ด๊ณ , Parameter Server ๋ถ์ฐํ์ต ๋ฐฉ์์ ๋คํธ์ํฌ ๋์ญํญ ๋ณ๋ชฉ ํ์์ ๊ฐ์ ํ๊ธฐ ์ํด ๊ณ ์๋ All Reduce ๋ถ์ฐํ์ต ๋ฐฉ์์ ์ฌ์ฉํ๋ค.
Horovod๋ฅผ ํ์ฉํ๋ฉด, ๊ธฐ์กด ํ์ต์ฝ๋์ ์ ์ ์ฝ๋๋ฅผ ์ถ๊ฐํ์ฌ ์ ์ฝ๊ฒ Distributed Training์ ๊ตฌํํ ์ ์๋ค.
Horovod์ ์ฌ์ฉํ๋ ์ฉ์ด
Horovod๋ฅผ ์ดํดํ๊ธฐ ์ํด์๋ ๋ค์๊ณผ ๊ฐ์ MPI์ ๊ด๋ จ๋ ๋ช๊ฐ์ง ์ฉ์ด๋ฅผ ๋จผ์ ์์์ผ ํ๋ค. MPI (Message Passing Interface)๋ ๋ณ๋ ฌํ ๊ธฐ์ ๋ก ์ฌ๋ฌ๋์ CPU/GPU๋ฅผ ๋ณ๋ ฌํํ ๋ ์ฌ์ฉํ๋ค.
size ์ ์ฒด ํ๋ก์ธ์ค ๊ฐ์
slots ํ๋ก์ธ์ค ๊ฐ์ (processing unit์ผ๋ก ๋ณดํต Worker๋น Process ๊ฐ์๋ฅผ ์ ์ํจ)
rank ํด๋ฌ์คํฐ์์ 0 ~ size-1
์ฌ์ด์ ๊ณ ์ ํ ํ๋ก์ธ์ค ID
local_rank ํ๋์ ํธ์คํธ์์ ๊ณ ์ ํ ํ๋ก์ธ์ค ID
๋ค์์ ๊ฐ๊ฐ GPU์ฅ์น๊ฐ 2๊ฐ์ฉ ์ฅ์ฐฉ๋ ํธ์คํธ 2๋๋ฅผ ๊ฐ์ง ์์คํ ํ๊ฒฝ์์ GPU ์ฅ์น๋ฅผ ์ด๋ป๊ฒ ๊ตฌ๋ณํ๋ ๋ณด์ฌ์ค๋ค.

Horovod ๊ธฐ๋ฐ ๋ถ์ฐํ์ต ์์
hvd.init()
์คํํ๋ค.
hvd.local_rank()
๋ณ๋ก ์ฌ์ฉํ GPU๋ฅผ ์ง์ ํ๋ค.
Learning rate๋ฅผ Worker ๊ฐ์์ ๋ฐ๋ผ ์กฐ์ ํ๋ค.
๊ธฐ์กด Optimizer๋ฅผ Horovod Optimzer ๋ก ํ์ฅํ๋ค.
rank 0 ์ด๊ธฐ ์ํ๋ฅผ ๋ค๋ฅธ rank์๊ณผ ๋๊ธฐํํ๊ธฐ ์ hook๋ฅผ ์ค์ ํ๋ค.
rank 0 ๋ง ์ฒดํฌํฌ์ธํธ ํ๋๋ก ์ค์ ํ๋ค.
import tensorflow as tf
import horovod.tensorflow as hvd
# Initialize Horovod
# ํธ๋ก๋ณด๋๋ฅผ ์ด๊ธฐํ ํ๋ค.
hvd.init()
# ์ฌ์ฉํ GPU๋ฅผ ๋งคํํ๋ค.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = str(hvd.local_rank())
# ๋ชจ๋ธ ์ฝ๋๋ฅผ ์์ฑํ๋ค.
loss = ...
opt = tf.train.AdamOptimizer(lr=0.01 * hvd.size())
# ๊ธฐ์กด ๋ชจ๋ธ ์ตํฐ๋ง์ด์ ๋ฅผ ํธ๋ก๋ณด๋ ๋ถ์ฐ ์ตํฐ๋ง์ด์ ๋ก ํ์ฅํ๋ค.
opt = hvd.DistributedOptimizer(opt)
# rank 0 ์ด๊ธฐ ์ํ๋ฅผ ๋ค๋ฅธ rank์ ๋๊ธฐํํ๊ธฐ ์ํด hook๋ฅผ ์ค์ ํ๋ค.
hooks = [hvd.BroadcastGlobalVariablesHook(0)]
# rank 0 ๋ฅผ ์ฒดํฌํฌ์ธํธ ํ๋ค.
ckpt_dir = "/tmp/train_logs" if hvd.rank() == 0 else None
Last updated
Was this helpful?