Unverified 提交 496ec33a authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Update detect.py

Added some recent updates that were missing, and updated the filename with an if else.
上级 68f63616
...@@ -46,7 +46,7 @@ def detect(save_img=False): ...@@ -46,7 +46,7 @@ def detect(save_img=False):
dataset = LoadImages(source, img_size=imgsz) dataset = LoadImages(source, img_size=imgsz)
# Get names and colors # Get names and colors
names = model.names if hasattr(model, 'names') else model.modules.names names = model.module.names if hasattr(model, 'module') else model.names
colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))] colors = [[random.randint(0, 255) for _ in range(3)] for _ in range(len(names))]
# Run inference # Run inference
...@@ -80,6 +80,7 @@ def detect(save_img=False): ...@@ -80,6 +80,7 @@ def detect(save_img=False):
p, s, im0 = path, '', im0s p, s, im0 = path, '', im0s
save_path = str(Path(out) / Path(p).name) save_path = str(Path(out) / Path(p).name)
txt_path = save_path[:save_path.rfind('.')] + ('_%g' % dataset.frame if dataset.mode == 'video' else '')
s += '%gx%g ' % img.shape[2:] # print string s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] #  normalization gain whwh gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] #  normalization gain whwh
if det is not None and len(det): if det is not None and len(det):
...@@ -95,12 +96,8 @@ def detect(save_img=False): ...@@ -95,12 +96,8 @@ def detect(save_img=False):
for *xyxy, conf, cls in det: for *xyxy, conf, cls in det:
if save_txt: # Write to file if save_txt: # Write to file
xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh xywh = (xyxy2xywh(torch.tensor(xyxy).view(1, 4)) / gn).view(-1).tolist() # normalized xywh
if dataset.frame == 0: with open(txt_path + '.txt', 'a') as f:
with open(save_path[:save_path.rfind('.')] + '.txt', 'a') as f: f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
else:
with open(save_path[:save_path.rfind('.')] + '_' + str(dataset.frame) + '.txt', 'a') as f:
f.write(('%g ' * 5 + '\n') % (cls, *xywh)) # label format
if save_img or view_img: # Add bbox to image if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf) label = '%s %.2f' % (names[int(cls)], conf)
...@@ -160,3 +157,8 @@ if __name__ == '__main__': ...@@ -160,3 +157,8 @@ if __name__ == '__main__':
with torch.no_grad(): with torch.no_grad():
detect() detect()
# Update all models
# for opt.weights in ['yolov5s.pt', 'yolov5m.pt', 'yolov5l.pt', 'yolov5x.pt', 'yolov3-spp.pt']:
# detect()
# create_pretrained(opt.weights, opt.weights)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论