YOLOv5模型的读取
创始人
2024-06-02 01:47:44
0

目录

1. 模型yaml文件

2. yolo.py

3. common.py

4. 改进YOLOv5具体步骤


YOLOv5是目前最主流的目标检测算法之一,我们可以在YOLOv5的基础之上进行改进和创新。本文针对YOLOv5的7.0版本,整理改进YOLOv5的模型加载流程,以便于后面改进yolo结构。

1. 模型yaml文件

        YOLOv5的模型结构是通过yaml文件组织的。

        下面为yolov5s.yaml

# YOLOv5 🚀 by Ultralytics, GPL-3.0 license# Parameters
nc: 80  # number of classes
depth_multiple: 0.33  # model depth multiple
width_multiple: 0.50  # layer channel multiple
anchors: - [10,13, 16,30, 33,23]  # P3/8- [30,61, 62,45, 59,119]  # P4/16- [116,90, 156,198, 373,326]  # P5/32# YOLOv5 v6.0 backbone
backbone:# [from, number, module, args][[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2[-1, 1, Conv, [128, 3, 2]],  # 1-P2/4[-1, 3, C3, [128]],[-1, 1, Conv, [256, 3, 2]],  # 3-P3/8[-1, 6, C3, [256]],[-1, 1, Conv, [512, 3, 2]],  # 5-P4/16[-1, 9, C3, [512]],[-1, 1, Conv, [1024, 3, 2]],  # 7-P5/32[-1, 3, C3, [1024]],[-1, 1, SPPF, [1024, 5]],  # 9]# YOLOv5 v6.0 head
head:[[-1, 1, Conv, [512, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 6], 1, Concat, [1]],  # cat backbone P4[-1, 3, C3, [512, False]],  # 13[-1, 1, Conv, [256, 1, 1]],[-1, 1, nn.Upsample, [None, 2, 'nearest']],[[-1, 4], 1, Concat, [1]],  # cat backbone P3[-1, 3, C3, [256, False]],  # 17 (P3/8-small)[-1, 1, Conv, [256, 3, 2]],[[-1, 14], 1, Concat, [1]],  # cat head P4[-1, 3, C3, [512, False]],  # 20 (P4/16-medium)[-1, 1, Conv, [512, 3, 2]],[[-1, 10], 1, Concat, [1]],  # cat head P5[-1, 3, C3, [1024, False]],  # 23 (P5/32-large)[[17, 20, 23], 1, Detect, [nc, anchors]],  # Detect(P3, P4, P5)]

        yaml文件保存的是yolo模型的结构,如果想对yolo的模型结构做出改变,yaml文件是必须要修改的。

        先来看网络的结构部分,yolov5的网络结构是由一个backbone和一个head组成的,backbone表示主干网络,负责提取特征;head部分由neck和检测头组成,neck主要负责融合浅层网络的图形特征和深层网络的语义特征,检测头负责将特征图转化为预测结果。具体细节可以查看我的这篇博客:YOLOv5深度剖析

        以第一个Conv层为例:

[-1, 1, Conv, [64, 6, 2, 2]],  # 0-P1/2

 [连接层, 模块重复的次数, 模块名, [模块参数1, 模块参数2, 模块参数3, 模块参数4]],

  • 连接层:表示模块连接了哪一层,-1表示连接了上一层,例:0表示第一层,1表示第二层……,这个参数的作用主要体现在neck中,用于连接不同层的参数。
  • 模块重复的次数:只有C3模块不是1,其他模块均为1,例:3表示将本层重复3次,提取更深层次的特征。
  • 模块名:表示这一层使用哪个模块
  • 模块参数:一个模块可以有很多参数,一般来说,模块参数1为输出通道数。这个可以自己定义,后面还会提到,例:[64, 6, 2, 2]表示[输出通道数为64,卷积核大小为6,步长为2,填充为2]

yaml文件中还有以下几个变量:

        nc:表示检测的类别数,这里默认为coco数据集中的80个类别。

        depth_multiple:用于控制网络深度的因子。

        width_multiple:用于控制网络宽度的因子。

        anchor:预定义的anchor大小。

yolov5中共有五种模型,从小到大依次为:yolov5n.yaml,yolov5s.yaml,yolov5m.yaml,yolov5l.yaml,yolov5x.yaml,越大的模型准确率越高,但其推理速度也会相应降低,如下图所示。

         其实四种结构的网络模型是完全一致的,区分他们的大小依靠 depth_multiple和width_multiple两因子。

        depth_multiple是用于控制网络深度的因子,是通过控制C3模块的重复次数来实现控制网络深度,C3重复次数越多,网络越深。

        width_multiple是用于控制网络宽度的因子,是通过控制网络中的通道数来实现控制网络宽度,输出的通道数越多,宽度越宽。

        也就是说,yaml文件中写的输出通道数和模块的重复次数,并不代表最总模型的输出通道数和模块的重复次数。yaml中的输出通道数以及模块的重复次数都要乘depth_multiple和width_multiple,才是真正的网络规模。

2. yolo.py

         yaml文件真正被加载为模型,是在yolo.py中完成的。

def parse_model(d, ch):  # model_dict, input_channels(3)# Parse a YOLOv5 model.yaml dictionaryLOGGER.info(f"\n{'':>3}{'from':>18}{'n':>3}{'params':>10}  {'module':<40}{'arguments':<30}")anchors, nc, gd, gw, act = d['anchors'], d['nc'], d['depth_multiple'], d['width_multiple'], d.get('activation')if act:Conv.default_act = eval(act)  # redefine default activation, i.e. Conv.default_act = nn.SiLU()LOGGER.info(f"{colorstr('activation:')} {act}")  # printna = (len(anchors[0]) // 2) if isinstance(anchors, list) else anchors  # number of anchorsno = na * (nc + 5)  # number of outputs = anchors * (classes + 5)layers, save, c2 = [], [], ch[-1]  # layers, savelist, ch outfor i, (f, n, m, args) in enumerate(d['backbone'] + d['head']):  # from, number, module, argsm = eval(m) if isinstance(m, str) else m  # eval stringsfor j, a in enumerate(args):with contextlib.suppress(NameError):args[j] = eval(a) if isinstance(a, str) else a  # eval stringsn = n_ = max(round(n * gd), 1) if n > 1 else n  # depth gainif m in {Conv, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, Conv, GhostConv,Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, MixConv2d, Focus, CrossConv,BottleneckCSP, C3, C3TR, C3SPP, C3Ghost, SELayer, C3CBAM, ECA, CARAFE, RepVGGBlock, nn.ConvTranspose2d,RepC3, CBAM, ECA, TCN, C3TCN}:c1, c2 = ch[f], args[0]if c2 != no:  # if not outputc2 = make_divisible(c2 * gw, 8)args = [c1, c2, *args[1:]]if m in {BottleneckCSP, C3, C3TR, C3Ghost, C3SE, C3SE, C3CBAM, RepC3, C3TCN}:args.insert(2, n)  # number of repeatsn = 1elif m is nn.ConvTranspose2d:if len(args) >= 7:args[6] = make_divisible(args[6] * gw, 8)elif m is nn.BatchNorm2d:args = [ch[f]]elif m in {Concat, BiFPN_Add2, BiFPN_Add3}:c2 = sum(ch[x] for x in f)# TODO: channel, gw, gdelif m in {Detect, Segment}:args.append([ch[x] for x in f])if isinstance(args[1], int):  # number of anchorsargs[1] = [list(range(args[1] * 2))] * len(f)if m is Segment:args[3] = make_divisible(args[3] * gw, 8)elif m is Contract:c2 = ch[f] * args[0] ** 2elif m is Expand:c2 = ch[f] // args[0] ** 2else:c2 = ch[f]m_ = nn.Sequential(*(m(*args) for _ in range(n))) if n > 1 else m(*args)  # modulet = str(m)[8:-2].replace('__main__.', '')  # module typenp = sum(x.numel() for x in m_.parameters())  # number paramsm_.i, m_.f, m_.type, m_.np = i, f, t, np  # attach index, 'from' index, type, number paramsLOGGER.info(f'{i:>3}{str(f):>18}{n_:>3}{np:10.0f}  {t:<40}{str(args):<30}')  # printsave.extend(x % i for x in ([f] if isinstance(f, int) else f) if x != -1)  # append to savelistlayers.append(m_)if i == 0:ch = []ch.append(c2)return nn.Sequential(*layers), sorted(save)

        yolo.py中有一个parse_model函数。这个函数主要负责将yaml文件转化为pyorch模型。

        在这个函数中,程序会解析读取的yaml文件,将C3模块的重复次数与depth_multiple相乘,产生新的重复次数;将输出通道数与width_multiple相乘,产生新的输出通道数,并通过输入图片的通道数(一般为3),将模块参数的第一位前插入输入通道数。

3. common.py

        common.py中定义的是所有层的具体网络结构

class Conv(nn.Module):# Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)default_act = nn.SiLU()  # default activationdef __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):super().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()def forward(self, x):return self.act(self.bn(self.conv(x)))def forward_fuse(self, x):return self.act(self.conv(x))

        common.py中的网络结构很多,yolov5作者很贴心的为我们预定义了一些常用的模块,上面只拿出Conv模块为例。common.py以类的形式,定义了网络的结构。

4. 改进YOLOv5具体步骤

  1. 第一步要先把要实现的网络结构,以类的形式写在common.py中。
  2. 在yolo.py中加入新实现的网络结构
  3. 修改网络结构yaml文件

后面我将持续更新YOLOv5的改进策略。

相关内容

热门资讯

女娲传说之灵珠经典台词 女娲传说之灵珠经典台词15句  你们根本不懂爱,你们的爱太过自私,不择手段。——仙乐  你放心,我不...
青春演讲比赛主持词 青春演讲比赛主持词  主持:  男——李勇  女——张雨  女:我与祖国共奋进男:我为崛起献青春。 ...
年会赞助商致辞 年会赞助商致辞(精选5篇)  在日常的学习、工作、生活中,要用到致辞的情况还是蛮多的,致辞具有很强的...
培训会主持词 培训会主持词(精选10篇)  在日常中,大家总免不了要培训及会议吧,那么你知道主持词怎么写吗?下面是...
电影《三少爷的剑》经典台词 电影《三少爷的剑》经典台词精选  1. 冷风如刀,大地荒漠,苍天无情。  2. 这世上永远有两种人,...
《十六个夏天》的经典台词 《十六个夏天》的经典台词大全  1.就说我不是他的收藏品,乱说什么啊!  2.说话冲的人,心里都是软...
80岁生日宴会致辞 80岁生日宴会致辞(精选11篇)  在学习、工作或生活中,大家都不可避免地会接触到致辞吧,致辞具有有...
婚宴长辈证婚人致辞 婚宴长辈证婚人致辞  在平日的学习、工作和生活里,大家总少不了要接触或使用致辞吧,致辞受场合、事件的...
元旦舞会主持词 元旦舞会主持词(精选7篇)  主持词要尽量增加文化内涵、寓教于乐,不断提高观众的文化知识和素养。在人...
圣诞联欢会主持词 圣诞联欢会主持词  活动对象的不同,主持词的写作风格也会大不一样。时代不断在进步,主持成为很多活动不...
美好童年—庆“六一”大型活动... 美好童年—庆“六一”大型活动主持词  利用在中国拥有几千年文化的诗词能够有效提高主持词的感染力。随着...
毕业晚会主持词串词 毕业晚会主持词串词  毕业,是人生的一个转折点,愿你们能展开双翼,飞得更高、看得更远。下面是小编给大...
运动会致辞 运动会致辞(精选5篇)  无论在学习、工作或是生活中,大家或多或少都用到过致辞吧,致辞要求风格的雅、...
六一儿童节的主持稿 六一儿童节的主持稿(精选8篇)  随着社会一步步向前发展,我们都不可避免地要接触到主持稿,主持稿是主...
元旦文艺汇演主持稿 元旦文艺汇演主持稿范文(通用5篇)  在当下社会,很多情况下我们需要用到主持稿,主持稿起到承上启下的...
颁奖主持词 颁奖主持词三篇  主持人在一场活动中是十分重要的,一个好的主持人是一直带动着活动过程中的气氛,让大家...
婚宴答谢宴简短主持词 婚宴答谢宴简短主持词  主持词要根据活动对象的不同去设置不同的主持词。在人们积极参与各种活动的今天,...
汽车公司庆典主持词 汽车公司庆典主持词  利用在中国拥有几千年文化的诗词能够有效提高主持词的感染力。现今社会在不断向前发...
古筝音乐会主持词 古筝音乐会主持词6篇  主持词要把握好吸引观众、导入主题、创设情境等环节以吸引观众。在一步步向前发展...
小学元旦联欢会主持词开场白和... 小学元旦联欢会主持词开场白和结束词  根据活动对象的不同,需要设置不同的主持词。随着社会一步步向前发...