import torch
import torch.nn as nn

import copy
from ..utils.misc import stack, logmeanexp
from torch.distributions import kl_divergence
from .modules import build_mlp
from .modules import TTPoolingEncoder_Dim
from .attention import SelfAttn
import math

import ipdb 
import torch.nn.functional as F
from torch.distributions import kl_divergence

from torch.distributions.normal import Normal
from attrdict import AttrDict
from ..utils.misc import stack, logmeanexp

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe_1 = torch.zeros(max_seq_len, d_model)
        pe_1[:, 0::2] = torch.sin(position * div_term)
        pe_1[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe_1', pe_1)

        pe_2 = torch.zeros(max_seq_len, d_model)

        pe_2[:, 0::2] = torch.cos(position * div_term)
        pe_2[:, 1::2] = torch.sin(position * div_term)
        self.register_buffer('pe_2', pe_2)

    def forward(self, x, y_dim):
        pe = torch.cat([self.pe_1[:x.size(-2)-y_dim, :], self.pe_2[:y_dim, :]])
        x = x + pe
        return x

class DimensionAggregator(nn.Module):
    def __init__(self, dim_hid, dim_out, max_seq_len=101):
        super(DimensionAggregator, self).__init__()
        self.dim_hid = dim_hid
        self.dim_out = dim_out
        self.positional_encoding = PositionalEncoding(self.dim_hid, max_seq_len)
        self.linear = nn.Linear(1, self.dim_hid)
        self.selfattention = SelfAttn(self.dim_hid, self.dim_out)

    def forward(self, data_xy, y_dim):
        data_xy_unsqueeze = data_xy.unsqueeze(-1) 
        # [B, num_data, dim_x+dim_y, 1]
        data_xy_linear = self.linear(data_xy_unsqueeze) 
        # [B, num_data, dim_x+dim_y, dim_hid]
        data_xy_positional = self.positional_encoding(data_xy_linear, y_dim) 
        # [B, num_data, dim_x+dim_y, dim_hid]
        
        data_xy_selfattn = self.selfattention(data_xy_positional.reshape(-1,\
            data_xy_positional.shape[-2], data_xy_positional.shape[-1])) 
        # [B * num_data, dim_x+dim_y, dim_out]
        
        data_xy_selfattn = data_xy_selfattn.reshape(data_xy_positional.shape[0],\
            data_xy_positional.shape[1], data_xy_positional.shape[2],
            data_xy_selfattn.shape[-1]) 
        # [B, num_data, dim_x+dim_y, dim_out]
        
        data_x_selfattn, data_y_selfattn = \
            data_xy_selfattn.split([data_xy_positional.shape[2]-y_dim,y_dim],
                                   dim=-2) 
        # [B, num_data, dim_x, dim_out], [B, num_data, dim_y, dim_out]
        
        data_x_agg = data_x_selfattn.mean(dim=-2, keepdim=True)
        data_x_expanded = data_x_agg.expand(-1, -1, y_dim, -1)

        data_xy_combined = torch.cat([data_x_expanded, data_y_selfattn], dim=-1)
        return data_xy_combined
    
class DTANP_Y_nolatent_base(nn.Module):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std
    ):
        super(DTANP_Y_nolatent_base, self).__init__()
        self.dim_agg = DimensionAggregator(int(d_model/2), int(d_model/2))

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder1 = nn.TransformerEncoder(encoder_layer, num_layers)

        # encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        # self.encoder2 = nn.TransformerEncoder(encoder_layer, 2)

        self.bound_std = bound_std

    def construct_input(self, batch, autoreg=False):
        x_y_ctx = torch.cat((batch.xc, batch.yc), dim=-1)
        x_0_tar = torch.cat((batch.xt, torch.zeros_like(batch.yt)), dim=-1)

        inp = torch.cat((x_y_ctx, x_0_tar), dim=1)

        return inp

    def create_mask(self, batch, y_dim, autoreg=False):
        num_ctx = batch.xc.shape[1]
        num_tar = batch.xt.shape[1]
        num_all = num_ctx + num_tar
        mask = torch.zeros(y_dim * num_all, y_dim * num_all, device='cuda').fill_(float('-inf'))
        mask[:, :y_dim * num_ctx] = 0.0

        return mask, num_tar 
    
    def encode(self, batch, z=None, num_samples=None, autoreg=False):
        y_dim = batch.yt.shape[-1]
        inp = self.construct_input(batch, autoreg)
        mask, num_tar = self.create_mask(batch, y_dim, autoreg)
        embeddings = self.dim_agg(inp, y_dim)
        embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])

        out = self.encoder1(embeddings, mask)
        # embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])
        # out = stack(self.encoder1(embeddings, mask), num_samples)
        
        # if z is None:
        #     context_embeddings = embeddings[:, :batch.xc.shape[1]*y_dim]
        #     pz = self.encoder2(context_embeddings) 
        #     pz = self.lenc(pz)
        #     z = pz.rsample() if num_samples is None \
        #             else pz.rsample([num_samples])
        
        # z = stack(z, inp.shape[-2], -2)
        # z = z.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z.shape[-1])
        # out = torch.cat([out, z], dim=-1)

        out = out.view(*out.shape[:1], -1, y_dim, out.shape[-1])

        return out[:, -num_tar:,:]

class DTANP_Y_NOLATENT(DTANP_Y_nolatent_base):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=True
    ):
        super(DTANP_Y_NOLATENT, self).__init__(
            dim_x,
            dim_y,
            d_model,
            emb_depth,
            dim_feedforward,
            nhead,
            dropout,
            num_layers,
            bound_std
        )

        self.predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, 2)
        )
    
    def forward(self, batch, num_samples=0, reduce_ll=True):
        pred_tar = self.predict(batch.xc, batch.yc, batch.x)

        outs = AttrDict()
        ll = pred_tar.log_prob(batch.y).sum(-1)
        if self.training:
            outs.loss = -ll.mean()
        else:
            num_ctx = batch.xc.shape[-2]
            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
                # outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1).mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
                # outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
        # outs.loss = - (outs.tar_ll)

        return outs

    def predict(self, xc, yc, xt, num_samples=0):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]), device='cuda')

        z_target = self.encode(batch, autoreg=False)

        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        mean, std = mean.reshape((*mean.shape[:-2],-1)), std.reshape((*std.shape[:-2],-1))
        if self.bound_std:
            std = 0.1 + 0.9 * F.softplus(std)
        else:
            std = torch.exp(std)

        return Normal(mean, std)