目录
1. 模型yaml文件
2. yolo.py
3. common.py
4. 改进YOLOv5具体步骤
YOLOv5是目前最主流的目标检测算法之一,我们可以在YOLOv5的基础之上进行改进和创新。本文针对YOLOv5的7.0版本,整理改进YOLOv5的模型加载流程,以便于后面改进yolo结构。
YOLOv5的模型结构是通过yaml文件组织的。
下面为yolov5s.yaml
# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80 # number of classes
depth_multiple: 0.33 # model depth multiple
width_multiple: 0.50 # layer channel multiple
anchors: - [10,13, 16,30, 33,23] # P3/8- [30,61, 62,45, 59,119] # P4/16- [116,90, 156,198, 373,326] # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2[-1, 1, Conv, [128, 3, 2]], # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]], # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]], # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]], # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]], # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]], # cat backbone P4[-1, 3, C3, [512, False]], # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]], # cat backbone P3[-1, 3, C3, [256, False]], # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]], # cat head P4[-1, 3, C3, [512, False]], # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]], # cat head P5[-1, 3, C3, [1024, False]], # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]], # Detect(P3, P4, P5)]
yaml文件保存的是yolo模型的结构,如果想对yolo的模型结构做出改变,yaml文件是必须要修改的。
先来看网络的结构部分,yolov5的网络结构是由一个backbone和一个head组成的,backbone表示主干网络,负责提取特征;head部分由neck和检测头组成,neck主要负责融合浅层网络的图形特征和深层网络的语义特征,检测头负责将特征图转化为预测结果。具体细节可以查看我的这篇博客:YOLOv5深度剖析
以第一个Conv层为例:
[-1, 1, Conv, [64, 6, 2, 2]], # 0-P1/2
[连接层, 模块重复的次数, 模块名, [模块参数1, 模块参数2, 模块参数3, 模块参数4]],
- 连接层:表示模块连接了哪一层,-1表示连接了上一层,例:0表示第一层,1表示第二层……,这个参数的作用主要体现在neck中,用于连接不同层的参数。
- 模块重复的次数:只有C3模块不是1,其他模块均为1,例:3表示将本层重复3次,提取更深层次的特征。
- 模块名:表示这一层使用哪个模块
- 模块参数:一个模块可以有很多参数,一般来说,模块参数1为输出通道数。这个可以自己定义,后面还会提到,例:[64, 6, 2, 2]表示[输出通道数为64,卷积核大小为6,步长为2,填充为2]
yaml文件中还有以下几个变量:
nc:表示检测的类别数,这里默认为coco数据集中的80个类别。
depth_multiple:用于控制网络深度的因子。
width_multiple:用于控制网络宽度的因子。
anchor:预定义的anchor大小。
yolov5中共有五种模型,从小到大依次为:yolov5n.yaml,yolov5s.yaml,yolov5m.yaml,yolov5l.yaml,yolov5x.yaml,越大的模型准确率越高,但其推理速度也会相应降低,如下图所示。
其实四种结构的网络模型是完全一致的,区分他们的大小依靠 depth_multiple和width_multiple两因子。
depth_multiple是用于控制网络深度的因子,是通过控制C3模块的重复次数来实现控制网络深度,C3重复次数越多,网络越深。
width_multiple是用于控制网络宽度的因子,是通过控制网络中的通道数来实现控制网络宽度,输出的通道数越多,宽度越宽。
也就是说,yaml文件中写的输出通道数和模块的重复次数,并不代表最总模型的输出通道数和模块的重复次数。yaml中的输出通道数以及模块的重复次数都要乘depth_multiple和width_multiple,才是真正的网络规模。
yaml文件真正被加载为模型,是在yolo.py中完成的。
def parse_model(d, ch): # model_dict, input_channels(3)# Parse a YOLOv5 model.yaml dictionaryLOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10} {'module':<40}{'arguments':<30}")anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act) # redefine default activation, i.e. Conv.default_act = nn.SiLU()LOGGER.info(f"{colorstr('activation:')} {act}") # printna = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors # number of anchorsno = na * (nc + 5) # number of outputs = anchors * (classes + 5)layers, save, c2 = [], [], ch[-1] # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']): # from, number, module, argsm = eval(m) if isinstance(m, str) else m # eval stringsfor j, a in enumerate(args):with contextlib.suppress(NameError):args[j] = eval(a) if isinstance(a, str) else a # eval stringsn = n_ = max(round(n * gd), 1) if n > 1 else n # depth gainif m in {Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, Conv, GhostConv,Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, SELayer, C3CBAM, ECA, CARAFE, RepVGGBlock, nn.ConvTranspose2d,RepC3, CBAM, ECA, TCN, C3TCN}:c1, c2 = ch[f], args[0]if c2 != no: # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3SE, C3SE, C3CBAM, RepC3, C3TCN}:args.insert(2, n) # number of repeatsn = 1elif m is nn.ConvTranspose2d:if len(args) >= 7:args[6] = make_divisible(args[6] * gw, 8)elif m is nn.BatchNorm2d:args = [ch[f]]elif m in {Concat, BiFPN_Add2, BiFPN_Add3}:c2 = sum(ch[x] for x in f)# TODO: channel, gw, gdelif m in {Detect, Segment}:args.append([ch[x] for x in f])if isinstance(args[1], int): # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)if m is Segment:args[3] = make_divisible(args[3] * gw, 8)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args) # modulet = str(m)[8:-2].replace('__main__.', '') # module typenp = sum(x.numel() for x in m_.parameters()) # number paramsm_.i, m_.f, m_.type, m_.np = i, f, t, np # attach index, 'from' index, type, number paramsLOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f} {t:<40}{str(args):<30}') # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1) # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)
yolo.py中有一个parse_model函数。这个函数主要负责将yaml文件转化为pyorch模型。
在这个函数中,程序会解析读取的yaml文件,将C3模块的重复次数与depth_multiple相乘,产生新的重复次数;将输出通道数与width_multiple相乘,产生新的输出通道数,并通过输入图片的通道数(一般为3),将模块参数的第一位前插入输入通道数。
common.py中定义的是所有层的具体网络结构
class Conv(nn.Module):# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)default_act = nn.SiLU() # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):return self.act(self.conv(x))
common.py中的网络结构很多,yolov5作者很贴心的为我们预定义了一些常用的模块,上面只拿出Conv模块为例。common.py以类的形式,定义了网络的结构。
后面我将持续更新YOLOv5的改进策略。