import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random

class trans_matrix(nn.Module):
    def __init__(self, num_examp, num_classes=10, ratio_T=0.01, k=1, init=-4.5):
        super(trans_matrix, self).__init__()
        self.num_classes = num_classes
        self.num_examp = num_examp
        self.k = k
        self.ratio_balance = ratio_balance
        self.ratio_T = ratio_T
        self.trans = nn.Parameter(init*torch.ones(num_classes, num_classes, dtype=torch.float32))

        self.u = nn.Parameter(torch.zeros(num_examp, num_classes, dtype=torch.float32))
        self.v = nn.Parameter(torch.zeros(num_examp, num_classes, dtype=torch.float32))
        self.lam = 0.5

        self.init_param(mean=0.0, std=1e-8)

    def init_param(self, mean=0., std=1e-8):
        torch.nn.init.normal_(self.u, mean=mean, std=std)
        torch.nn.init.normal_(self.v, mean=mean, std=std)


    def forward(self, epoch, index, output, label, batch_idx, ep=-1, labels_s=None, P=None):
        eps = 1e-7

        original_prediction = F.softmax(output, dim=1)


        T_diag = torch.eye(self.num_classes, dtype=torch.float32).cuda()
        T_1 = torch.ones(self.num_classes, self.num_classes, dtype=torch.float32).cuda()
        T_other = T_1 - T_diag



        T_w = torch.sigmoid(self.trans)
        T = T_diag + T_w * T_other
        T = F.normalize(T, p=1, dim=1)
        
        
        # if batch_idx % 100 == 0:
        #     print(T)


        n = self.num_classes


        U_square = self.u[index] ** 2
        V_square = self.v[index] ** 2

        U_square = torch.clamp(U_square, 0, 1)
        V_square = torch.clamp(V_square, 0, 1)


        if P is None:
            T_s = T_1 * p / n + (1 - p) * T_diag
        else:
            T_s = torch.tensor(P, dtype=torch.float32).cuda()
        error = torch.sum(torch.abs(T - T_s)) / torch.sum(torch.abs(T_s))

        det_T = T.slogdet().logabsdet
        vol_loss = self.ratio_T * det_T


        if epoch < ep:
            prediction = torch.clamp(original_prediction @ T, min=eps)
            prediction = F.normalize(prediction, p=1, eps=eps)
            prediction = torch.clamp(prediction, min=eps)

            CE_loss = torch.mean(-torch.sum(label * torch.log(prediction), dim=-1))
            MSE_loss = F.mse_loss(prediction, label, reduction='sum') / len(label)
            loss = CE_loss + vol_loss

        else:
            prediction = torch.clamp(original_prediction @ T + U_square - V_square.detach(), min=eps)
            prediction = F.normalize(prediction, p=1, eps=eps)
            prediction = torch.clamp(prediction, min=eps)

            CE_loss = torch.mean(-torch.sum(label * torch.log(prediction), dim=-1))
            loss = CE_loss

            MSE_loss = F.mse_loss((original_prediction.detach() @ T.detach() +
                                   U_square - V_square), label, reduction='sum') / len(label)
            loss = loss + MSE_loss
            loss = loss + vol_loss



        return loss, MSE_loss, CE_loss, vol_loss, error, det_T

