import numpy as np
import torch
import math
from training.evaluation import soft_predictions, get_models_in_buffer
import copy


def tri_2_square(v):
    v = np.array(v)
    s = len(v[-1])
    matrix = []
    for i in v:

        line =[]
        for el in i:
            line.append(el)
        while len(line)<s:
            line.append(math.inf)

        matrix.append(line)

    return np.array(matrix).transpose()


def remove_indices(v,indices):
    print('remove')
    v_remove=[]
    for i in range(0,len(v)):
        if i!=indices:
            v_remove.append(v[i])
    final =[]
    for i in v_remove:
        #print(i)
        if len(i)>indices:
            i.pop(indices)
            final.append(i)
        else:
            final.append(i)

    return final


def compute_M(inputs, buffer, M, cfg, num_samples=2):
    net_list = get_models_in_buffer(cfg, buffer)

    softmax_output_all = []

    for i in range(0, len(net_list)):
        softmax_output = soft_predictions(inputs, num_samples, net_list[i])
        softmax_output_all.append(softmax_output)

    kl = torch.nn.KLDivLoss(reduction="batchmean")

    diff_all = []
    for i in softmax_output_all[0:-1]:

        diff1 = kl(i.log(), softmax_output_all[-1]).item()
        diff2 = kl(softmax_output_all[-1], i.log()).item()

        diff_all.append((diff1 + diff2) * 0.5)

    M.append(diff_all)


    #print(tri_2_square(M))

    return M


def erase_LIM(M):

    indices = np.unravel_index(np.argmin(tri_2_square(M), axis=None), tri_2_square(M).shape)
    min_indices = min(indices)


    M = remove_indices(M,min_indices)

    return min_indices, M