import torch
import torch.nn.functional as F

from torch import nn


class FareHead(nn.Module):

    def __init__(self, in_dim, temperature, projection_dim = 3, lsh = 'default', weight_init = 'default', projection_mode = 'linear'):
        super(FareHead, self).__init__()
        self.temperature = temperature
        self.in_dim = in_dim
        self.projection_dim = projection_dim
        self.w1_z = nn.Linear(self.in_dim, self.in_dim, bias=False)
        self.w2_z = nn.Linear(self.in_dim, self.in_dim, bias=False)
        self.softmax = nn.Softmax(dim = -1) 
        self.lsh = lsh
        self.projection_mode = projection_mode

        if self.temperature is None:
            self.temperature = 1 / torch.sqrt(torch.tensor(in_dim))

        self.nonlinearproject1 = nn.Sequential(
            nn.Linear(3, 3),
            nn.ReLU(),
            nn.Linear(3,3),
            nn.ReLU(),
            nn.Linear(3,3),
            nn.BatchNorm1d(3, affine=False))
        
        self.nonlinearproject2 = nn.Sequential(
            nn.Linear(3, 3),
            nn.ReLU(),
            nn.Linear(3,3),
            nn.ReLU(),
            nn.Linear(3,3),
            nn.BatchNorm1d(3, affine=False))


        if weight_init == 'normal': # apply normal init
            self.apply(self._init_norm)

        if weight_init == 'identity': # apply identity init
            with torch.no_grad():
                self.w1_z.weight.copy_(torch.eye(3,3))
                self.w2_z.weight.copy_(torch.eye(3,3))

        
    def _init_norm(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean= 0.0, std = 1.0)
            if module.bias is not None:
                module.bias.data.zero_()


    def nonlinearproject(self, h1, h2):
        return self.nonlinearproject1(h1), self.nonlinearproject2(h2)
    

    def forward(self, h1, h2):

        if self.projection_mode == 'linear':
            # linear projection
            h1, h2 = h1.float(), h2.float() # ensure floats incase of using celeba dataset 
            z1, z2 = self.w1_z(h1), self.w2_z(h2)

        if self.projection_mode == 'nonlinear':
            z1, z2 = self.nonlinearproject(h1, h2)
        
        if self.lsh == 'default':
            # when lsh is on, only require the projections
            return z1, z2
        
        # attention implementation
        sim11 = self.softmax(torch.matmul(z1, z1.T)/ self.temperature)
        sim22 = self.softmax(torch.matmul(z2, z2.T)/ self.temperature)
        sim12 = self.softmax(torch.matmul(z1, z2.T)/ self.temperature)
        
        d = sim12.shape[-1]

        #sim11[..., range(d), range(d)] = float('-inf')
        #sim22[..., range(d), range(d)] = float('-inf')
        
        raw_scores1 = torch.cat([sim12, sim11], dim=-1)
        raw_scores2 = torch.cat([sim22, sim12.transpose(-1, -2)], dim=-1)
        raw_scores = torch.cat([raw_scores1, raw_scores2], dim=-2)
       
        return raw_scores
