123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- from skimage.draw import disk
- import numpy as np
- import matplotlib.pyplot as plt
- from sklearn.cluster import spectral_clustering
- from sklearn.feature_extraction import image
- import time
- def euclidian(x1, y1, x2, y2):
- return ((x1 - x2)**2 + (y1 - y2)**2)**0.5
- img = np.zeros((100, 100))
- # img[30:50, 30:50] = 1
- rr, cc = disk((35, 35), 20)
- img[rr, cc] = 1
- rr, cc = disk((80,35), 15)
- img[rr, cc] = 1
- rr, cc = disk((65, 65), 25)
- img[rr, cc] = 1
- start = time.perf_counter()
- mask2 = img.astype(bool)
- img2 = img.astype(float)
- graph = image.img_to_graph(img2, mask=mask2)
- graph.data = np.exp(-graph.data / graph.data.std())
- labels = spectral_clustering(graph, n_clusters=2, eigen_solver="arpack")
- label_im = np.full(mask2.shape, -1.0)
- label_im[mask2] = labels
- pos = np.where(img == 1)
- distance_map = img.copy()
- for y,x in zip(*pos):
- step = 1
- min_d = 10**19
- while True:
- big = np.ones((step+2, step+2))
- big[1:-1, 1:-1] = 0
- frame_y, frame_x = np.where(big==1)
- frame_y += y-(step//2)-1
- frame_x += x-(step//2)-1
- for ny, nx in zip(frame_y, frame_x):
- if img[ny ,nx] == 0:
- d = euclidian(y, x, ny, nx)
- if d < min_d:
- min_d = d
- if min_d != 10**19:
- distance_map[y, x] = min_d
- break
- step += 2
- centers_y = []
- center_x = []
- frame_size = (2,2)
- for y in range(frame_size[0],distance_map.shape[0]-frame_size[0]-1):
- for x in range(frame_size[1],distance_map.shape[1]-frame_size[1]-1):
- distances = distance_map[y-frame_size[0]:y+frame_size[0]+1, x-frame_size[1]:x+frame_size[1]+1].copy()
- distances[*frame_size] = 0.001
- if np.all(distances <= distance_map[y, x]) and np.all(distances != 0):
- ny, nx = np.where(distances == distance_map[y, x])
- print(ny, nx)
- if np.any(ny > frame_size[0]) or np.any(nx > frame_size[1]):
- continue
- centers_y.append(y)
- center_x.append(x)
- print(f"Elapsed for {time.perf_counter() - start}")
- plt.subplot(1,3,1)
- plt.imshow(img)
- plt.subplot(1,3,2)
- plt.imshow(distance_map)
- plt.scatter(center_x, centers_y)
- plt.subplot(1,3,3)
- plt.imshow(label_im)
- plt.show()
|