Unverified 提交 b25d5a75 authored 作者: Glenn Jocher's avatar Glenn Jocher 提交者: GitHub

Refactor dataset batch-size (#9551)

上级 489920ab
...@@ -91,10 +91,9 @@ def run( ...@@ -91,10 +91,9 @@ def run(
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow()
dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) dataset = LoadStreams(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = len(dataset) # batch_size
else: else:
dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride) dataset = LoadImages(source, img_size=imgsz, transforms=classify_transforms(imgsz[0]), vid_stride=vid_stride)
bs = 1 # batch_size bs = len(dataset) # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
......
...@@ -99,10 +99,9 @@ def run( ...@@ -99,10 +99,9 @@ def run(
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow()
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) # batch_size
else: else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = 1 # batch_size bs = len(dataset) # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
......
...@@ -101,10 +101,9 @@ def run( ...@@ -101,10 +101,9 @@ def run(
if webcam: if webcam:
view_img = check_imshow() view_img = check_imshow()
dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = len(dataset) # batch_size
else: else:
dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride) dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt, vid_stride=vid_stride)
bs = 1 # batch_size bs = len(dataset) # batch_size
vid_path, vid_writer = [None] * bs, [None] * bs vid_path, vid_writer = [None] * bs, [None] * bs
# Run inference # Run inference
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论