kernel_element.h 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. /* ----------------------------------------------------------------------
  2. * Project: TinyEngine
  3. * Title: kernel_element.h
  4. *
  5. * Reference papers:
  6. * - MCUNet: Tiny Deep Learning on IoT Device, NeurIPS 2020
  7. * - MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning, NeurIPS 2021
  8. * - MCUNetV3: On-Device Training Under 256KB Memory, NeurIPS 2022
  9. * Contact authors:
  10. * - Wei-Ming Chen, wmchen@mit.edu
  11. * - Wei-Chen Wang, wweichen@mit.edu
  12. * - Ji Lin, jilin@mit.edu
  13. * - Ligeng Zhu, ligeng@mit.edu
  14. * - Song Han, songhan@mit.edu
  15. *
  16. * Target ISA: ARMv7E-M
  17. * -------------------------------------------------------------------- */
  18. #ifndef ARMNN_INCLUDE_KERNEL_ELEMENT_H_
  19. #define ARMNN_INCLUDE_KERNEL_ELEMENT_H_
  20. #include "mutable_function.h"
  21. #include "precision_cnt.h"
  22. #define loop_ele_ext() \
  23. sum = __SMLAD(col32[0], k_buf1[0], sum); \
  24. sum_2 = __SMLAD(col32[1], k_buf1[1], sum_2); \
  25. sum_3 = __SMLAD(col32[2], k_buf1[2], sum_3); \
  26. sum_4 = __SMLAD(col32[3], k_buf1[3], sum_4); \
  27. col32 += 4;\
  28. k_buf1 += 4; \
  29. #define loop_ele() \
  30. op_a = hpm_nn_read_q15x2(col_pos); \
  31. op_b = hpm_nn_read_q15x2(col_pos + input_ch); \
  32. \
  33. op_c = __PKHBT(op_b, op_a, 16); \
  34. op_a = __PKHTB(op_b, op_a, 16); \
  35. sum = __SMLAD(op_c, k_buf1[0], sum); \
  36. sum_2 = __SMLAD(op_a, k_buf1[q32_elements], sum_2); \
  37. \
  38. op_a = hpm_nn_read_q15x2(col_pos + 2); \
  39. op_b = hpm_nn_read_q15x2(col_pos + input_ch + 2); \
  40. \
  41. op_c = __PKHBT(op_b, op_a, 16); \
  42. op_a = __PKHTB(op_b, op_a, 16); \
  43. sum_3 = __SMLAD(op_c, k_buf1[q32_elements*2], sum_3); \
  44. sum_4 = __SMLAD(op_a, k_buf1[q32_elements*3], sum_4); \
  45. \
  46. col_pos += two_inch; \
  47. k_buf1++;
  48. /* end of loop_ele() */
  49. #define prepare_loops()\
  50. q7_t *out_1 = out + output_ch / output_scaler;\
  51. const int32_t *out_shift = output_shift;\
  52. const int32_t *out_mult = output_mult;\
  53. const int32_t *obias = bias;\
  54. uint16_t row_count = output_ch / 2;\
  55. q31_t *ksrc = &kbuf[0];\
  56. /* end of prepare_loops() */
  57. #define conv_1stloop_ele()\
  58. q31_t ch_0_out_0 = *obias;\
  59. q31_t ch_0_out_1 = *obias++;\
  60. q31_t ch_1_out_0 = *obias;\
  61. q31_t ch_1_out_1 = *obias++;\
  62. q31_t b0 = hpm_nn_read_q15x2_ia((const q15_t **)&ip_b0);\
  63. q31_t b1 = hpm_nn_read_q15x2_ia((const q15_t **)&ip_b1);\
  64. ch_0_out_0 = __SMLAD(*ksrc, b0, ch_0_out_0);\
  65. ch_0_out_1 = __SMLAD(*ksrc++, b1, ch_0_out_1);\
  66. ch_1_out_0 = __SMLAD(*ksrc2, b0, ch_1_out_0);\
  67. b0 = hpm_nn_read_q15x2_ia((const q15_t **)&ip_b0);\
  68. ch_1_out_1 = __SMLAD(*ksrc2++, b1, ch_1_out_1);\
  69. /* end of conv_1stloop_ele */
  70. #define conv_lastloop_ele()\
  71. b1 = hpm_nn_read_q15x2_ia((const q15_t **)&ip_b1);\
  72. \
  73. ch_0_out_0 = __SMLAD(*ksrc, b0, ch_0_out_0);\
  74. ch_0_out_1 = __SMLAD(*ksrc++, b1, ch_0_out_1);\
  75. ch_1_out_0 = __SMLAD(*ksrc2, b0, ch_1_out_0);\
  76. ch_1_out_1 = __SMLAD(*ksrc2++, b1, ch_1_out_1);\
  77. \
  78. ksrc = ksrc2;\
  79. /* end of conv_lastloop_ele */
  80. #define conv_midloop_ele(k_index) \
  81. b1 = hpm_nn_read_q15x2_ia(&ip_b1);\
  82. ch_0_out_0 = __SMLAD(ksrc[k_index], b0, ch_0_out_0);\
  83. ch_0_out_1 = __SMLAD(ksrc[k_index], b1, ch_0_out_1);\
  84. ch_1_out_0 = __SMLAD(ksrc2[k_index], b0, ch_1_out_0);\
  85. b0 = hpm_nn_read_q15x2_ia(&ip_b0);\
  86. ch_1_out_1 = __SMLAD(ksrc2[k_index], b1, ch_1_out_1);\
  87. /* end of conv_midloop_ele */
  88. #define conv_midloop_ptrele() \
  89. b1 = hpm_nn_read_q15x2_ia((const q15_t **)&ip_b1);\
  90. ch_0_out_0 = __SMLAD(*ksrc, b0, ch_0_out_0);\
  91. ch_0_out_1 = __SMLAD(*ksrc++, b1, ch_0_out_1);\
  92. ch_1_out_0 = __SMLAD(*ksrc2, b0, ch_1_out_0);\
  93. b0 = hpm_nn_read_q15x2_ia((const q15_t **)&ip_b0);\
  94. ch_1_out_1 = __SMLAD(*ksrc2++, b1, ch_1_out_1);\
  95. /* end of conv_midloop_ele */
  96. #define unroll_8inch()\
  97. prepare_loops();\
  98. while (row_count) {\
  99. const q15_t *ip_b0 = two_column_buffer;\
  100. const q15_t *ip_b1 = ip_b0 + 8;\
  101. q31_t *ksrc2 = ksrc + 4;\
  102. conv_1stloop_ele()\
  103. conv_midloop_ptrele()\
  104. conv_midloop_ptrele()\
  105. conv_lastloop_ele()\
  106. mix_assign_requantize()\
  107. row_count--;\
  108. }\
  109. /* Specialized Loop Unrolling */
  110. //this can be selected for different models
  111. #define unroll_8inch()\
  112. prepare_loops();\
  113. while (row_count) {\
  114. const q15_t *ip_b0 = two_column_buffer;\
  115. const q15_t *ip_b1 = ip_b0 + 8;\
  116. q31_t *ksrc2 = ksrc + 4;\
  117. conv_1stloop_ele()\
  118. conv_midloop_ptrele()\
  119. conv_midloop_ptrele()\
  120. conv_lastloop_ele()\
  121. mix_assign_requantize()\
  122. row_count--;\
  123. }\
  124. #define unroll_12inch()\
  125. prepare_loops();\
  126. while (row_count) {\
  127. const q15_t *ip_b0 = two_column_buffer;\
  128. const q15_t *ip_b1 = ip_b0 + 12;\
  129. q31_t *ksrc2 = ksrc + 6;\
  130. conv_1stloop_ele()\
  131. conv_midloop_ptrele()\
  132. conv_midloop_ptrele()\
  133. conv_midloop_ptrele()\
  134. conv_midloop_ptrele()\
  135. conv_lastloop_ele()\
  136. mix_assign_requantize()\
  137. row_count--;\
  138. }\
  139. #define unroll_16inch()\
  140. prepare_loops();\
  141. while (row_count) {\
  142. const q15_t *ip_b0 = two_column_buffer;\
  143. const q15_t *ip_b1 = ip_b0 + 16;\
  144. q31_t *ksrc2 = ksrc + 8;\
  145. conv_1stloop_ele()\
  146. conv_midloop_ptrele()\
  147. conv_midloop_ptrele()\
  148. conv_midloop_ptrele()\
  149. conv_midloop_ptrele()\
  150. conv_midloop_ptrele()\
  151. conv_midloop_ptrele()\
  152. conv_lastloop_ele()\
  153. mix_assign_requantize()\
  154. row_count--;\
  155. }\
  156. #define unroll_20inch()\
  157. prepare_loops();\
  158. while (row_count) {\
  159. const q15_t *ip_b0 = two_column_buffer;\
  160. const q15_t *ip_b1 = ip_b0 + 20;\
  161. q31_t *ksrc2 = ksrc + 10;\
  162. conv_1stloop_ele()\
  163. conv_midloop_ptrele()\
  164. conv_midloop_ptrele()\
  165. conv_midloop_ptrele()\
  166. conv_midloop_ptrele()\
  167. conv_midloop_ptrele()\
  168. conv_midloop_ptrele()\
  169. conv_midloop_ptrele()\
  170. conv_midloop_ptrele()\
  171. conv_lastloop_ele()\
  172. mix_assign_requantize()\
  173. row_count--;\
  174. }\
  175. #define unroll_24inch()\
  176. prepare_loops();\
  177. while (row_count) {\
  178. const q15_t *ip_b0 = two_column_buffer;\
  179. const q15_t *ip_b1 = ip_b0 + 24;\
  180. q31_t *ksrc2 = ksrc + 12;\
  181. conv_1stloop_ele()\
  182. conv_midloop_ptrele()\
  183. conv_midloop_ptrele()\
  184. conv_midloop_ptrele()\
  185. conv_midloop_ptrele()\
  186. conv_midloop_ptrele()\
  187. conv_midloop_ptrele()\
  188. conv_midloop_ptrele()\
  189. conv_midloop_ptrele()\
  190. conv_midloop_ptrele()\
  191. conv_midloop_ptrele()\
  192. conv_lastloop_ele()\
  193. mix_assign_requantize()\
  194. row_count--;\
  195. }\
  196. #define unroll_32inch()\
  197. prepare_loops();\
  198. while (row_count) {\
  199. const q15_t *ip_b0 = two_column_buffer;\
  200. const q15_t *ip_b1 = ip_b0 + 32;\
  201. q31_t *ksrc2 = ksrc + 16;\
  202. conv_1stloop_ele()\
  203. conv_midloop_ptrele()\
  204. conv_midloop_ptrele()\
  205. conv_midloop_ptrele()\
  206. conv_midloop_ptrele()\
  207. conv_midloop_ptrele()\
  208. conv_midloop_ptrele()\
  209. conv_midloop_ptrele()\
  210. conv_midloop_ptrele()\
  211. conv_midloop_ptrele()\
  212. conv_midloop_ptrele()\
  213. conv_midloop_ptrele()\
  214. conv_midloop_ptrele()\
  215. conv_midloop_ptrele()\
  216. conv_midloop_ptrele()\
  217. conv_lastloop_ele()\
  218. mix_assign_requantize()\
  219. row_count--;\
  220. }\
  221. #define unroll_36inch()\
  222. prepare_loops();\
  223. while (row_count) {\
  224. const q15_t *ip_b0 = two_column_buffer;\
  225. const q15_t *ip_b1 = ip_b0 + 36;\
  226. q31_t *ksrc2 = ksrc + 18;\
  227. conv_1stloop_ele()\
  228. conv_midloop_ptrele()\
  229. conv_midloop_ptrele()\
  230. conv_midloop_ptrele()\
  231. conv_midloop_ptrele()\
  232. conv_midloop_ptrele()\
  233. conv_midloop_ptrele()\
  234. conv_midloop_ptrele()\
  235. conv_midloop_ptrele()\
  236. conv_midloop_ptrele()\
  237. conv_midloop_ptrele()\
  238. conv_midloop_ptrele()\
  239. conv_midloop_ptrele()\
  240. conv_midloop_ptrele()\
  241. conv_midloop_ptrele()\
  242. conv_midloop_ptrele()\
  243. conv_midloop_ptrele()\
  244. conv_lastloop_ele()\
  245. mix_assign_requantize()\
  246. row_count--;\
  247. }\
  248. #define unroll_40inch()\
  249. prepare_loops();\
  250. while (row_count) {\
  251. const q15_t *ip_b0 = two_column_buffer;\
  252. const q15_t *ip_b1 = ip_b0 + 40;\
  253. q31_t *ksrc2 = ksrc + 20;\
  254. conv_1stloop_ele()\
  255. conv_midloop_ptrele()\
  256. conv_midloop_ptrele()\
  257. conv_midloop_ptrele()\
  258. conv_midloop_ptrele()\
  259. conv_midloop_ptrele()\
  260. conv_midloop_ptrele()\
  261. conv_midloop_ptrele()\
  262. conv_midloop_ptrele()\
  263. conv_midloop_ptrele()\
  264. conv_midloop_ptrele()\
  265. conv_midloop_ptrele()\
  266. conv_midloop_ptrele()\
  267. conv_midloop_ptrele()\
  268. conv_midloop_ptrele()\
  269. conv_midloop_ptrele()\
  270. conv_midloop_ptrele()\
  271. conv_midloop_ptrele()\
  272. conv_midloop_ptrele()\
  273. conv_lastloop_ele()\
  274. mix_assign_requantize()\
  275. row_count--;\
  276. }\
  277. #define unroll_48inch()\
  278. prepare_loops();\
  279. while (row_count) {\
  280. const q15_t *ip_b0 = two_column_buffer;\
  281. const q15_t *ip_b1 = ip_b0 + 48;\
  282. q31_t *ksrc2 = ksrc + 24;\
  283. conv_1stloop_ele()\
  284. conv_midloop_ptrele()\
  285. conv_midloop_ptrele()\
  286. conv_midloop_ptrele()\
  287. conv_midloop_ptrele()\
  288. conv_midloop_ptrele()\
  289. conv_midloop_ptrele()\
  290. conv_midloop_ptrele()\
  291. conv_midloop_ptrele()\
  292. conv_midloop_ptrele()\
  293. conv_midloop_ptrele()\
  294. conv_midloop_ptrele()\
  295. conv_midloop_ptrele()\
  296. conv_midloop_ptrele()\
  297. conv_midloop_ptrele()\
  298. conv_midloop_ptrele()\
  299. conv_midloop_ptrele()\
  300. conv_midloop_ptrele()\
  301. conv_midloop_ptrele()\
  302. conv_midloop_ptrele()\
  303. conv_midloop_ptrele()\
  304. conv_midloop_ptrele()\
  305. conv_midloop_ptrele()\
  306. conv_lastloop_ele()\
  307. mix_assign_requantize()\
  308. row_count--;\
  309. }\
  310. /* END: Specialized Loop Unrolling */
  311. #define b2_assign_requantize() \
  312. ch_0_out_0 = hpm_nn_requantize(ch_0_out_0, *out_mult,*out_shift);\
  313. ch_0_out_0 += out_offset;\
  314. ch_0_out_0 = MAX(ch_0_out_0, out_activation_min);\
  315. ch_0_out_0 = MIN(ch_0_out_0, out_activation_max);\
  316. \
  317. ch_0_out_1 = hpm_nn_requantize(ch_0_out_1, *out_mult,*out_shift);\
  318. ch_0_out_1 += out_offset;\
  319. ch_0_out_1 = MAX(ch_0_out_1, out_activation_min);\
  320. ch_0_out_1 = MIN(ch_0_out_1, out_activation_max);\
  321. out_mult++;\
  322. out_shift++;\
  323. ch_1_out_0 = hpm_nn_requantize(ch_1_out_0, *out_mult,*out_shift);\
  324. ch_1_out_0 += out_offset;\
  325. ch_1_out_0 = MAX(ch_1_out_0, out_activation_min);\
  326. ch_1_out_0 = MIN(ch_1_out_0, out_activation_max);\
  327. ch_1_out_1 = hpm_nn_requantize(ch_1_out_1, *out_mult,*out_shift);\
  328. ch_1_out_1 += out_offset;\
  329. ch_1_out_1 = MAX(ch_1_out_1, out_activation_min);\
  330. ch_1_out_1 = MIN(ch_1_out_1, out_activation_max);\
  331. if(lower_bit == 1){\
  332. *out = (q7_t) ((ch_0_out_0 & 0x03) + ((ch_1_out_0 & 0x03) << 2));\
  333. *out_1 = (q7_t) ((ch_0_out_0 & 0x03) + ((ch_1_out_1 & 0x03) << 2));\
  334. lower_bit = 3;\
  335. }\
  336. else{\
  337. *out++ += (q7_t) (((ch_0_out_0 & 0x03) + ((ch_1_out_0 & 0x03) << 2)) << 4);\
  338. *out_1++ += (q7_t) (((ch_0_out_1 & 0x03) + ((ch_1_out_1 & 0x03) << 2)) << 4);\
  339. lower_bit = 1;\
  340. }\
  341. out_mult++;\
  342. out_shift++;\
  343. #define b4_assign_requantize() \
  344. ch_0_out_0 = hpm_nn_requantize(ch_0_out_0, *out_mult,*out_shift);\
  345. ch_0_out_0 += out_offset;\
  346. ch_0_out_0 = MAX(ch_0_out_0, out_activation_min);\
  347. ch_0_out_0 = MIN(ch_0_out_0, out_activation_max);\
  348. \
  349. ch_0_out_1 = hpm_nn_requantize(ch_0_out_1, *out_mult,*out_shift);\
  350. ch_0_out_1 += out_offset;\
  351. ch_0_out_1 = MAX(ch_0_out_1, out_activation_min);\
  352. ch_0_out_1 = MIN(ch_0_out_1, out_activation_max);\
  353. out_mult++;\
  354. out_shift++;\
  355. ch_1_out_0 = hpm_nn_requantize(ch_1_out_0, *out_mult,*out_shift);\
  356. ch_1_out_0 += out_offset;\
  357. ch_1_out_0 = MAX(ch_1_out_0, out_activation_min);\
  358. ch_1_out_0 = MIN(ch_1_out_0, out_activation_max);\
  359. *out++ = (q7_t) ((ch_0_out_0 & 0x0F) + ((ch_1_out_0 & 0x0F) << 4));\
  360. ch_1_out_1 = hpm_nn_requantize(ch_1_out_1, *out_mult,*out_shift);\
  361. ch_1_out_1 += out_offset;\
  362. ch_1_out_1 = MAX(ch_1_out_1, out_activation_min);\
  363. ch_1_out_1 = MIN(ch_1_out_1, out_activation_max);\
  364. *out_1++ = (q7_t) ((ch_0_out_1 & 0x0F) + ((ch_1_out_1 & 0x0F) << 4));\
  365. out_mult++;\
  366. out_shift++;\
  367. #define assign_requantize() \
  368. ch_0_out_0 = hpm_nn_requantize(ch_0_out_0, *out_mult,*out_shift);\
  369. ch_0_out_0 += out_offset;\
  370. ch_0_out_0 = MAX(ch_0_out_0, out_activation_min);\
  371. ch_0_out_0 = MIN(ch_0_out_0, out_activation_max);\
  372. *out++ = (q7_t) ch_0_out_0;\
  373. \
  374. ch_0_out_1 = hpm_nn_requantize(ch_0_out_1, *out_mult,*out_shift);\
  375. ch_0_out_1 += out_offset;\
  376. ch_0_out_1 = MAX(ch_0_out_1, out_activation_min);\
  377. ch_0_out_1 = MIN(ch_0_out_1, out_activation_max);\
  378. *out_1++ = (q7_t) ch_0_out_1;\
  379. out_mult++;\
  380. out_shift++;\
  381. ch_1_out_0 = hpm_nn_requantize(ch_1_out_0, *out_mult,*out_shift);\
  382. ch_1_out_0 += out_offset;\
  383. ch_1_out_0 = MAX(ch_1_out_0, out_activation_min);\
  384. ch_1_out_0 = MIN(ch_1_out_0, out_activation_max);\
  385. *out++ = (q7_t) ch_1_out_0;\
  386. \
  387. ch_1_out_1 = hpm_nn_requantize(ch_1_out_1, *out_mult,*out_shift);\
  388. ch_1_out_1 += out_offset;\
  389. ch_1_out_1 = MAX(ch_1_out_1, out_activation_min);\
  390. ch_1_out_1 = MIN(ch_1_out_1, out_activation_max);\
  391. *out_1++ = (q7_t) ch_1_out_1;\
  392. out_mult++;\
  393. out_shift++;\
  394. /* end of assign_requantize */
  395. #endif /* ARMNN_INCLUDE_KERNEL_ELEMENT_H_ */