import torch
import torch.nn as nn
from torch.distributions import kl_divergence
from attrdict import AttrDict

from torch.distributions import Normal

from ..utils.misc import stack, logmeanexp, log_w_weighted_sum_exp
from ..utils.sampling import sample_subset
from .modules import PoolingEncoder, Decoder



class ForwardTransition(nn.Module):
    def __init__(self, d_model):
        super(ForwardTransition, self).__init__()
        self.predictor = nn.Sequential(
            nn.Linear(d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model)
        )
    
    def forward(self, psi):
        out = self.predictor(psi)
        mean, std = torch.chunk(out, 2, dim=-1)
        # std = 0.01+0.09*torch.sigmoid(std)
        std = 0.01*torch.exp(std)
        return Normal(psi+mean, std)


class BackwardTransition(nn.Module):
    def __init__(self, d_model):
        super(BackwardTransition, self).__init__()
        self.predictor = nn.Sequential(
            nn.Linear(d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model),
            nn.ReLU(),
            nn.Linear(2*d_model, 2*d_model)
        )
    
    def forward(self, psi):
        out = self.predictor(psi)
        mean, std = torch.chunk(out, 2, dim=-1)
        std = 0.01*torch.exp(std)
        return Normal(psi+mean, std)

class NP_TTS_5(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            dim_lat=128,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()

        self.denc = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.lenc = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                dim_lat=dim_lat,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_enc=dim_hid+dim_lat,
                dim_hid=dim_hid,
                depth=dec_depth)
        
        self.forward_transition = ForwardTransition(d_model=dim_lat)
        self.backward_transition = BackwardTransition(d_model=dim_lat)

    def test_time_scaling(self, xc, yc, pz=None, z=None, w=None, T=10, num_samples=50):
        log_w = torch.log(w).cuda()
        log_w_list = []
        for t in range(T):
            forward_normal = self.forward_transition(z)

            new_z = forward_normal.rsample()
            backward_normal = self.backward_transition(new_z)
            update_log_pi   = self.compute_log_pi(T, t+1, new_z, pz, xc, yc, num_samples)
            previous_log_pi = self.compute_log_pi(T, t, z, pz, xc, yc, num_samples)
            
            log_backward = backward_normal.log_prob(z).sum(dim=-1).cuda()
            log_forward  = forward_normal.log_prob(new_z).sum(dim=-1).cuda()
            updated_new_log_w = log_w + update_log_pi + log_backward - previous_log_pi - log_forward

            z = new_z
            log_w = updated_new_log_w - torch.max(updated_new_log_w, dim=0,keepdim=True)[0]
            log_w_list.append(log_w)
            # shifted_w = torch.exp(updated_new_log_w - torch.max(updated_new_log_w, dim=0,keepdim=True)[0])
            # w = shifted_w / shifted_w.sum(dim=0, keepdim=True)
            # z = new_z
            # log_w = torch.log(w).cuda()
        log_w_diff = torch.max(log_w[:,0])-torch.min(log_w[:,0])
        update_log_pi_diff = torch.max(update_log_pi[:,0])-torch.min(update_log_pi[:,0])
        previous_log_pi_diff = torch.max(previous_log_pi[:,0])-torch.min(previous_log_pi[:,0])
        log_backward_diff = torch.max(log_backward[:,0])-torch.min(log_backward[:,0])
        log_forward_diff = torch.max(log_forward[:,0])-torch.min(log_forward[:,0])
        return z,log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff

    def compute_log_pi(self, T, t, z, pz, xc, yc, num_samples):
        py = self.predict(xc, yc, xc, z=z, num_samples=num_samples)
        prior = Normal(torch.zeros([16,128]).cuda(),torch.ones([16,128]).cuda())    
        log_pi = t/T*py.log_prob(yc).sum(dim=-2).squeeze(dim=-1) + t/T*prior.log_prob(z).sum(dim=-1)+(T-t)/T*pz.log_prob(z).sum(dim=-1)
        # print(py.log_prob(yc).sum(dim=-2).squeeze(dim=-1)[:,0])
        # print(prior.log_prob(z).sum(dim=-1)[:,0])
        # print(pz.log_prob(z).sum(dim=-1)[:,0])
        # print(log_pi[:,0])
        return log_pi

    def compute_ess_loss(self, log_w):
        ess = torch.exp(torch.logsumexp(2*log_w, dim=0)-2*torch.logsumexp(log_w, dim=0))
        return ess.mean()

    def predict(self, xc, yc, xt, z=None, num_samples=None):
        theta = stack(self.denc(xc, yc), num_samples)
        if z is None:
            pz = self.lenc(xc, yc)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])
        encoded = torch.cat([theta, z], -1)
        encoded = stack(encoded, xt.shape[-2], -2)

        return self.dec(encoded, stack(xt, num_samples))

    def tts_predict(self, xc, yc, xt, z=None, num_samples=None, T=10):
        theta = stack(self.denc(xc, yc), num_samples)
        if z is None:
            pz = self.lenc(xc, yc)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])

        w = torch.ones([z.shape[0], z.shape[1]])/z.shape[0]

        z, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.test_time_scaling(xc, yc, pz, z, w, num_samples=num_samples,T=T)
        encoded = torch.cat([theta, z], -1)
        encoded = stack(encoded, xt.shape[-2], -2)
        if z.get_device() == '-1':
            log_w = stack(log_w_list[-1], xt.shape[-2],-1).cpu()
        else:
            log_w = stack(log_w_list[-1], xt.shape[-2],-1).cuda()

        return self.dec(encoded, stack(xt, num_samples)), log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff

    def sample(self, xc, yc, xt, z=None, num_samples=None):
        pred_dist = self.predict(xc, yc, xt, z, num_samples)
        return pred_dist.loc

    def forward(self, batch, num_samples=None, reduce_ll=True, test_time_scaling=False, ess=False, T=10, ess_lambda=1.):
        outs = AttrDict()
        if test_time_scaling == False:
            if self.training:
                pz = self.lenc(batch.xc, batch.yc)
                qz = self.lenc(batch.x, batch.y)
                z = qz.rsample() if num_samples is None else \
                        qz.rsample([num_samples])
                py = self.predict(batch.xc, batch.yc, batch.x,
                        z=z, num_samples=num_samples)

                if num_samples > 1:
                    # K * B * N
                    recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                    # K * B
                    log_qz = qz.log_prob(z).sum(-1)
                    log_pz = pz.log_prob(z).sum(-1)

                    # K * B
                    log_w = recon.sum(-1) + log_pz - log_qz

                    outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
                else:
                    outs.recon = py.log_prob(batch.y).sum(-1).mean()
                    outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                    outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]

            else:
                py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
                if num_samples is None:
                    ll = py.log_prob(batch.y).sum(-1)
                else:
                    y = torch.stack([batch.y]*num_samples)
                    if reduce_ll:
                        ll = logmeanexp(py.log_prob(y).sum(-1))
                    else:
                        ll = py.log_prob(y).sum(-1)
                num_ctx = batch.xc.shape[-2]
                if reduce_ll:
                    outs.ctx_ll = ll[...,:num_ctx].mean()
                    outs.tar_ll = ll[...,num_ctx:].mean()
                else:
                    outs.ctx_ll = ll[...,:num_ctx]
                    outs.tar_ll = ll[...,num_ctx:]
            return outs
        else:
            if self.training:
                py, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.tts_predict(batch.xc, batch.yc, batch.x, num_samples=num_samples,T=T)

                if num_samples > 1:
                    # K * B * N

                    recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                    # K * B

                    outs.loss = -log_w_weighted_sum_exp(recon, log_w).mean()
                    
                    if ess ==True:
                        ess_loss_list = []
                        for i in range(T):
                            ess_loss_list.append(self.compute_ess_loss(log_w_list[i]))
                            outs.loss += ess_lambda*self.compute_ess_loss(log_w_list[i])
                        outs.ess_loss = sum(ess_loss_list)/T
                        # outs.ess_loss = self.compute_ess_loss(log_w)
                        # outs.loss += ess_lambda*outs.ess_loss
                    # outs.loss = -log_weighted_sum_exp(log_w, w) / batch.x.shape[-2]
                    outs.log_w_diff = log_w_diff
                    outs.update_log_pi_diff = update_log_pi_diff
                    outs.previous_log_pi_diff = previous_log_pi_diff
                    outs.log_backward_diff = log_backward_diff
                    outs.log_forward_diff = log_forward_diff
            else:
                py, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.tts_predict(batch.xc, batch.yc, batch.x, num_samples=num_samples,T=T)
                y = torch.stack([batch.y]*num_samples)
                if reduce_ll:
                    ll = log_w_weighted_sum_exp(py.log_prob(y).sum(-1), log_w)
                else:
                    ll = py.log_prob(y).sum(-1)
                num_ctx = batch.xc.shape[-2]
                if reduce_ll:
                    outs.ctx_ll = ll[...,:num_ctx].mean()
                    outs.tar_ll = ll[...,num_ctx:].mean()
                else:
                    outs.ctx_ll = ll[...,:num_ctx]
                    outs.tar_ll = ll[...,num_ctx:]
            return outs