import numpy as np
import torch
from torch.nn import functional as F


def mixup_two_pairs(x1, x2, conf_one, conf_zero, alpha=1.0, is_bias=True):
    """
        Returns mixed inputs, pairs of targets, and lambda
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    if is_bias: lam = max(lam, 1 - lam)

    (x1_one, x2_one) = x1
    (x1_zero, x2_zero) = x2
    size_one = x1_one.size(0)
    size_zero = x1_zero.size(0)

    if(size_one < size_zero):
        # print(size_one)
        for i in range(size_zero - size_one):
            idx = torch.randperm(size_one)[0]
            x1_one = torch.cat((x1_one, torch.unsqueeze(x1_one[idx, :], 0)), dim=0)
            x2_one = torch.cat((x2_one, torch.unsqueeze(x2_one[idx, :], 0)), dim=0)
            conf_one = torch.cat((conf_one, torch.unsqueeze(conf_one[idx], 0)), dim=0)

        mixed_x1 = lam * x1_one + (1 - lam) * x1_zero
        mixed_x2 = lam * x2_one + (1 - lam) * x2_zero
        mixup_conf = lam * conf_one + (1 - lam) * conf_zero
    else:
        mixed_x1 = lam * x1_one[:size_zero,:] + (1 - lam) * x1_zero
        mixed_x2 = lam * x2_one[:size_zero,:] + (1 - lam) * x2_zero
        mixup_conf = lam * conf_one[:size_zero] + (1 - lam) * conf_zero

    return mixed_x1, mixed_x2, mixup_conf, lam



# adapted from https://github.com/iBelieveCJM/Tricks-of-Semi-supervisedDeepLeanring-Pytorch
def mixup_two_targets(x, y, alpha=1.0, device='cuda', is_bias=False):
    """
        Returns mixed inputs, pairs of targets, and lambda
    """
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1
    if is_bias: lam = max(lam, 1 - lam)

    index = torch.randperm(x.size(0)).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mixup_ce_loss_with_softmax(preds, targets_a, targets_b, lam):
    """
        mixed categorical cross-entropy loss
    """
    mixup_loss_a = -torch.mean(
        torch.sum(F.softmax(targets_a, 1) * F.log_softmax(preds, dim=1),
                  dim=1))
    mixup_loss_b = -torch.mean(
        torch.sum(F.softmax(targets_b, 1) * F.log_softmax(preds, dim=1),
                  dim=1))

    mixup_loss = lam * mixup_loss_a + (1 - lam) * mixup_loss_b
    return mixup_loss


def mixup_bce(scores, targets_a, targets_b, lam):
    mixup_loss_a = F.binary_cross_entropy(scores, targets_a)
    mixup_loss_b = F.binary_cross_entropy(scores, targets_b)

    mixup_loss = lam * mixup_loss_a + (1 - lam) * mixup_loss_b
    return mixup_loss


def get_mix_up_loss(x, y, model, device):
    mixed_x, y_a, y_b, lam = mixup_two_targets(x,
                                               y,
                                               alpha=1.0,
                                               device=device,
                                               is_bias=False)
    return mixup_ce_loss_with_softmax(model(mixed_x), y_a, y_b, lam)
