永发信息网

caffe里面的误差的反向传播怎么实现来的

答案:1  悬赏:70  手机版
解决时间 2021-04-05 10:46
caffe里面的误差的反向传播怎么实现来的
最佳答案
首先概括回答一下这个问题,分类的CNN是有监督的,就是在最后一层计算分类结果的loss,然后利用这个loss对整个网络进行更新,更新的关键就是计算梯度和偏置的导数dW和db,而Back Propagation主要就是为了解决前面的层的dW不容易计算的问题,具体是将loss通过一个残差delta一层一层往前传,因此无论是全连接层还是卷积层,全部是有监督的。
至于实现BP的理论和推导,cjwdeq同学已经讲的非常清楚了。既然答题组的大大们总说要发扬“左手代码,右手公式”的精神,我就结合caffe的源码讲讲具体反向传播是怎么实现的。先从简单的全连接层入手:
打开Inner_product_layer.cpp,里面的Backward_cpu函数实现了反向传播的过程。(如果使用的是GPU,则会调用Inner_product_layer.cu文件里的Backward_gpu函数,实现过程是类似的)
先通过LayerSetUp函数明确几个变量:
N_ = num_output;
K_ = bottom[0]->count(axis);
M_ = bottom[0]->count(0, axis);
N_表示输出的特征维数,即输出的神经元的个数
K_表示输入的样本的特征维数,即输入的神经元的个数
M_表示样本个数
因此全连接层的W维数就是N_×K_,b维数就是N_×1
weight_shape[0] = N_;
weight_shape[1] = K_;
vector bias_shape(1, N_);
this->blobs_[1].reset(new Blob(bias_shape));
下面一行一行看Backward_cpu函数的代码,整个更新过程大概可以分成三步:(顺便盗几个cjwdeq同学贴的公式,哈哈)
1.
caffe_cpu_gemm(CblasTrans, CblasNoTrans, N_, K_, M_, (Dtype)1.,
top_diff, bottom_data, (Dtype)0., this->blobs_[0]->mutable_cpu_diff());
这一句是为了计算dW,对应公式就是

1.jpg

其中的bottom_data对应的是a,即输入的神经元激活值,维数为K_×N_,top_diff对应的是delta,维数是M_×N_,而caffe_cpu_gemm函数是对blas中的函数进行封装,实现了一个N_×M_的矩阵与一个M_×K_的矩阵相乘(注意此处乘之前对top_diff进行了转置)。相乘得到的结果保存于blobs_[0]->mutable_cpu_diff(),对应dW。
2.
caffe_cpu_gemv(CblasTrans, M_, N_, (Dtype)1., top_diff,
bias_multiplier_.cpu_data(), (Dtype)0.,
this->blobs_[1]->mutable_cpu_diff());
这一句是为了计算db,对应公式为

2.jpg

caffe_cpu_gemv函数实现了一个M_×N_的矩阵与N_×1的向量进行乘积,其实主要实现的是对delta进行了一下转置,就得到了db的值,保存于blobs_[1]->mutable_cpu_diff()中。此处的与bias_multiplier_.cpu_data()相乘是实现对M_个样本求和,bias_multiplier_.cpu_data()是全1向量,从公式上看应该是取平均的,但是从loss传过来时已经取过平均了,此处直接求和即可。(感谢@孙琳钧和@辛淼同学的提醒)
3.
caffe_cpu_gemm(CblasNoTrans, CblasNoTrans, M_, K_, N_, (Dtype)1.,
top_diff, this->blobs_[0]->cpu_data(), (Dtype)0.,
bottom[0]->mutable_cpu_diff());
这一句是为了利用后面层传过来的delta_l+1计算本层的delta_l,对应公式为

3.jpg

主要Inner_product层里面并没有激活函数,因此没有乘f’,与f’的相乘写在ReLU层的Backward函数里了,因此这一句里只有W和delta_l+1相乘。blobs_[0]->cpu_data()对应W,维度是N_×K_,bottom[0]->mutable_cpu_diff()是本层的delta_l,维度是M_×K_。
写了这么多,Backward_cpu函数终于结束了。但是更新其实没结束,我最初看源码时就觉得奇怪,因为Backward_cpu函数里只计算了dW,db,delta,并没有对W和b进行更新呀?后来才发现,其实caffe里的反向传播过程只是计算每层的梯度的导,把所有层都计算完之后,在solver.cpp里面统一对整个网络进行了更新。具体是在step函数里先通过ComputeUpdateValue把learning rate、momentum、weight_decay什么的都算好,然后调用了Net.cpp的update函数逐层更新,对应公式就是:

4.jpg
我要举报
如以上问答信息为低俗、色情、不良、暴力、侵权、涉及违法等信息,可以点下面链接进行举报!
大家都在看
名侦探柯南在哪里看比较全,集数排列正确的
已知0<α<π/2,cosα-sinα=-根号5/5....
有关七月的语句,关于七月的唯美句子
用更什么更什么更什么造句
石门县常德石门县创伤急救中心我想知道这个在
如果让你来安慰比赛失利的运动员你会说些什么
有哪些直戳心灵的有关亲情的散文,句子
别对别人太好的句子,看不起我的人经典句子
九几年的时候,听过一盘磁带全是一男一女对唱
打歌舞台是什么意思,问一下此人是exo里的谁?
The suggestion has been made that the spor
七天上旺都什么人参加
社保卡可以在取款机查询吗?
攻城掠地一件兵器可打几孔
“恶魔”的英文是什么?
推荐资讯
爱国的古诗名句陆游,写白梅的古诗词
我想加盟一家200平方的包天下中式快餐,大概
健康之路在武汉协和医院网上挂号候诊时间过了
之前看过一件户外风衣很好看,logo上好像是一
哈弗h2拉高速可以吗
感恩节给情人的祝福语,用一句很深奥的话形容
几年前的奇幻动画片:一个女孩是什么的女儿,会
有哪些电影或者电视剧的名字是两个字的?
变形计爱的守护台词,有关守护的古诗词
在新房睡觉喘不上来气,每次回去都是晚上被憋
猫叫春一般在什么时候
阅读下文,回答问题。沙 之 聚张抗抗去敦煌不
正方形一边上任一点到这个正方形两条对角线的
阴历怎么看 ?