cv_segm.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from skimage.draw import disk
  2. import numpy as np
  3. import matplotlib.pyplot as plt
  4. from sklearn.cluster import spectral_clustering
  5. from sklearn.feature_extraction import image
  6. import time
  7. def euclidian(x1, y1, x2, y2):
  8. return ((x1 - x2)**2 + (y1 - y2)**2)**0.5
  9. img = np.zeros((100, 100))
  10. # img[30:50, 30:50] = 1
  11. rr, cc = disk((35, 35), 20)
  12. img[rr, cc] = 1
  13. rr, cc = disk((80,35), 15)
  14. img[rr, cc] = 1
  15. rr, cc = disk((65, 65), 25)
  16. img[rr, cc] = 1
  17. start = time.perf_counter()
  18. mask2 = img.astype(bool)
  19. img2 = img.astype(float)
  20. graph = image.img_to_graph(img2, mask=mask2)
  21. graph.data = np.exp(-graph.data / graph.data.std())
  22. labels = spectral_clustering(graph, n_clusters=2, eigen_solver="arpack")
  23. label_im = np.full(mask2.shape, -1.0)
  24. label_im[mask2] = labels
  25. pos = np.where(img == 1)
  26. distance_map = img.copy()
  27. for y,x in zip(*pos):
  28. step = 1
  29. min_d = 10**19
  30. while True:
  31. big = np.ones((step+2, step+2))
  32. big[1:-1, 1:-1] = 0
  33. frame_y, frame_x = np.where(big==1)
  34. frame_y += y-(step//2)-1
  35. frame_x += x-(step//2)-1
  36. for ny, nx in zip(frame_y, frame_x):
  37. if img[ny ,nx] == 0:
  38. d = euclidian(y, x, ny, nx)
  39. if d < min_d:
  40. min_d = d
  41. if min_d != 10**19:
  42. distance_map[y, x] = min_d
  43. break
  44. step += 2
  45. centers_y = []
  46. center_x = []
  47. frame_size = (2,2)
  48. for y in range(frame_size[0],distance_map.shape[0]-frame_size[0]-1):
  49. for x in range(frame_size[1],distance_map.shape[1]-frame_size[1]-1):
  50. distances = distance_map[y-frame_size[0]:y+frame_size[0]+1, x-frame_size[1]:x+frame_size[1]+1].copy()
  51. distances[*frame_size] = 0.001
  52. if np.all(distances <= distance_map[y, x]) and np.all(distances != 0):
  53. ny, nx = np.where(distances == distance_map[y, x])
  54. print(ny, nx)
  55. if np.any(ny > frame_size[0]) or np.any(nx > frame_size[1]):
  56. continue
  57. centers_y.append(y)
  58. center_x.append(x)
  59. print(f"Elapsed for {time.perf_counter() - start}")
  60. plt.subplot(1,3,1)
  61. plt.imshow(img)
  62. plt.subplot(1,3,2)
  63. plt.imshow(distance_map)
  64. plt.scatter(center_x, centers_y)
  65. plt.subplot(1,3,3)
  66. plt.imshow(label_im)
  67. plt.show()