import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

def entropy(outputs, correct_pos, corrupt_pos):
    with torch.no_grad():
        entropy= F.softmax(outputs, dim=1)*F.log_softmax(outputs, dim=1)
    return torch.mean(entropy[correct_pos]), torch.mean(entropy[corrupt_pos])

def offset(outputs, labels, correct_pos, corrupt_pos):
    with torch.no_grad():
        labels = labels.view(-1,1)
        indices = torch.tensor(list(range(0, outputs.size(1))), dtype=torch.long).cuda()
        indices = indices.repeat([outputs.size(0), 1])
        excluded_outputs = outputs[indices!=labels].view([outputs.size(0), outputs.size(1)-1])#.cpu()
        # del indices
        diffs = (outputs[indices==labels] - excluded_outputs.max(dim=1)[0])
    return torch.mean(diffs[correct_pos]), torch.mean(diffs[corrupt_pos])

def compareloss(loss, correct_pos, corrupt_pos):
    return torch.mean(loss[correct_pos]), torch.mean(loss[corrupt_pos])

def get_reweights(reweights, correct_pos, corrupt_pos, ifplot=False, epoch=1, args=None):
    correct_reweights = reweights[correct_pos]
    corrupt_reweights = reweights[corrupt_pos]
    if ifplot:
        n, bins, patches = plt.hist(correct_reweights.cpu().numpy(), 50, density=True, facecolor='g', alpha=0.75)
        plt.title(f'Histogram {epoch} of correct reweights')
        plt.savefig(f"{args.save_plot_path}/epoch{epoch}_correct.png")
        plt.clf()
        n, bins, patches = plt.hist(corrupt_reweights.cpu().numpy(), 50, density=True, facecolor='g', alpha=0.75)
        plt.title(f'Histogram {epoch} of corrupt reweights')
        plt.savefig(f"{args.save_plot_path}/epoch{epoch}_corrupt.png")
        plt.clf()
    return torch.mean(correct_reweights), torch.mean(corrupt_reweights)

def get_all_information(model, proxy, loader, loss_func, correct_pos, corrupt_pos, proxy_input = "loss", ifplot=False, epoch=1, args=None):
    outputs, targets, losses, reweights = [], [], [], []
    model.eval()
    with torch.no_grad():
        for data, target in loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            loss = loss_func(output, target, reduction='none')
            outputs.append(output)
            targets.append(target)
            losses.append(loss)
            margins = margin(output, target).reshape(-1,1)
            # get variance
            variance = torch.var(output, dim=1).reshape(-1,1)
            # get std
            std = torch.std(output, dim=1).reshape(-1,1)
            loss_vector_reshape = loss.reshape(-1,1)
            # gather proxy inputs
            with torch.no_grad():
                if proxy_input == "loss":
                    proxy_input_ = loss_vector_reshape
                elif proxy_input == "logits":
                    proxy_input_ = output
                elif proxy_input == "margin":
                    proxy_input_ = margins
                elif proxy_input == "var":
                    proxy_input_ = variance
                elif proxy_input == "all":
                    proxy_input_ = torch.cat([loss_vector_reshape, output, margins, variance], dim=1)
                elif proxy_input == "out+label+concat":
                    proxy_input_ = torch.cat([F.softmax(output, dim=1), F.one_hot(target, num_classes=10).cuda().float()], dim=1)
                elif proxy_input == "out+label+add":
                    proxy_input_ = F.softmax(output, dim=1) + F.one_hot(target, num_classes=10).cuda().float()
                elif proxy_input == "loss+label":
                    proxy_input_ = torch.cat([loss_vector_reshape, F.one_hot(target, num_classes=10).cuda().float()], dim=1)
                elif args.proxy_input == "loss+std":
                    proxy_input_ = torch.cat([loss_vector_reshape, std], dim=1)
                elif args.proxy_input == "loss+var":
                    proxy_input_ = torch.cat([loss_vector_reshape, variance], dim=1)
                elif proxy_input == "loss+out+label":
                    proxy_input_ = torch.cat([loss_vector_reshape, F.softmax(output, dim=1), F.one_hot(target, num_classes=10).cuda().float()], dim=1)
            reweight = proxy(proxy_input_)
            # reweight = proxy(torch.reshape(loss, (-1, 1)))
            reweights.append(reweight)
    outputs, targets, losses, reweights = torch.cat(outputs, dim=0), torch.cat(targets, dim=0), torch.cat(losses, dim=0), torch.cat(reweights, dim=0)
    print(f"shape out {outputs.size()}, shape targets {targets.size()}, shape losses {losses.size()}")
    entropies = entropy(outputs, correct_pos, corrupt_pos)
    offsets = offset(outputs, targets, correct_pos, corrupt_pos)
    reweights = get_reweights(reweights, correct_pos, corrupt_pos, ifplot=ifplot, epoch=epoch, args=args)
    losses = compareloss(losses, correct_pos, corrupt_pos)
    return entropies, offsets, losses, reweights

def get_all_information_idx(model, proxy, loader, loss_func, correct_pos, corrupt_pos):
    outputs, targets, losses, reweights = [], [], [], []
    with torch.no_grad():
        for data, target, _ in loader:
            data, target = data.cuda(), target.cuda()
            output = model(data)
            loss = loss_func(output, target, reduction='none')
            outputs.append(output)
            targets.append(target)
            losses.append(loss)
            reweight = proxy(torch.reshape(loss, (-1, 1)))
            reweights.append(reweight)
    outputs, targets, losses, reweights = torch.cat(outputs, dim=0), torch.cat(targets, dim=0), torch.cat(losses, dim=0), torch.cat(reweights, dim=0)
    print(f"shape out {outputs.size()}, shape targets {targets.size()}, shape losses {losses.size()}")
    entropies = entropy(outputs, correct_pos, corrupt_pos)
    offsets = offset(outputs, targets, correct_pos, corrupt_pos)
    reweights = get_reweights(reweights, correct_pos, corrupt_pos)
    losses = compareloss(losses, correct_pos, corrupt_pos)
    return entropies, offsets, losses, reweights

def margin(outputs, labels):
    with torch.no_grad():
        labels = labels.view(-1,1)
        indices = torch.tensor(list(range(0, outputs.size(1))), dtype=torch.long).cuda()
        indices = indices.repeat([outputs.size(0), 1])
        excluded_outputs = outputs[indices!=labels].view([outputs.size(0), outputs.size(1)-1])#.cpu()
        # del indices
        diffs = (outputs[indices==labels] - excluded_outputs.max(dim=1)[0])
    return diffs