import numpy as np
import scipy
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.sparse import coo_matrix
from collections import Counter

def create_problem(seed, N, dim = 2):
    np.random.seed(seed)
    cities = np.random.random((N, dim))
    distance_matrix = scipy.spatial.distance.cdist(cities, cities)

    return cities, distance_matrix

def create_problem_cluster(seed, N, n_clusters, dim = 2):
    n_clusters = 3
    city_list = []
    np.random.seed(seed)

    for k in range(n_clusters):
        mean = 0.25 + np.random.random(2) * 0.5
        cov = 2e-3 * np.diag([np.random.random() for _ in range(2)])
        #cov = 0.001 * np.eye(2)

        #print(mean, cov)

        cnt = 0
        if k == n_clusters - 1:
            max_cnt = N - len(city_list)
        else:
            max_cnt =  N // n_clusters

        while cnt < max_cnt:
            z = np.random.multivariate_normal(mean, cov, 1)

            if 0 <= z[0][0] <= 1 and 0 <= z[0][1] <= 1:
                city_list.append(z)
                cnt += 1

    cities = np.array(city_list).squeeze()
    distance_matrix = scipy.spatial.distance.cdist(cities, cities)

    return cities, distance_matrix

def create_problem_nonmetric(seed, N):
    np.random.seed(seed)
    distance_matrix = np.random.random((N, N))
    distance_matrix[np.triu_indices(N, k = 0)] = 0
    distance_matrix = (distance_matrix + distance_matrix.T) / 2

    return distance_matrix

def create_problem_cluster2(seed, N, n_clusters):
    # np.random.seed(seed)

    city_list = []
    cluster_radius = 0.05  # радиус кластера (влияет на разброс)
    margin = 0.1  # минимальное расстояние между центрами

    # Выбираем центры кластеров, избегая близкого расположения
    centers = []
    attempts = 0
    max_attempts = 1000

    while len(centers) < n_clusters and attempts < max_attempts:
        candidate = 0.1 + 0.8 * np.random.rand(2)  # центр в пределах [0.1, 0.9]
        too_close = False
        for c in centers:
            if np.linalg.norm(candidate - c) < (2 * cluster_radius + margin):
                too_close = True
                break
        if not too_close:
            centers.append(candidate)
        attempts += 1

    if len(centers) < n_clusters:
        raise ValueError("Не удалось разместить кластеры без пересечений")

    # Создаём точки в кластерах
    points_per_cluster = [N // n_clusters] * n_clusters
    points_per_cluster[-1] += N - sum(points_per_cluster)  # остаток в последний кластер

    for i in range(n_clusters):
        center = centers[i]
        cov = (cluster_radius ** 2) * np.eye(2)
        cnt = 0
        cluster_points = []
        while cnt < points_per_cluster[i]:
            z = np.random.multivariate_normal(center, cov)
            if 0 <= z[0] <= 1 and 0 <= z[1] <= 1:
                cluster_points.append(z)
                cnt += 1
        city_list.extend(cluster_points)

    cities = np.array(city_list)
    distance_matrix = scipy.spatial.distance.cdist(cities, cities)
    return cities, distance_matrix

def optimize_D(Din, lr, n_iter = 100):
    D = np.copy(Din)
    best_cnt = float('inf')
    best_D = np.copy(Din)

    for i in range(n_iter):
        mst = minimum_spanning_tree(D)
        mst_coo = coo_matrix(mst)
        m = np.mean(mst_coo.data)

        c = Counter()

        for i in range(D.shape[0] - 1):
            c[mst_coo.row[i]] += 1
            c[mst_coo.col[i]] += 1

        np.fill_diagonal(D, 0)

        cnt_deg_more_2 = len(list(filter(lambda x : x > 2, [x for x in c.values()])))

        #print(np.mean(mst_coo.data), np.max([x for x in c.values()]), cnt_deg_more_2)

        if cnt_deg_more_2 < best_cnt:
            best_cnt = cnt_deg_more_2
            best_D = np.copy(D)

            print('best', best_cnt)

        for v, degree in c.items():
            D[:, v] += lr * m * (degree - 2)
            D[v, :] += lr * m * (degree - 2)

    return best_D
