caffe 源码学习笔记(10) eltwise layer

背景

这个layer和reduce layer有一些相似,就干脆一起看了. 作用是输入至少两个blob,然后对每个blob中的元素所一些运算,最后得到一个blob.

caffe 支持的运算有"PROD","SUM","MAX"三种

顺便提一句,TensorRT支持的要多一些:

 1
 2enum class ElementWiseOperation : int
 3{
 4    kSUM = 0,  //!< Sum of the two elements.
 5    kPROD = 1, //!< Product of the two elements.
 6    kMAX = 2,  //!< Maximum of the two elements.
 7    kMIN = 3,  //!< Minimum of the two elements.
 8    kSUB = 4,  //!< Substract the second element from the first.
 9    kDIV = 5,  //!< Divide the first element by the second.
10    kPOW = 6   //!< The first element to the power of the second element.
11};
12

proto

 1
 2message EltwiseParameter {
 3  enum EltwiseOp {
 4    PROD = 0;
 5    SUM = 1;
 6    MAX = 2;
 7  }
 8  optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation
 9  repeated float coeff = 2; // blob-wise coefficient for SUM operation
10
11  // Whether to use an asymptotically slower (for >2 inputs) but stabler method
12  // of computing the gradient for the PROD operation. (No effect for SUM op.)
13  optional bool stable_prod_grad = 3 [default = true];
14}
15
16

proto里面的coeff是对于SUM操作,可以给每一个bottom blob一个加权系数, stable_prod_grad是backward用的,不用管.

c++ 实现

 1
 2代码比较容易看懂,加了一些注释. 有两个地方可以提一下. 一个是PROD和MAX的做法,都是先求前两个,再把得到的结果和后面的blob进行运算.(其实是很自然的操作...似乎也没什么可说的orz)
 3
 4另外一个是mask这个变量,是在MAX操作时用来标记在哪个bottom blob 取到了最大值,反向传播时要用.
 5
 6
 7
 8template <typename Dtype>
 9void EltwiseLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
10      const vector<Blob<Dtype>*>& top) {
11  for (int i = 1; i < bottom.size(); ++i) {
12    CHECK(bottom[i]->shape() == bottom[0]->shape());
13  }
14  //  check所有的bottom blob的shape都一样. 至少存在两个bottom blob
15  top[0]->ReshapeLike(*bottom[0]);
16  // If max operation, we will initialize the vector index part.
17  if (this->layer_param_.eltwise_param().operation() ==
18      EltwiseParameter_EltwiseOp_MAX && top.size() == 1) {
19    max_idx_.Reshape(bottom[0]->shape());
20  }
21}
22
23template <typename Dtype>
24void EltwiseLayer<Dtype>::Forward_cpu(
25    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
26  int* mask = NULL;
27  const Dtype* bottom_data_a = NULL;
28  const Dtype* bottom_data_b = NULL;
29  const int count = top[0]->count();
30  Dtype* top_data = top[0]->mutable_cpu_data();
31  switch (op_) {
32  case EltwiseParameter_EltwiseOp_PROD:
33    caffe_mul(count, bottom[0]->cpu_data(), bottom[1]->cpu_data(), top_data);
34    for (int i = 2; i < bottom.size(); ++i) {
35      caffe_mul(count, top_data, bottom[i]->cpu_data(), top_data);
36    }
37    //  先算前两个,然后把结果和后面的每一个blob(如果还有的话)做运算
38    break;
39  case EltwiseParameter_EltwiseOp_SUM:
40    caffe_set(count, Dtype(0), top_data);
41    // 初始化top data为0
42    // TODO(shelhamer) does BLAS optimize to sum for coeff = 1?
43    for (int i = 0; i < bottom.size(); ++i) {
44      caffe_axpy(count, coeffs_[i], bottom[i]->cpu_data(), top_data);
45    }
46    break;
47  //  mask干啥用的??? 
48  //  forward应该用不到,是backward求梯度需要知道在哪个位置得到了最大值
49  case EltwiseParameter_EltwiseOp_MAX:
50    // Initialize
51    mask = max_idx_.mutable_cpu_data();
52    caffe_set(count, -1, mask);
53    caffe_set(count, Dtype(-FLT_MAX), top_data);
54    // bottom 0 & 1
55    bottom_data_a = bottom[0]->cpu_data();
56    bottom_data_b = bottom[1]->cpu_data();
57    for (int idx = 0; idx < count; ++idx) {
58      if (bottom_data_a[idx] > bottom_data_b[idx]) {
59        top_data[idx] = bottom_data_a[idx];  // maxval
60        mask[idx] = 0;  // maxid
61      } else {
62        top_data[idx] = bottom_data_b[idx];  // maxval
63        mask[idx] = 1;  // maxid
64      }
65    }
66    // bottom 2++
67    for (int blob_idx = 2; blob_idx < bottom.size(); ++blob_idx) {
68      bottom_data_b = bottom[blob_idx]->cpu_data();
69      for (int idx = 0; idx < count; ++idx) {
70        if (bottom_data_b[idx] > top_data[idx]) {
71          top_data[idx] = bottom_data_b[idx];  // maxval
72          mask[idx] = blob_idx;  // maxid
73        }
74      }
75    }
76    break;
77  default:
78    LOG(FATAL) << "Unknown elementwise operation.";
79  }
80}
81

Posts in this Series