from skimage.draw import disk
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import spectral_clustering
from sklearn.feature_extraction import image
import time

def euclidian(x1, y1, x2, y2):
    return ((x1 - x2)**2 + (y1 - y2)**2)**0.5

img = np.zeros((100, 100))
# img[30:50, 30:50] = 1

rr, cc = disk((35, 35), 20)
img[rr, cc] = 1

rr, cc = disk((80,35), 15)
img[rr, cc] = 1

rr, cc = disk((65, 65), 25)
img[rr, cc] = 1


start = time.perf_counter()

mask2 = img.astype(bool)
img2 = img.astype(float)
graph = image.img_to_graph(img2, mask=mask2)
graph.data = np.exp(-graph.data / graph.data.std())
labels = spectral_clustering(graph, n_clusters=2, eigen_solver="arpack")
label_im = np.full(mask2.shape, -1.0)
label_im[mask2] = labels

pos = np.where(img == 1)
distance_map = img.copy()
for y,x in zip(*pos):
    step = 1
    min_d = 10**19
    while True:
        big = np.ones((step+2, step+2))
        big[1:-1, 1:-1] = 0
        frame_y, frame_x = np.where(big==1)
        frame_y += y-(step//2)-1
        frame_x += x-(step//2)-1
        for ny, nx in zip(frame_y, frame_x):
            if img[ny ,nx] == 0:
                d = euclidian(y, x, ny, nx)
                if d < min_d:
                    min_d = d
        if min_d != 10**19:
            distance_map[y, x] = min_d
            break           
        step += 2

centers_y = []
center_x = []
frame_size = (2,2)
for y in range(frame_size[0],distance_map.shape[0]-frame_size[0]-1):
    for x in range(frame_size[1],distance_map.shape[1]-frame_size[1]-1):
        distances = distance_map[y-frame_size[0]:y+frame_size[0]+1, x-frame_size[1]:x+frame_size[1]+1].copy()
        distances[*frame_size] = 0.001
        if np.all(distances <= distance_map[y, x]) and np.all(distances != 0):
            ny, nx = np.where(distances == distance_map[y, x])
            print(ny, nx)
            if np.any(ny > frame_size[0]) or np.any(nx > frame_size[1]):
                continue
            centers_y.append(y)
            center_x.append(x)



print(f"Elapsed for {time.perf_counter() - start}")

plt.subplot(1,3,1)
plt.imshow(img)
plt.subplot(1,3,2)
plt.imshow(distance_map)
plt.scatter(center_x, centers_y)
plt.subplot(1,3,3)
plt.imshow(label_im)
plt.show()