check_firmware.cpp 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #include <Arduino.h>
  2. #include "check_firmware.h"
  3. #include "monocypher.h"
  4. #include "parameters.h"
  5. #include <string.h>
  6. /*
  7. simple base64 decoder, not particularly efficient, but small
  8. */
  9. static int32_t base64_decode(const char *s, uint8_t *out, const uint32_t max_len)
  10. {
  11. static const char b64[] = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
  12. const char *p;
  13. uint32_t n = 0;
  14. uint32_t i = 0;
  15. while (*s && (p=strchr(b64,*s))) {
  16. const uint8_t idx = (p - b64);
  17. const uint32_t byte_offset = (i*6)/8;
  18. const uint32_t bit_offset = (i*6)%8;
  19. out[byte_offset] &= ~((1<<(8-bit_offset))-1);
  20. if (bit_offset < 3) {
  21. if (byte_offset >= max_len) {
  22. break;
  23. }
  24. out[byte_offset] |= (idx << (2-bit_offset));
  25. n = byte_offset+1;
  26. } else {
  27. if (byte_offset >= max_len) {
  28. break;
  29. }
  30. out[byte_offset] |= (idx >> (bit_offset-2));
  31. n = byte_offset+1;
  32. if (byte_offset+1 >= max_len) {
  33. break;
  34. }
  35. out[byte_offset+1] = (idx << (8-(bit_offset-2))) & 0xFF;
  36. n = byte_offset+2;
  37. }
  38. s++; i++;
  39. }
  40. if ((n > 0) && (*s == '=')) {
  41. n -= 1;
  42. }
  43. return n;
  44. }
  45. bool CheckFirmware::check_partition(const uint8_t *flash, uint32_t flash_len,
  46. const uint8_t *lead_bytes, uint32_t lead_length,
  47. const app_descriptor_t *ad, const uint8_t public_key[32])
  48. {
  49. crypto_check_ctx ctx {};
  50. crypto_check_ctx_abstract *actx = (crypto_check_ctx_abstract*)&ctx;
  51. crypto_check_init(actx, ad->sign_signature, public_key);
  52. if (lead_length > 0) {
  53. crypto_check_update(actx, lead_bytes, lead_length);
  54. }
  55. crypto_check_update(actx, &flash[lead_length], flash_len-lead_length);
  56. return crypto_check_final(actx) == 0;
  57. }
  58. bool CheckFirmware::check_OTA_partition(const esp_partition_t *part, const uint8_t *lead_bytes, uint32_t lead_length, uint32_t &board_id)
  59. {
  60. Serial.printf("Checking partition %s\n", part->label);
  61. spi_flash_mmap_handle_t handle;
  62. const void *ptr = nullptr;
  63. auto ret = esp_partition_mmap(part, 0, part->size, SPI_FLASH_MMAP_DATA, &ptr, &handle);
  64. if (ret != ESP_OK) {
  65. Serial.printf("mmap failed\n");
  66. return false;
  67. }
  68. const uint8_t sig_rev[] = APP_DESCRIPTOR_REV;
  69. uint8_t sig[8];
  70. for (uint8_t i=0; i<8; i++) {
  71. sig[i] = sig_rev[7-i];
  72. }
  73. const app_descriptor_t *ad = (app_descriptor_t *)memmem(ptr, part->size, sig, sizeof(sig));
  74. if (ad == nullptr) {
  75. Serial.printf("app_descriptor not found\n");
  76. spi_flash_munmap(handle);
  77. return false;
  78. }
  79. Serial.printf("app descriptor at 0x%x size=%u id=%u\n", unsigned(ad)-unsigned(ptr), ad->image_size, ad->board_id);
  80. const uint32_t img_len = uint32_t(uintptr_t(ad) - uintptr_t(ptr));
  81. if (ad->image_size != img_len) {
  82. Serial.printf("app_descriptor bad size %u\n", ad->image_size);
  83. spi_flash_munmap(handle);
  84. return false;
  85. }
  86. board_id = ad->board_id;
  87. bool no_keys = true;
  88. for (uint8_t i=0; i<MAX_PUBLIC_KEYS; i++) {
  89. const char *b64_key = g.public_keys[i].b64_key;
  90. Serial.printf("Checking public key: '%s'\n", b64_key);
  91. const char *ktype = "PUBLIC_KEYV1:";
  92. if (strncmp(b64_key, ktype, strlen(ktype)) != 0) {
  93. continue;
  94. }
  95. no_keys = false;
  96. b64_key += strlen(ktype);
  97. uint8_t key[32];
  98. int32_t out_len = base64_decode(b64_key, key, sizeof(key));
  99. if (out_len != 32) {
  100. continue;
  101. }
  102. if (check_partition((const uint8_t *)ptr, img_len, lead_bytes, lead_length, ad, key)) {
  103. Serial.printf("check firmware good for key %u\n", i);
  104. spi_flash_munmap(handle);
  105. return true;
  106. }
  107. Serial.printf("check failed key %u\n", i);
  108. }
  109. spi_flash_munmap(handle);
  110. if (no_keys) {
  111. Serial.printf("No public keys - accepting firmware\n");
  112. return true;
  113. }
  114. Serial.printf("firmware failed checks\n");
  115. return false;
  116. }
  117. bool CheckFirmware::check_OTA_next(const uint8_t *lead_bytes, uint32_t lead_length)
  118. {
  119. const auto *running_part = esp_ota_get_running_partition();
  120. if (running_part == nullptr) {
  121. Serial.printf("No running OTA partition\n");
  122. return false;
  123. }
  124. const auto *part = esp_ota_get_next_update_partition(running_part);
  125. if (part == nullptr) {
  126. Serial.printf("No next OTA partition\n");
  127. return false;
  128. }
  129. uint32_t board_id=0;
  130. bool sig_ok = check_OTA_partition(part, lead_bytes, lead_length, board_id);
  131. // if app descriptor has a board ID and the ID is wrong then reject
  132. if (board_id != 0 && board_id != BOARD_ID) {
  133. return false;
  134. }
  135. if (g.lock_level == 0) {
  136. // if unlocked then accept any firmware
  137. return true;
  138. }
  139. return sig_ok;
  140. }
  141. bool CheckFirmware::check_OTA_running(void)
  142. {
  143. const auto *running_part = esp_ota_get_running_partition();
  144. if (running_part == nullptr) {
  145. Serial.printf("No running OTA partition\n");
  146. return false;
  147. }
  148. uint32_t board_id=0;
  149. return check_OTA_partition(running_part, nullptr, 0, board_id);
  150. }
  151. esp_err_t esp_partition_read_raw(const esp_partition_t* partition,
  152. size_t src_offset, void* dst, size_t size);