import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from Model.MaskNets import MultiNets, Generator

from itertools import cycle
from Model.MHSA import MHSA


class Model(nn.Module):
    def __init__(self, model_config):
        super(Model, self).__init__()
        self.hidden_dim =  16
        self.data_dim = model_config['data_dim']
        self.basis_num = model_config['basis_num']
        self.model = model_config['model']
        self.prototype_num = model_config['prototype_num']
        self.maskmodel = Generator(MultiNets(), model_config)
        
        self.feature_prototype = nn.Parameter(torch.randn(self.basis_num, self.data_dim, self.hidden_dim))
        self.attention_prototype = nn.Parameter(torch.randn(5, 4 ,self.data_dim))

        self.npt = NPT(self.data_dim, self.hidden_dim, self.feature_prototype)
        
        

    def forward(self, x_input):
        x_mask, masks = self.maskmodel(x_input)
        
        B, T, D = x_mask.shape # sample, basis_num, data_dim
        x_mask = x_mask.reshape(B*T, D)
        x_pred, feature_map, attention_map = self.npt(x_mask)
        x_pred = x_pred.reshape(x_input.shape[0], self.basis_num, x_input.shape[-1])
        return x_pred, x_input, feature_map, self.feature_prototype, attention_map, self.attention_prototype


class NPT(nn.Module):
    def __init__(self, data_dim, hidden_dim, prototype):
        super(NPT, self).__init__()
        self.data_dim = data_dim
        self.hidden_dim = hidden_dim
        self.prototype = prototype

        self.in_embedding = nn.ModuleList([nn.Linear(1, self.hidden_dim) for _ in range(self.data_dim)])
        self.encoder, self.decoder = self.get_npt()
        self.out_embedding = nn.ModuleList([nn.Linear(self.hidden_dim, 1) for _ in range(self.data_dim)])

    def build_block(self, DA_dim, AA_dim, AttentionBlocks):
        block = []
        block.append(ReshapeToFlat())
        block.append(next(AttentionBlocks)(DA_dim, DA_dim, DA_dim))
        block.append(ReshapeToNested(self.data_dim))
        block.append(next(AttentionBlocks)(AA_dim, AA_dim, AA_dim))
        return block

    
    def get_npt(self):
        print("Building NPT...")

        AttentionBlocks = cycle([MHSA])
        DA_dim = self.data_dim * self.hidden_dim
        AA_dim = self.hidden_dim
        encoder = self.build_block(DA_dim, AA_dim, AttentionBlocks)
        decoder = self.build_block(DA_dim, AA_dim, AttentionBlocks)
        encoder = nn.Sequential(*encoder)
        decoder = nn.Sequential(*decoder)
        
        return encoder, decoder
    


    def forward(self, x):
        x = [self.in_embedding[i](x[:, i].unsqueeze(1)) for i in range(self.data_dim)]
        x = torch.cat(x, dim=1)
        x = self.encoder(x)
        feature_map = x
        x = self.feature_mask(x) 
        attention_map = None
        for layer in self.decoder:
            x = layer(x)
            if isinstance(layer, MHSA):
                attention_map = layer.mab.attention_map
        
        x = [self.out_embedding[i](x[:, i].unsqueeze(1)) for i in range(self.data_dim)]
        x = torch.cat(x, dim=1)
        x = x.squeeze(2)

        return x, feature_map, attention_map
    

    def feature_mask(self, feature_map):
        batch_size, data_dim, hidden_dim = feature_map.size()
        basis_num = self.prototype.size(0)
        feature_map = feature_map.reshape(feature_map.size(0), -1).unsqueeze(1)
        prototype = self.prototype.reshape(self.prototype.size(0), -1).unsqueeze(0)
        distance = (feature_map - prototype) ** 2
        mean_distance = distance.mean(dim=-1, keepdim=True)
        mask = (distance < mean_distance).float()
        feature_map = feature_map.repeat(1, basis_num, 1)
        masked_feature = feature_map * mask
        return masked_feature.reshape(-1, data_dim, hidden_dim)






class Permute(nn.Module):
    """Permutation as nn.Module to include in nn.Sequential."""
    def __init__(self, idxs):
        super(Permute, self).__init__()
        self.idxs = idxs

    def forward(self, X):
        return X.permute(self.idxs)


class ReshapeToFlat(nn.Module):
    """Reshapes a tensor of shape (N, D, E) to (1, N, D*E)."""
    def __init__(self):
        super(ReshapeToFlat, self).__init__()

    @staticmethod
    def forward(X):
        return X.reshape(1, X.size(0), -1)


class ReshapeToNested(nn.Module):
    """Reshapes a tensor of shape (1, N, D*E) to (N, D, E)."""
    def __init__(self, D):
        super(ReshapeToNested, self).__init__()
        self.D = D

    def forward(self, X):
        return X.reshape(X.size(1), self.D, -1)


class Print(nn.Module):
    def __init__(self):
        super(Print, self).__init__()

    def forward(self, x):
        print('Debug', x.shape)
        return x