tinyengine_function_fp.h 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /* ----------------------------------------------------------------------
  2. * Project: Tiny Training Engine, MCUNetV3
  3. * Title: tinyengine_function_fp.h
  4. *
  5. * Reference papers:
  6. * - MCUNet: Tiny Deep Learning on IoT Device, NeurIPS 2020
  7. * - MCUNetV2: Memory-Efficient Patch-based Inference for Tiny Deep Learning, NeurIPS 2021
  8. * - MCUNetV3: On-Device Training Under 256KB Memory, NeurIPS 2022
  9. * Contact authors:
  10. * - Wei-Chen Wang, wweichen@mit.edu
  11. * - Wei-Ming Chen, wmchen@mit.edu
  12. * - Ji Lin, jilin@mit.edu
  13. * - Ligeng Zhu, ligeng@mit.edu
  14. * - Song Han, songhan@mit.edu
  15. * - Chuang Gan, ganchuang@csail.mit.edu
  16. *
  17. * Target ISA: ARMv7E-M
  18. * -------------------------------------------------------------------- */
  19. #include <stdint.h>
  20. #include <complex.h>
  21. #include <stdio.h>
  22. #include <stdbool.h>
  23. #include <math.h>
  24. #include <float.h>
  25. typedef enum {
  26. STATE_SUCCESS_fp = 0, /* No error */
  27. PARAM_NO_SUPPORT_fp = 1, /* Unsupported parameters */
  28. } tinyengine_status_fp;
  29. #define TN_MAX(A,B) ((A) > (B) ? (A) : (B))
  30. #define TN_MIN(A,B) ((A) < (B) ? (A) : (B))
  31. tinyengine_status_fp add_fp(const uint16_t size, const float* input1_data,
  32. const float* input2_data, float* output_data);
  33. tinyengine_status_fp div_fp(const uint16_t size, const float* input1_data,
  34. const float* input2_data, float* output_data);
  35. tinyengine_status_fp less(const uint16_t size, const float* input1_data,
  36. const float* input2_data, bool* output_data);
  37. tinyengine_status_fp LogSoftmax(const float* input_data, const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  38. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth);
  39. tinyengine_status_fp mul(const uint16_t size, const float* input1_data,
  40. const float* input2_data, float* output_data);
  41. tinyengine_status_fp negative(const uint16_t size, const float* input1_data, bool* output_data);
  42. tinyengine_status_fp nll_loss(const float* input_data, const uint16_t input_dim, const uint16_t input_depth,
  43. const float* target, const uint16_t target_size, float* output_data);
  44. tinyengine_status_fp strided_slice_3Dto3D(const float* input, const uint16_t input_h, const uint16_t input_w, const uint16_t input_c,
  45. const uint16_t* begin, const uint16_t* end, const uint16_t* stride,
  46. float* output, const uint16_t output_h, const uint16_t output_w, const uint16_t output_c);
  47. tinyengine_status_fp strided_slice_4Dto4D(const float* input, const uint16_t inn, const uint16_t inc, const uint16_t inh, const uint16_t inw,
  48. const uint16_t* begin, const uint16_t* end, const uint16_t* stride,
  49. float* output, const uint16_t on, const uint16_t oc, const uint16_t oh, const uint16_t ow);
  50. tinyengine_status_fp sub(const uint16_t size, const float* input1_data,
  51. const float* input2_data, float* output_data);
  52. tinyengine_status_fp sum_2D(const float* input_data, const uint16_t matA_row,
  53. const uint16_t matA_col, const uint16_t axis, float* output_data);
  54. tinyengine_status_fp sum_3D(const float* input_data, const uint16_t input_w, const uint16_t input_h,
  55. const uint16_t input_c, const uint16_t axis, float* output_data);
  56. tinyengine_status_fp sum_4D_exclude(const float* input_data, const uint16_t d1, const uint16_t d2,
  57. const uint16_t d3, const uint16_t d4, const uint16_t axis, float* output_data);
  58. tinyengine_status_fp tte_exp(const uint16_t size, const float* input_data, float* output_data);
  59. tinyengine_status_fp where(const bool* inMask, const uint16_t size, const float* input1_data,
  60. const float* input2_data, float* output_data);
  61. tinyengine_status_fp where_zeros(const bool* inMask, const uint16_t size, const float* input1_data, float* output_data);
  62. tinyengine_status_fp where_zeros_inplace(const bool* inMask, const uint16_t size, float* input1_data);
  63. tinyengine_status_fp where_zeros_inplace_bit(const unsigned char* inMask, const uint16_t size, float* input1_data);
  64. tinyengine_status_fp group_pointwise_conv_fp_in1x1_out1x1_1row10col_uniweight_int8input_inplace(const int8_t* input_data,
  65. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  66. const float* filter_data, const float* bias_data,
  67. int8_t* output_weight_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  68. const float output_activation_min, const float output_activation_max,
  69. float* im2col_data, const uint16_t batches, const uint16_t groups,
  70. const float* scales, const float learning_rate);
  71. tinyengine_status_fp group_pointwise_conv_fp_in1x1_out1x1_1row10col_uniweight_inplace(const float* input_data,
  72. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  73. const float* filter_data, const float* bias_data,
  74. int8_t* output_weight_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  75. const float output_activation_min, const float output_activation_max,
  76. float* im2col_data, const uint16_t batches, const uint16_t groups,
  77. const float* scales, const float learning_rate);
  78. tinyengine_status_fp group_conv_fp_kernel4_stride1_pad0_in4x4_out1x1_uniweight_4row16col_inplace(const float* input_data,
  79. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  80. const float* filter_data, const float* bias_data,
  81. int8_t* output_weight_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  82. const float output_activation_min, const float output_activation_max,
  83. float* im2col_data, const uint16_t batches, const uint16_t groups,
  84. const float* scales, const float learning_rate);
  85. tinyengine_status_fp group_conv_fp_kernel4_stride1_pad0_in4x4_out1x1_uniweight_4row8col_inplace(const float* input_data,
  86. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  87. const float* filter_data, const float* bias_data,
  88. int8_t* output_weight_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  89. const float output_activation_min, const float output_activation_max,
  90. float* im2col_data, const uint16_t batches, const uint16_t groups,
  91. const float* scales, const float learning_rate);
  92. tinyengine_status_fp group_conv_fp_kernel8_stride1_pad0_in8x8_out1x1_uniweight_4row16col_inplace(const float* input_data,
  93. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  94. const float* filter_data, const float* bias_data,
  95. int8_t* output_weight_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  96. const float output_activation_min, const float output_activation_max,
  97. float* im2col_data, const uint16_t batches, const uint16_t groups,
  98. const float* scales, const float learning_rate);
  99. tinyengine_status_fp group_conv_fp_kernel8_stride1_pad0_in8x8_out1x1_uniweight_4row8col_inplace(const float* input_data,
  100. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  101. const float* filter_data, const float* bias_data,
  102. int8_t* output_weight_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  103. const float output_activation_min, const float output_activation_max,
  104. float* im2col_data, const uint16_t batches, const uint16_t groups,
  105. const float* scales, const float learning_rate);
  106. tinyengine_status_fp transpose_depthwise_conv_fp_kernel3_stride1_inpad1_outpad0_IOHW_int8weight_partialCH(float* input_output_data,
  107. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  108. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  109. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  110. const float output_activation_min, const float output_activation_max,
  111. float* im2col_data, const uint16_t batches, const int pad_value);
  112. tinyengine_status_fp transpose_depthwise_conv_fp_kernel3_stride1_inpad1_outpad0_IOHW_int8weight(float* input_output_data,
  113. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  114. const int8_t* filter_data, const float* bias_data,
  115. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  116. const float output_activation_min, const float output_activation_max,
  117. float* im2col_data, const uint16_t batches, const int pad_value);
  118. tinyengine_status_fp transpose_depthwise_conv_fp_kernel3_stride2_inpad1_outpad1_IOHW_int8weight_partialCH(float* input_data,
  119. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  120. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  121. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  122. const float output_activation_min, const float output_activation_max,
  123. float* im2col_data, const uint16_t batches, const int pad_value);
  124. tinyengine_status_fp transpose_depthwise_conv_fp_kernel3_stride2_inpad1_outpad1_IOHW_int8weight(float* input_data,
  125. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  126. const int8_t* filter_data, const float* bias_data,
  127. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  128. const float output_activation_min, const float output_activation_max,
  129. float* im2col_data, const uint16_t batches, const int pad_value);
  130. tinyengine_status_fp transpose_depthwise_conv_fp_kernel5_stride1_inpad2_outpad0_IOHW_int8weight_partialCH(float* input_output_data,
  131. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  132. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  133. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  134. const float output_activation_min, const float output_activation_max,
  135. float* im2col_data, const uint16_t batches, const int pad_value);
  136. tinyengine_status_fp transpose_depthwise_conv_fp_kernel5_stride1_inpad2_outpad0_IOHW_int8weight(float* input_output_data,
  137. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  138. const int8_t* filter_data, const float* bias_data,
  139. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  140. const float output_activation_min, const float output_activation_max,
  141. float* im2col_data, const uint16_t batches, const int pad_value);
  142. tinyengine_status_fp transpose_depthwise_conv_fp_kernel5_stride2_inpad2_outpad1_IOHW_int8weight_partialCH(float* input_data,
  143. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  144. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  145. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  146. const float output_activation_min, const float output_activation_max,
  147. float* im2col_data, const uint16_t batches, const int pad_value);
  148. tinyengine_status_fp transpose_depthwise_conv_fp_kernel5_stride2_inpad2_outpad1_IOHW_int8weight(float* input_data,
  149. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  150. const int8_t* filter_data, const float* bias_data,
  151. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  152. const float output_activation_min, const float output_activation_max,
  153. float* im2col_data, const uint16_t batches, const int pad_value);
  154. tinyengine_status_fp transpose_depthwise_conv_fp_kernel7_stride1_inpad3_outpad0_IOHW_int8weight_partialCH(float* input_output_data,
  155. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  156. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  157. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  158. const float output_activation_min, const float output_activation_max,
  159. float* im2col_data, const uint16_t batches, const int pad_value);
  160. tinyengine_status_fp transpose_depthwise_conv_fp_kernel7_stride1_inpad3_outpad0_IOHW_int8weight(float* input_output_data,
  161. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  162. const int8_t* filter_data, const float* bias_data,
  163. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  164. const float output_activation_min, const float output_activation_max,
  165. float* im2col_data, const uint16_t batches, const int pad_value);
  166. tinyengine_status_fp transpose_depthwise_conv_fp_kernel7_stride2_inpad3_outpad1_IOHW_int8weight_partialCH(float* input_data,
  167. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  168. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  169. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  170. const float output_activation_min, const float output_activation_max,
  171. float* im2col_data, const uint16_t batches, const int pad_value);
  172. tinyengine_status_fp transpose_depthwise_conv_fp_kernel7_stride2_inpad3_outpad1_IOHW_int8weight(float* input_data,
  173. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  174. const int8_t* filter_data, const float* bias_data,
  175. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  176. const float output_activation_min, const float output_activation_max,
  177. float* im2col_data, const uint16_t batches, const int pad_value);
  178. tinyengine_status_fp pointwise_conv_fp_1row10col_10inputdepth_IOHW_int8weight(const float* input_data, const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  179. const int8_t* filter_data, const float* bias_data,
  180. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  181. const float output_activation_min, const float output_activation_max,
  182. float* im2col_data, const uint16_t batches);
  183. tinyengine_status_fp pointwise_conv_fp_4row4col_IOHW_int8weight(const float* input_data,
  184. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  185. const int8_t* filter_data, const float* bias_data,
  186. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  187. const float output_activation_min, const float output_activation_max,
  188. float* im2col_data, const uint16_t batches);
  189. tinyengine_status_fp pointwise_conv_fp_4row4col_IOHW_int8weight_partialCH_8innercol(const float* input_data,
  190. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  191. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  192. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  193. const float output_activation_min, const float output_activation_max,
  194. float* im2col_data, const uint16_t batches);
  195. tinyengine_status_fp pointwise_conv_fp_4row4col_IOHW_int8weight_partialCH_4innercol(const float* input_data,
  196. const uint16_t input_height, const uint16_t input_width, const uint16_t input_depth,
  197. const int8_t* filter_sram, const int8_t* filter_flash, const uint16_t first_k_channel, const float* bias_data,
  198. float* output_data, const uint16_t output_height, const uint16_t output_width, const uint16_t output_depth,
  199. const float output_activation_min, const float output_activation_max,
  200. float* im2col_data, const uint16_t batches);