main.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. import os
  2. import matplotlib.pyplot as plt
  3. import numpy as np
  4. from skimage.filters import threshold_otsu
  5. from skimage.measure import label, regionprops
  6. from collections import defaultdict
  7. def filling_factor(arr):
  8. return np.sum(arr) / arr.size
  9. def count_holes(arr):
  10. labeled = label(np.logical_not(arr))
  11. regions = regionprops(labeled)
  12. holes = 0
  13. for region in regions:
  14. coords = np.transpose(region.coords, (1, 0))
  15. ymin = np.min(coords[1])
  16. ymax = np.max(coords[1])
  17. xmin = np.min(coords[0])
  18. xmax = np.max(coords[0])
  19. if (
  20. ymin == 0
  21. or ymax == arr.shape[1] - 1
  22. or xmin == 0
  23. or xmax == arr.shape[0] - 1
  24. ):
  25. continue
  26. holes += 1
  27. return holes
  28. def count_holes_rame(arr):
  29. labeled = label(np.logical_not(arr))
  30. return np.max(labeled)
  31. def count_vline(arr):
  32. return np.sum(arr.mean(0) == 1)
  33. def recognize(region):
  34. if filling_factor(region.image) == 1.0:
  35. return "-"
  36. else:
  37. holes = count_holes(region.image)
  38. if holes == 2: # B or 8
  39. if count_vline(region.image) >= 3:
  40. return "B"
  41. else:
  42. return "8"
  43. elif holes == 1: # A or 0
  44. if count_vline(region.image) >= 2:
  45. ecc = region.eccentricity
  46. if ecc < 0.65:
  47. return "D"
  48. else:
  49. return "P"
  50. else:
  51. if (
  52. abs(
  53. (region.local_centroid[0] / region.image.shape[0])
  54. - (region.local_centroid[1] / region.image.shape[1])
  55. )
  56. > 0.02
  57. ):
  58. return "A"
  59. else:
  60. return "0"
  61. else:
  62. if count_vline(region.image) >= 1:
  63. return "1"
  64. else:
  65. ecc = region.eccentricity
  66. if ecc < 0.4:
  67. return "*"
  68. match count_holes_rame(region.image):
  69. case 2:
  70. return "/"
  71. case 4:
  72. return "X"
  73. case _:
  74. return "W"
  75. img = plt.imread("./symbols.png")
  76. img = np.mean(img, axis=2)
  77. thrash = threshold_otsu(img)
  78. img[img > 0] = 1
  79. regions = regionprops(label(img))
  80. result = {}
  81. path = "./res"
  82. if not os.path.exists(path):
  83. os.mkdir(path)
  84. for i, region in enumerate(regions):
  85. symbol = recognize(region)
  86. if symbol in ["P", "D"]:
  87. plt.clf()
  88. plt.title(f"{symbol}")
  89. plt.imshow(region.image)
  90. plt.savefig(f"{path}/{i}")
  91. if symbol not in result.keys():
  92. result[symbol] = 0
  93. result[symbol] += 1
  94. print(result, sum(result.values()))