import numpy as np
from collections import Counter
from scipy.sparse.csgraph import minimum_spanning_tree
from scipy.sparse import coo_matrix

def optimize_D_1tree(Din, lr, n_iter = 10**5):
    D = np.copy(Din)
    
    best_loss = float('inf')
    best_D = np.copy(Din)
    best_iter = 0

    pi = np.zeros(D.shape[0])

    for iter_num in range(n_iter):

        D1 = D[1:, 1:]
        xsorted = sorted(enumerate(D[0]), key = lambda x : x[1])

        mst = minimum_spanning_tree(D1)
        mst_coo = coo_matrix(mst)
        m = np.mean(mst_coo.data)

        c = Counter()

        for i in range(D1.shape[0] - 1):
            c[mst_coo.row[i] + 1] += 1
            c[mst_coo.col[i] + 1] += 1

        # xsorted[0] - (0, 0) - self distance = 0
        # adding edges (0, xsorted[1]), (1, xsorted[2])
        c[xsorted[1][0]] += 1
        c[xsorted[2][0]] += 1

        cnt_deg_neq_2 = len(list(filter(lambda x : x != 2, [x for x in c.values()])))
        #print(iter_num, np.mean(mst_coo.data), np.max([x for x in c.values()]), cnt_deg_more_2)

        #L = np.sum(mst_coo.data) + D[0, xsorted[1][0]] + D[0, xsorted[2][0]]
        #loss = -(L - 2 * np.sum(pi))
        loss = cnt_deg_neq_2

        if loss < best_loss:
            best_loss = loss
            best_D = np.copy(D)
            best_iter = iter_num

            print(iter_num, 'best', best_loss)

        if iter_num - best_iter > 100:
            return best_loss, best_D

        for v, degree in c.items():
            step = lr * m * (degree - 2)

            pi[v] += step
            D[:, v] += step
            D[v, :] += step

        np.fill_diagonal(D, 0)

    return best_loss, best_D
