Unverified 提交 4e841b9b authored 作者: imyhxy's avatar imyhxy 提交者: GitHub

Reuse `de_parallel()` rather than `is_parallel()` (#6354)

上级 9708cf56
...@@ -7,7 +7,7 @@ import torch ...@@ -7,7 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from utils.metrics import bbox_iou from utils.metrics import bbox_iou
from utils.torch_utils import is_parallel from utils.torch_utils import de_parallel
def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
...@@ -107,7 +107,7 @@ class ComputeLoss: ...@@ -107,7 +107,7 @@ class ComputeLoss:
if g > 0: if g > 0:
BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
det = model.module.model[-1] if is_parallel(model) else model.model[-1] # Detect() module det = de_parallel(model).model[-1] # Detect() module
self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 self.balance = {3: [4.0, 1.0, 0.4]}.get(det.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7
self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index self.ssi = list(det.stride).index(16) if autobalance else 0 # stride 16 index
self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
......
...@@ -295,7 +295,7 @@ class ModelEMA: ...@@ -295,7 +295,7 @@ class ModelEMA:
def __init__(self, model, decay=0.9999, updates=0): def __init__(self, model, decay=0.9999, updates=0):
# Create EMA # Create EMA
self.ema = deepcopy(model.module if is_parallel(model) else model).eval() # FP32 EMA self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu': # if next(model.parameters()).device.type != 'cpu':
# self.ema.half() # FP16 EMA # self.ema.half() # FP16 EMA
self.updates = updates # number of EMA updates self.updates = updates # number of EMA updates
...@@ -309,7 +309,7 @@ class ModelEMA: ...@@ -309,7 +309,7 @@ class ModelEMA:
self.updates += 1 self.updates += 1
d = self.decay(self.updates) d = self.decay(self.updates)
msd = model.module.state_dict() if is_parallel(model) else model.state_dict() # model state_dict msd = de_parallel(model).state_dict() # model state_dict
for k, v in self.ema.state_dict().items(): for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point: if v.dtype.is_floating_point:
v *= d v *= d
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论