import torch
import torch.nn as nn
import torch.nn.functional as F
import os
import copy
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"


class Features_Projector(nn.Module):
    def __init__(self, model_dim):
        super(Features_Projector, self).__init__()
        self.net = nn.Sequential(
                nn.Linear(in_features=model_dim, out_features=model_dim),
                nn.ReLU(),
                nn.Linear(in_features=model_dim, out_features=model_dim))
        
    def forward(self, x):
        return self.net(x)
    

class EMB_Trainer(nn.Module):
    def __init__(self, base_net, method, model_dim):
        super(EMB_Trainer, self).__init__()

        self.base_net = base_net
        self.projector = Features_Projector(model_dim=model_dim).to(device)

        self.method = method
        
        if self.method == "mix_up":
            self.loss_fn = MixUpLoss(device=device)
            
        else:
            self.loss_fn = self.nll_loss_fn


        print("Number Trainable Parameters: " + str(self.count_parameters()))
        

    def forward(self, x1, x2):
        bs = x1.shape[0]
        
        if self.method == "mix_up":
            betas = np.random.beta(0.2, 0.2, bs)
            betas = torch.tensor(betas).float().view(-1, 1).to(device)
            aug_views = x1 * betas + x2 * (1 - betas)
            tmp_x = torch.cat([x1, x2, aug_views], dim=0)

            features = self.base_net(tmp_x).squeeze()
            features_projection = self.projector(features)

            z1, z2, z_aug = torch.chunk(features_projection, 3, dim = 0)

            loss = self.loss_fn(z_aug, z1, z2, betas)

        else:
            tmp_x = torch.cat([x1, x2], dim=0)
            features = self.base_net(tmp_x).squeeze()
            features_projection = self.projector(features)
            loss = self.loss_fn(features_projection)

        return loss

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def get_projections(self, x):
        return self.base_net(x)
    
    
    def nll_loss_fn(self, feats):

        cos_sim = F.cosine_similarity(feats[:,None,:], feats[None,:,:], dim=-1)
        self_mask = torch.eye(cos_sim.shape[0], dtype=torch.bool, device=cos_sim.device)
        cos_sim.masked_fill_(self_mask, -9e15)
    
        pos_mask = self_mask.roll(shifts=cos_sim.shape[0]//2, dims=0)
    
        cos_sim = cos_sim / 0.1
        nll = -cos_sim[pos_mask] + torch.logsumexp(cos_sim, dim=-1)
        nll = nll.mean()
        return nll

    
    
    


class MixUpLoss(nn.Module):

    def __init__(self, device):
        super(MixUpLoss, self).__init__()
        
        self.tau = 0.5
        self.device = device
        
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, z_aug, z_1, z_2, lam):

        bs = z_1.shape[0]

        z_1 = nn.functional.normalize(z_1)
        z_2 = nn.functional.normalize(z_2)
        z_aug = nn.functional.normalize(z_aug)

        labels_lam_0 = lam * torch.eye(bs, device=self.device)
        labels_lam_1 = (1-lam) * torch.eye(bs, device=self.device)

        labels = torch.cat((labels_lam_0, labels_lam_1), 1)

        logits = torch.cat((torch.mm(z_aug, z_1.T),
                            torch.mm(z_aug, z_2.T)), 1)

        loss = self.cross_entropy(logits / self.tau, labels)

        return loss

    def cross_entropy(self, logits, soft_targets):
        return torch.mean(torch.sum(- soft_targets * self.logsoftmax(logits), 1))
    