main.py 1.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. from skimage.measure import label, regionprops
  4. from skimage.color import rgb2hsv
  5. from collections import defaultdict
  6. def get_shapes(regions):
  7. shapes = defaultdict(lambda: 0)
  8. for _, region in enumerate(regions):
  9. eccent = region.eccentricity
  10. if eccent == 0:
  11. if region.image.size == region.area:
  12. key = "rect"
  13. else:
  14. key = "ball"
  15. else:
  16. key = "rect"
  17. shapes[key] += 1
  18. return shapes
  19. file = "./balls_and_rects.png"
  20. img = plt.imread(file)
  21. binary = np.mean(img, 2)
  22. binary[binary > 0] = 1
  23. labeled = label(binary)
  24. print(f"Total f={labeled.max()}")
  25. hsv_image = rgb2hsv(img)
  26. c = 0
  27. prev = 0
  28. for index, i in enumerate(np.linspace(0, 1, 10)):
  29. tmp_img = np.copy(hsv_image)
  30. tmp_img[tmp_img[:, :, 0] < prev] = 0
  31. tmp_img[tmp_img[:, :, 0] > i] = 0
  32. tmp_img[tmp_img[:, :, 0] > 0] = 1
  33. tmp_img = np.mean(tmp_img, 2)
  34. labeled = label(tmp_img)
  35. if labeled.max() > 0:
  36. c += 1
  37. print(f"Color {c}:")
  38. regs = regionprops(labeled)
  39. shapes = get_shapes(regs)
  40. for cur_key in shapes:
  41. print(f"\t{cur_key}s: {shapes[cur_key]}")
  42. prev = i