main.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from skimage.filters import threshold_otsu
  4. from skimage.measure import label, regionprops
  5. from collections import defaultdict
  6. from pathlib import Path
  7. def filling_factor(arr):
  8. return np.sum(arr) / arr.size
  9. def count_holes(arr):
  10. labeled = label(np.logical_not(arr))
  11. regions = regionprops(labeled)
  12. holes = 0
  13. for region in regions:
  14. coords = np.transpose(region.coords, (1,0))
  15. ymin = np.min(coords[1])
  16. ymax = np.max(coords[1])
  17. xmin = np.min(coords[0])
  18. xmax = np.max(coords[0])
  19. if ymin == 0 or ymax == arr.shape[1]-1 or xmin == 0 or xmax == arr.shape[0]-1: continue
  20. holes += 1
  21. return holes
  22. def count_holes_rame(arr):
  23. labeled = label(np.logical_not(arr))
  24. return np.max(labeled)
  25. def count_vline(arr):
  26. return np.sum(arr.mean(0) == 1)
  27. def recognize(region):
  28. if filling_factor(region.image) == 1.0:
  29. return '-'
  30. else:
  31. holes = count_holes(region.image)
  32. if holes == 2: # B or 8
  33. if count_vline(region.image) >= 3:
  34. return 'B'
  35. else: return '8'
  36. elif holes == 1: #A or 0
  37. if count_vline(region.image) >= 2:
  38. if abs((region.local_centroid[0] / region.image.shape[0]) - (region.local_centroid[1] / region.image.shape[1])) > 0.035:
  39. return 'D'
  40. else: return 'P'
  41. elif abs((region.local_centroid[0] / region.image.shape[0]) - (region.local_centroid[1] / region.image.shape[1])) > 0.02:
  42. return 'A'
  43. else: return '0'
  44. else:
  45. if count_vline(region.image) >= 1:
  46. return '1'
  47. else:
  48. if region.eccentricity < 0.4:
  49. return '*'
  50. else:
  51. match count_holes_rame(region.image):
  52. case 2: return '/'
  53. case 4: return 'X'
  54. case 5: return 'W'
  55. return '_'
  56. img = plt.imread('/home/jezv/Projects/Volovikov_CV/alphabet/symbols.png')
  57. img = np.mean(img, axis=2)
  58. thrash = threshold_otsu(img)
  59. img[img > 0] = 1
  60. regions = regionprops(label(img))
  61. result = defaultdict(lambda: 0)
  62. path = Path('.') / 'alphabet' /'result'
  63. path.mkdir(exist_ok=True)
  64. for i,region in enumerate(regions):
  65. symbol = recognize(region)
  66. result[symbol] += 1
  67. if symbol in 'PD':
  68. plt.clf()
  69. plt.title(f"{symbol}")
  70. plt.imshow(region.image)
  71. plt.savefig(path / f"{i}")
  72. print(result, sum(result.values()))