import torch
import torch.nn.functional as F

from torch import nn


class LinearCritic(nn.Module):

    def __init__(self, latent_dim, temperature=1., attention_sim = False):
        super(LinearCritic, self).__init__()
        self.temperature = temperature
        self.projection_dim = 128
        self.w1 = nn.Linear(latent_dim, latent_dim, bias=False)
        self.bn1 = nn.BatchNorm1d(latent_dim)
        self.relu = nn.ReLU()
        self.w2 = nn.Linear(latent_dim, self.projection_dim, bias=False)
        self.bn2 = nn.BatchNorm1d(self.projection_dim, affine=False)
        self.cossim = nn.CosineSimilarity(dim=-1)
        self.softmax = nn.Softmax(dim = -1)
        self.attention_sim = attention_sim
        
        self.W_b = nn.Linear(latent_dim, self.projection_dim, bias = False)

    def project(self, h):
        return self.bn2(self.w2(self.relu(self.bn1(self.w1(h)))))
    
    
    def attention_forward(self, h1, h2):
        z1, z2 = self.project(h1), self.project(h2)

        df = torch.sqrt(torch.tensor(self.projection_dim))
        
        # attention implementation
        sim11 = self.softmax(torch.matmul(z1, z1.T)/ self.temperature)
        sim22 = self.softmax(torch.matmul(z2, z2.T)/ self.temperature)
        sim12 = self.softmax(torch.matmul(z1, z2.T)/ self.temperature)
        
        d = sim12.shape[-1]
        
        diag_tensor = torch.diag(torch.ones(d, device = sim12.device)*float('-inf'))
        sim11 = sim11 + diag_tensor
        sim22 = sim22 + diag_tensor
        
        raw_scores1 = torch.cat([sim12, sim11], dim=-1)
        raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1)
        raw_scores = torch.cat([raw_scores1, raw_scores2], dim=-2)
        targets = torch.arange(2 * d, dtype=torch.long, device=raw_scores.device)
        return raw_scores, targets



    def cosine_forward(self, h1, h2):
        z1, z2 = self.project(h1), self.project(h2)
        sim11 = self.cossim(z1.unsqueeze(-2), z1.unsqueeze(-3)) / self.temperature
        sim22 = self.cossim(z2.unsqueeze(-2), z2.unsqueeze(-3)) / self.temperature
        sim12 = self.cossim(z1.unsqueeze(-2), z2.unsqueeze(-3)) / self.temperature
        d = sim12.shape[-1]
        sim11[..., range(d), range(d)] = float('-inf')
        sim22[..., range(d), range(d)] = float('-inf')
        raw_scores1 = torch.cat([sim12, sim11], dim=-1)
        raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1)
        raw_scores = torch.cat([raw_scores1, raw_scores2], dim=-2)
        targets = torch.arange(2 * d, dtype=torch.long, device=raw_scores.device)
        return raw_scores, targets
    
    def forward(self, h1, h2):
        if self.attention_sim:
            raw_scores, targets = self.attention_forward(h1,h2)
        else:
            raw_scores, targets = self.cosine_forward(h1, h2)
        return raw_scores, targets
        

