提交 a21bd068 authored 作者: Glenn Jocher's avatar Glenn Jocher

Update train.py forward simplification

上级 455f7b8f
...@@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None): ...@@ -265,18 +265,12 @@ def train(hyp, opt, device, tb_writer=None):
ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple) ns = [math.ceil(x * sf / gs) * gs for x in imgs.shape[2:]] # new shape (stretched to gs-multiple)
imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False) imgs = F.interpolate(imgs, size=ns, mode='bilinear', align_corners=False)
# Autocast # Forward
with amp.autocast(enabled=cuda): with amp.autocast(enabled=cuda):
# Forward pred = model(imgs) # forward
pred = model(imgs) loss, loss_items = compute_loss(pred, targets.to(device), model) # loss scaled by batch_size
# Loss
loss, loss_items = compute_loss(pred, targets.to(device), model) # scaled by batch_size
if rank != -1: if rank != -1:
loss *= opt.world_size # gradient averaged between devices in DDP mode loss *= opt.world_size # gradient averaged between devices in DDP mode
# if not torch.isfinite(loss):
# logger.info('WARNING: non-finite loss, ending training ', loss_items)
# return results
# Backward # Backward
scaler.scale(loss).backward() scaler.scale(loss).backward()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论