import torch
from torch.nn.functional import one_hot
import time



def label_mapping_base(logits, mapping_sequence):
    modified_logits = logits[:, mapping_sequence]

    return modified_logits


def get_dist_matrix(fx, y):
    fx = one_hot(torch.argmax(fx, dim = -1), num_classes=fx.size(-1))
    dist_matrix = [fx[y==i].sum(0).unsqueeze(1) for i in range(len(y.unique()))]
    dist_matrix = torch.cat(dist_matrix, dim=1)

    return dist_matrix


def predictive_distribution_based_multi_label_mapping(dist_matrix, mlm_num: int):
    assert mlm_num * dist_matrix.size(1) <= dist_matrix.size(0), "source label number not enough for mapping"
    mapping_matrix = torch.zeros_like(dist_matrix, dtype=int)
    dist_matrix_flat = dist_matrix.flatten() # same memory
    for _ in range(mlm_num * dist_matrix.size(1)):
        loc = dist_matrix_flat.argmax().item()
        loc = [loc // dist_matrix.size(1), loc % dist_matrix.size(1)]
        mapping_matrix[loc[0], loc[1]] = 1
        dist_matrix[loc[0]] = -1
        if mapping_matrix[:, loc[1]].sum() == mlm_num:
            dist_matrix[:, loc[1]] = -1

    return mapping_matrix


def generate_label_mapping_by_frequency(visual_prompt, network, data_loader, mapping_num = 1):
    device = next(visual_prompt.parameters()).device
    if visual_prompt:
        visual_prompt.eval()
    if hasattr(network, "eval"):
        network.eval()
    fx0s = []
    ys = []
    start = time.time()
    for i, (x, y) in enumerate(data_loader):
        x, y = x.to(device), y.to(device)
        with torch.no_grad():
            fx0 = network(visual_prompt(x))
        fx0s.append(fx0)
        ys.append(y)
    end = time.time()
    print('Label Mapping\t'
        f'Time {end-start:.2f}')
    fx0s = torch.cat(fx0s).cpu().float()
    ys = torch.cat(ys).cpu().int()
    if ys.size(0) != fx0s.size(0):
        assert fx0s.size(0) % ys.size(0) == 0
        ys = ys.repeat(int(fx0s.size(0) / ys.size(0)))
    dist_matrix = get_dist_matrix(fx0s, ys)
    pairs = torch.nonzero(predictive_distribution_based_multi_label_mapping(dist_matrix, mapping_num))
    mapping_sequence = pairs[:, 0][torch.sort(pairs[:, 1]).indices.tolist()]

    return mapping_sequence


def generate_label_mapping_by_frequency_ordinary(network, data_loader, mapping_num = 1):
    if hasattr(network, "eval"):
        network.eval()
    fx0s = []
    ys = []
    start = time.time()
    for i, (x, y) in enumerate(data_loader):
        x, y = x.cuda(), y.cuda()
        with torch.no_grad():
            fx0 = network((x))
        fx0s.append(fx0)
        ys.append(y)
    end = time.time()
    print('Label Mapping\t'
        f'Time {end-start:.2f}')
    fx0s = torch.cat(fx0s).cpu().float()
    ys = torch.cat(ys).cpu().int()
    if ys.size(0) != fx0s.size(0):
        assert fx0s.size(0) % ys.size(0) == 0
        ys = ys.repeat(int(fx0s.size(0) / ys.size(0)))
    dist_matrix = get_dist_matrix(fx0s, ys)
    pairs = torch.nonzero(predictive_distribution_based_multi_label_mapping(dist_matrix, mapping_num))
    mapping_sequence = pairs[:, 0][torch.sort(pairs[:, 1]).indices.tolist()]

    return mapping_sequence
