import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from Model.sinkhorn import SinkhornDistance


class LossFunction(nn.Module):
    def __init__(self, model_config):
        super(LossFunction, self).__init__()
        self.basis_num = model_config['basis_num']
        self.basis_type = model_config['basis_type']
        self.attention_type = model_config['attention_type']
        self.otloss_feature = OTLoss("feature", model_config['basis_num'])
        self.otloss_attention = OTLoss("attention", model_config['basis_num'])
        self.orthogonalloss = OrthogonalLoss()
        

    def forward(self, x_input, x_pred, masks, feature_map, feature_prototype, attention_map, attention_prototype):
        x_input = x_input.unsqueeze(1).repeat(1, self.basis_num, 1)
        sub_result = x_pred - x_input
        mse = torch.norm(sub_result, p=2, dim=2)
        mse_score = torch.mean(mse, dim=1, keepdim=True)
        e = torch.mean(mse_score)

        ortho_loss = self.orthogonalloss(feature_prototype) * 0.1
        loss = torch.mean(e) + ortho_loss

        feature_loss = self.otloss_feature(feature_prototype, feature_map)
        attention_loss = self.otloss_attention(attention_prototype, attention_map)

        return loss, torch.mean(e),torch.mean(ortho_loss) , torch.sum(feature_loss), torch.sum(attention_loss)
        
        

class OrthogonalLoss(nn.Module):
    def __init__(self):
        super(OrthogonalLoss, self).__init__()

    def forward(self, z):
        if isinstance(z, nn.Parameter):
            z = z.data
        z = z.reshape(z.shape[0], -1)
        r_1 = torch.sqrt(torch.sum(z.float()**2,dim=1,keepdim=True))
        topic_metrix = torch.mm(z.float(), z.T.float()) / torch.mm(r_1, r_1.T)
        topic_metrix = torch.clamp(topic_metrix.abs(), 0, 1)

        l1 = torch.sum(topic_metrix.abs())
        l2 = torch.sum(topic_metrix ** 2)

        loss_sparse = l1 / l2
        loss_constraint = torch.abs(l1 - topic_metrix.shape[0])

        return loss_sparse + 0.5*loss_constraint


class OTLoss(nn.Module):
    def __init__(self, prototype_type, basis_num):
        super(OTLoss, self).__init__()
        self.basis_num = basis_num
        self.prototype_type = prototype_type
        self.sinkhorn = SinkhornDistance(eps=0.1, max_iter=200, reduction=None, dis='euc').cuda()
    
    def forward(self, prototype, match_feature):
        if self.prototype_type == 'feature':
            prototype = prototype.view(prototype.shape[0], -1)
            match_feature = match_feature.view(match_feature.shape[0], -1)
            C, pi, cost = self.sinkhorn(match_feature, prototype)
            cost, _ = torch.min(C * pi, dim=-1)
            
        elif self.prototype_type == 'attention':
            prototype = prototype.permute(1, 0, 2 )
            match_feature = match_feature.reshape(4, -1, match_feature.shape[1])
            cost = 0
            for i in range(4):
                C, pi, tmp_cost = self.sinkhorn(match_feature[i], prototype[i])
                tmp_cost, _ = torch.min(C * pi, dim=-1)
                tmp_cost = torch.sum(tmp_cost)
                cost += tmp_cost
            cost /= 4
        return cost

