import torch.nn.functional as F
import torch
from parse_config import ConfigParser
import torch.nn as nn
import numpy as np

cross_entropy_val = nn.CrossEntropyLoss


class CrossEntropyLossStable(nn.Module):
    def __init__(self, reduction='mean', eps=1e-5):
        super(CrossEntropyLossStable, self).__init__()
        self._name = "Stable Cross Entropy Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)
        self._nllloss = nn.NLLLoss(reduction=reduction)

    def forward(self, outputs, labels):
        return self._nllloss(torch.log(self._softmax(outputs) + self._eps), labels)


class ProbLossStable(nn.Module):
    def __init__(self, reduction='none', eps=1e-5):
        super(ProbLossStable, self).__init__()
        self._name = "Prob Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)
        self._nllloss = nn.NLLLoss(reduction='none')

    def forward(self, outputs, labels):
        return self._nllloss(self._softmax(outputs), labels)


class ProbLossStable2Term(nn.Module):
    def __init__(self, reduction='none', eps=1e-5):
        super(ProbLossStable2Term, self).__init__()
        self._name = "Prob Loss"
        self._eps = eps
        self._softmax = nn.Softmax(dim=-1)

    def forward(self, outputs):
        return self._softmax(outputs)


class FLoss(nn.Module):
    def __init__(self, div, num_examp, num_classes=10, change_var=True, ratio_consistency=0, ratio_balance=0, T=None):
        super(FLoss, self).__init__()
        self.div = div
        self.num_examp = num_examp
        self.num_classes = num_classes
        self.config = ConfigParser.get_instance()
        self.USE_CUDA = torch.cuda.is_available()
        self.criterion = CrossEntropyLossStable() # .to(device)?
        self.criterion_prob = ProbLossStable()
        self.criterion_prob2 = ProbLossStable2Term()
        self.loss_fn_2 = nn.BCELoss(reduction='none')
        self.loss_fn_3 = nn.CrossEntropyLoss()
        self.change_var = change_var
        self.ratio_consistency = ratio_consistency
        self.ratio_balance = ratio_balance
        if change_var:
            a=0
        else:
            if div == 'KL':
                self.activation = lambda x: -torch.mean(x)
                self.conjugate = lambda x: -torch.mean(torch.exp(x - 1.))
            elif div == 'Reverse-KL':
                self.activation = lambda x: -torch.mean(-torch.exp(x))
                self.conjugate = lambda x: -torch.mean(-1. - x)
            elif div == 'Jeffrey':
                self.activation = lambda x: -torch.mean(x)
                self.conjugate = lambda x: -torch.mean(x + torch.mul(x, x) / 4. + torch.mul(torch.mul(x, x), x) / 16.)
            elif div == 'Squared-Hellinger':
                self.activation = lambda x: -torch.mean(1. - torch.exp(x))
                self.conjugate = lambda x: -torch.mean((1. - torch.exp(x)) / (torch.exp(x)))
            elif div == 'Pearson':
                self.activation = lambda x: -torch.mean(x)
                self.conjugate = lambda x: -torch.mean(torch.mul(x, x) / 4. + x)
            elif div == 'Neyman':
                self.activation = lambda x: -torch.mean(1. - torch.exp(x))
                self.conjugate = lambda x: -torch.mean(2. - 2. * torch.sqrt(1. - x))
            elif div == 'Jensen-Shannon':
                self.activation = lambda x: -torch.mean(- torch.log(1. + torch.exp(-x))) - torch.log(torch.tensor(2.))
                self.conjugate = lambda x: -torch.mean(x + torch.log(1. + torch.exp(-x))) + torch.log(torch.tensor(2.))
            elif div == 'Total-Variation':
                self.activation = lambda x: -torch.mean(torch.tanh(x) / 2.)
                self.conjugate = lambda x: -torch.mean(torch.tanh(x) / 2.)
            elif div == 'SL':
                self.activation = lambda x: -torch.mean(1 / (1 + torch.exp(-x)) - 1)
                self.conjugate = lambda x: -torch.mean(x + torch.log(1 + torch.exp(-x) + torch.exp(-x)/(1 + torch.exp(-x))))
            else:
                raise NotImplementedError("[-] Not Implemented f-divergence %s" % div)

    def forward(self, index, outputs, label):  # outputs is the tuple of outputs
        eps = 1e-4
        (out_1, out_2, out_aug) = outputs
        if self.change_var:
            device = "cuda:0"
            alpha = 0.9
            current_batch_size = out_1.shape[0]
            loss = compute_loss_divergence(self.div, out_1, out_2, label, self.num_classes, current_batch_size, alpha, device)
        else:
            prob_reg = -self.criterion_prob(out_1, label.long())  
            loss_regular = self.activation(prob_reg)
            loss_peer = self.conjugate(out_2)  
            loss = loss_regular - loss_peer
        if self.ratio_balance > 0:
            avg_prediction = torch.mean(out_1, dim=0)
            prior_distr = 1.0 / self.num_classes * torch.ones_like(avg_prediction)
            avg_prediction = torch.clamp(avg_prediction, min=eps, max=1.0)
            balance_kl = torch.mean(-(prior_distr * torch.log(avg_prediction)).sum(dim=0))
            loss += self.ratio_balance * balance_kl
        if (self.ratio_consistency > 0):  
            consistency_loss = self.consistency_loss(index, out_1, out_aug)
            loss += self.ratio_consistency * torch.mean(consistency_loss)
      
        return loss

    def consistency_loss(self, index, output1, output2):
        ddiv = "KL"
        if ddiv == "KL":
            if self.num_classes == 10:
                if self.change_var:
                    output1 = obtain_posterior_from_net_out(output1, self.div)
                    output2 = obtain_posterior_from_net_out(output2, self.div)
                
            preds1 = F.softmax(output1, dim=1).detach()

            preds2 = F.log_softmax(output2, dim=1)

            loss_kldiv = F.kl_div(preds2, preds1, reduction='none')
            loss_div = torch.sum(loss_kldiv, dim=1)

        else:

            loss_div = self.kl_loss_compute(output1, output2, reduce=False) + self.kl_loss_compute(output2, output1, reduce=False)
        return loss_div


    def kl_loss_compute(self, pred, soft_targets, reduce=True):

        kl = F.kl_div(F.log_softmax(pred, dim=1), F.softmax(soft_targets, dim=1), reduce=False)
        if reduce:
            return torch.mean(torch.sum(kl, dim=1))
        else:
            return torch.sum(kl, 1)

def jsd(p, q):
    m = 0.5 * (p + q)
    return 0.5 * (torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(m, p) +
                   torch.nn.KLDivLoss(reduction='batchmean', log_target=True)(m, q))


class JSD(nn.Module):
    def __init__(self):
        super(JSD, self).__init__()
        self.kl = nn.KLDivLoss(reduction='batchmean', log_target=True)

    def forward(self, p: torch.tensor, q: torch.tensor):
        p, q = p.view(-1, p.size(-1)), q.view(-1, q.size(-1))
        m = (0.5 * (p + q)).log()
        return 0.5 * (self.kl(p.log(), m) + self.kl(q.log(), m))


def obtain_posterior_from_net_out(D, cost_function_v):
    if cost_function_v == "Jensen-Shannon" or cost_function_v == "SL":
        R = (1-D)/(D)
    elif cost_function_v == "KL":
        R = torch.exp(D)  
    elif cost_function_v == "JS_s" or cost_function_v == "SL_s" or cost_function_v == "JS_s2":
        R = D

    return R


class FLossBias(nn.Module):
    def __init__(self, div, num_examp, num_classes=10, T=None, change_var=True, ratio_consistency=0, ratio_balance=0, noise_type="custom_T"):
        super(FLossBias, self).__init__()
        self.div = div
        self.num_examp = num_examp
        self.num_classes = num_classes
        self.config = ConfigParser.get_instance()
        self.USE_CUDA = torch.cuda.is_available()
        self.criterion = CrossEntropyLossStable() 
        self.criterion_prob = ProbLossStable()
        self.criterion_prob2 = ProbLossStable2Term()
        self.loss_fn_2 = nn.BCELoss(reduction='none')
        self.loss_fn_3 = nn.CrossEntropyLoss()
        self.change_var = change_var
        self.T = torch.tensor(T)
        self._nllloss = nn.NLLLoss(reduction='none')
        self.noise_type = noise_type
        if self.change_var:
            a = 0
        else:
            if div == 'KL':
                self.activation = lambda x: -torch.mean(x)
                self.conjugate = lambda x: -torch.mean(torch.exp(x - 1.))
            elif div == 'Reverse-KL':
                self.activation = lambda x: -torch.mean(-torch.exp(x))  
                self.conjugate = lambda x: -torch.mean(-1. - x)  
            elif div == 'Jeffrey':
                self.activation = lambda x: -torch.mean(x)
                self.conjugate = lambda x: -torch.mean(x + torch.mul(x, x) / 4. + torch.mul(torch.mul(x, x), x) / 16.)
            elif div == 'Squared-Hellinger':
                self.activation = lambda x: -torch.mean(1. - torch.exp(x))  
                self.conjugate = lambda x: -torch.mean((1. - torch.exp(x)) / (torch.exp(x)))  
            elif div == 'Pearson':
                self.activation = lambda x: -torch.mean(x)
                self.conjugate = lambda x: -torch.mean(torch.mul(x, x) / 4. + x)
            elif div == 'Neyman':
                self.activation = lambda x: -torch.mean(1. - torch.exp(x))
                self.conjugate = lambda x: -torch.mean(2. - 2. * torch.sqrt(1. - x))
            elif div == 'Jensen-Shannon':
                self.activation = lambda x: -torch.mean(- torch.log(1. + torch.exp(-x))) - torch.log(torch.tensor(2.))
                self.conjugate = lambda x: -torch.mean(x + torch.log(1. + torch.exp(-x))) + torch.log(torch.tensor(2.))
            elif div == 'Total-Variation':
                self.activation = lambda x: -torch.mean(torch.tanh(x) / 2.)
                self.conjugate = lambda x: -torch.mean(torch.tanh(x) / 2.)
            elif div == 'SL':
                self.activation = lambda x: -torch.mean(1 / (1 + torch.exp(-x)) - 1)
                self.conjugate = lambda x: -torch.mean(x + torch.log(1 + torch.exp(-x) + torch.exp(-x)/(1 + torch.exp(-x))))
            else:
                raise NotImplementedError("[-] Not Implemented f-divergence %s" % div)

    def compute_first_term(self, output, label):
        if self.div == 'Jensen-Shannon':
            eps = 1e-4
            return torch.mean(-self._nllloss(torch.log(1 - output +eps), label))

    def compute_second_term(self, output):
        if self.div == 'Jensen-Shannon':
            eps = 1e-4
            return torch.mean(torch.sum(torch.log(output + eps), dim=1))  

    def forward(self, index, outputs, label):  # outputs is the tuple of outputs
        (out_1, out_2, out_aug) = outputs
        if self.change_var:
            device = "cuda:0"
            alpha = 0.9
            current_batch_size = out_1.shape[0]
            loss = compute_loss_divergence(self.div, out_1, out_2, label, self.num_classes, current_batch_size, alpha, device)
        else:
            prob_reg = -self.criterion_prob(out_1, label.long())  
            loss_regular = self.activation(prob_reg)
            loss_peer = self.conjugate(self.criterion_prob2(out_2))  
            loss = loss_regular - loss_peer
        if self.noise_type == "custom_T":
            bias = self.bias_estimate_custom_t(out_1, label)  
        else:
            bias = self.bias_estimate_symm(out_1, label)
        return loss - bias

    def bias_estimate_symm(self, output, label):
        bias = 0
        sum_e = torch.sum(self.T) 
        for k in range(self.num_classes):
            label_tmp = label * 0. + k
            first_term = self.T[k] * self.activation(-self.criterion_prob(output, label_tmp.long()))
            second_term = sum_e * self.conjugate(-self.criterion_prob(output, label_tmp.long()))
            bias += first_term - second_term
        return bias

    def bias_estimate_custom_t(self, output, label):
        bias = 0
        #print("self.T: ", self.T)
        e_js = self.T[0]
        #print("e_js: ", e_js)
        e_js[0] = self.T[1][0]
        #print("e_js: ", e_js)
        e_js = torch.tensor(e_js)
        sum_e = torch.sum(e_js)
        if self.change_var:
            for k in range(self.num_classes):
                label_tmp = label * 0. + k
                first_term = e_js[k] * self.compute_first_term(output, label_tmp.long()) 
                second_term = sum_e * self.compute_second_term(output) 
                bias += first_term - second_term
        else:
            for k in range(self.num_classes):
                label_tmp = label * 0. + k
                first_term = e_js[k] * self.activation(-self.criterion_prob(output, label_tmp.long()))
                second_term = sum_e * self.conjugate(-self.criterion_prob(output, label_tmp.long()))
                bias += first_term - second_term
        return bias

def sl_cost_fcn(out_1, out_2, data_tx, num_classes, alpha):
    #print("data_tx.shape: ", data_tx.shape)
    loss_1 = sl_first(out_1.squeeze(), data_tx, num_classes)
    loss_2 = sl_sec(out_2.squeeze())
    loss = loss_1 + alpha * loss_2
    return loss


def sl_first(y_pred, data_tx, num_classes, t_tensor=True):
    #print("y_pred.shape: ", y_pred.shape)
    loss_1 = torch.matmul(y_pred, torch.transpose(data_tx.float(), 0, 1))
    loss_1 = torch.diagonal(loss_1, 0)
    loss_1 = torch.mean(loss_1)
    return loss_1


def sl_sec(y_pred):
    #print("y_pred_x_y.shape: ", y_pred.shape)
    log_pred = torch.log(y_pred) - y_pred
    sum_log_pred = torch.mean(log_pred, dim=1)
    loss = torch.mean(sum_log_pred)
    return -loss

def to_categorical(y, num_classes, t_tensor=False, dtype="uint8"):
    """ 1-hot encodes a tensor """
    if t_tensor:
        return F.one_hot(y, num_classes=num_classes)
    else:
        return np.eye(num_classes, dtype=dtype)[y.astype(int).squeeze()]

def compute_loss_divergence(cost_function_v, out_1, out_2, data_tx, num_classes, current_batch_size, alpha, device):
    loss_fn = nn.BCELoss()
    loss_fn_2 = nn.BCELoss(reduction='none')
    loss_fn_3 = nn.CrossEntropyLoss()

    data_tx_categorical = torch.Tensor(to_categorical(data_tx, t_tensor=True, num_classes=num_classes))

    if cost_function_v == "Jensen-Shannon":  # GAN
        loss = gan_cost_fcn(out_1, out_2, data_tx_categorical, num_classes, device=device)
    elif cost_function_v == "KL":  # cross-entropy / KL
        loss = loss_fn_3(out_1.squeeze(), data_tx.squeeze().long())
    elif cost_function_v == "SL":  # SL
        loss = sl_cost_fcn(out_1, out_2, data_tx_categorical, num_classes, alpha)
    elif cost_function_v == 7:  # KL with softplus
        loss = cross_entropy_sup(out_1, out_2, data_tx_categorical, num_classes, alpha)
    elif cost_function_v == 9: # RKL
        loss = reverse_kl(out_1, out_2, data_tx_categorical, num_classes, alpha)
    elif cost_function_v == 10: # HD
        loss = hellinger_distance(out_1, out_2, data_tx_categorical, num_classes, alpha)
    elif cost_function_v == 12: # P
        loss = pearson_chi2(out_1, out_2, data_tx_categorical, num_classes, alpha)
    elif cost_function_v == "JS_s":
        loss = js_s_cost_fcn(out_1, out_2, data_tx_categorical, num_classes, device=device)
    elif cost_function_v == "SL_s":
        loss = sl_s_cost_fcn(out_1, out_2, data_tx_categorical, num_classes, device=device)
    elif cost_function_v == "JS_s2":
        loss = js_s2_cost_fcn(out_1, out_2, data_tx_categorical, num_classes, device=device)
    return loss


def gan_cost_fcn(out_1, out_2, digits, num_classes, device="cpu", t_tensor=True):
    loss_fn = nn.BCELoss()
    loss_fn_2 = nn.BCELoss(reduction='none')
    batch_size = out_1.shape[0]
    valid = np.ones((batch_size, num_classes))
    non_valid = np.zeros((batch_size, num_classes))
    loss_1 = loss_fn_2(out_1.squeeze(), torch.Tensor(non_valid).to(device))
    loss_1 = torch.matmul(loss_1, torch.transpose(digits.float(), 0, 1).to(device))
    loss_1 = torch.diagonal(loss_1, 0)
    loss_1 = torch.mean(loss_1)
    loss_2 = loss_fn(out_2.squeeze(), torch.Tensor(valid).to(device))
    loss = loss_1 + loss_2
    return loss


def gan_cost_fcn2(out_1, out_2, digits, num_classes, device="cpu", t_tensor=True):
    eps = 1e-4
    loss_1 = torch.log(1-out_1.squeeze() + eps)
    loss_1 = torch.matmul(loss_1, torch.transpose(digits.float(), 0, 1).to(device))
    loss_1 = torch.diagonal(loss_1, 0)
    loss_1 = torch.mean(loss_1)
    loss_2 = torch.mean(torch.log(out_2.squeeze()+eps)) 
    loss = -(loss_1 + loss_2)
    return loss


def js_s_cost_fcn(out_1, out_2, digits, num_classes, device="cpu", t_tensor=True):
    batch_size = out_1.shape[0]
    loss_1 = torch.log(out_1/(out_1+1))
    loss_1 = torch.matmul(loss_1, torch.transpose(digits.float(), 0, 1).to(device))
    loss_1 = torch.diagonal(loss_1, 0)
    loss_1 = torch.mean(loss_1)
    loss_2 = torch.mean(torch.sum(torch.log(out_1 + 1), dim=1))
    loss = -(loss_1 - loss_2)
    return loss

def js_s2_cost_fcn(out_1, out_2, digits, num_classes, device="cpu", t_tensor=True):
    eps = 1e-8
    s_outputs = F.softmax(out_1, dim=1)
    loss_1 = torch.log((s_outputs + eps) / (s_outputs + 1))  
    loss_1 = torch.matmul(loss_1, torch.transpose(digits.float(), 0, 1).to(device))
    loss_1 = torch.diagonal(loss_1, 0)
    loss_1 = torch.mean(loss_1)
    loss_2 = torch.mean(torch.sum(torch.log(s_outputs + 1), dim=1))  
    loss = -(loss_1 - loss_2)
    return loss


def sl_s_cost_fcn(out_1, out_2, digits, num_classes, device="cpu", t_tensor=True):

    batch_size = out_1.shape[0]
    loss_1 = - 1/(out_1+1) 
    loss_1 = torch.matmul(loss_1, torch.transpose(digits.float(), 0, 1).to(device))
    loss_1 = torch.diagonal(loss_1, 0)
    loss_1 = torch.mean(loss_1)
    loss_2 = torch.mean(torch.sum(torch.log(1/(out_2 + 1)) - 1/(out_2+1), dim=1))
    loss = -(loss_1 + loss_2)
    return loss

