Unverified 提交 34da872a authored 作者: Yoni Chechik's avatar Yoni Chechik 提交者: GitHub

fix `tf` conversion in new v6 models (#5153)

* fix `tf` conversion in new v6 (#5147) * sort imports Co-authored-by: 's avatarGlenn Jocher <glenn.jocher@ultralytics.com>
上级 956be8e6
...@@ -28,7 +28,7 @@ import torch ...@@ -28,7 +28,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from tensorflow import keras from tensorflow import keras
from models.common import Conv, Bottleneck, SPP, DWConv, Focus, BottleneckCSP, Concat, autopad, C3 from models.common import Bottleneck, BottleneckCSP, Concat, Conv, C3, DWConv, Focus, SPP, SPPF, autopad
from models.experimental import CrossConv, MixConv2d, attempt_load from models.experimental import CrossConv, MixConv2d, attempt_load
from models.yolo import Detect from models.yolo import Detect
from utils.general import make_divisible, print_args, set_logging from utils.general import make_divisible, print_args, set_logging
...@@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer): ...@@ -183,6 +183,22 @@ class TFSPP(keras.layers.Layer):
return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3)) return self.cv2(tf.concat([x] + [m(x) for m in self.m], 3))
class TFSPPF(keras.layers.Layer):
# Spatial pyramid pooling-Fast layer
def __init__(self, c1, c2, k=5, w=None):
super(TFSPPF, self).__init__()
c_ = c1 // 2 # hidden channels
self.cv1 = TFConv(c1, c_, 1, 1, w=w.cv1)
self.cv2 = TFConv(c_ * 4, c2, 1, 1, w=w.cv2)
self.m = keras.layers.MaxPool2D(pool_size=k, strides=1, padding='SAME')
def call(self, inputs):
x = self.cv1(inputs)
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(tf.concat([x, y1, y2, self.m(y2)], 3))
class TFDetect(keras.layers.Layer): class TFDetect(keras.layers.Layer):
def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer def __init__(self, nc=80, anchors=(), ch=(), imgsz=(640, 640), w=None): # detection layer
super(TFDetect, self).__init__() super(TFDetect, self).__init__()
...@@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3) ...@@ -272,7 +288,7 @@ def parse_model(d, ch, model, imgsz): # model_dict, input_channels(3)
pass pass
n = max(round(n * gd), 1) if n > 1 else n # depth gain n = max(round(n * gd), 1) if n > 1 else n # depth gain
if m in [nn.Conv2d, Conv, Bottleneck, SPP, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]: if m in [nn.Conv2d, Conv, Bottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv, BottleneckCSP, C3]:
c1, c2 = ch[f], args[0] c1, c2 = ch[f], args[0]
c2 = make_divisible(c2 * gw, 8) if c2 != no else c2 c2 = make_divisible(c2 * gw, 8) if c2 != no else c2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论