import torch
import torch.nn as nn

from ..utils.misc import stack, logmeanexp
from torch.distributions import kl_divergence
from .modules import build_mlp
from .modules import TTPoolingEncoder_Dim
from .attention import SelfAttn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x, y_dim):
        pe = torch.cat([self.pe[:x.size(-2)-y_dim, :], self.pe[-y_dim, :].unsqueeze(0)])
        # pe = torch.cat([self.pe[1:x.size(-2), :], self.pe[0, :].unsqueeze(0)])
        x = x + pe
        return x
    
class TargetPositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(TargetPositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, y_dim):
        pe = self.pe[-y_dim, :].unsqueeze(0)
        # pe = torch.cat([self.pe[1:x.size(-2), :], self.pe[0, :].unsqueeze(0)])
        return pe

class DimensionAggregator(nn.Module):
    def __init__(self, dim_hid, dim_out, max_seq_len=101):
        super(DimensionAggregator, self).__init__()
        self.dim_hid = dim_hid
        self.dim_out = dim_out
        self.positional_encoding = PositionalEncoding(self.dim_hid, max_seq_len)
        self.linear = nn.Linear(1, self.dim_hid)
        self.selfattention = SelfAttn(self.dim_hid, self.dim_out)

    def forward(self, data_xy, y_dim):
        data_xy_unsqueeze = data_xy.unsqueeze(-1) # [B, num_data, dim_x+1, 1]
        data_xy_linear = self.linear(data_xy_unsqueeze) # [B, num_data, dim_x+1, dim_hid]
        # data_xy_positional = data_xy_linear # [B, num_data, dim_x+1, dim_hid]
        data_xy_positional = self.positional_encoding(data_xy_linear, y_dim) # [B, num_data, dim_x+1, dim_hid]
        data_xy_selfattn = self.selfattention(data_xy_positional.reshape(-1, data_xy_positional.shape[-2], data_xy_positional.shape[-1])) # [B * num_data, dim_x+1, dim_out]
        data_xy_selfattn = data_xy_selfattn.reshape(data_xy_positional.shape[0], data_xy_positional.shape[1], data_xy_positional.shape[2], data_xy_selfattn.shape[-1]) # [B, num_data, dim_x+1, dim_out]
        data_x_selfattn, data_y_selfattn = data_xy_selfattn.split([data_xy_positional.shape[2]-y_dim,y_dim],dim=-2) # [B, num_data, dim_x, dim_out], [B, num_data, 1, dim_out]
        data_x_agg = data_x_selfattn.mean(dim=-2, keepdim=True) # [B, num_data, 1, dim_out]
        data_y_agg = data_y_selfattn.mean(dim=-2, keepdim=True)
        data_xy_agg = torch.cat([data_x_agg, data_y_agg], dim=-1).squeeze(-2) # [B, num_data, 2, dim_out]])
        return data_xy_agg

class DDLTTTANP(nn.Module):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std
    ):
        super(DDLTTTANP, self).__init__()
        self.dim_agg = DimensionAggregator(int(d_model/2), int(d_model/2))
        # self.tpe = TargetPositionalEncoding(int(d_model/2),max_seq_len=101)
        # self.embedder = build_mlp(dim_x + dim_y, d_model, d_model, emb_depth)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder1 = nn.TransformerEncoder(encoder_layer, num_layers)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder2 = nn.TransformerEncoder(encoder_layer, 2)

        self.bound_std = bound_std
        self.lenc = TTPoolingEncoder_Dim(
                dim_x=int(d_model/2),
                dim_y=int(d_model/2),
                dim_hid=d_model,
                dim_lat=dim_feedforward,
                self_attn=True,
                pre_depth=4,
                post_depth=2)
        self.rnn = nn.RNN(d_model,d_model+dim_feedforward)



    def construct_input(self, batch, autoreg=False):
        x_y_ctx = torch.cat((batch.xc, batch.yc), dim=-1)
        x_0_tar = torch.cat((batch.xt, torch.zeros_like(batch.yt)), dim=-1)

        inp = torch.cat((x_y_ctx, x_0_tar), dim=1)

        return inp

    def create_mask(self, batch, autoreg=False):
        num_ctx = batch.xc.shape[1]
        num_tar = batch.xt.shape[1]
        num_all = num_ctx + num_tar

        mask = torch.zeros(num_all, num_all, device='cuda').fill_(float('-inf'))
        mask[:, :num_ctx] = 0.0

        return mask, num_tar

    def encode(self, batch, z=None, num_samples=None, autoreg=False):
        inp = self.construct_input(batch, autoreg)
        mask, num_tar = self.create_mask(batch, autoreg)
        embeddings = self.dim_agg(inp, batch.yt.shape[-1])
        out = stack(self.encoder1(embeddings, mask=mask), num_samples)
        if z is None:
            pz = self.encoder2(embeddings[:,:batch.xc.shape[1]])
            pz = self.lenc(pz)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])
        z = stack(z, inp.shape[-2], -2)
        out = torch.cat([out, z], -1)
        states = stack(embeddings, num_samples)
        states_shape = states.shape
        out_shape = out.shape
        states = states.reshape((-1, *states.shape[3:]))
        out = out.reshape((-1,*out.shape[3:]))
        y_dim = batch.yt.shape[-1]
        states = stack(states, y_dim)
        out = stack(out, 1)
        result, _ = self.rnn(states, out)
        result = torch.swapaxes(result, 0,1)
        result = result.reshape((*states_shape[:-1],y_dim,out_shape[-1]))

        return result[:, :, -num_tar:]

    def lencode(self, batch, autoreg=False):
        inp = self.construct_input(batch, autoreg)
        embeddings = self.dim_agg(inp,batch.yt.shape[-1])
        pz = self.lenc(embeddings[:,:batch.xc.shape[1]])
        qz = self.lenc(embeddings[:,:batch.x.shape[1]])
        return pz, qz