nnfunctions_fp.h 54 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997
  1. /* ----------------------------------------------------------------------
  2. * Project: Tiny Training Engine, MCUNetV3
  3. * Title: nnfunctions_fp.h
  4. *
  5. * Reference papers:
  6. * - MCUNet: Tiny Deep Learning on IoT Device, NeurIPS 2020
  7. * - MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning, NeurIPS 2021
  8. * - MCUNetV3: On-Device Training Under 256KB Memory, NeurIPS 2022
  9. * Contact authors:
  10. * - Wei-Chen Wang, wweichen@mit.edu
  11. * - Wei-Ming Chen, wmchen@mit.edu
  12. * - Ji Lin, jilin@mit.edu
  13. * - Ligeng Zhu, ligeng@mit.edu
  14. * - Song Han, songhan@mit.edu
  15. * - Chuang Gan, ganchuang@csail.mit.edu
  16. *
  17. * Target ISA: ARMv7E-M
  18. * -------------------------------------------------------------------- */
  19. #include "nnfunctions.h"
  20. /* START: MAC Functions for Pointwise Conv */
  21. static inline void mac_1row_10col_fp_IOHW_forint8w(float* sum, const float* input_0,
  22. const float filter_0, const float filter_1, const float filter_2, const float filter_3, const float filter_4,
  23. const float filter_5, const float filter_6, const float filter_7, const float filter_8, const float filter_9) {
  24. *sum += *input_0++ * filter_0;
  25. *sum += *input_0++ * filter_1;
  26. *sum += *input_0++ * filter_2;
  27. *sum += *input_0++ * filter_3;
  28. *sum += *input_0++ * filter_4;
  29. *sum += *input_0++ * filter_5;
  30. *sum += *input_0++ * filter_6;
  31. *sum += *input_0++ * filter_7;
  32. *sum += *input_0++ * filter_8;
  33. *sum += *input_0++ * filter_9;
  34. }
  35. static inline void mac_4row_4col_fp_IOHW_forint8w(float* sum, const float* input_0, const float* input_1, const float* input_2, const float* input_3,
  36. const float filter_0, const float filter_1, const float filter_2, const float filter_3) {
  37. *sum += *input_0++ * filter_0;
  38. *sum += *input_0++ * filter_1;
  39. *sum += *input_0++ * filter_2;
  40. *sum++ += *input_0++ * filter_3;
  41. *sum += *input_1++ * filter_0;
  42. *sum += *input_1++ * filter_1;
  43. *sum += *input_1++ * filter_2;
  44. *sum++ += *input_1++ * filter_3;
  45. *sum += *input_2++ * filter_0;
  46. *sum += *input_2++ * filter_1;
  47. *sum += *input_2++ * filter_2;
  48. *sum++ += *input_2++ * filter_3;
  49. *sum += *input_3++ * filter_0;
  50. *sum += *input_3++ * filter_1;
  51. *sum += *input_3++ * filter_2;
  52. *sum++ += *input_3++ * filter_3;
  53. }
  54. static inline void mac_1row_4col_fp_IOHW_forint8w(float* sum, const float* input_0,
  55. const float filter_0, const float filter_1, const float filter_2, const float filter_3) {
  56. *sum += *input_0++ * filter_0;
  57. *sum += *input_0++ * filter_1;
  58. *sum += *input_0++ * filter_2;
  59. *sum += *input_0++ * filter_3;
  60. }
  61. /* END: MAC Functions for Pointwise Conv */
  62. /* START: MAC Functions for Group Conv */
  63. static inline void group_mac_kernel4_4row_fp_uniweight_reuse_output_input(float* sum_0, float* sum_1, float* sum_2, float* sum_3,
  64. const float* input_0, const float* input_1, const float* input_2, const float* input_3,
  65. const float* filter) {
  66. float tmp;
  67. tmp = 0;
  68. tmp += input_0[0] * filter[0];
  69. tmp += input_0[1] * filter[1];
  70. tmp += input_0[2] * filter[2];
  71. tmp += input_0[3] * filter[3];
  72. tmp += input_0[4] * filter[4];
  73. tmp += input_0[5] * filter[5];
  74. tmp += input_0[6] * filter[6];
  75. tmp += input_0[7] * filter[7];
  76. tmp += input_0[8] * filter[8];
  77. tmp += input_0[9] * filter[9];
  78. tmp += input_0[10] * filter[10];
  79. tmp += input_0[11] * filter[11];
  80. tmp += input_0[12] * filter[12];
  81. tmp += input_0[13] * filter[13];
  82. tmp += input_0[14] * filter[14];
  83. tmp += input_0[15] * filter[15];
  84. *sum_0 += tmp;
  85. tmp = 0;
  86. tmp += input_1[0] * filter[0];
  87. tmp += input_1[1] * filter[1];
  88. tmp += input_1[2] * filter[2];
  89. tmp += input_1[3] * filter[3];
  90. tmp += input_1[4] * filter[4];
  91. tmp += input_1[5] * filter[5];
  92. tmp += input_1[6] * filter[6];
  93. tmp += input_1[7] * filter[7];
  94. tmp += input_1[8] * filter[8];
  95. tmp += input_1[9] * filter[9];
  96. tmp += input_1[10] * filter[10];
  97. tmp += input_1[11] * filter[11];
  98. tmp += input_1[12] * filter[12];
  99. tmp += input_1[13] * filter[13];
  100. tmp += input_1[14] * filter[14];
  101. tmp += input_1[15] * filter[15];
  102. *sum_1 += tmp;
  103. tmp = 0;
  104. tmp += input_2[0] * filter[0];
  105. tmp += input_2[1] * filter[1];
  106. tmp += input_2[2] * filter[2];
  107. tmp += input_2[3] * filter[3];
  108. tmp += input_2[4] * filter[4];
  109. tmp += input_2[5] * filter[5];
  110. tmp += input_2[6] * filter[6];
  111. tmp += input_2[7] * filter[7];
  112. tmp += input_2[8] * filter[8];
  113. tmp += input_2[9] * filter[9];
  114. tmp += input_2[10] * filter[10];
  115. tmp += input_2[11] * filter[11];
  116. tmp += input_2[12] * filter[12];
  117. tmp += input_2[13] * filter[13];
  118. tmp += input_2[14] * filter[14];
  119. tmp += input_2[15] * filter[15];
  120. *sum_2 += tmp;
  121. tmp = 0;
  122. tmp += input_3[0] * filter[0];
  123. tmp += input_3[1] * filter[1];
  124. tmp += input_3[2] * filter[2];
  125. tmp += input_3[3] * filter[3];
  126. tmp += input_3[4] * filter[4];
  127. tmp += input_3[5] * filter[5];
  128. tmp += input_3[6] * filter[6];
  129. tmp += input_3[7] * filter[7];
  130. tmp += input_3[8] * filter[8];
  131. tmp += input_3[9] * filter[9];
  132. tmp += input_3[10] * filter[10];
  133. tmp += input_3[11] * filter[11];
  134. tmp += input_3[12] * filter[12];
  135. tmp += input_3[13] * filter[13];
  136. tmp += input_3[14] * filter[14];
  137. tmp += input_3[15] * filter[15];
  138. *sum_3 += tmp;
  139. }
  140. static inline void group_mac_kernel8_4row_fp_uniweight_reuse_output_input(float* sum_0, float* sum_1, float* sum_2, float* sum_3,
  141. const float* input_0, const float* input_1, const float* input_2, const float* input_3,
  142. const float* filter) {
  143. float tmp;
  144. tmp = 0;
  145. tmp += input_0[0] * filter[0];
  146. tmp += input_0[1] * filter[1];
  147. tmp += input_0[2] * filter[2];
  148. tmp += input_0[3] * filter[3];
  149. tmp += input_0[4] * filter[4];
  150. tmp += input_0[5] * filter[5];
  151. tmp += input_0[6] * filter[6];
  152. tmp += input_0[7] * filter[7];
  153. tmp += input_0[8] * filter[8];
  154. tmp += input_0[9] * filter[9];
  155. tmp += input_0[10] * filter[10];
  156. tmp += input_0[11] * filter[11];
  157. tmp += input_0[12] * filter[12];
  158. tmp += input_0[13] * filter[13];
  159. tmp += input_0[14] * filter[14];
  160. tmp += input_0[15] * filter[15];
  161. tmp += input_0[16] * filter[16];
  162. tmp += input_0[17] * filter[17];
  163. tmp += input_0[18] * filter[18];
  164. tmp += input_0[19] * filter[19];
  165. tmp += input_0[20] * filter[20];
  166. tmp += input_0[21] * filter[21];
  167. tmp += input_0[22] * filter[22];
  168. tmp += input_0[23] * filter[23];
  169. tmp += input_0[24] * filter[24];
  170. tmp += input_0[25] * filter[25];
  171. tmp += input_0[26] * filter[26];
  172. tmp += input_0[27] * filter[27];
  173. tmp += input_0[28] * filter[28];
  174. tmp += input_0[29] * filter[29];
  175. tmp += input_0[30] * filter[30];
  176. tmp += input_0[31] * filter[31];
  177. tmp += input_0[32] * filter[32];
  178. tmp += input_0[33] * filter[33];
  179. tmp += input_0[34] * filter[34];
  180. tmp += input_0[35] * filter[35];
  181. tmp += input_0[36] * filter[36];
  182. tmp += input_0[37] * filter[37];
  183. tmp += input_0[38] * filter[38];
  184. tmp += input_0[39] * filter[39];
  185. tmp += input_0[40] * filter[40];
  186. tmp += input_0[41] * filter[41];
  187. tmp += input_0[42] * filter[42];
  188. tmp += input_0[43] * filter[43];
  189. tmp += input_0[44] * filter[44];
  190. tmp += input_0[45] * filter[45];
  191. tmp += input_0[46] * filter[46];
  192. tmp += input_0[47] * filter[47];
  193. tmp += input_0[48] * filter[48];
  194. tmp += input_0[49] * filter[49];
  195. tmp += input_0[50] * filter[50];
  196. tmp += input_0[51] * filter[51];
  197. tmp += input_0[52] * filter[52];
  198. tmp += input_0[53] * filter[53];
  199. tmp += input_0[54] * filter[54];
  200. tmp += input_0[55] * filter[55];
  201. tmp += input_0[56] * filter[56];
  202. tmp += input_0[57] * filter[57];
  203. tmp += input_0[58] * filter[58];
  204. tmp += input_0[59] * filter[59];
  205. tmp += input_0[60] * filter[60];
  206. tmp += input_0[61] * filter[61];
  207. tmp += input_0[62] * filter[62];
  208. tmp += input_0[63] * filter[63];
  209. *sum_0 += tmp;
  210. tmp = 0;
  211. tmp += input_1[0] * filter[0];
  212. tmp += input_1[1] * filter[1];
  213. tmp += input_1[2] * filter[2];
  214. tmp += input_1[3] * filter[3];
  215. tmp += input_1[4] * filter[4];
  216. tmp += input_1[5] * filter[5];
  217. tmp += input_1[6] * filter[6];
  218. tmp += input_1[7] * filter[7];
  219. tmp += input_1[8] * filter[8];
  220. tmp += input_1[9] * filter[9];
  221. tmp += input_1[10] * filter[10];
  222. tmp += input_1[11] * filter[11];
  223. tmp += input_1[12] * filter[12];
  224. tmp += input_1[13] * filter[13];
  225. tmp += input_1[14] * filter[14];
  226. tmp += input_1[15] * filter[15];
  227. tmp += input_1[16] * filter[16];
  228. tmp += input_1[17] * filter[17];
  229. tmp += input_1[18] * filter[18];
  230. tmp += input_1[19] * filter[19];
  231. tmp += input_1[20] * filter[20];
  232. tmp += input_1[21] * filter[21];
  233. tmp += input_1[22] * filter[22];
  234. tmp += input_1[23] * filter[23];
  235. tmp += input_1[24] * filter[24];
  236. tmp += input_1[25] * filter[25];
  237. tmp += input_1[26] * filter[26];
  238. tmp += input_1[27] * filter[27];
  239. tmp += input_1[28] * filter[28];
  240. tmp += input_1[29] * filter[29];
  241. tmp += input_1[30] * filter[30];
  242. tmp += input_1[31] * filter[31];
  243. tmp += input_1[32] * filter[32];
  244. tmp += input_1[33] * filter[33];
  245. tmp += input_1[34] * filter[34];
  246. tmp += input_1[35] * filter[35];
  247. tmp += input_1[36] * filter[36];
  248. tmp += input_1[37] * filter[37];
  249. tmp += input_1[38] * filter[38];
  250. tmp += input_1[39] * filter[39];
  251. tmp += input_1[40] * filter[40];
  252. tmp += input_1[41] * filter[41];
  253. tmp += input_1[42] * filter[42];
  254. tmp += input_1[43] * filter[43];
  255. tmp += input_1[44] * filter[44];
  256. tmp += input_1[45] * filter[45];
  257. tmp += input_1[46] * filter[46];
  258. tmp += input_1[47] * filter[47];
  259. tmp += input_1[48] * filter[48];
  260. tmp += input_1[49] * filter[49];
  261. tmp += input_1[50] * filter[50];
  262. tmp += input_1[51] * filter[51];
  263. tmp += input_1[52] * filter[52];
  264. tmp += input_1[53] * filter[53];
  265. tmp += input_1[54] * filter[54];
  266. tmp += input_1[55] * filter[55];
  267. tmp += input_1[56] * filter[56];
  268. tmp += input_1[57] * filter[57];
  269. tmp += input_1[58] * filter[58];
  270. tmp += input_1[59] * filter[59];
  271. tmp += input_1[60] * filter[60];
  272. tmp += input_1[61] * filter[61];
  273. tmp += input_1[62] * filter[62];
  274. tmp += input_1[63] * filter[63];
  275. *sum_1 += tmp;
  276. tmp = 0;
  277. tmp += input_2[0] * filter[0];
  278. tmp += input_2[1] * filter[1];
  279. tmp += input_2[2] * filter[2];
  280. tmp += input_2[3] * filter[3];
  281. tmp += input_2[4] * filter[4];
  282. tmp += input_2[5] * filter[5];
  283. tmp += input_2[6] * filter[6];
  284. tmp += input_2[7] * filter[7];
  285. tmp += input_2[8] * filter[8];
  286. tmp += input_2[9] * filter[9];
  287. tmp += input_2[10] * filter[10];
  288. tmp += input_2[11] * filter[11];
  289. tmp += input_2[12] * filter[12];
  290. tmp += input_2[13] * filter[13];
  291. tmp += input_2[14] * filter[14];
  292. tmp += input_2[15] * filter[15];
  293. tmp += input_2[16] * filter[16];
  294. tmp += input_2[17] * filter[17];
  295. tmp += input_2[18] * filter[18];
  296. tmp += input_2[19] * filter[19];
  297. tmp += input_2[20] * filter[20];
  298. tmp += input_2[21] * filter[21];
  299. tmp += input_2[22] * filter[22];
  300. tmp += input_2[23] * filter[23];
  301. tmp += input_2[24] * filter[24];
  302. tmp += input_2[25] * filter[25];
  303. tmp += input_2[26] * filter[26];
  304. tmp += input_2[27] * filter[27];
  305. tmp += input_2[28] * filter[28];
  306. tmp += input_2[29] * filter[29];
  307. tmp += input_2[30] * filter[30];
  308. tmp += input_2[31] * filter[31];
  309. tmp += input_2[32] * filter[32];
  310. tmp += input_2[33] * filter[33];
  311. tmp += input_2[34] * filter[34];
  312. tmp += input_2[35] * filter[35];
  313. tmp += input_2[36] * filter[36];
  314. tmp += input_2[37] * filter[37];
  315. tmp += input_2[38] * filter[38];
  316. tmp += input_2[39] * filter[39];
  317. tmp += input_2[40] * filter[40];
  318. tmp += input_2[41] * filter[41];
  319. tmp += input_2[42] * filter[42];
  320. tmp += input_2[43] * filter[43];
  321. tmp += input_2[44] * filter[44];
  322. tmp += input_2[45] * filter[45];
  323. tmp += input_2[46] * filter[46];
  324. tmp += input_2[47] * filter[47];
  325. tmp += input_2[48] * filter[48];
  326. tmp += input_2[49] * filter[49];
  327. tmp += input_2[50] * filter[50];
  328. tmp += input_2[51] * filter[51];
  329. tmp += input_2[52] * filter[52];
  330. tmp += input_2[53] * filter[53];
  331. tmp += input_2[54] * filter[54];
  332. tmp += input_2[55] * filter[55];
  333. tmp += input_2[56] * filter[56];
  334. tmp += input_2[57] * filter[57];
  335. tmp += input_2[58] * filter[58];
  336. tmp += input_2[59] * filter[59];
  337. tmp += input_2[60] * filter[60];
  338. tmp += input_2[61] * filter[61];
  339. tmp += input_2[62] * filter[62];
  340. tmp += input_2[63] * filter[63];
  341. *sum_2 += tmp;
  342. tmp = 0;
  343. tmp += input_3[0] * filter[0];
  344. tmp += input_3[1] * filter[1];
  345. tmp += input_3[2] * filter[2];
  346. tmp += input_3[3] * filter[3];
  347. tmp += input_3[4] * filter[4];
  348. tmp += input_3[5] * filter[5];
  349. tmp += input_3[6] * filter[6];
  350. tmp += input_3[7] * filter[7];
  351. tmp += input_3[8] * filter[8];
  352. tmp += input_3[9] * filter[9];
  353. tmp += input_3[10] * filter[10];
  354. tmp += input_3[11] * filter[11];
  355. tmp += input_3[12] * filter[12];
  356. tmp += input_3[13] * filter[13];
  357. tmp += input_3[14] * filter[14];
  358. tmp += input_3[15] * filter[15];
  359. tmp += input_3[16] * filter[16];
  360. tmp += input_3[17] * filter[17];
  361. tmp += input_3[18] * filter[18];
  362. tmp += input_3[19] * filter[19];
  363. tmp += input_3[20] * filter[20];
  364. tmp += input_3[21] * filter[21];
  365. tmp += input_3[22] * filter[22];
  366. tmp += input_3[23] * filter[23];
  367. tmp += input_3[24] * filter[24];
  368. tmp += input_3[25] * filter[25];
  369. tmp += input_3[26] * filter[26];
  370. tmp += input_3[27] * filter[27];
  371. tmp += input_3[28] * filter[28];
  372. tmp += input_3[29] * filter[29];
  373. tmp += input_3[30] * filter[30];
  374. tmp += input_3[31] * filter[31];
  375. tmp += input_3[32] * filter[32];
  376. tmp += input_3[33] * filter[33];
  377. tmp += input_3[34] * filter[34];
  378. tmp += input_3[35] * filter[35];
  379. tmp += input_3[36] * filter[36];
  380. tmp += input_3[37] * filter[37];
  381. tmp += input_3[38] * filter[38];
  382. tmp += input_3[39] * filter[39];
  383. tmp += input_3[40] * filter[40];
  384. tmp += input_3[41] * filter[41];
  385. tmp += input_3[42] * filter[42];
  386. tmp += input_3[43] * filter[43];
  387. tmp += input_3[44] * filter[44];
  388. tmp += input_3[45] * filter[45];
  389. tmp += input_3[46] * filter[46];
  390. tmp += input_3[47] * filter[47];
  391. tmp += input_3[48] * filter[48];
  392. tmp += input_3[49] * filter[49];
  393. tmp += input_3[50] * filter[50];
  394. tmp += input_3[51] * filter[51];
  395. tmp += input_3[52] * filter[52];
  396. tmp += input_3[53] * filter[53];
  397. tmp += input_3[54] * filter[54];
  398. tmp += input_3[55] * filter[55];
  399. tmp += input_3[56] * filter[56];
  400. tmp += input_3[57] * filter[57];
  401. tmp += input_3[58] * filter[58];
  402. tmp += input_3[59] * filter[59];
  403. tmp += input_3[60] * filter[60];
  404. tmp += input_3[61] * filter[61];
  405. tmp += input_3[62] * filter[62];
  406. tmp += input_3[63] * filter[63];
  407. *sum_3 += tmp;
  408. }
  409. /* END: MAC Functions for Group Conv */
  410. /* START: MAC Functions for Transpose Depthwise Conv */
  411. /* START: For 3x3 kernel size*/
  412. static inline void transpose_depthwise_mac_kernel3_2row_fp_uniweight(float* sum_0, float* sum_1,
  413. const float* two_column_buffer, const float* ksrc_transposed, const uint16_t input_width,
  414. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  415. *sum_0 += two_column_buffer[0] * ksrc_transposed[0];
  416. *sum_1 += two_column_buffer[1] * ksrc_transposed[0];
  417. *sum_0 += two_column_buffer[1] * ksrc_transposed[1];
  418. *sum_1 += two_column_buffer[2] * ksrc_transposed[1];
  419. *sum_0 += two_column_buffer[2] * ksrc_transposed[2];
  420. *sum_1 += two_column_buffer[3] * ksrc_transposed[2];
  421. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  422. *sum_0 += two_column_buffer[0] * ksrc_transposed[3];
  423. *sum_1 += two_column_buffer[1] * ksrc_transposed[3];
  424. *sum_0 += two_column_buffer[1] * ksrc_transposed[4];
  425. *sum_1 += two_column_buffer[2] * ksrc_transposed[4];
  426. *sum_0 += two_column_buffer[2] * ksrc_transposed[5];
  427. *sum_1 += two_column_buffer[3] * ksrc_transposed[5];
  428. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  429. *sum_0 += two_column_buffer[0] * ksrc_transposed[6];
  430. *sum_1 += two_column_buffer[1] * ksrc_transposed[6];
  431. *sum_0 += two_column_buffer[1] * ksrc_transposed[7];
  432. *sum_1 += two_column_buffer[2] * ksrc_transposed[7];
  433. *sum_0 += two_column_buffer[2] * ksrc_transposed[8];
  434. *sum_1 += two_column_buffer[3] * ksrc_transposed[8];
  435. }
  436. static inline void transpose_depthwise_mac_kernel3_1row_fp_uniweight(float* sum_0,
  437. const float* two_column_buffer, const float* ksrc_transposed, const uint16_t input_width,
  438. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  439. *sum_0 += two_column_buffer[0] * ksrc_transposed[0];
  440. *sum_0 += two_column_buffer[1] * ksrc_transposed[1];
  441. *sum_0 += two_column_buffer[2] * ksrc_transposed[2];
  442. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  443. *sum_0 += two_column_buffer[0] * ksrc_transposed[3];
  444. *sum_0 += two_column_buffer[1] * ksrc_transposed[4];
  445. *sum_0 += two_column_buffer[2] * ksrc_transposed[5];
  446. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  447. *sum_0 += two_column_buffer[0] * ksrc_transposed[6];
  448. *sum_0 += two_column_buffer[1] * ksrc_transposed[7];
  449. *sum_0 += two_column_buffer[2] * ksrc_transposed[8];
  450. }
  451. /* END: For 3x3 kernel size*/
  452. /* START: For 5x5 kernel size*/
  453. static inline void transpose_depthwise_mac_kernel5_2row_fp_uniweight(float* sum_0, float* sum_1,
  454. const float* two_column_buffer, const float* ksrc_transposed, const uint16_t input_width,
  455. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  456. *sum_0 += two_column_buffer[0] * ksrc_transposed[0];
  457. *sum_1 += two_column_buffer[1] * ksrc_transposed[0];
  458. *sum_0 += two_column_buffer[1] * ksrc_transposed[1];
  459. *sum_1 += two_column_buffer[2] * ksrc_transposed[1];
  460. *sum_0 += two_column_buffer[2] * ksrc_transposed[2];
  461. *sum_1 += two_column_buffer[3] * ksrc_transposed[2];
  462. *sum_0 += two_column_buffer[3] * ksrc_transposed[3];
  463. *sum_1 += two_column_buffer[4] * ksrc_transposed[3];
  464. *sum_0 += two_column_buffer[4] * ksrc_transposed[4];
  465. *sum_1 += two_column_buffer[5] * ksrc_transposed[4];
  466. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  467. *sum_0 += two_column_buffer[0] * ksrc_transposed[5];
  468. *sum_1 += two_column_buffer[1] * ksrc_transposed[5];
  469. *sum_0 += two_column_buffer[1] * ksrc_transposed[6];
  470. *sum_1 += two_column_buffer[2] * ksrc_transposed[6];
  471. *sum_0 += two_column_buffer[2] * ksrc_transposed[7];
  472. *sum_1 += two_column_buffer[3] * ksrc_transposed[7];
  473. *sum_0 += two_column_buffer[3] * ksrc_transposed[8];
  474. *sum_1 += two_column_buffer[4] * ksrc_transposed[8];
  475. *sum_0 += two_column_buffer[4] * ksrc_transposed[9];
  476. *sum_1 += two_column_buffer[5] * ksrc_transposed[9];
  477. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  478. *sum_0 += two_column_buffer[0] * ksrc_transposed[10];
  479. *sum_1 += two_column_buffer[1] * ksrc_transposed[10];
  480. *sum_0 += two_column_buffer[1] * ksrc_transposed[11];
  481. *sum_1 += two_column_buffer[2] * ksrc_transposed[11];
  482. *sum_0 += two_column_buffer[2] * ksrc_transposed[12];
  483. *sum_1 += two_column_buffer[3] * ksrc_transposed[12];
  484. *sum_0 += two_column_buffer[3] * ksrc_transposed[13];
  485. *sum_1 += two_column_buffer[4] * ksrc_transposed[13];
  486. *sum_0 += two_column_buffer[4] * ksrc_transposed[14];
  487. *sum_1 += two_column_buffer[5] * ksrc_transposed[14];
  488. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  489. *sum_0 += two_column_buffer[0] * ksrc_transposed[15];
  490. *sum_1 += two_column_buffer[1] * ksrc_transposed[15];
  491. *sum_0 += two_column_buffer[1] * ksrc_transposed[16];
  492. *sum_1 += two_column_buffer[2] * ksrc_transposed[16];
  493. *sum_0 += two_column_buffer[2] * ksrc_transposed[17];
  494. *sum_1 += two_column_buffer[3] * ksrc_transposed[17];
  495. *sum_0 += two_column_buffer[3] * ksrc_transposed[18];
  496. *sum_1 += two_column_buffer[4] * ksrc_transposed[18];
  497. *sum_0 += two_column_buffer[4] * ksrc_transposed[19];
  498. *sum_1 += two_column_buffer[5] * ksrc_transposed[19];
  499. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  500. *sum_0 += two_column_buffer[0] * ksrc_transposed[20];
  501. *sum_1 += two_column_buffer[1] * ksrc_transposed[20];
  502. *sum_0 += two_column_buffer[1] * ksrc_transposed[21];
  503. *sum_1 += two_column_buffer[2] * ksrc_transposed[21];
  504. *sum_0 += two_column_buffer[2] * ksrc_transposed[22];
  505. *sum_1 += two_column_buffer[3] * ksrc_transposed[22];
  506. *sum_0 += two_column_buffer[3] * ksrc_transposed[23];
  507. *sum_1 += two_column_buffer[4] * ksrc_transposed[23];
  508. *sum_0 += two_column_buffer[4] * ksrc_transposed[24];
  509. *sum_1 += two_column_buffer[5] * ksrc_transposed[24];
  510. }
  511. static inline void transpose_depthwise_mac_kernel5_1row_fp_uniweight(float* sum_0,
  512. const float* two_column_buffer, const float* ksrc_transposed, const uint16_t input_width,
  513. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  514. *sum_0 += two_column_buffer[0] * ksrc_transposed[0];
  515. *sum_0 += two_column_buffer[1] * ksrc_transposed[1];
  516. *sum_0 += two_column_buffer[2] * ksrc_transposed[2];
  517. *sum_0 += two_column_buffer[3] * ksrc_transposed[3];
  518. *sum_0 += two_column_buffer[4] * ksrc_transposed[4];
  519. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  520. *sum_0 += two_column_buffer[0] * ksrc_transposed[5];
  521. *sum_0 += two_column_buffer[1] * ksrc_transposed[6];
  522. *sum_0 += two_column_buffer[2] * ksrc_transposed[7];
  523. *sum_0 += two_column_buffer[3] * ksrc_transposed[8];
  524. *sum_0 += two_column_buffer[4] * ksrc_transposed[9];
  525. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  526. *sum_0 += two_column_buffer[0] * ksrc_transposed[10];
  527. *sum_0 += two_column_buffer[1] * ksrc_transposed[11];
  528. *sum_0 += two_column_buffer[2] * ksrc_transposed[12];
  529. *sum_0 += two_column_buffer[3] * ksrc_transposed[13];
  530. *sum_0 += two_column_buffer[4] * ksrc_transposed[14];
  531. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  532. *sum_0 += two_column_buffer[0] * ksrc_transposed[15];
  533. *sum_0 += two_column_buffer[1] * ksrc_transposed[16];
  534. *sum_0 += two_column_buffer[2] * ksrc_transposed[17];
  535. *sum_0 += two_column_buffer[3] * ksrc_transposed[18];
  536. *sum_0 += two_column_buffer[4] * ksrc_transposed[19];
  537. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  538. *sum_0 += two_column_buffer[0] * ksrc_transposed[20];
  539. *sum_0 += two_column_buffer[1] * ksrc_transposed[21];
  540. *sum_0 += two_column_buffer[2] * ksrc_transposed[22];
  541. *sum_0 += two_column_buffer[3] * ksrc_transposed[23];
  542. *sum_0 += two_column_buffer[4] * ksrc_transposed[24];
  543. }
  544. /* END: For 5x5 kernel size*/
  545. /* START: For 7x7 kernel size*/
  546. static inline void transpose_depthwise_mac_kernel7_2row_fp_uniweight(float* sum_0, float* sum_1,
  547. const float* two_column_buffer, const float* ksrc_transposed, const uint16_t input_width,
  548. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  549. *sum_0 += two_column_buffer[0] * ksrc_transposed[0];
  550. *sum_1 += two_column_buffer[1] * ksrc_transposed[0];
  551. *sum_0 += two_column_buffer[1] * ksrc_transposed[1];
  552. *sum_1 += two_column_buffer[2] * ksrc_transposed[1];
  553. *sum_0 += two_column_buffer[2] * ksrc_transposed[2];
  554. *sum_1 += two_column_buffer[3] * ksrc_transposed[2];
  555. *sum_0 += two_column_buffer[3] * ksrc_transposed[3];
  556. *sum_1 += two_column_buffer[4] * ksrc_transposed[3];
  557. *sum_0 += two_column_buffer[4] * ksrc_transposed[4];
  558. *sum_1 += two_column_buffer[5] * ksrc_transposed[4];
  559. *sum_0 += two_column_buffer[5] * ksrc_transposed[5];
  560. *sum_1 += two_column_buffer[6] * ksrc_transposed[5];
  561. *sum_0 += two_column_buffer[6] * ksrc_transposed[6];
  562. *sum_1 += two_column_buffer[7] * ksrc_transposed[6];
  563. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  564. *sum_0 += two_column_buffer[0] * ksrc_transposed[7];
  565. *sum_1 += two_column_buffer[1] * ksrc_transposed[7];
  566. *sum_0 += two_column_buffer[1] * ksrc_transposed[8];
  567. *sum_1 += two_column_buffer[2] * ksrc_transposed[8];
  568. *sum_0 += two_column_buffer[2] * ksrc_transposed[9];
  569. *sum_1 += two_column_buffer[3] * ksrc_transposed[9];
  570. *sum_0 += two_column_buffer[3] * ksrc_transposed[10];
  571. *sum_1 += two_column_buffer[4] * ksrc_transposed[10];
  572. *sum_0 += two_column_buffer[4] * ksrc_transposed[11];
  573. *sum_1 += two_column_buffer[5] * ksrc_transposed[11];
  574. *sum_0 += two_column_buffer[5] * ksrc_transposed[12];
  575. *sum_1 += two_column_buffer[6] * ksrc_transposed[12];
  576. *sum_0 += two_column_buffer[6] * ksrc_transposed[13];
  577. *sum_1 += two_column_buffer[7] * ksrc_transposed[13];
  578. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  579. *sum_0 += two_column_buffer[0] * ksrc_transposed[14];
  580. *sum_1 += two_column_buffer[1] * ksrc_transposed[14];
  581. *sum_0 += two_column_buffer[1] * ksrc_transposed[15];
  582. *sum_1 += two_column_buffer[2] * ksrc_transposed[15];
  583. *sum_0 += two_column_buffer[2] * ksrc_transposed[16];
  584. *sum_1 += two_column_buffer[3] * ksrc_transposed[16];
  585. *sum_0 += two_column_buffer[3] * ksrc_transposed[17];
  586. *sum_1 += two_column_buffer[4] * ksrc_transposed[17];
  587. *sum_0 += two_column_buffer[4] * ksrc_transposed[18];
  588. *sum_1 += two_column_buffer[5] * ksrc_transposed[18];
  589. *sum_0 += two_column_buffer[5] * ksrc_transposed[19];
  590. *sum_1 += two_column_buffer[6] * ksrc_transposed[19];
  591. *sum_0 += two_column_buffer[6] * ksrc_transposed[20];
  592. *sum_1 += two_column_buffer[7] * ksrc_transposed[20];
  593. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  594. *sum_0 += two_column_buffer[0] * ksrc_transposed[21];
  595. *sum_1 += two_column_buffer[1] * ksrc_transposed[21];
  596. *sum_0 += two_column_buffer[1] * ksrc_transposed[22];
  597. *sum_1 += two_column_buffer[2] * ksrc_transposed[22];
  598. *sum_0 += two_column_buffer[2] * ksrc_transposed[23];
  599. *sum_1 += two_column_buffer[3] * ksrc_transposed[23];
  600. *sum_0 += two_column_buffer[3] * ksrc_transposed[24];
  601. *sum_1 += two_column_buffer[4] * ksrc_transposed[24];
  602. *sum_0 += two_column_buffer[4] * ksrc_transposed[25];
  603. *sum_1 += two_column_buffer[5] * ksrc_transposed[25];
  604. *sum_0 += two_column_buffer[5] * ksrc_transposed[26];
  605. *sum_1 += two_column_buffer[6] * ksrc_transposed[26];
  606. *sum_0 += two_column_buffer[6] * ksrc_transposed[27];
  607. *sum_1 += two_column_buffer[7] * ksrc_transposed[27];
  608. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  609. *sum_0 += two_column_buffer[0] * ksrc_transposed[28];
  610. *sum_1 += two_column_buffer[1] * ksrc_transposed[28];
  611. *sum_0 += two_column_buffer[1] * ksrc_transposed[29];
  612. *sum_1 += two_column_buffer[2] * ksrc_transposed[29];
  613. *sum_0 += two_column_buffer[2] * ksrc_transposed[30];
  614. *sum_1 += two_column_buffer[3] * ksrc_transposed[30];
  615. *sum_0 += two_column_buffer[3] * ksrc_transposed[31];
  616. *sum_1 += two_column_buffer[4] * ksrc_transposed[31];
  617. *sum_0 += two_column_buffer[4] * ksrc_transposed[32];
  618. *sum_1 += two_column_buffer[5] * ksrc_transposed[32];
  619. *sum_0 += two_column_buffer[5] * ksrc_transposed[33];
  620. *sum_1 += two_column_buffer[6] * ksrc_transposed[33];
  621. *sum_0 += two_column_buffer[6] * ksrc_transposed[34];
  622. *sum_1 += two_column_buffer[7] * ksrc_transposed[34];
  623. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  624. *sum_0 += two_column_buffer[0] * ksrc_transposed[35];
  625. *sum_1 += two_column_buffer[1] * ksrc_transposed[35];
  626. *sum_0 += two_column_buffer[1] * ksrc_transposed[36];
  627. *sum_1 += two_column_buffer[2] * ksrc_transposed[36];
  628. *sum_0 += two_column_buffer[2] * ksrc_transposed[37];
  629. *sum_1 += two_column_buffer[3] * ksrc_transposed[37];
  630. *sum_0 += two_column_buffer[3] * ksrc_transposed[38];
  631. *sum_1 += two_column_buffer[4] * ksrc_transposed[38];
  632. *sum_0 += two_column_buffer[4] * ksrc_transposed[39];
  633. *sum_1 += two_column_buffer[5] * ksrc_transposed[39];
  634. *sum_0 += two_column_buffer[5] * ksrc_transposed[40];
  635. *sum_1 += two_column_buffer[6] * ksrc_transposed[40];
  636. *sum_0 += two_column_buffer[6] * ksrc_transposed[41];
  637. *sum_1 += two_column_buffer[7] * ksrc_transposed[41];
  638. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  639. *sum_0 += two_column_buffer[0] * ksrc_transposed[42];
  640. *sum_1 += two_column_buffer[1] * ksrc_transposed[42];
  641. *sum_0 += two_column_buffer[1] * ksrc_transposed[43];
  642. *sum_1 += two_column_buffer[2] * ksrc_transposed[43];
  643. *sum_0 += two_column_buffer[2] * ksrc_transposed[44];
  644. *sum_1 += two_column_buffer[3] * ksrc_transposed[44];
  645. *sum_0 += two_column_buffer[3] * ksrc_transposed[45];
  646. *sum_1 += two_column_buffer[4] * ksrc_transposed[45];
  647. *sum_0 += two_column_buffer[4] * ksrc_transposed[46];
  648. *sum_1 += two_column_buffer[5] * ksrc_transposed[46];
  649. *sum_0 += two_column_buffer[5] * ksrc_transposed[47];
  650. *sum_1 += two_column_buffer[6] * ksrc_transposed[47];
  651. *sum_0 += two_column_buffer[6] * ksrc_transposed[48];
  652. *sum_1 += two_column_buffer[7] * ksrc_transposed[48];
  653. }
  654. static inline void transpose_depthwise_mac_kernel7_1row_fp_uniweight(float* sum_0,
  655. const float* two_column_buffer, const float* ksrc_transposed, const uint16_t input_width,
  656. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  657. *sum_0 += two_column_buffer[0] * ksrc_transposed[0];
  658. *sum_0 += two_column_buffer[1] * ksrc_transposed[1];
  659. *sum_0 += two_column_buffer[2] * ksrc_transposed[2];
  660. *sum_0 += two_column_buffer[3] * ksrc_transposed[3];
  661. *sum_0 += two_column_buffer[4] * ksrc_transposed[4];
  662. *sum_0 += two_column_buffer[5] * ksrc_transposed[5];
  663. *sum_0 += two_column_buffer[6] * ksrc_transposed[6];
  664. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  665. *sum_0 += two_column_buffer[0] * ksrc_transposed[7];
  666. *sum_0 += two_column_buffer[1] * ksrc_transposed[8];
  667. *sum_0 += two_column_buffer[2] * ksrc_transposed[9];
  668. *sum_0 += two_column_buffer[3] * ksrc_transposed[10];
  669. *sum_0 += two_column_buffer[4] * ksrc_transposed[11];
  670. *sum_0 += two_column_buffer[5] * ksrc_transposed[12];
  671. *sum_0 += two_column_buffer[6] * ksrc_transposed[13];
  672. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  673. *sum_0 += two_column_buffer[0] * ksrc_transposed[14];
  674. *sum_0 += two_column_buffer[1] * ksrc_transposed[15];
  675. *sum_0 += two_column_buffer[2] * ksrc_transposed[16];
  676. *sum_0 += two_column_buffer[3] * ksrc_transposed[17];
  677. *sum_0 += two_column_buffer[4] * ksrc_transposed[18];
  678. *sum_0 += two_column_buffer[5] * ksrc_transposed[19];
  679. *sum_0 += two_column_buffer[6] * ksrc_transposed[20];
  680. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  681. *sum_0 += two_column_buffer[0] * ksrc_transposed[21];
  682. *sum_0 += two_column_buffer[1] * ksrc_transposed[22];
  683. *sum_0 += two_column_buffer[2] * ksrc_transposed[23];
  684. *sum_0 += two_column_buffer[3] * ksrc_transposed[24];
  685. *sum_0 += two_column_buffer[4] * ksrc_transposed[25];
  686. *sum_0 += two_column_buffer[5] * ksrc_transposed[26];
  687. *sum_0 += two_column_buffer[6] * ksrc_transposed[27];
  688. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  689. *sum_0 += two_column_buffer[0] * ksrc_transposed[28];
  690. *sum_0 += two_column_buffer[1] * ksrc_transposed[29];
  691. *sum_0 += two_column_buffer[2] * ksrc_transposed[30];
  692. *sum_0 += two_column_buffer[3] * ksrc_transposed[31];
  693. *sum_0 += two_column_buffer[4] * ksrc_transposed[32];
  694. *sum_0 += two_column_buffer[5] * ksrc_transposed[33];
  695. *sum_0 += two_column_buffer[6] * ksrc_transposed[34];
  696. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  697. *sum_0 += two_column_buffer[0] * ksrc_transposed[35];
  698. *sum_0 += two_column_buffer[1] * ksrc_transposed[36];
  699. *sum_0 += two_column_buffer[2] * ksrc_transposed[37];
  700. *sum_0 += two_column_buffer[3] * ksrc_transposed[38];
  701. *sum_0 += two_column_buffer[4] * ksrc_transposed[39];
  702. *sum_0 += two_column_buffer[5] * ksrc_transposed[40];
  703. *sum_0 += two_column_buffer[6] * ksrc_transposed[41];
  704. two_column_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  705. *sum_0 += two_column_buffer[0] * ksrc_transposed[42];
  706. *sum_0 += two_column_buffer[1] * ksrc_transposed[43];
  707. *sum_0 += two_column_buffer[2] * ksrc_transposed[44];
  708. *sum_0 += two_column_buffer[3] * ksrc_transposed[45];
  709. *sum_0 += two_column_buffer[4] * ksrc_transposed[46];
  710. *sum_0 += two_column_buffer[5] * ksrc_transposed[47];
  711. *sum_0 += two_column_buffer[6] * ksrc_transposed[48];
  712. }
  713. /* END: For 7x7 kernel size*/
  714. /* END: MAC Functions for Transpose Depthwise Conv */
  715. /* START: Assign Output Functions */
  716. /* START: For Group Conv */
  717. static inline void assign_sum_to_group_output_4row16col(int8_t* out_0, int8_t* out_1, int8_t* out_2, int8_t* out_3, int8_t* out_4, int8_t* out_5, int8_t* out_6, int8_t* out_7,
  718. int8_t* out_8, int8_t* out_9, int8_t* out_10, int8_t* out_11, int8_t* out_12, int8_t* out_13, int8_t* out_14, int8_t* out_15,
  719. const float* sum_0, const float* sum_1, const float* sum_2, const float* sum_3,
  720. const float output_activation_min, const float output_activation_max, const float* scales, const float learning_rate, const int i_output_depth) {
  721. *out_0++ -= TN_MIN(TN_MAX(sum_0[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  722. *out_1++ -= TN_MIN(TN_MAX(sum_0[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  723. *out_2++ -= TN_MIN(TN_MAX(sum_0[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  724. *out_3++ -= TN_MIN(TN_MAX(sum_0[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  725. *out_4++ -= TN_MIN(TN_MAX(sum_0[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  726. *out_5++ -= TN_MIN(TN_MAX(sum_0[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  727. *out_6++ -= TN_MIN(TN_MAX(sum_0[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  728. *out_7++ -= TN_MIN(TN_MAX(sum_0[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  729. *out_8++ -= TN_MIN(TN_MAX(sum_0[8], output_activation_min), output_activation_max) * scales[i_output_depth + 8] * learning_rate;
  730. *out_9++ -= TN_MIN(TN_MAX(sum_0[9], output_activation_min), output_activation_max) * scales[i_output_depth + 9] * learning_rate;
  731. *out_10++ -= TN_MIN(TN_MAX(sum_0[10], output_activation_min), output_activation_max) * scales[i_output_depth + 10] * learning_rate;
  732. *out_11++ -= TN_MIN(TN_MAX(sum_0[11], output_activation_min), output_activation_max) * scales[i_output_depth + 11] * learning_rate;
  733. *out_12++ -= TN_MIN(TN_MAX(sum_0[12], output_activation_min), output_activation_max) * scales[i_output_depth + 12] * learning_rate;
  734. *out_13++ -= TN_MIN(TN_MAX(sum_0[13], output_activation_min), output_activation_max) * scales[i_output_depth + 13] * learning_rate;
  735. *out_14++ -= TN_MIN(TN_MAX(sum_0[14], output_activation_min), output_activation_max) * scales[i_output_depth + 14] * learning_rate;
  736. *out_15++ -= TN_MIN(TN_MAX(sum_0[15], output_activation_min), output_activation_max) * scales[i_output_depth + 15] * learning_rate;
  737. *out_0++ -= TN_MIN(TN_MAX(sum_1[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  738. *out_1++ -= TN_MIN(TN_MAX(sum_1[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  739. *out_2++ -= TN_MIN(TN_MAX(sum_1[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  740. *out_3++ -= TN_MIN(TN_MAX(sum_1[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  741. *out_4++ -= TN_MIN(TN_MAX(sum_1[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  742. *out_5++ -= TN_MIN(TN_MAX(sum_1[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  743. *out_6++ -= TN_MIN(TN_MAX(sum_1[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  744. *out_7++ -= TN_MIN(TN_MAX(sum_1[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  745. *out_8++ -= TN_MIN(TN_MAX(sum_1[8], output_activation_min), output_activation_max) * scales[i_output_depth + 8] * learning_rate;
  746. *out_9++ -= TN_MIN(TN_MAX(sum_1[9], output_activation_min), output_activation_max) * scales[i_output_depth + 9] * learning_rate;
  747. *out_10++ -= TN_MIN(TN_MAX(sum_1[10], output_activation_min), output_activation_max) * scales[i_output_depth + 10] * learning_rate;
  748. *out_11++ -= TN_MIN(TN_MAX(sum_1[11], output_activation_min), output_activation_max) * scales[i_output_depth + 11] * learning_rate;
  749. *out_12++ -= TN_MIN(TN_MAX(sum_1[12], output_activation_min), output_activation_max) * scales[i_output_depth + 12] * learning_rate;
  750. *out_13++ -= TN_MIN(TN_MAX(sum_1[13], output_activation_min), output_activation_max) * scales[i_output_depth + 13] * learning_rate;
  751. *out_14++ -= TN_MIN(TN_MAX(sum_1[14], output_activation_min), output_activation_max) * scales[i_output_depth + 14] * learning_rate;
  752. *out_15++ -= TN_MIN(TN_MAX(sum_1[15], output_activation_min), output_activation_max) * scales[i_output_depth + 15] * learning_rate;
  753. *out_0++ -= TN_MIN(TN_MAX(sum_2[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  754. *out_1++ -= TN_MIN(TN_MAX(sum_2[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  755. *out_2++ -= TN_MIN(TN_MAX(sum_2[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  756. *out_3++ -= TN_MIN(TN_MAX(sum_2[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  757. *out_4++ -= TN_MIN(TN_MAX(sum_2[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  758. *out_5++ -= TN_MIN(TN_MAX(sum_2[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  759. *out_6++ -= TN_MIN(TN_MAX(sum_2[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  760. *out_7++ -= TN_MIN(TN_MAX(sum_2[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  761. *out_8++ -= TN_MIN(TN_MAX(sum_2[8], output_activation_min), output_activation_max) * scales[i_output_depth + 8] * learning_rate;
  762. *out_9++ -= TN_MIN(TN_MAX(sum_2[9], output_activation_min), output_activation_max) * scales[i_output_depth + 9] * learning_rate;
  763. *out_10++ -= TN_MIN(TN_MAX(sum_2[10], output_activation_min), output_activation_max) * scales[i_output_depth + 10] * learning_rate;
  764. *out_11++ -= TN_MIN(TN_MAX(sum_2[11], output_activation_min), output_activation_max) * scales[i_output_depth + 11] * learning_rate;
  765. *out_12++ -= TN_MIN(TN_MAX(sum_2[12], output_activation_min), output_activation_max) * scales[i_output_depth + 12] * learning_rate;
  766. *out_13++ -= TN_MIN(TN_MAX(sum_2[13], output_activation_min), output_activation_max) * scales[i_output_depth + 13] * learning_rate;
  767. *out_14++ -= TN_MIN(TN_MAX(sum_2[14], output_activation_min), output_activation_max) * scales[i_output_depth + 14] * learning_rate;
  768. *out_15++ -= TN_MIN(TN_MAX(sum_2[15], output_activation_min), output_activation_max) * scales[i_output_depth + 15] * learning_rate;
  769. *out_0++ -= TN_MIN(TN_MAX(sum_3[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  770. *out_1++ -= TN_MIN(TN_MAX(sum_3[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  771. *out_2++ -= TN_MIN(TN_MAX(sum_3[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  772. *out_3++ -= TN_MIN(TN_MAX(sum_3[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  773. *out_4++ -= TN_MIN(TN_MAX(sum_3[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  774. *out_5++ -= TN_MIN(TN_MAX(sum_3[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  775. *out_6++ -= TN_MIN(TN_MAX(sum_3[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  776. *out_7++ -= TN_MIN(TN_MAX(sum_3[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  777. *out_8++ -= TN_MIN(TN_MAX(sum_3[8], output_activation_min), output_activation_max) * scales[i_output_depth + 8] * learning_rate;
  778. *out_9++ -= TN_MIN(TN_MAX(sum_3[9], output_activation_min), output_activation_max) * scales[i_output_depth + 9] * learning_rate;
  779. *out_10++ -= TN_MIN(TN_MAX(sum_3[10], output_activation_min), output_activation_max) * scales[i_output_depth + 10] * learning_rate;
  780. *out_11++ -= TN_MIN(TN_MAX(sum_3[11], output_activation_min), output_activation_max) * scales[i_output_depth + 11] * learning_rate;
  781. *out_12++ -= TN_MIN(TN_MAX(sum_3[12], output_activation_min), output_activation_max) * scales[i_output_depth + 12] * learning_rate;
  782. *out_13++ -= TN_MIN(TN_MAX(sum_3[13], output_activation_min), output_activation_max) * scales[i_output_depth + 13] * learning_rate;
  783. *out_14++ -= TN_MIN(TN_MAX(sum_3[14], output_activation_min), output_activation_max) * scales[i_output_depth + 14] * learning_rate;
  784. *out_15++ -= TN_MIN(TN_MAX(sum_3[15], output_activation_min), output_activation_max) * scales[i_output_depth + 15] * learning_rate;
  785. }
  786. static inline void assign_sum_to_group_output_4row8col(int8_t* out_0, int8_t* out_1, int8_t* out_2, int8_t* out_3, int8_t* out_4, int8_t* out_5, int8_t* out_6, int8_t* out_7,
  787. const float* sum_0, const float* sum_1, const float* sum_2, const float* sum_3,
  788. const float output_activation_min, const float output_activation_max, const float* scales, const float learning_rate, const int i_output_depth) {
  789. *out_0++ -= TN_MIN(TN_MAX(sum_0[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  790. *out_1++ -= TN_MIN(TN_MAX(sum_0[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  791. *out_2++ -= TN_MIN(TN_MAX(sum_0[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  792. *out_3++ -= TN_MIN(TN_MAX(sum_0[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  793. *out_4++ -= TN_MIN(TN_MAX(sum_0[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  794. *out_5++ -= TN_MIN(TN_MAX(sum_0[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  795. *out_6++ -= TN_MIN(TN_MAX(sum_0[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  796. *out_7++ -= TN_MIN(TN_MAX(sum_0[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  797. *out_0++ -= TN_MIN(TN_MAX(sum_1[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  798. *out_1++ -= TN_MIN(TN_MAX(sum_1[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  799. *out_2++ -= TN_MIN(TN_MAX(sum_1[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  800. *out_3++ -= TN_MIN(TN_MAX(sum_1[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  801. *out_4++ -= TN_MIN(TN_MAX(sum_1[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  802. *out_5++ -= TN_MIN(TN_MAX(sum_1[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  803. *out_6++ -= TN_MIN(TN_MAX(sum_1[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  804. *out_7++ -= TN_MIN(TN_MAX(sum_1[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  805. *out_0++ -= TN_MIN(TN_MAX(sum_2[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  806. *out_1++ -= TN_MIN(TN_MAX(sum_2[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  807. *out_2++ -= TN_MIN(TN_MAX(sum_2[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  808. *out_3++ -= TN_MIN(TN_MAX(sum_2[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  809. *out_4++ -= TN_MIN(TN_MAX(sum_2[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  810. *out_5++ -= TN_MIN(TN_MAX(sum_2[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  811. *out_6++ -= TN_MIN(TN_MAX(sum_2[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  812. *out_7++ -= TN_MIN(TN_MAX(sum_2[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  813. *out_0++ -= TN_MIN(TN_MAX(sum_3[0], output_activation_min), output_activation_max) * scales[i_output_depth] * learning_rate;
  814. *out_1++ -= TN_MIN(TN_MAX(sum_3[1], output_activation_min), output_activation_max) * scales[i_output_depth + 1] * learning_rate;
  815. *out_2++ -= TN_MIN(TN_MAX(sum_3[2], output_activation_min), output_activation_max) * scales[i_output_depth + 2] * learning_rate;
  816. *out_3++ -= TN_MIN(TN_MAX(sum_3[3], output_activation_min), output_activation_max) * scales[i_output_depth + 3] * learning_rate;
  817. *out_4++ -= TN_MIN(TN_MAX(sum_3[4], output_activation_min), output_activation_max) * scales[i_output_depth + 4] * learning_rate;
  818. *out_5++ -= TN_MIN(TN_MAX(sum_3[5], output_activation_min), output_activation_max) * scales[i_output_depth + 5] * learning_rate;
  819. *out_6++ -= TN_MIN(TN_MAX(sum_3[6], output_activation_min), output_activation_max) * scales[i_output_depth + 6] * learning_rate;
  820. *out_7++ -= TN_MIN(TN_MAX(sum_3[7], output_activation_min), output_activation_max) * scales[i_output_depth + 7] * learning_rate;
  821. }
  822. /* END: For Group Conv */
  823. /* START: For Pointwise Conv */
  824. static inline void assign_sum_to_pointwise_output_4row8col(float* out_0, float* out_1, float* out_2, float* out_3,
  825. const float* sum, const float output_activation_min, const float output_activation_max) {
  826. *out_0++ += TN_MIN(TN_MAX(sum[0], output_activation_min), output_activation_max);
  827. *out_1++ += TN_MIN(TN_MAX(sum[1], output_activation_min), output_activation_max);
  828. *out_2++ += TN_MIN(TN_MAX(sum[2], output_activation_min), output_activation_max);
  829. *out_3++ += TN_MIN(TN_MAX(sum[3], output_activation_min), output_activation_max);
  830. *out_0++ += TN_MIN(TN_MAX(sum[4], output_activation_min), output_activation_max);
  831. *out_1++ += TN_MIN(TN_MAX(sum[5], output_activation_min), output_activation_max);
  832. *out_2++ += TN_MIN(TN_MAX(sum[6], output_activation_min), output_activation_max);
  833. *out_3++ += TN_MIN(TN_MAX(sum[7], output_activation_min), output_activation_max);
  834. *out_0++ += TN_MIN(TN_MAX(sum[8], output_activation_min), output_activation_max);
  835. *out_1++ += TN_MIN(TN_MAX(sum[9], output_activation_min), output_activation_max);
  836. *out_2++ += TN_MIN(TN_MAX(sum[10], output_activation_min), output_activation_max);
  837. *out_3++ += TN_MIN(TN_MAX(sum[11], output_activation_min), output_activation_max);
  838. *out_0++ += TN_MIN(TN_MAX(sum[12], output_activation_min), output_activation_max);
  839. *out_1++ += TN_MIN(TN_MAX(sum[13], output_activation_min), output_activation_max);
  840. *out_2++ += TN_MIN(TN_MAX(sum[14], output_activation_min), output_activation_max);
  841. *out_3++ += TN_MIN(TN_MAX(sum[15], output_activation_min), output_activation_max);
  842. *out_0++ += TN_MIN(TN_MAX(sum[16], output_activation_min), output_activation_max);
  843. *out_1++ += TN_MIN(TN_MAX(sum[17], output_activation_min), output_activation_max);
  844. *out_2++ += TN_MIN(TN_MAX(sum[18], output_activation_min), output_activation_max);
  845. *out_3++ += TN_MIN(TN_MAX(sum[19], output_activation_min), output_activation_max);
  846. *out_0++ += TN_MIN(TN_MAX(sum[20], output_activation_min), output_activation_max);
  847. *out_1++ += TN_MIN(TN_MAX(sum[21], output_activation_min), output_activation_max);
  848. *out_2++ += TN_MIN(TN_MAX(sum[22], output_activation_min), output_activation_max);
  849. *out_3++ += TN_MIN(TN_MAX(sum[23], output_activation_min), output_activation_max);
  850. *out_0++ += TN_MIN(TN_MAX(sum[24], output_activation_min), output_activation_max);
  851. *out_1++ += TN_MIN(TN_MAX(sum[25], output_activation_min), output_activation_max);
  852. *out_2++ += TN_MIN(TN_MAX(sum[26], output_activation_min), output_activation_max);
  853. *out_3++ += TN_MIN(TN_MAX(sum[27], output_activation_min), output_activation_max);
  854. *out_0++ += TN_MIN(TN_MAX(sum[28], output_activation_min), output_activation_max);
  855. *out_1++ += TN_MIN(TN_MAX(sum[29], output_activation_min), output_activation_max);
  856. *out_2++ += TN_MIN(TN_MAX(sum[30], output_activation_min), output_activation_max);
  857. *out_3++ += TN_MIN(TN_MAX(sum[31], output_activation_min), output_activation_max);
  858. }
  859. static inline void assign_sum_to_pointwise_output_1row8col(float* out_0,
  860. const float* sum, const float output_activation_min, const float output_activation_max) {
  861. *out_0++ += TN_MIN(TN_MAX(sum[0], output_activation_min), output_activation_max);
  862. *out_0++ += TN_MIN(TN_MAX(sum[1], output_activation_min), output_activation_max);
  863. *out_0++ += TN_MIN(TN_MAX(sum[2], output_activation_min), output_activation_max);
  864. *out_0++ += TN_MIN(TN_MAX(sum[3], output_activation_min), output_activation_max);
  865. *out_0++ += TN_MIN(TN_MAX(sum[4], output_activation_min), output_activation_max);
  866. *out_0++ += TN_MIN(TN_MAX(sum[5], output_activation_min), output_activation_max);
  867. *out_0++ += TN_MIN(TN_MAX(sum[6], output_activation_min), output_activation_max);
  868. *out_0++ += TN_MIN(TN_MAX(sum[7], output_activation_min), output_activation_max);
  869. }
  870. static inline void assign_sum_to_pointwise_output_4row8col_noMINMAX(float* out_0, float* out_1, float* out_2, float* out_3, const float* sum) {
  871. *out_0++ += sum[0];
  872. *out_1++ += sum[1];
  873. *out_2++ += sum[2];
  874. *out_3++ += sum[3];
  875. *out_0++ += sum[4];
  876. *out_1++ += sum[5];
  877. *out_2++ += sum[6];
  878. *out_3++ += sum[7];
  879. *out_0++ += sum[8];
  880. *out_1++ += sum[9];
  881. *out_2++ += sum[10];
  882. *out_3++ += sum[11];
  883. *out_0++ += sum[12];
  884. *out_1++ += sum[13];
  885. *out_2++ += sum[14];
  886. *out_3++ += sum[15];
  887. *out_0++ += sum[16];
  888. *out_1++ += sum[17];
  889. *out_2++ += sum[18];
  890. *out_3++ += sum[19];
  891. *out_0++ += sum[20];
  892. *out_1++ += sum[21];
  893. *out_2++ += sum[22];
  894. *out_3++ += sum[23];
  895. *out_0++ += sum[24];
  896. *out_1++ += sum[25];
  897. *out_2++ += sum[26];
  898. *out_3++ += sum[27];
  899. *out_0++ += sum[28];
  900. *out_1++ += sum[29];
  901. *out_2++ += sum[30];
  902. *out_3++ += sum[31];
  903. }
  904. static inline void assign_sum_to_pointwise_output_1row8col_noMINMAX(float* out_0, const float* sum) {
  905. *out_0++ += sum[0];
  906. *out_0++ += sum[1];
  907. *out_0++ += sum[2];
  908. *out_0++ += sum[3];
  909. *out_0++ += sum[4];
  910. *out_0++ += sum[5];
  911. *out_0++ += sum[6];
  912. *out_0++ += sum[7];
  913. }
  914. static inline void assign_sum_to_pointwise_output_4row4col(float* out_0, float* out_1, float* out_2, float* out_3,
  915. const float* sum, const float output_activation_min, const float output_activation_max) {
  916. *out_0++ += TN_MIN(TN_MAX(sum[0], output_activation_min), output_activation_max);
  917. *out_1++ += TN_MIN(TN_MAX(sum[1], output_activation_min), output_activation_max);
  918. *out_2++ += TN_MIN(TN_MAX(sum[2], output_activation_min), output_activation_max);
  919. *out_3++ += TN_MIN(TN_MAX(sum[3], output_activation_min), output_activation_max);
  920. *out_0++ += TN_MIN(TN_MAX(sum[4], output_activation_min), output_activation_max);
  921. *out_1++ += TN_MIN(TN_MAX(sum[5], output_activation_min), output_activation_max);
  922. *out_2++ += TN_MIN(TN_MAX(sum[6], output_activation_min), output_activation_max);
  923. *out_3++ += TN_MIN(TN_MAX(sum[7], output_activation_min), output_activation_max);
  924. *out_0++ += TN_MIN(TN_MAX(sum[8], output_activation_min), output_activation_max);
  925. *out_1++ += TN_MIN(TN_MAX(sum[9], output_activation_min), output_activation_max);
  926. *out_2++ += TN_MIN(TN_MAX(sum[10], output_activation_min), output_activation_max);
  927. *out_3++ += TN_MIN(TN_MAX(sum[11], output_activation_min), output_activation_max);
  928. *out_0++ += TN_MIN(TN_MAX(sum[12], output_activation_min), output_activation_max);
  929. *out_1++ += TN_MIN(TN_MAX(sum[13], output_activation_min), output_activation_max);
  930. *out_2++ += TN_MIN(TN_MAX(sum[14], output_activation_min), output_activation_max);
  931. *out_3++ += TN_MIN(TN_MAX(sum[15], output_activation_min), output_activation_max);
  932. }
  933. static inline void assign_sum_to_pointwise_output_1row4col(float* out_0,
  934. const float* sum, const float output_activation_min, const float output_activation_max) {
  935. *out_0++ += TN_MIN(TN_MAX(sum[0], output_activation_min), output_activation_max);
  936. *out_0++ += TN_MIN(TN_MAX(sum[1], output_activation_min), output_activation_max);
  937. *out_0++ += TN_MIN(TN_MAX(sum[2], output_activation_min), output_activation_max);
  938. *out_0++ += TN_MIN(TN_MAX(sum[3], output_activation_min), output_activation_max);
  939. }
  940. /* END: For Pointwise Conv */
  941. /* END: Assign Output Functions */