nnfunctions.h 42 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045
  1. /* ----------------------------------------------------------------------
  2. * Project: Tiny Training Engine, MCUNetV3
  3. * Title: nnfunctions.h
  4. *
  5. * Reference papers:
  6. * - MCUNet: Tiny Deep Learning on IoT Device, NeurIPS 2020
  7. * - MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning, NeurIPS 2021
  8. * - MCUNetV3: On-Device Training Under 256KB Memory, NeurIPS 2022
  9. * Contact authors:
  10. * - Wei-Chen Wang, wweichen@mit.edu
  11. * - Wei-Ming Chen, wmchen@mit.edu
  12. * - Ji Lin, jilin@mit.edu
  13. * - Ligeng Zhu, ligeng@mit.edu
  14. * - Song Han, songhan@mit.edu
  15. * - Chuang Gan, ganchuang@csail.mit.edu
  16. *
  17. * Target ISA: ARMv7E-M
  18. * -------------------------------------------------------------------- */
  19. /*
  20. * Copyright (c) 2023 HPMicro
  21. *
  22. * SPDX-License-Identifier: BSD-3-Clause
  23. * Target ISA: RISCV D45
  24. *
  25. */
  26. #include "hpm_math.h"
  27. #include <stdlib.h>
  28. /* START: MAC Functions for Pointwise Conv */
  29. static inline void mac_4row_4col_IOHW_forint8w(q31_t* sum, const q7_t* input_0, const q7_t* input_1, const q7_t* input_2, const q7_t* input_3,
  30. const q7_t* filter_0, const q7_t* filter_1, const q7_t* filter_2, const q7_t* filter_3) {
  31. *sum += *input_0++ * *filter_0;
  32. *sum += *input_0++ * *filter_1;
  33. *sum += *input_0++ * *filter_2;
  34. *sum++ += *input_0++ * *filter_3;
  35. *sum += *input_1++ * *filter_0;
  36. *sum += *input_1++ * *filter_1;
  37. *sum += *input_1++ * *filter_2;
  38. *sum++ += *input_1++ * *filter_3;
  39. *sum += *input_2++ * *filter_0;
  40. *sum += *input_2++ * *filter_1;
  41. *sum += *input_2++ * *filter_2;
  42. *sum++ += *input_2++ * *filter_3;
  43. *sum += *input_3++ * *filter_0;
  44. *sum += *input_3++ * *filter_1;
  45. *sum += *input_3++ * *filter_2;
  46. *sum++ += *input_3++ * *filter_3;
  47. }
  48. static inline void mac_1row_4col_IOHW_forint8w(q31_t* sum, const q7_t* input_0,
  49. const q7_t* filter_0, const q7_t* filter_1, const q7_t* filter_2, const q7_t* filter_3) {
  50. *sum += *input_0++ * *filter_0;
  51. *sum += *input_0++ * *filter_1;
  52. *sum += *input_0++ * *filter_2;
  53. *sum += *input_0++ * *filter_3;
  54. }
  55. /* END: MAC Functions for Pointwise Conv */
  56. /* START: MAC Functions for Group Conv */
  57. static inline void group_mac_kernel8_4row_uniweight_reuse_output_input(q31_t* sum_0, q31_t* sum_1, q31_t* sum_2, q31_t* sum_3,
  58. const q7_t* input_0, const q7_t* input_1, const q7_t* input_2, const q7_t* input_3,
  59. const q7_t* filter) {
  60. q31_t tmp;
  61. tmp = 0;
  62. tmp += input_0[0] * filter[0];
  63. tmp += input_0[1] * filter[1];
  64. tmp += input_0[2] * filter[2];
  65. tmp += input_0[3] * filter[3];
  66. tmp += input_0[4] * filter[4];
  67. tmp += input_0[5] * filter[5];
  68. tmp += input_0[6] * filter[6];
  69. tmp += input_0[7] * filter[7];
  70. tmp += input_0[8] * filter[8];
  71. tmp += input_0[9] * filter[9];
  72. tmp += input_0[10] * filter[10];
  73. tmp += input_0[11] * filter[11];
  74. tmp += input_0[12] * filter[12];
  75. tmp += input_0[13] * filter[13];
  76. tmp += input_0[14] * filter[14];
  77. tmp += input_0[15] * filter[15];
  78. tmp += input_0[16] * filter[16];
  79. tmp += input_0[17] * filter[17];
  80. tmp += input_0[18] * filter[18];
  81. tmp += input_0[19] * filter[19];
  82. tmp += input_0[20] * filter[20];
  83. tmp += input_0[21] * filter[21];
  84. tmp += input_0[22] * filter[22];
  85. tmp += input_0[23] * filter[23];
  86. tmp += input_0[24] * filter[24];
  87. tmp += input_0[25] * filter[25];
  88. tmp += input_0[26] * filter[26];
  89. tmp += input_0[27] * filter[27];
  90. tmp += input_0[28] * filter[28];
  91. tmp += input_0[29] * filter[29];
  92. tmp += input_0[30] * filter[30];
  93. tmp += input_0[31] * filter[31];
  94. tmp += input_0[32] * filter[32];
  95. tmp += input_0[33] * filter[33];
  96. tmp += input_0[34] * filter[34];
  97. tmp += input_0[35] * filter[35];
  98. tmp += input_0[36] * filter[36];
  99. tmp += input_0[37] * filter[37];
  100. tmp += input_0[38] * filter[38];
  101. tmp += input_0[39] * filter[39];
  102. tmp += input_0[40] * filter[40];
  103. tmp += input_0[41] * filter[41];
  104. tmp += input_0[42] * filter[42];
  105. tmp += input_0[43] * filter[43];
  106. tmp += input_0[44] * filter[44];
  107. tmp += input_0[45] * filter[45];
  108. tmp += input_0[46] * filter[46];
  109. tmp += input_0[47] * filter[47];
  110. tmp += input_0[48] * filter[48];
  111. tmp += input_0[49] * filter[49];
  112. tmp += input_0[50] * filter[50];
  113. tmp += input_0[51] * filter[51];
  114. tmp += input_0[52] * filter[52];
  115. tmp += input_0[53] * filter[53];
  116. tmp += input_0[54] * filter[54];
  117. tmp += input_0[55] * filter[55];
  118. tmp += input_0[56] * filter[56];
  119. tmp += input_0[57] * filter[57];
  120. tmp += input_0[58] * filter[58];
  121. tmp += input_0[59] * filter[59];
  122. tmp += input_0[60] * filter[60];
  123. tmp += input_0[61] * filter[61];
  124. tmp += input_0[62] * filter[62];
  125. tmp += input_0[63] * filter[63];
  126. *sum_0 += tmp;
  127. tmp = 0;
  128. tmp += input_1[0] * filter[0];
  129. tmp += input_1[1] * filter[1];
  130. tmp += input_1[2] * filter[2];
  131. tmp += input_1[3] * filter[3];
  132. tmp += input_1[4] * filter[4];
  133. tmp += input_1[5] * filter[5];
  134. tmp += input_1[6] * filter[6];
  135. tmp += input_1[7] * filter[7];
  136. tmp += input_1[8] * filter[8];
  137. tmp += input_1[9] * filter[9];
  138. tmp += input_1[10] * filter[10];
  139. tmp += input_1[11] * filter[11];
  140. tmp += input_1[12] * filter[12];
  141. tmp += input_1[13] * filter[13];
  142. tmp += input_1[14] * filter[14];
  143. tmp += input_1[15] * filter[15];
  144. tmp += input_1[16] * filter[16];
  145. tmp += input_1[17] * filter[17];
  146. tmp += input_1[18] * filter[18];
  147. tmp += input_1[19] * filter[19];
  148. tmp += input_1[20] * filter[20];
  149. tmp += input_1[21] * filter[21];
  150. tmp += input_1[22] * filter[22];
  151. tmp += input_1[23] * filter[23];
  152. tmp += input_1[24] * filter[24];
  153. tmp += input_1[25] * filter[25];
  154. tmp += input_1[26] * filter[26];
  155. tmp += input_1[27] * filter[27];
  156. tmp += input_1[28] * filter[28];
  157. tmp += input_1[29] * filter[29];
  158. tmp += input_1[30] * filter[30];
  159. tmp += input_1[31] * filter[31];
  160. tmp += input_1[32] * filter[32];
  161. tmp += input_1[33] * filter[33];
  162. tmp += input_1[34] * filter[34];
  163. tmp += input_1[35] * filter[35];
  164. tmp += input_1[36] * filter[36];
  165. tmp += input_1[37] * filter[37];
  166. tmp += input_1[38] * filter[38];
  167. tmp += input_1[39] * filter[39];
  168. tmp += input_1[40] * filter[40];
  169. tmp += input_1[41] * filter[41];
  170. tmp += input_1[42] * filter[42];
  171. tmp += input_1[43] * filter[43];
  172. tmp += input_1[44] * filter[44];
  173. tmp += input_1[45] * filter[45];
  174. tmp += input_1[46] * filter[46];
  175. tmp += input_1[47] * filter[47];
  176. tmp += input_1[48] * filter[48];
  177. tmp += input_1[49] * filter[49];
  178. tmp += input_1[50] * filter[50];
  179. tmp += input_1[51] * filter[51];
  180. tmp += input_1[52] * filter[52];
  181. tmp += input_1[53] * filter[53];
  182. tmp += input_1[54] * filter[54];
  183. tmp += input_1[55] * filter[55];
  184. tmp += input_1[56] * filter[56];
  185. tmp += input_1[57] * filter[57];
  186. tmp += input_1[58] * filter[58];
  187. tmp += input_1[59] * filter[59];
  188. tmp += input_1[60] * filter[60];
  189. tmp += input_1[61] * filter[61];
  190. tmp += input_1[62] * filter[62];
  191. tmp += input_1[63] * filter[63];
  192. *sum_1 += tmp;
  193. tmp = 0;
  194. tmp += input_2[0] * filter[0];
  195. tmp += input_2[1] * filter[1];
  196. tmp += input_2[2] * filter[2];
  197. tmp += input_2[3] * filter[3];
  198. tmp += input_2[4] * filter[4];
  199. tmp += input_2[5] * filter[5];
  200. tmp += input_2[6] * filter[6];
  201. tmp += input_2[7] * filter[7];
  202. tmp += input_2[8] * filter[8];
  203. tmp += input_2[9] * filter[9];
  204. tmp += input_2[10] * filter[10];
  205. tmp += input_2[11] * filter[11];
  206. tmp += input_2[12] * filter[12];
  207. tmp += input_2[13] * filter[13];
  208. tmp += input_2[14] * filter[14];
  209. tmp += input_2[15] * filter[15];
  210. tmp += input_2[16] * filter[16];
  211. tmp += input_2[17] * filter[17];
  212. tmp += input_2[18] * filter[18];
  213. tmp += input_2[19] * filter[19];
  214. tmp += input_2[20] * filter[20];
  215. tmp += input_2[21] * filter[21];
  216. tmp += input_2[22] * filter[22];
  217. tmp += input_2[23] * filter[23];
  218. tmp += input_2[24] * filter[24];
  219. tmp += input_2[25] * filter[25];
  220. tmp += input_2[26] * filter[26];
  221. tmp += input_2[27] * filter[27];
  222. tmp += input_2[28] * filter[28];
  223. tmp += input_2[29] * filter[29];
  224. tmp += input_2[30] * filter[30];
  225. tmp += input_2[31] * filter[31];
  226. tmp += input_2[32] * filter[32];
  227. tmp += input_2[33] * filter[33];
  228. tmp += input_2[34] * filter[34];
  229. tmp += input_2[35] * filter[35];
  230. tmp += input_2[36] * filter[36];
  231. tmp += input_2[37] * filter[37];
  232. tmp += input_2[38] * filter[38];
  233. tmp += input_2[39] * filter[39];
  234. tmp += input_2[40] * filter[40];
  235. tmp += input_2[41] * filter[41];
  236. tmp += input_2[42] * filter[42];
  237. tmp += input_2[43] * filter[43];
  238. tmp += input_2[44] * filter[44];
  239. tmp += input_2[45] * filter[45];
  240. tmp += input_2[46] * filter[46];
  241. tmp += input_2[47] * filter[47];
  242. tmp += input_2[48] * filter[48];
  243. tmp += input_2[49] * filter[49];
  244. tmp += input_2[50] * filter[50];
  245. tmp += input_2[51] * filter[51];
  246. tmp += input_2[52] * filter[52];
  247. tmp += input_2[53] * filter[53];
  248. tmp += input_2[54] * filter[54];
  249. tmp += input_2[55] * filter[55];
  250. tmp += input_2[56] * filter[56];
  251. tmp += input_2[57] * filter[57];
  252. tmp += input_2[58] * filter[58];
  253. tmp += input_2[59] * filter[59];
  254. tmp += input_2[60] * filter[60];
  255. tmp += input_2[61] * filter[61];
  256. tmp += input_2[62] * filter[62];
  257. tmp += input_2[63] * filter[63];
  258. *sum_2 += tmp;
  259. tmp = 0;
  260. tmp += input_3[0] * filter[0];
  261. tmp += input_3[1] * filter[1];
  262. tmp += input_3[2] * filter[2];
  263. tmp += input_3[3] * filter[3];
  264. tmp += input_3[4] * filter[4];
  265. tmp += input_3[5] * filter[5];
  266. tmp += input_3[6] * filter[6];
  267. tmp += input_3[7] * filter[7];
  268. tmp += input_3[8] * filter[8];
  269. tmp += input_3[9] * filter[9];
  270. tmp += input_3[10] * filter[10];
  271. tmp += input_3[11] * filter[11];
  272. tmp += input_3[12] * filter[12];
  273. tmp += input_3[13] * filter[13];
  274. tmp += input_3[14] * filter[14];
  275. tmp += input_3[15] * filter[15];
  276. tmp += input_3[16] * filter[16];
  277. tmp += input_3[17] * filter[17];
  278. tmp += input_3[18] * filter[18];
  279. tmp += input_3[19] * filter[19];
  280. tmp += input_3[20] * filter[20];
  281. tmp += input_3[21] * filter[21];
  282. tmp += input_3[22] * filter[22];
  283. tmp += input_3[23] * filter[23];
  284. tmp += input_3[24] * filter[24];
  285. tmp += input_3[25] * filter[25];
  286. tmp += input_3[26] * filter[26];
  287. tmp += input_3[27] * filter[27];
  288. tmp += input_3[28] * filter[28];
  289. tmp += input_3[29] * filter[29];
  290. tmp += input_3[30] * filter[30];
  291. tmp += input_3[31] * filter[31];
  292. tmp += input_3[32] * filter[32];
  293. tmp += input_3[33] * filter[33];
  294. tmp += input_3[34] * filter[34];
  295. tmp += input_3[35] * filter[35];
  296. tmp += input_3[36] * filter[36];
  297. tmp += input_3[37] * filter[37];
  298. tmp += input_3[38] * filter[38];
  299. tmp += input_3[39] * filter[39];
  300. tmp += input_3[40] * filter[40];
  301. tmp += input_3[41] * filter[41];
  302. tmp += input_3[42] * filter[42];
  303. tmp += input_3[43] * filter[43];
  304. tmp += input_3[44] * filter[44];
  305. tmp += input_3[45] * filter[45];
  306. tmp += input_3[46] * filter[46];
  307. tmp += input_3[47] * filter[47];
  308. tmp += input_3[48] * filter[48];
  309. tmp += input_3[49] * filter[49];
  310. tmp += input_3[50] * filter[50];
  311. tmp += input_3[51] * filter[51];
  312. tmp += input_3[52] * filter[52];
  313. tmp += input_3[53] * filter[53];
  314. tmp += input_3[54] * filter[54];
  315. tmp += input_3[55] * filter[55];
  316. tmp += input_3[56] * filter[56];
  317. tmp += input_3[57] * filter[57];
  318. tmp += input_3[58] * filter[58];
  319. tmp += input_3[59] * filter[59];
  320. tmp += input_3[60] * filter[60];
  321. tmp += input_3[61] * filter[61];
  322. tmp += input_3[62] * filter[62];
  323. tmp += input_3[63] * filter[63];
  324. *sum_3 += tmp;
  325. }
  326. static inline void group_mac_kernel4_4row_uniweight_reuse_output_input(q31_t* sum_0, q31_t* sum_1, q31_t* sum_2, q31_t* sum_3,
  327. const q7_t* input_0, const q7_t* input_1, const q7_t* input_2, const q7_t* input_3,
  328. const q7_t* filter) {
  329. q31_t tmp;
  330. tmp = 0;
  331. tmp += input_0[0] * filter[0];
  332. tmp += input_0[1] * filter[1];
  333. tmp += input_0[2] * filter[2];
  334. tmp += input_0[3] * filter[3];
  335. tmp += input_0[4] * filter[4];
  336. tmp += input_0[5] * filter[5];
  337. tmp += input_0[6] * filter[6];
  338. tmp += input_0[7] * filter[7];
  339. tmp += input_0[8] * filter[8];
  340. tmp += input_0[9] * filter[9];
  341. tmp += input_0[10] * filter[10];
  342. tmp += input_0[11] * filter[11];
  343. tmp += input_0[12] * filter[12];
  344. tmp += input_0[13] * filter[13];
  345. tmp += input_0[14] * filter[14];
  346. tmp += input_0[15] * filter[15];
  347. *sum_0 += tmp;
  348. tmp = 0;
  349. tmp += input_1[0] * filter[0];
  350. tmp += input_1[1] * filter[1];
  351. tmp += input_1[2] * filter[2];
  352. tmp += input_1[3] * filter[3];
  353. tmp += input_1[4] * filter[4];
  354. tmp += input_1[5] * filter[5];
  355. tmp += input_1[6] * filter[6];
  356. tmp += input_1[7] * filter[7];
  357. tmp += input_1[8] * filter[8];
  358. tmp += input_1[9] * filter[9];
  359. tmp += input_1[10] * filter[10];
  360. tmp += input_1[11] * filter[11];
  361. tmp += input_1[12] * filter[12];
  362. tmp += input_1[13] * filter[13];
  363. tmp += input_1[14] * filter[14];
  364. tmp += input_1[15] * filter[15];
  365. *sum_1 += tmp;
  366. tmp = 0;
  367. tmp += input_2[0] * filter[0];
  368. tmp += input_2[1] * filter[1];
  369. tmp += input_2[2] * filter[2];
  370. tmp += input_2[3] * filter[3];
  371. tmp += input_2[4] * filter[4];
  372. tmp += input_2[5] * filter[5];
  373. tmp += input_2[6] * filter[6];
  374. tmp += input_2[7] * filter[7];
  375. tmp += input_2[8] * filter[8];
  376. tmp += input_2[9] * filter[9];
  377. tmp += input_2[10] * filter[10];
  378. tmp += input_2[11] * filter[11];
  379. tmp += input_2[12] * filter[12];
  380. tmp += input_2[13] * filter[13];
  381. tmp += input_2[14] * filter[14];
  382. tmp += input_2[15] * filter[15];
  383. *sum_2 += tmp;
  384. tmp = 0;
  385. tmp += input_3[0] * filter[0];
  386. tmp += input_3[1] * filter[1];
  387. tmp += input_3[2] * filter[2];
  388. tmp += input_3[3] * filter[3];
  389. tmp += input_3[4] * filter[4];
  390. tmp += input_3[5] * filter[5];
  391. tmp += input_3[6] * filter[6];
  392. tmp += input_3[7] * filter[7];
  393. tmp += input_3[8] * filter[8];
  394. tmp += input_3[9] * filter[9];
  395. tmp += input_3[10] * filter[10];
  396. tmp += input_3[11] * filter[11];
  397. tmp += input_3[12] * filter[12];
  398. tmp += input_3[13] * filter[13];
  399. tmp += input_3[14] * filter[14];
  400. tmp += input_3[15] * filter[15];
  401. *sum_3 += tmp;
  402. }
  403. /* END: MAC Functions for Group Conv */
  404. /* START: MAC Functions for Transpose Depthwise Conv */
  405. /* START: For 3x3 kernel size*/
  406. static inline void transpose_depthwise_mac_kernel3_2row_uniweight(q31_t* sum_0, q31_t* sum_1,
  407. const q7_t* im2col_buffer, const q7_t* ksrc_transposed, const uint16_t input_width,
  408. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  409. *sum_0 += im2col_buffer[0] * ksrc_transposed[0];
  410. *sum_1 += im2col_buffer[1] * ksrc_transposed[0];
  411. *sum_0 += im2col_buffer[1] * ksrc_transposed[1];
  412. *sum_1 += im2col_buffer[2] * ksrc_transposed[1];
  413. *sum_0 += im2col_buffer[2] * ksrc_transposed[2];
  414. *sum_1 += im2col_buffer[3] * ksrc_transposed[2];
  415. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  416. *sum_0 += im2col_buffer[0] * ksrc_transposed[3];
  417. *sum_1 += im2col_buffer[1] * ksrc_transposed[3];
  418. *sum_0 += im2col_buffer[1] * ksrc_transposed[4];
  419. *sum_1 += im2col_buffer[2] * ksrc_transposed[4];
  420. *sum_0 += im2col_buffer[2] * ksrc_transposed[5];
  421. *sum_1 += im2col_buffer[3] * ksrc_transposed[5];
  422. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  423. *sum_0 += im2col_buffer[0] * ksrc_transposed[6];
  424. *sum_1 += im2col_buffer[1] * ksrc_transposed[6];
  425. *sum_0 += im2col_buffer[1] * ksrc_transposed[7];
  426. *sum_1 += im2col_buffer[2] * ksrc_transposed[7];
  427. *sum_0 += im2col_buffer[2] * ksrc_transposed[8];
  428. *sum_1 += im2col_buffer[3] * ksrc_transposed[8];
  429. }
  430. static inline void transpose_depthwise_mac_kernel3_1row_uniweight(q31_t* sum_0,
  431. const q7_t* im2col_buffer, const q7_t* ksrc_transposed, const uint16_t input_width,
  432. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  433. *sum_0 += im2col_buffer[0] * ksrc_transposed[0];
  434. *sum_0 += im2col_buffer[1] * ksrc_transposed[1];
  435. *sum_0 += im2col_buffer[2] * ksrc_transposed[2];
  436. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  437. *sum_0 += im2col_buffer[0] * ksrc_transposed[3];
  438. *sum_0 += im2col_buffer[1] * ksrc_transposed[4];
  439. *sum_0 += im2col_buffer[2] * ksrc_transposed[5];
  440. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  441. *sum_0 += im2col_buffer[0] * ksrc_transposed[6];
  442. *sum_0 += im2col_buffer[1] * ksrc_transposed[7];
  443. *sum_0 += im2col_buffer[2] * ksrc_transposed[8];
  444. }
  445. /* END: For 3x3 kernel size*/
  446. /* START: For 5x5 kernel size*/
  447. static inline void transpose_depthwise_mac_kernel5_2row_uniweight(q31_t* sum_0, q31_t* sum_1,
  448. const q7_t* im2col_buffer, const q7_t* ksrc_transposed, const uint16_t input_width,
  449. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  450. *sum_0 += im2col_buffer[0] * ksrc_transposed[0];
  451. *sum_1 += im2col_buffer[1] * ksrc_transposed[0];
  452. *sum_0 += im2col_buffer[1] * ksrc_transposed[1];
  453. *sum_1 += im2col_buffer[2] * ksrc_transposed[1];
  454. *sum_0 += im2col_buffer[2] * ksrc_transposed[2];
  455. *sum_1 += im2col_buffer[3] * ksrc_transposed[2];
  456. *sum_0 += im2col_buffer[3] * ksrc_transposed[3];
  457. *sum_1 += im2col_buffer[4] * ksrc_transposed[3];
  458. *sum_0 += im2col_buffer[4] * ksrc_transposed[4];
  459. *sum_1 += im2col_buffer[5] * ksrc_transposed[4];
  460. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  461. *sum_0 += im2col_buffer[0] * ksrc_transposed[5];
  462. *sum_1 += im2col_buffer[1] * ksrc_transposed[5];
  463. *sum_0 += im2col_buffer[1] * ksrc_transposed[6];
  464. *sum_1 += im2col_buffer[2] * ksrc_transposed[6];
  465. *sum_0 += im2col_buffer[2] * ksrc_transposed[7];
  466. *sum_1 += im2col_buffer[3] * ksrc_transposed[7];
  467. *sum_0 += im2col_buffer[3] * ksrc_transposed[8];
  468. *sum_1 += im2col_buffer[4] * ksrc_transposed[8];
  469. *sum_0 += im2col_buffer[4] * ksrc_transposed[9];
  470. *sum_1 += im2col_buffer[5] * ksrc_transposed[9];
  471. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  472. *sum_0 += im2col_buffer[0] * ksrc_transposed[10];
  473. *sum_1 += im2col_buffer[1] * ksrc_transposed[10];
  474. *sum_0 += im2col_buffer[1] * ksrc_transposed[11];
  475. *sum_1 += im2col_buffer[2] * ksrc_transposed[11];
  476. *sum_0 += im2col_buffer[2] * ksrc_transposed[12];
  477. *sum_1 += im2col_buffer[3] * ksrc_transposed[12];
  478. *sum_0 += im2col_buffer[3] * ksrc_transposed[13];
  479. *sum_1 += im2col_buffer[4] * ksrc_transposed[13];
  480. *sum_0 += im2col_buffer[4] * ksrc_transposed[14];
  481. *sum_1 += im2col_buffer[5] * ksrc_transposed[14];
  482. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  483. *sum_0 += im2col_buffer[0] * ksrc_transposed[15];
  484. *sum_1 += im2col_buffer[1] * ksrc_transposed[15];
  485. *sum_0 += im2col_buffer[1] * ksrc_transposed[16];
  486. *sum_1 += im2col_buffer[2] * ksrc_transposed[16];
  487. *sum_0 += im2col_buffer[2] * ksrc_transposed[17];
  488. *sum_1 += im2col_buffer[3] * ksrc_transposed[17];
  489. *sum_0 += im2col_buffer[3] * ksrc_transposed[18];
  490. *sum_1 += im2col_buffer[4] * ksrc_transposed[18];
  491. *sum_0 += im2col_buffer[4] * ksrc_transposed[19];
  492. *sum_1 += im2col_buffer[5] * ksrc_transposed[19];
  493. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  494. *sum_0 += im2col_buffer[0] * ksrc_transposed[20];
  495. *sum_1 += im2col_buffer[1] * ksrc_transposed[20];
  496. *sum_0 += im2col_buffer[1] * ksrc_transposed[21];
  497. *sum_1 += im2col_buffer[2] * ksrc_transposed[21];
  498. *sum_0 += im2col_buffer[2] * ksrc_transposed[22];
  499. *sum_1 += im2col_buffer[3] * ksrc_transposed[22];
  500. *sum_0 += im2col_buffer[3] * ksrc_transposed[23];
  501. *sum_1 += im2col_buffer[4] * ksrc_transposed[23];
  502. *sum_0 += im2col_buffer[4] * ksrc_transposed[24];
  503. *sum_1 += im2col_buffer[5] * ksrc_transposed[24];
  504. }
  505. static inline void transpose_depthwise_mac_kernel5_1row_uniweight(q31_t* sum_0,
  506. const q7_t* im2col_buffer, const q7_t* ksrc_transposed, const uint16_t input_width,
  507. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  508. *sum_0 += im2col_buffer[0] * ksrc_transposed[0];
  509. *sum_0 += im2col_buffer[1] * ksrc_transposed[1];
  510. *sum_0 += im2col_buffer[2] * ksrc_transposed[2];
  511. *sum_0 += im2col_buffer[3] * ksrc_transposed[3];
  512. *sum_0 += im2col_buffer[4] * ksrc_transposed[4];
  513. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  514. *sum_0 += im2col_buffer[0] * ksrc_transposed[5];
  515. *sum_0 += im2col_buffer[1] * ksrc_transposed[6];
  516. *sum_0 += im2col_buffer[2] * ksrc_transposed[7];
  517. *sum_0 += im2col_buffer[3] * ksrc_transposed[8];
  518. *sum_0 += im2col_buffer[4] * ksrc_transposed[9];
  519. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  520. *sum_0 += im2col_buffer[0] * ksrc_transposed[10];
  521. *sum_0 += im2col_buffer[1] * ksrc_transposed[11];
  522. *sum_0 += im2col_buffer[2] * ksrc_transposed[12];
  523. *sum_0 += im2col_buffer[3] * ksrc_transposed[13];
  524. *sum_0 += im2col_buffer[4] * ksrc_transposed[14];
  525. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  526. *sum_0 += im2col_buffer[0] * ksrc_transposed[15];
  527. *sum_0 += im2col_buffer[1] * ksrc_transposed[16];
  528. *sum_0 += im2col_buffer[2] * ksrc_transposed[17];
  529. *sum_0 += im2col_buffer[3] * ksrc_transposed[18];
  530. *sum_0 += im2col_buffer[4] * ksrc_transposed[19];
  531. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  532. *sum_0 += im2col_buffer[0] * ksrc_transposed[20];
  533. *sum_0 += im2col_buffer[1] * ksrc_transposed[21];
  534. *sum_0 += im2col_buffer[2] * ksrc_transposed[22];
  535. *sum_0 += im2col_buffer[3] * ksrc_transposed[23];
  536. *sum_0 += im2col_buffer[4] * ksrc_transposed[24];
  537. }
  538. /* END: For 5x5 kernel size*/
  539. /* START: For 7x7 kernel size*/
  540. static inline void transpose_depthwise_mac_kernel7_2row_uniweight(q31_t* sum_0, q31_t* sum_1,
  541. const q7_t* im2col_buffer, const q7_t* ksrc_transposed, const uint16_t input_width,
  542. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  543. *sum_0 += im2col_buffer[0] * ksrc_transposed[0];
  544. *sum_1 += im2col_buffer[1] * ksrc_transposed[0];
  545. *sum_0 += im2col_buffer[1] * ksrc_transposed[1];
  546. *sum_1 += im2col_buffer[2] * ksrc_transposed[1];
  547. *sum_0 += im2col_buffer[2] * ksrc_transposed[2];
  548. *sum_1 += im2col_buffer[3] * ksrc_transposed[2];
  549. *sum_0 += im2col_buffer[3] * ksrc_transposed[3];
  550. *sum_1 += im2col_buffer[4] * ksrc_transposed[3];
  551. *sum_0 += im2col_buffer[4] * ksrc_transposed[4];
  552. *sum_1 += im2col_buffer[5] * ksrc_transposed[4];
  553. *sum_0 += im2col_buffer[5] * ksrc_transposed[5];
  554. *sum_1 += im2col_buffer[6] * ksrc_transposed[5];
  555. *sum_0 += im2col_buffer[6] * ksrc_transposed[6];
  556. *sum_1 += im2col_buffer[7] * ksrc_transposed[6];
  557. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  558. *sum_0 += im2col_buffer[0] * ksrc_transposed[7];
  559. *sum_1 += im2col_buffer[1] * ksrc_transposed[7];
  560. *sum_0 += im2col_buffer[1] * ksrc_transposed[8];
  561. *sum_1 += im2col_buffer[2] * ksrc_transposed[8];
  562. *sum_0 += im2col_buffer[2] * ksrc_transposed[9];
  563. *sum_1 += im2col_buffer[3] * ksrc_transposed[9];
  564. *sum_0 += im2col_buffer[3] * ksrc_transposed[10];
  565. *sum_1 += im2col_buffer[4] * ksrc_transposed[10];
  566. *sum_0 += im2col_buffer[4] * ksrc_transposed[11];
  567. *sum_1 += im2col_buffer[5] * ksrc_transposed[11];
  568. *sum_0 += im2col_buffer[5] * ksrc_transposed[12];
  569. *sum_1 += im2col_buffer[6] * ksrc_transposed[12];
  570. *sum_0 += im2col_buffer[6] * ksrc_transposed[13];
  571. *sum_1 += im2col_buffer[7] * ksrc_transposed[13];
  572. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  573. *sum_0 += im2col_buffer[0] * ksrc_transposed[14];
  574. *sum_1 += im2col_buffer[1] * ksrc_transposed[14];
  575. *sum_0 += im2col_buffer[1] * ksrc_transposed[15];
  576. *sum_1 += im2col_buffer[2] * ksrc_transposed[15];
  577. *sum_0 += im2col_buffer[2] * ksrc_transposed[16];
  578. *sum_1 += im2col_buffer[3] * ksrc_transposed[16];
  579. *sum_0 += im2col_buffer[3] * ksrc_transposed[17];
  580. *sum_1 += im2col_buffer[4] * ksrc_transposed[17];
  581. *sum_0 += im2col_buffer[4] * ksrc_transposed[18];
  582. *sum_1 += im2col_buffer[5] * ksrc_transposed[18];
  583. *sum_0 += im2col_buffer[5] * ksrc_transposed[19];
  584. *sum_1 += im2col_buffer[6] * ksrc_transposed[19];
  585. *sum_0 += im2col_buffer[6] * ksrc_transposed[20];
  586. *sum_1 += im2col_buffer[7] * ksrc_transposed[20];
  587. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  588. *sum_0 += im2col_buffer[0] * ksrc_transposed[21];
  589. *sum_1 += im2col_buffer[1] * ksrc_transposed[21];
  590. *sum_0 += im2col_buffer[1] * ksrc_transposed[22];
  591. *sum_1 += im2col_buffer[2] * ksrc_transposed[22];
  592. *sum_0 += im2col_buffer[2] * ksrc_transposed[23];
  593. *sum_1 += im2col_buffer[3] * ksrc_transposed[23];
  594. *sum_0 += im2col_buffer[3] * ksrc_transposed[24];
  595. *sum_1 += im2col_buffer[4] * ksrc_transposed[24];
  596. *sum_0 += im2col_buffer[4] * ksrc_transposed[25];
  597. *sum_1 += im2col_buffer[5] * ksrc_transposed[25];
  598. *sum_0 += im2col_buffer[5] * ksrc_transposed[26];
  599. *sum_1 += im2col_buffer[6] * ksrc_transposed[26];
  600. *sum_0 += im2col_buffer[6] * ksrc_transposed[27];
  601. *sum_1 += im2col_buffer[7] * ksrc_transposed[27];
  602. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  603. *sum_0 += im2col_buffer[0] * ksrc_transposed[28];
  604. *sum_1 += im2col_buffer[1] * ksrc_transposed[28];
  605. *sum_0 += im2col_buffer[1] * ksrc_transposed[29];
  606. *sum_1 += im2col_buffer[2] * ksrc_transposed[29];
  607. *sum_0 += im2col_buffer[2] * ksrc_transposed[30];
  608. *sum_1 += im2col_buffer[3] * ksrc_transposed[30];
  609. *sum_0 += im2col_buffer[3] * ksrc_transposed[31];
  610. *sum_1 += im2col_buffer[4] * ksrc_transposed[31];
  611. *sum_0 += im2col_buffer[4] * ksrc_transposed[32];
  612. *sum_1 += im2col_buffer[5] * ksrc_transposed[32];
  613. *sum_0 += im2col_buffer[5] * ksrc_transposed[33];
  614. *sum_1 += im2col_buffer[6] * ksrc_transposed[33];
  615. *sum_0 += im2col_buffer[6] * ksrc_transposed[34];
  616. *sum_1 += im2col_buffer[7] * ksrc_transposed[34];
  617. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  618. *sum_0 += im2col_buffer[0] * ksrc_transposed[35];
  619. *sum_1 += im2col_buffer[1] * ksrc_transposed[35];
  620. *sum_0 += im2col_buffer[1] * ksrc_transposed[36];
  621. *sum_1 += im2col_buffer[2] * ksrc_transposed[36];
  622. *sum_0 += im2col_buffer[2] * ksrc_transposed[37];
  623. *sum_1 += im2col_buffer[3] * ksrc_transposed[37];
  624. *sum_0 += im2col_buffer[3] * ksrc_transposed[38];
  625. *sum_1 += im2col_buffer[4] * ksrc_transposed[38];
  626. *sum_0 += im2col_buffer[4] * ksrc_transposed[39];
  627. *sum_1 += im2col_buffer[5] * ksrc_transposed[39];
  628. *sum_0 += im2col_buffer[5] * ksrc_transposed[40];
  629. *sum_1 += im2col_buffer[6] * ksrc_transposed[40];
  630. *sum_0 += im2col_buffer[6] * ksrc_transposed[41];
  631. *sum_1 += im2col_buffer[7] * ksrc_transposed[41];
  632. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  633. *sum_0 += im2col_buffer[0] * ksrc_transposed[42];
  634. *sum_1 += im2col_buffer[1] * ksrc_transposed[42];
  635. *sum_0 += im2col_buffer[1] * ksrc_transposed[43];
  636. *sum_1 += im2col_buffer[2] * ksrc_transposed[43];
  637. *sum_0 += im2col_buffer[2] * ksrc_transposed[44];
  638. *sum_1 += im2col_buffer[3] * ksrc_transposed[44];
  639. *sum_0 += im2col_buffer[3] * ksrc_transposed[45];
  640. *sum_1 += im2col_buffer[4] * ksrc_transposed[45];
  641. *sum_0 += im2col_buffer[4] * ksrc_transposed[46];
  642. *sum_1 += im2col_buffer[5] * ksrc_transposed[46];
  643. *sum_0 += im2col_buffer[5] * ksrc_transposed[47];
  644. *sum_1 += im2col_buffer[6] * ksrc_transposed[47];
  645. *sum_0 += im2col_buffer[6] * ksrc_transposed[48];
  646. *sum_1 += im2col_buffer[7] * ksrc_transposed[48];
  647. }
  648. static inline void transpose_depthwise_mac_kernel7_1row_uniweight(q31_t* sum_0,
  649. const q7_t* im2col_buffer, const q7_t* ksrc_transposed, const uint16_t input_width,
  650. const uint16_t STRIDE, const uint16_t IN_PAD, const uint16_t OUT_PAD) {
  651. *sum_0 += im2col_buffer[0] * ksrc_transposed[0];
  652. *sum_0 += im2col_buffer[1] * ksrc_transposed[1];
  653. *sum_0 += im2col_buffer[2] * ksrc_transposed[2];
  654. *sum_0 += im2col_buffer[3] * ksrc_transposed[3];
  655. *sum_0 += im2col_buffer[4] * ksrc_transposed[4];
  656. *sum_0 += im2col_buffer[5] * ksrc_transposed[5];
  657. *sum_0 += im2col_buffer[6] * ksrc_transposed[6];
  658. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  659. *sum_0 += im2col_buffer[0] * ksrc_transposed[7];
  660. *sum_0 += im2col_buffer[1] * ksrc_transposed[8];
  661. *sum_0 += im2col_buffer[2] * ksrc_transposed[9];
  662. *sum_0 += im2col_buffer[3] * ksrc_transposed[10];
  663. *sum_0 += im2col_buffer[4] * ksrc_transposed[11];
  664. *sum_0 += im2col_buffer[5] * ksrc_transposed[12];
  665. *sum_0 += im2col_buffer[6] * ksrc_transposed[13];
  666. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  667. *sum_0 += im2col_buffer[0] * ksrc_transposed[14];
  668. *sum_0 += im2col_buffer[1] * ksrc_transposed[15];
  669. *sum_0 += im2col_buffer[2] * ksrc_transposed[16];
  670. *sum_0 += im2col_buffer[3] * ksrc_transposed[17];
  671. *sum_0 += im2col_buffer[4] * ksrc_transposed[18];
  672. *sum_0 += im2col_buffer[5] * ksrc_transposed[19];
  673. *sum_0 += im2col_buffer[6] * ksrc_transposed[20];
  674. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  675. *sum_0 += im2col_buffer[0] * ksrc_transposed[21];
  676. *sum_0 += im2col_buffer[1] * ksrc_transposed[22];
  677. *sum_0 += im2col_buffer[2] * ksrc_transposed[23];
  678. *sum_0 += im2col_buffer[3] * ksrc_transposed[24];
  679. *sum_0 += im2col_buffer[4] * ksrc_transposed[25];
  680. *sum_0 += im2col_buffer[5] * ksrc_transposed[26];
  681. *sum_0 += im2col_buffer[6] * ksrc_transposed[27];
  682. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  683. *sum_0 += im2col_buffer[0] * ksrc_transposed[28];
  684. *sum_0 += im2col_buffer[1] * ksrc_transposed[29];
  685. *sum_0 += im2col_buffer[2] * ksrc_transposed[30];
  686. *sum_0 += im2col_buffer[3] * ksrc_transposed[31];
  687. *sum_0 += im2col_buffer[4] * ksrc_transposed[32];
  688. *sum_0 += im2col_buffer[5] * ksrc_transposed[33];
  689. *sum_0 += im2col_buffer[6] * ksrc_transposed[34];
  690. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  691. *sum_0 += im2col_buffer[0] * ksrc_transposed[35];
  692. *sum_0 += im2col_buffer[1] * ksrc_transposed[36];
  693. *sum_0 += im2col_buffer[2] * ksrc_transposed[37];
  694. *sum_0 += im2col_buffer[3] * ksrc_transposed[38];
  695. *sum_0 += im2col_buffer[4] * ksrc_transposed[39];
  696. *sum_0 += im2col_buffer[5] * ksrc_transposed[40];
  697. *sum_0 += im2col_buffer[6] * ksrc_transposed[41];
  698. im2col_buffer += (input_width - 1) * STRIDE + 1 + IN_PAD * 2 + OUT_PAD;
  699. *sum_0 += im2col_buffer[0] * ksrc_transposed[42];
  700. *sum_0 += im2col_buffer[1] * ksrc_transposed[43];
  701. *sum_0 += im2col_buffer[2] * ksrc_transposed[44];
  702. *sum_0 += im2col_buffer[3] * ksrc_transposed[45];
  703. *sum_0 += im2col_buffer[4] * ksrc_transposed[46];
  704. *sum_0 += im2col_buffer[5] * ksrc_transposed[47];
  705. *sum_0 += im2col_buffer[6] * ksrc_transposed[48];
  706. }
  707. /* END: For 7x7 kernel size*/
  708. /* END: MAC Functions for Transpose Depthwise Conv */
  709. /* START: Assign Output Functions */
  710. /* START: For Pointwise Conv */
  711. static inline void assign_sum_to_pointwise_tmp_output_buffer_4row8col_int8(q31_t* out_0, q31_t* out_1, q31_t* out_2, q31_t* out_3,
  712. const q31_t* sum) {
  713. *out_0++ += sum[0];
  714. *out_1++ += sum[1];
  715. *out_2++ += sum[2];
  716. *out_3++ += sum[3];
  717. *out_0++ += sum[4];
  718. *out_1++ += sum[5];
  719. *out_2++ += sum[6];
  720. *out_3++ += sum[7];
  721. *out_0++ += sum[8];
  722. *out_1++ += sum[9];
  723. *out_2++ += sum[10];
  724. *out_3++ += sum[11];
  725. *out_0++ += sum[12];
  726. *out_1++ += sum[13];
  727. *out_2++ += sum[14];
  728. *out_3++ += sum[15];
  729. *out_0++ += sum[16];
  730. *out_1++ += sum[17];
  731. *out_2++ += sum[18];
  732. *out_3++ += sum[19];
  733. *out_0++ += sum[20];
  734. *out_1++ += sum[21];
  735. *out_2++ += sum[22];
  736. *out_3++ += sum[23];
  737. *out_0++ += sum[24];
  738. *out_1++ += sum[25];
  739. *out_2++ += sum[26];
  740. *out_3++ += sum[27];
  741. *out_0++ += sum[28];
  742. *out_1++ += sum[29];
  743. *out_2++ += sum[30];
  744. *out_3++ += sum[31];
  745. }
  746. static inline void assign_sum_to_pointwise_tmp_output_buffer_1row8col_int8(q31_t* out_0, const q31_t* sum) {
  747. *out_0++ += sum[0];
  748. *out_0++ += sum[1];
  749. *out_0++ += sum[2];
  750. *out_0++ += sum[3];
  751. *out_0++ += sum[4];
  752. *out_0++ += sum[5];
  753. *out_0++ += sum[6];
  754. *out_0++ += sum[7];
  755. }
  756. static inline void assign_sum_to_pointwise_tmp_output_buffer_4row4col_int8(q31_t* out_0, q31_t* out_1, q31_t* out_2, q31_t* out_3,
  757. const q31_t* sum) {
  758. *out_0++ += sum[0];
  759. *out_1++ += sum[1];
  760. *out_2++ += sum[2];
  761. *out_3++ += sum[3];
  762. *out_0++ += sum[4];
  763. *out_1++ += sum[5];
  764. *out_2++ += sum[6];
  765. *out_3++ += sum[7];
  766. *out_0++ += sum[8];
  767. *out_1++ += sum[9];
  768. *out_2++ += sum[10];
  769. *out_3++ += sum[11];
  770. *out_0++ += sum[12];
  771. *out_1++ += sum[13];
  772. *out_2++ += sum[14];
  773. *out_3++ += sum[15];
  774. }
  775. static inline void assign_sum_to_pointwise_tmp_output_buffer_1row4col_int8(q31_t* out_0, const q31_t* sum) {
  776. *out_0++ += sum[0];
  777. *out_0++ += sum[1];
  778. *out_0++ += sum[2];
  779. *out_0++ += sum[3];
  780. }
  781. /* END: For Pointwise Conv */
  782. /* START: For Group Conv */
  783. static inline void assign_sum_to_group_tmp_output_buffer_4row8col_int8(q31_t* out_0, q31_t* out_1, q31_t* out_2, q31_t* out_3,
  784. q31_t* out_4, q31_t* out_5, q31_t* out_6, q31_t* out_7,
  785. const q31_t* sum_0, const q31_t* sum_1, const q31_t* sum_2, const q31_t* sum_3,
  786. q31_t* out_max_0, q31_t* out_max_1, q31_t* out_max_2, q31_t* out_max_3,
  787. q31_t* out_max_4, q31_t* out_max_5, q31_t* out_max_6, q31_t* out_max_7) {
  788. *out_0++ = sum_0[0];
  789. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_0[0]));
  790. *out_1++ = sum_0[1];
  791. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_0[1]));
  792. *out_2++ = sum_0[2];
  793. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_0[2]));
  794. *out_3++ = sum_0[3];
  795. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_0[3]));
  796. *out_4++ = sum_0[4];
  797. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_0[4]));
  798. *out_5++ = sum_0[5];
  799. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_0[5]));
  800. *out_6++ = sum_0[6];
  801. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_0[6]));
  802. *out_7++ = sum_0[7];
  803. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_0[7]));
  804. *out_0++ = sum_1[0];
  805. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_1[0]));
  806. *out_1++ = sum_1[1];
  807. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_1[1]));
  808. *out_2++ = sum_1[2];
  809. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_1[2]));
  810. *out_3++ = sum_1[3];
  811. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_1[3]));
  812. *out_4++ = sum_1[4];
  813. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_1[4]));
  814. *out_5++ = sum_1[5];
  815. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_1[5]));
  816. *out_6++ = sum_1[6];
  817. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_1[6]));
  818. *out_7++ = sum_1[7];
  819. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_1[7]));
  820. *out_0++ = sum_2[0];
  821. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_2[0]));
  822. *out_1++ = sum_2[1];
  823. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_2[1]));
  824. *out_2++ = sum_2[2];
  825. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_2[2]));
  826. *out_3++ = sum_2[3];
  827. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_2[3]));
  828. *out_4++ = sum_2[4];
  829. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_2[4]));
  830. *out_5++ = sum_2[5];
  831. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_2[5]));
  832. *out_6++ = sum_2[6];
  833. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_2[6]));
  834. *out_7++ = sum_2[7];
  835. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_2[7]));
  836. *out_0++ = sum_3[0];
  837. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_3[0]));
  838. *out_1++ = sum_3[1];
  839. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_3[1]));
  840. *out_2++ = sum_3[2];
  841. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_3[2]));
  842. *out_3++ = sum_3[3];
  843. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_3[3]));
  844. *out_4++ = sum_3[4];
  845. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_3[4]));
  846. *out_5++ = sum_3[5];
  847. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_3[5]));
  848. *out_6++ = sum_3[6];
  849. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_3[6]));
  850. *out_7++ = sum_3[7];
  851. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_3[7]));
  852. }
  853. static inline void assign_sum_to_group_tmp_output_buffer_4row16col_int8(q31_t* out_0, q31_t* out_1, q31_t* out_2, q31_t* out_3, q31_t* out_4, q31_t* out_5,
  854. q31_t* out_6, q31_t* out_7, q31_t* out_8, q31_t* out_9, q31_t* out_10, q31_t* out_11, q31_t* out_12, q31_t* out_13, q31_t* out_14, q31_t* out_15,
  855. const q31_t* sum_0, const q31_t* sum_1, const q31_t* sum_2, const q31_t* sum_3,
  856. q31_t* out_max_0, q31_t* out_max_1, q31_t* out_max_2, q31_t* out_max_3, q31_t* out_max_4, q31_t* out_max_5, q31_t* out_max_6, q31_t* out_max_7,
  857. q31_t* out_max_8, q31_t* out_max_9, q31_t* out_max_10, q31_t* out_max_11, q31_t* out_max_12, q31_t* out_max_13, q31_t* out_max_14, q31_t* out_max_15) {
  858. *out_0++ = sum_0[0];
  859. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_0[0]));
  860. *out_1++ = sum_0[1];
  861. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_0[1]));
  862. *out_2++ = sum_0[2];
  863. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_0[2]));
  864. *out_3++ = sum_0[3];
  865. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_0[3]));
  866. *out_4++ = sum_0[4];
  867. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_0[4]));
  868. *out_5++ = sum_0[5];
  869. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_0[5]));
  870. *out_6++ = sum_0[6];
  871. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_0[6]));
  872. *out_7++ = sum_0[7];
  873. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_0[7]));
  874. *out_8++ = sum_0[8];
  875. *out_max_8 = TN_MAX(abs(*out_max_8), abs(sum_0[8]));
  876. *out_9++ = sum_0[9];
  877. *out_max_9 = TN_MAX(abs(*out_max_9), abs(sum_0[9]));
  878. *out_10++ = sum_0[10];
  879. *out_max_10 = TN_MAX(abs(*out_max_10), abs(sum_0[10]));
  880. *out_11++ = sum_0[11];
  881. *out_max_11 = TN_MAX(abs(*out_max_11), abs(sum_0[11]));
  882. *out_12++ = sum_0[12];
  883. *out_max_12 = TN_MAX(abs(*out_max_12), abs(sum_0[12]));
  884. *out_13++ = sum_0[13];
  885. *out_max_13 = TN_MAX(abs(*out_max_13), abs(sum_0[13]));
  886. *out_14++ = sum_0[14];
  887. *out_max_14 = TN_MAX(abs(*out_max_14), abs(sum_0[14]));
  888. *out_15++ = sum_0[15];
  889. *out_max_15 = TN_MAX(abs(*out_max_15), abs(sum_0[15]));
  890. *out_0++ = sum_1[0];
  891. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_1[0]));
  892. *out_1++ = sum_1[1];
  893. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_1[1]));
  894. *out_2++ = sum_1[2];
  895. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_1[2]));
  896. *out_3++ = sum_1[3];
  897. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_1[3]));
  898. *out_4++ = sum_1[4];
  899. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_1[4]));
  900. *out_5++ = sum_1[5];
  901. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_1[5]));
  902. *out_6++ = sum_1[6];
  903. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_1[6]));
  904. *out_7++ = sum_1[7];
  905. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_1[7]));
  906. *out_8++ = sum_1[8];
  907. *out_max_8 = TN_MAX(abs(*out_max_8), abs(sum_1[8]));
  908. *out_9++ = sum_1[9];
  909. *out_max_9 = TN_MAX(abs(*out_max_9), abs(sum_1[9]));
  910. *out_10++ = sum_1[10];
  911. *out_max_10 = TN_MAX(abs(*out_max_10), abs(sum_1[10]));
  912. *out_11++ = sum_1[11];
  913. *out_max_11 = TN_MAX(abs(*out_max_11), abs(sum_1[11]));
  914. *out_12++ = sum_1[12];
  915. *out_max_12 = TN_MAX(abs(*out_max_12), abs(sum_1[12]));
  916. *out_13++ = sum_1[13];
  917. *out_max_13 = TN_MAX(abs(*out_max_13), abs(sum_1[13]));
  918. *out_14++ = sum_1[14];
  919. *out_max_14 = TN_MAX(abs(*out_max_14), abs(sum_1[14]));
  920. *out_15++ = sum_1[15];
  921. *out_max_15 = TN_MAX(abs(*out_max_15), abs(sum_1[15]));
  922. *out_0++ = sum_2[0];
  923. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_2[0]));
  924. *out_1++ = sum_2[1];
  925. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_2[1]));
  926. *out_2++ = sum_2[2];
  927. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_2[2]));
  928. *out_3++ = sum_2[3];
  929. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_2[3]));
  930. *out_4++ = sum_2[4];
  931. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_2[4]));
  932. *out_5++ = sum_2[5];
  933. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_2[5]));
  934. *out_6++ = sum_2[6];
  935. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_2[6]));
  936. *out_7++ = sum_2[7];
  937. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_2[7]));
  938. *out_8++ = sum_2[8];
  939. *out_max_8 = TN_MAX(abs(*out_max_8), abs(sum_2[8]));
  940. *out_9++ = sum_2[9];
  941. *out_max_9 = TN_MAX(abs(*out_max_9), abs(sum_2[9]));
  942. *out_10++ = sum_2[10];
  943. *out_max_10 = TN_MAX(abs(*out_max_10), abs(sum_2[10]));
  944. *out_11++ = sum_2[11];
  945. *out_max_11 = TN_MAX(abs(*out_max_11), abs(sum_2[11]));
  946. *out_12++ = sum_2[12];
  947. *out_max_12 = TN_MAX(abs(*out_max_12), abs(sum_2[12]));
  948. *out_13++ = sum_2[13];
  949. *out_max_13 = TN_MAX(abs(*out_max_13), abs(sum_2[13]));
  950. *out_14++ = sum_2[14];
  951. *out_max_14 = TN_MAX(abs(*out_max_14), abs(sum_2[14]));
  952. *out_15++ = sum_2[15];
  953. *out_max_15 = TN_MAX(abs(*out_max_15), abs(sum_2[15]));
  954. *out_0++ = sum_3[0];
  955. *out_max_0 = TN_MAX(abs(*out_max_0), abs(sum_3[0]));
  956. *out_1++ = sum_3[1];
  957. *out_max_1 = TN_MAX(abs(*out_max_1), abs(sum_3[1]));
  958. *out_2++ = sum_3[2];
  959. *out_max_2 = TN_MAX(abs(*out_max_2), abs(sum_3[2]));
  960. *out_3++ = sum_3[3];
  961. *out_max_3 = TN_MAX(abs(*out_max_3), abs(sum_3[3]));
  962. *out_4++ = sum_3[4];
  963. *out_max_4 = TN_MAX(abs(*out_max_4), abs(sum_3[4]));
  964. *out_5++ = sum_3[5];
  965. *out_max_5 = TN_MAX(abs(*out_max_5), abs(sum_3[5]));
  966. *out_6++ = sum_3[6];
  967. *out_max_6 = TN_MAX(abs(*out_max_6), abs(sum_3[6]));
  968. *out_7++ = sum_3[7];
  969. *out_max_7 = TN_MAX(abs(*out_max_7), abs(sum_3[7]));
  970. *out_8++ = sum_3[8];
  971. *out_max_8 = TN_MAX(abs(*out_max_8), abs(sum_3[8]));
  972. *out_9++ = sum_3[9];
  973. *out_max_9 = TN_MAX(abs(*out_max_9), abs(sum_3[9]));
  974. *out_10++ = sum_3[10];
  975. *out_max_10 = TN_MAX(abs(*out_max_10), abs(sum_3[10]));
  976. *out_11++ = sum_3[11];
  977. *out_max_11 = TN_MAX(abs(*out_max_11), abs(sum_3[11]));
  978. *out_12++ = sum_3[12];
  979. *out_max_12 = TN_MAX(abs(*out_max_12), abs(sum_3[12]));
  980. *out_13++ = sum_3[13];
  981. *out_max_13 = TN_MAX(abs(*out_max_13), abs(sum_3[13]));
  982. *out_14++ = sum_3[14];
  983. *out_max_14 = TN_MAX(abs(*out_max_14), abs(sum_3[14]));
  984. *out_15++ = sum_3[15];
  985. *out_max_15 = TN_MAX(abs(*out_max_15), abs(sum_3[15]));
  986. }
  987. /* END: For Group Conv */
  988. /* END: Assign Output Functions */