import torch
import torch.nn as nn
import torch.nn.functional as F
from .backbone import resnet18, resnet34, resnet101
from geomloss import SamplesLoss

def gumbel_softmax(logits: torch.Tensor, tau: float = 1, hard: bool = True, dim: int = -1) -> torch.Tensor:
    # _gumbels = (-torch.empty_like(
    #     logits,
    #     memory_format=torch.legacy_contiguous_format).exponential_().log()
    #             )  # ~Gumbel(0,1)
    # more stable https://github.com/pytorch/pytorch/issues/41663
    gumbel_dist = torch.distributions.gumbel.Gumbel(
        torch.tensor(0., device=logits.device, dtype=logits.dtype),
        torch.tensor(1., device=logits.device, dtype=logits.dtype))
    gumbels = gumbel_dist.sample(logits.shape)
    # if torch.isnan(gumbels).any():
    #     print("gumbels",gumbels)

    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    
    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

def straight_through(logits):
    probs_soft=F.softmax(logits, dim=-1)
    y_hard_soft=torch.zeros_like(probs_soft)
    y_hard_soft.scatter_(-1,probs_soft.argmax(dim=-1, keepdim=True), 1.0)
    y_soft_st=y_hard_soft-probs_soft.detach()+probs_soft
    return y_soft_st

class Sinkhorn(nn.Module):
    def __init__(self, T=1):
        super(Sinkhorn, self).__init__()
        self.T = T
    def sinkhorn_normalized(self,x, n_iters=10):
        for _ in range(n_iters):
            x = x / torch.sum(x, dim=2, keepdim=True) #row
            x = x / torch.sum(x, dim=1, keepdim=True) #col
        return x

    def sinkhorn_loss(self,x, y, epsilon=0.1, n_iters=20):
        Wxy = torch.cdist(x, y, p=1)  
        K = torch.exp(-Wxy / epsilon) 
        P = self.sinkhorn_normalized(K, n_iters)  
        return torch.sum(P * Wxy)
    def forward(self, y_s, y_t, mode="classification"):
        softmax = nn.Softmax(dim=1)
        p_s = softmax(y_s/self.T)
        p_t = softmax(y_t/self.T)
        emd_loss = self.sinkhorn_loss(x=p_s,y=p_t)
        return emd_loss

        
        
class exp_b1(nn.Module):
    def __init__(self, args):
        super(exp_b1, self).__init__()
        
        if args.dataset == 'CREMAD':
            n_classes = 6
        elif args.dataset == 'KineticSound':
            n_classes = 31
        else:
            raise NotImplementedError('Incorrect dataset name {}'.format(args.dataset))
        
        self.n_classes=n_classes
        self.args=args
        self.net_a = resnet18(modality='audio')
        self.net_v= resnet18(modality='visual')
        self.temperature=0.5
        self.feat_dim=512
        
        
        self.classifier_a = nn.Linear(args.embed_dim, n_classes)
        self.classifier_v = nn.Linear(args.embed_dim, n_classes)

        self.proto= nn.Parameter(torch.randn(n_classes, self.feat_dim), requires_grad=True)
        
        self.sinkhornloss=SamplesLoss("sinkhorn", p=1, blur=0.05)
        
    def cosine_sim_dim(self,x,y,eps=1e-8):
        numerator = x * y
        denominator = torch.abs(x) * torch.abs(y) + eps
        return numerator / denominator
    
    def forward(self, audio, visual, label=None, B=16, epoch=150, train=True):
        a = self.net_a(audio)
        a = F.adaptive_avg_pool2d(a, 1)
        a = torch.flatten(a, 1)
        out_a = self.classifier_a(a)
        
        a=F.normalize(a, dim=-1)
        
        v = self.net_v(visual)
        (_, C, H, W) = v.size()
        v = v.view(B, -1, C, H, W)
        v = v.permute(0, 2, 1, 3, 4)
        v = F.adaptive_avg_pool3d(v, 1)
        v = torch.flatten(v, 1)
        
        out_v = self.classifier_v(v)
        v=F.normalize(v, dim=-1)

        proto=self.proto
        proto=F.normalize(proto, dim=-1)
        
        
        fusion_all=[]
        logits_all=[]
        logit= torch.zeros(B, self.n_classes).cuda()
        
        a_sq=a.unsqueeze(1) 
        v_sq=v.unsqueeze(1) 
        
        proto_sq=proto.unsqueeze(0) 

        sim_a_all=self.cosine_sim_dim(a_sq, proto_sq) 
        sim_v_all=self.cosine_sim_dim(v_sq, proto_sq) 
        count=0
            
        for c in range(self.n_classes):
            sim_a=sim_a_all[:,c,:].squeeze(1) 
            sim_v=sim_v_all[:,c,:].squeeze(1) 
            sim=torch.stack([sim_a,sim_v], dim=1).transpose(1,2)

            if train==True:
                pre_sim=gumbel_softmax(sim, tau=0.5)
                if torch.isnan(pre_sim).any():
                    return ValueError("NaN is gumbels")
            else:
                pre_sim=torch.zeros_like(sim).cuda()
                max_idx=sim.argmax(dim=-1, keepdim=True)
                pre_sim.scatter_(-1, max_idx, 1.0)
                
            av=torch.stack([a,v], dim=2)
            fusion_c=(pre_sim*av).sum(dim=2)
            fusion_c=F.normalize(fusion_c, dim=-1)   
             
            fusion_all.append(fusion_c)

        fusion_all=torch.stack(fusion_all, dim=1) 
        proto_list=proto.repeat(B,1,1)
        
        sim_pf=fusion_all@proto_list.transpose(1,2)
        
        max_k,_= sim_pf.max(dim=1, keepdim=True)
        max_k=max_k.squeeze()
        logit_el=max_k.argmax(dim=1)
        
        fusion=fusion_all[torch.arange(B), logit_el]
        logit=fusion@proto.T


        if train==True:
            label_fusion=fusion_all[torch.arange(B), label]
            label_sim=label_fusion@proto.detach().T
            prob=F.softmax(label_sim, dim=-1)

            label_proto=proto_list[torch.arange(B), label].detach()
            target=label_proto@ proto.detach().T 
            target=F.softmax(target, dim=1)

            loss=self.sinkhornloss(prob.contiguous(), target.contiguous())
            loss=loss.mean()
            return a, v,  proto, out_a, out_v, logit, loss
            
        
        else:return a, v,  proto, out_a, out_v, logit
        
        
        
        
