import torch
import torch.nn as nn

from .modules import build_mlp
import math

from .modules import CrossAttnEncoder, Decoder, PoolingEncoder, CrossAttnEncoder_Dim, Decoder_Dim
from .attention import SelfAttn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe = torch.cat([self.pe[:x.size(-2)-1, :], self.pe[-1, :].unsqueeze(0)])
        x = x + pe
        return x

class DimensionAggregator(nn.Module):
    def __init__(self, dim_hid, dim_out, nhead, dim_feedforward, dropout, num_layers, max_seq_len=101):
        super(DimensionAggregator, self).__init__()
        self.dim_hid = dim_hid
        self.dim_out = dim_out
        self.positional_encoding = PositionalEncoding(self.dim_hid, max_seq_len)
        self.linear = nn.Linear(1, self.dim_hid)
        # self.selfattention1 = SelfAttn(self.dim_hid, self.dim_out)
        # self.selfattention2 = SelfAttn(self.dim_out, self.dim_out)
        tencoder = nn.TransformerEncoderLayer(self.dim_hid, nhead, dim_feedforward, dropout, batch_first=True)
        # print(num_layers)
        # exit()
        self.tencoder = nn.TransformerEncoder(tencoder, int(num_layers/2))


    def forward(self, data_xy):
        data_xy_unsqueeze = data_xy.unsqueeze(-1) # [B, num_data, dim_x+1, 1]
        data_xy_linear = self.linear(data_xy_unsqueeze) # [B, num_data, dim_x+1, dim_hid]
        # data_xy_positional = data_xy_linear # [B, num_data, dim_x+1, dim_hid]
        data_xy_positional = self.positional_encoding(data_xy_linear) # [B, num_data, dim_x+1, dim_hid]
        data_xy_selfattn = self.tencoder(data_xy_positional.reshape(-1, data_xy_positional.shape[-2], data_xy_positional.shape[-1])) # [B * num_data, dim_x+1, dim_out]
        # data_xy_selfattn = self.selfattention1(data_xy_positional.reshape(-1, data_xy_positional.shape[-2], data_xy_positional.shape[-1])) # [B * num_data, dim_x+1, dim_out]
        # data_xy_selfattn = self.selfattention2(data_xy_selfattn)
        data_xy_selfattn = data_xy_selfattn.reshape(data_xy_positional.shape[0], data_xy_positional.shape[1], data_xy_positional.shape[2], data_xy_selfattn.shape[-1]) # [B, num_data, dim_x+1, dim_out]
        data_x_selfattn, data_y_selfattn = data_xy_selfattn.split([data_xy_positional.shape[2]-1,1],dim=-2) # [B, num_data, dim_x, dim_out], [B, num_data, 1, dim_out]
        data_x_agg = data_x_selfattn.mean(dim=-2) # [B, num_data, 1, dim_out]
        data_y_agg = data_y_selfattn.squeeze(dim=-2)
        data_xy_agg = torch.cat([data_x_agg, data_y_agg], dim=-1) # [B, num_data, 2, dim_out]])
        # data_xy_agg = torch.mean(data_xy_selfattn,dim=-2)
        return data_xy_agg

class DHTLTNP(nn.Module):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std
    ):
        super(DHTLTNP, self).__init__()
        d_rep = 64
        self.dim_agg = DimensionAggregator(int(d_rep/2), int(d_rep/2),nhead, dim_feedforward, dropout, num_layers)

        self.embedder = build_mlp(d_rep, d_model, d_model, emb_depth)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers)

        self.bound_std = bound_std

    def construct_input(self, batch, autoreg=False):
        x_y_ctx = torch.cat((batch.xc, batch.yc), dim=-1)
        x_0_tar = torch.cat((batch.xt, torch.zeros_like(batch.yt)), dim=-1)
        if not autoreg:
            inp = torch.cat((x_y_ctx, x_0_tar), dim=1)
        else:
            if self.training and self.bound_std:
                yt_noise = batch.yt + 0.05 * torch.randn_like(batch.yt) # add noise to the past to smooth the model
                x_y_tar = torch.cat((batch.xt, yt_noise), dim=-1)
            else:
                x_y_tar = torch.cat((batch.xt, batch.yt), dim=-1)
            inp = torch.cat((x_y_ctx, x_y_tar, x_0_tar), dim=1)
        inp = self.dim_agg(inp)
        return inp

    def create_mask(self, batch, autoreg=False):
        num_ctx = batch.xc.shape[1]
        num_tar = batch.xt.shape[1]
        num_all = num_ctx + num_tar
        if not autoreg:
            mask = torch.zeros(num_all, num_all, device='cuda').fill_(float('-inf'))
            mask[:, :num_ctx] = 0.0
        else:
            mask = torch.zeros((num_all+num_tar, num_all+num_tar), device='cuda').fill_(float('-inf'))
            mask[:, :num_ctx] = 0.0 # all points attend to context points
            mask[num_ctx:num_all, num_ctx:num_all].triu_(diagonal=1) # each real target point attends to itself and precedding real target points
            mask[num_all:, num_ctx:num_all].triu_(diagonal=0) # each fake target point attends to preceeding real target points

        return mask, num_tar

    def encode(self, batch, autoreg=False):
        embeddings = self.construct_input(batch, autoreg)
        mask, num_tar = self.create_mask(batch, autoreg)
        embeddings = self.embedder(embeddings)
        out = self.encoder(embeddings, mask=mask)
        return out[:, -num_tar:]
