import sys
from skimage.measure import label
import matplotlib.pyplot as plt
import numpy as np


def dil(arr, struct):
    res = np.zeros_like(arr)
    for y in range(1, arr.shape[0] - 1):
        for x in range(1, arr.shape[1] - 1):
            sub = arr[y - 1 : y + 2, x - 1 : x + 2]
            rsub = np.logical_and(arr[y, x], struct)
            res[y - 1 : y + 2, x - 1 : x + 2] = np.logical_or(
                res[y - 1 : y + 2, x - 1 : x + 2], rsub
            )
    return res


def erros(arr, struct):
    result = np.zeros_like(arr)
    for y in range(1, arr.shape[0] - 1):
        for x in range(1, arr.shape[1] - 1):
            sub = arr[y - 1 : y + 2, x - 1 : x + 2]
            if np.all(sub >= struct):
                result[y, x] = 1
    return result


def cl(arr, struct):
    return dil(erros(arr, struct), struct)


def op(arr, struct):
    return erros(dil(arr, struct), struct)


if len(sys.argv) < 2:
    print("No path provided")
    exit()

num = np.load(sys.argv[1])

m = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]])

plt.subplot(1, 2, 1)
plt.imshow(num)
plt.subplot(1, 2, 2)
plt.imshow(cl(num, m))

for w in np.unique(label(num))[1:]:
    res = label(cl(num, m))[label(num) == w]
    match len(np.unique(res)[1:]):
        case 1:
            print("Провод не порван")
        case 0:
            print("Провод уничтожен")
        case _:
            print(f"Провод {w} порван на {len(np.unique(res)[1:])}")

plt.show()