import pdb
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from Model.sinkhorn import SinkhornDistance


class ScoreFunction(nn.Module):
    def __init__(self, model_config):
        super(ScoreFunction, self).__init__()
        self.basis_num = model_config['basis_num']

    def forward(self, x_input, x_pred):
        x_input = x_input.unsqueeze(1).repeat(1, self.basis_num, 1)
        sub_result = x_pred - x_input
        mse = torch.norm(sub_result, p=2, dim=2)
        mse_score = torch.mean(mse, dim=1,keepdim=True)
        return mse_score

class OTScoreFunction(nn.Module):
    def __init__(self, model_config):
        super(OTScoreFunction, self).__init__()
        self.basis_num = model_config['basis_num']
        self.sinkhorn = SinkhornDistance(eps=0.1, max_iter=200, reduction=None, dis='euc').cuda()

    def forward(self, prototype, match_feature, prototype_type):
        if prototype_type == 'feature':
            prototype = prototype.view(prototype.shape[0], -1)
            match_feature = match_feature.view(match_feature.shape[0], -1)
            C, pi, cost = self.sinkhorn(match_feature, prototype)
            cost = torch.min(C * pi, dim=-1).values
            cost.unsqueeze_(1)
        elif prototype_type == 'attention':
            prototype = prototype.permute(1, 0, 2 )
            match_feature = match_feature.reshape(4, -1, match_feature.shape[1])
            cost = 0
            for i in range(4):
                C, pi, tmp_cost = self.sinkhorn(match_feature[i], prototype[i])
                # NOTE Need to get the min match cost
                tmp_cost = torch.min(C * pi, dim=-1).values
                cost += tmp_cost
            cost /= 4
            cost = cost.reshape(-1, self.basis_num).sum(dim=1, keepdim=True)
        return cost
