🏥DETR 源码笔记_CSDN(二)

2022-9-4|2022-9-4
NotionNext
NotionNext
type
status
date
slug
summary
tags
category
icon
password
Property
Sep 4, 2022 01:39 AM
URL
DETR参考CSDNblog的代码注释,主要transformer构建和后处理

CSDN

来源:

搭建 Transformer

看 build_transformer(args):
实质调用的 Transformer(),d_model:  transformer 输入通道数, nhead: 多头注意力头数, num_encoder_layer: encoder 层数,num_decoder_layer: decoder 层数, dim_feedforward:前馈网络层输入通道数。

Encoder

首先用 TransformerEncoderLayer()forward 建立 encoder 中的其中一层,因为每层都是相同的,后面直接复制就行。根据归一化的前后顺序不同有两种构建方式:
根据前面建立的encoder层来搭建encoder。 self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

Decoder

回到 Transformer(), 接着 decoder 就是和 encoder 类似的操作了
但需要注意的是,decoder 会多一个输入是 query embedding, query_pos 是可学习输出位置向量, 解码器中的这个参数全局共享,提供全局注意力
我理解它为我们预测的输出,一开始预测就初始化为 0, 需要在 transformer decoder 中不断 refine 它,具体的了解推荐一个源码解析目标检测的跨界之星 DETR(四)、Detection with Transformer - 简书

DETR 搭建

Transformer 搭建完后又回到最上面的 build(args) 函数,紧接着是搭建 DETR 模型,将 Backbone 和 transformer 搭在一起。

LOSS 计算和 GT 匈牙利匹配

DETR 整个模型搭建完成后,回到 build() 中,因为它的预测结果是无序的,是以集合的形式输出,需要准备模型的预测结果与 GT 的匹配函数,来判断 GT 是否被检测分类成功,以进行 loss 计算。匹配函数使用的是匈牙利匹配,一个二分图的最大匹配算法,可以尝试用一下进化版的 KM 匹配算法试一试。
这篇博客讲的挺有趣的
SetCriterion()计算各个 loss 后返回,最后只对分类损失部分的源码贴了注释,其他的也是类似操作,可以自己看一下。
分类 LOSS 计算

PostProcess 后处理

再回到 build() 中,处理完 Loss 计算函数,就是准备 detr 模型输出的后处理方法了。
看 PostProcess,masks 语义模型的暂不探究。

main(三) 构建数据集、训练验证操作

模型搭建讲完后,就回到我们的 main()函数,该轮到构建数据集、模型训练和验证了。大体操作和以前的模型操作大体相同,已经有很多人说了,就不多叙述了,但还是有中文注释供参考。
训练的主要部分就在train_one_epoch()中:
自此,主体部分就差不多完成了,撒花。。后面会写一个训练自己的数据集的 DETR 的 blog,欢迎评论探讨!
(带注释的代码可以在 https://gitee.com/fgy120/DETR 自取) > 本文由简悦 SimpRead 转码
DETR 源码笔记_CSDN(一)论文笔记
Loading...