import torch
import torch.nn as nn
from torch.distributions import kl_divergence
from attrdict import AttrDict
import math

from ..utils.misc import stack, logmeanexp

from .modules import CrossAttnEncoder_Dim, PoolingEncoder_Dim, Decoder_Dim
from .attention import SelfAttn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe = torch.zeros(max_seq_len, d_model)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pe = torch.cat([self.pe[:x.size(-2)-1, :], self.pe[-1, :].unsqueeze(0)])
        x = x + pe
        return x

class DimensionAggregator(nn.Module):
    def __init__(self, dim_hid, dim_out, max_seq_len=101):
        super(DimensionAggregator, self).__init__()
        self.dim_hid = dim_hid
        self.dim_out = dim_out
        self.positional_encoding = PositionalEncoding(self.dim_hid, max_seq_len)
        self.linear = nn.Linear(1, self.dim_hid)
        self.selfattention = SelfAttn(self.dim_hid, self.dim_out)

    def forward(self, data_xy):
        data_xy_unsqueeze = data_xy.unsqueeze(-1) # [B, num_data, dim_x+1, 1]
        data_xy_linear = self.linear(data_xy_unsqueeze) # [B, num_data, dim_x+1, dim_hid]
        # data_xy_positional = data_xy_linear # [B, num_data, dim_x+1, dim_hid]
        data_xy_positional = self.positional_encoding(data_xy_linear) # [B, num_data, dim_x+1, dim_hid]
        data_xy_selfattn = self.selfattention(data_xy_positional.reshape(-1, data_xy_positional.shape[-2], data_xy_positional.shape[-1])) # [B * num_data, dim_x+1, dim_out]
        data_xy_selfattn = data_xy_selfattn.reshape(data_xy_positional.shape[0], data_xy_positional.shape[1], data_xy_positional.shape[2], data_xy_selfattn.shape[-1]) # [B, num_data, dim_x+1, dim_out]
        data_x_selfattn, data_y_selfattn = data_xy_selfattn.split([data_xy_positional.shape[2]-1,1],dim=-2) # [B, num_data, dim_x, dim_out], [B, num_data, 1, dim_out]
        data_x_agg = data_x_selfattn.mean(dim=-2, keepdim=True) # [B, num_data, 1, dim_out]
        data_xy_agg = torch.cat([data_x_agg, data_y_selfattn], dim=-2) # [B, num_data, 2, dim_out]])
        return data_xy_agg


class DDANP(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            dim_lat=128,
            enc_v_depth=4,
            enc_qk_depth=2,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()

        self.denc1 = SelfAttn(
                dim_in=dim_hid, dim_out=dim_hid)
        
        self.denc2 = CrossAttnEncoder_Dim(
                dim_hid=dim_hid)

        self.lenc = PoolingEncoder_Dim(
                dim_x=dim_hid,
                dim_y=dim_hid,
                dim_hid=dim_hid,
                dim_lat=dim_lat,
                self_attn=True,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder_Dim(
                dim_y=dim_y,
                dim_enc=dim_hid,
                dim_lat=dim_lat,
                dim_hid=dim_hid,
                depth=dec_depth)

        self.dim_agg = DimensionAggregator(dim_hid, dim_hid)

    def predict(self, xc, yc, xt, z=None, num_samples=None):
        yt_fake = torch.zeros([*xt.shape[:-1],1]).to(xt.device)
        target_xy = torch.cat([xt, yt_fake], dim=-1) # [B, num_target, dim_x+1]
        context_xy = torch.cat([xc, yc], dim=-1) # [B, num_context, dim_x+1]
        target_xy_agg = self.dim_agg(target_xy) # [B, num_target, 2, dim_hid]
        context_xy_agg = self.dim_agg(context_xy) # [B, num_context, 2, dim_hid]
        query = target_xy_agg[:, :, 0, :] # [B, num_target, dim_hid]
        key = context_xy_agg[:, :, 0, :] # [B, num_context, dim_hid]
        value = context_xy_agg[:, :, 1, :] # [B, num_context, dim_hid]
        num_targ, num_cont = query.shape[-2], key.shape[-2]
        querykeyvalue = self.denc1(torch.cat([query, key, value],dim=-2))
        query, key, value = querykeyvalue[:,:num_targ,:], querykeyvalue[:,num_targ:num_targ+num_cont,:], querykeyvalue[:,num_targ+num_cont:,:]
        theta = stack(self.denc2(query, key, value), num_samples)
        if z is None:
            pz = self.lenc(key, value)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])
        z = stack(z, query.shape[-2], -2)
        encoded = torch.cat([theta, z], -1)
        return self.dec(encoded, stack(query, num_samples))

    def sample(self, xc, yc, xt, z=None, num_samples=None):
        pred_dist = self.predict(xc, yc, xt, z, num_samples)
        return pred_dist.loc

    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        if self.training:
            context_xy = torch.cat([batch.xc, batch.yc], dim=-1)
            target_xy = torch.cat([batch.x, batch.y], dim=-1)
            context_xy_agg = self.dim_agg(context_xy)
            target_xy_agg = self.dim_agg(target_xy)
            cx = context_xy_agg[:, :, 0, :]
            cy = context_xy_agg[:, :, 1, :]
            tx = target_xy_agg[:, :, 0, :]
            ty = target_xy_agg[:, :, 1, :]
            pz = self.lenc(cx, cy)
            qz = self.lenc(tx, ty)
            z = qz.rsample() if num_samples is None else \
                    qz.rsample([num_samples])
            py = self.predict(batch.xc, batch.yc, batch.x,
                    z=z, num_samples=num_samples)

            if num_samples > 1:
                # K * B * N
                recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                # K * B
                log_qz = qz.log_prob(z).sum(-1)
                log_pz = pz.log_prob(z).sum(-1)

                # K * B
                log_w = recon.sum(-1) + log_pz - log_qz

                outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
            else:
                outs.recon = py.log_prob(batch.y).sum(-1).mean()
                outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]

        else:
            py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
            if num_samples is None:
                ll = py.log_prob(batch.y).sum(-1)
            else:
                y = torch.stack([batch.y]*num_samples)
                if reduce_ll:
                    ll = logmeanexp(py.log_prob(y).sum(-1))
                else:
                    ll = py.log_prob(y).sum(-1)
            num_ctx = batch.xc.shape[-2]

            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]

        return outs
