Unverified 提交 54043a9f authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Streaming --save-txt bug fix (#1672)

* Streaming --save-txt bug fix * cleanup
上级 bc52ea2d
...@@ -81,12 +81,12 @@ def detect(save_img=False): ...@@ -81,12 +81,12 @@ def detect(save_img=False):
# Process detections # Process detections
for i, det in enumerate(pred): # detections per image for i, det in enumerate(pred): # detections per image
if webcam: # batch_size >= 1 if webcam: # batch_size >= 1
p, s, im0 = Path(path[i]), '%g: ' % i, im0s[i].copy() p, s, im0, frame = Path(path[i]), '%g: ' % i, im0s[i].copy(), dataset.count
else: else:
p, s, im0 = Path(path), '', im0s p, s, im0, frame = Path(path), '', im0s, getattr(dataset, 'frame', 0)
save_path = str(save_dir / p.name) save_path = str(save_dir / p.name)
txt_path = str(save_dir / 'labels' / p.stem) + ('_%g' % dataset.frame if dataset.mode == 'video' else '') txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}')
s += '%gx%g ' % img.shape[2:] # print string s += '%gx%g ' % img.shape[2:] # print string
gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
if len(det): if len(det):
...@@ -96,7 +96,7 @@ def detect(save_img=False): ...@@ -96,7 +96,7 @@ def detect(save_img=False):
# Print results # Print results
for c in det[:, -1].unique(): for c in det[:, -1].unique():
n = (det[:, -1] == c).sum() # detections per class n = (det[:, -1] == c).sum() # detections per class
s += '%g %ss, ' % (n, names[int(c)]) # add to string s += f'{n} {names[int(c)]}s, ' # add to string
# Write results # Write results
for *xyxy, conf, cls in reversed(det): for *xyxy, conf, cls in reversed(det):
...@@ -107,11 +107,11 @@ def detect(save_img=False): ...@@ -107,11 +107,11 @@ def detect(save_img=False):
f.write(('%g ' * len(line)).rstrip() % line + '\n') f.write(('%g ' * len(line)).rstrip() % line + '\n')
if save_img or view_img: # Add bbox to image if save_img or view_img: # Add bbox to image
label = '%s %.2f' % (names[int(cls)], conf) label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3) plot_one_box(xyxy, im0, label=label, color=colors[int(cls)], line_thickness=3)
# Print time (inference + NMS) # Print time (inference + NMS)
print('%sDone. (%.3fs)' % (s, t2 - t1)) print(f'{s}Done. ({t2 - t1:.3f}s)')
# Stream results # Stream results
if view_img: if view_img:
...@@ -121,9 +121,9 @@ def detect(save_img=False): ...@@ -121,9 +121,9 @@ def detect(save_img=False):
# Save results (image with detections) # Save results (image with detections)
if save_img: if save_img:
if dataset.mode == 'images': if dataset.mode == 'image':
cv2.imwrite(save_path, im0) cv2.imwrite(save_path, im0)
else: else: # 'video'
if vid_path != save_path: # new video if vid_path != save_path: # new video
vid_path = save_path vid_path = save_path
if isinstance(vid_writer, cv2.VideoWriter): if isinstance(vid_writer, cv2.VideoWriter):
...@@ -140,7 +140,7 @@ def detect(save_img=False): ...@@ -140,7 +140,7 @@ def detect(save_img=False):
s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else '' s = f"\n{len(list(save_dir.glob('labels/*.txt')))} labels saved to {save_dir / 'labels'}" if save_txt else ''
print(f"Results saved to {save_dir}{s}") print(f"Results saved to {save_dir}{s}")
print('Done. (%.3fs)' % (time.time() - t0)) print(f'Done. ({time.time() - t0:.3f}s)')
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -138,7 +138,7 @@ class LoadImages: # for inference ...@@ -138,7 +138,7 @@ class LoadImages: # for inference
self.files = images + videos self.files = images + videos
self.nf = ni + nv # number of files self.nf = ni + nv # number of files
self.video_flag = [False] * ni + [True] * nv self.video_flag = [False] * ni + [True] * nv
self.mode = 'images' self.mode = 'image'
if any(videos): if any(videos):
self.new_video(videos[0]) # new video self.new_video(videos[0]) # new video
else: else:
...@@ -256,7 +256,7 @@ class LoadWebcam: # for inference ...@@ -256,7 +256,7 @@ class LoadWebcam: # for inference
class LoadStreams: # multiple IP or RTSP cameras class LoadStreams: # multiple IP or RTSP cameras
def __init__(self, sources='streams.txt', img_size=640): def __init__(self, sources='streams.txt', img_size=640):
self.mode = 'images' self.mode = 'stream'
self.img_size = img_size self.img_size = img_size
if os.path.isfile(sources): if os.path.isfile(sources):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论