import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class MLP(nn.Module):
    """Just  an MLP"""
    def __init__(self, n_inputs, k = 1):
        super(MLP, self).__init__()
        self.input = nn.Linear(n_inputs, n_inputs)
        self.dropout = nn.Dropout(0.2)
        self.hiddens = nn.ModuleList([
            nn.Linear(n_inputs, n_inputs)
            for _ in range(k)])
        self.output = nn.Linear(n_inputs, n_inputs)

    def forward(self, x):
        x = self.input(x)
        x = self.dropout(x)
        x = F.relu(x)
        for hidden in self.hiddens:
            x = hidden(x)
            x = self.dropout(x)
            x = F.relu(x)
        x = self.output(x)
        return x
    
class CNN_encoder(nn.Module):
    def __init__(self, n_inputs):
        super(CNN_encoder, self).__init__()
        self.conv1 = nn.Conv2d(2, 128, 3, 4, padding=1)
        self.conv2 = nn.Conv2d(128, 512, 3, stride=4, padding=1)
        self.bn0 = nn.GroupNorm(8, 128)
        self.bn1 = nn.GroupNorm(8, 512)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, x):
        x = x.reshape(x.shape[0], 2, 16, 16)

        x = self.conv1(x)
        x = F.relu(x)
        x = self.bn0(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.bn1(x)


        x = self.avgpool(x).squeeze()

        return x
    
class PositionalEncoding(nn.Module):
    def __init__(self,d_model, dropout=0.1,max_len=5000):
        super(PositionalEncoding,self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0,max_len,dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:,0::2] = torch.sin(position * div_term)
        pe[:,1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0,1)
        self.register_buffer('pe',pe)
    def forward(self,x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
class IndexNet_Transformer(nn.Module):
    def __init__(self, in_dim=512):
        super(IndexNet_Transformer, self).__init__()
        self.in_dim = in_dim
        # self.embedding_1 = MLP(in_dim, 3)
        # self.embedding_2 = MLP(in_dim, 3)

        # CNN Encoder
        # self.embedding_1 = CNN_encoder(in_dim)
        # self.embedding_2 = CNN_encoder(in_dim)

        # Transformer Encoder
        dim = 32
        self.posiontion_1 = PositionalEncoding(d_model=dim)
        encode_layers_1 = nn.TransformerEncoderLayer(d_model=dim, nhead=8)
        self.embedding_1 = nn.TransformerEncoder(encode_layers_1, num_layers=3)


        self.reconstruction = nn.Linear(in_dim * 2, in_dim, bias=False)
        self.projection = nn.Linear(in_dim, in_dim)

    def forward(self, x):
        x = torch.cat([x, x], dim=1)
        x = x.reshape(x.shape[0], 32, 32)
        x = x.permute(1, 0, 2)
        x = x * math.sqrt(32)
        x_1 = self.posiontion_1(x)
        x_1 = self.embedding_1(x_1)
        x_1 = x_1.permute(1, 0, 2)
        x_1 = x_1.reshape(x_1.shape[0], x_1.shape[1] * x_1.shape[2])
        return x_1[:, :512], x_1[:, 512:]
    
    def reconstruction_forward(self, x_1, x_2):
        x = torch.cat([x_1, x_2], dim=1)
        # x = self.fusion(x)
        x = self.reconstruction(x)
        return x
    
    def projection_forward(self, x):
        x = self.projection(x)
        return x

class IndexNet_MLP(nn.Module):
    def __init__(self, in_dim=512):
        super(IndexNet_MLP, self).__init__()
        self.in_dim = in_dim
        self.embedding_1 = MLP(in_dim, 3)
        self.embedding_2 = MLP(in_dim, 3)

        # CNN Encoder
        # self.embedding_1 = CNN_encoder(in_dim)
        # self.embedding_2 = CNN_encoder(in_dim)

        # Transformer Encoder
        # dim = 128
        # self.posiontion_1 = PositionalEncoding(d_model=dim)
        # self.posiontion_2 = PositionalEncoding(d_model=dim)
        # encode_layers_1 = nn.TransformerEncoderLayer(d_model=dim, nhead=2)
        # encode_layers_2 = nn.TransformerEncoderLayer(d_model=dim, nhead=2)
        # self.embedding_1 = nn.TransformerEncoder(encode_layers_1, num_layers=3)
        # self.embedding_2 = nn.TransformerEncoder(encode_layers_2, num_layers=3)

        # Linear Encoder
        # self.embedding_1 = nn.Linear(in_dim, in_dim)
        # self.embedding_2 = nn.Linear(in_dim, in_dim)

        self.reconstruction = nn.Linear(in_dim * 2, in_dim)
        self.projection = nn.Linear(in_dim, in_dim)

    def forward(self, x):
        # x = x.reshape(x.shape[0], 4, 128)
        # x = x.permute(1, 0, 2)
        # x = x * math.sqrt(128)
        # x_1 = self.posiontion_1(x)
        # x_2 = self.posiontion_2(x)
        # x_1 = self.embedding_1(x_1)
        # x_2 = self.embedding_2(x_2)
        # x_1 = x_1.permute(1, 0, 2)
        # x_1 = x_1.reshape(x_1.shape[0], x_1.shape[1] * x_1.shape[2])
        # x_2 = x_2.permute(1, 0, 2)
        # x_2 = x_2.reshape(x_2.shape[0], x_2.shape[1] * x_2.shape[2])


        x_1 = self.embedding_1(x)
        x_2 = self.embedding_2(x)
        return x_1, x_2
    
    def reconstruction_forward(self, x_1, x_2):
        x = torch.cat([x_1, x_2], dim=1)
        # x = self.fusion(x)
        x = self.reconstruction(x)
        return x
    
    def projection_forward(self, x):
        x = self.projection(x)
        return x