import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from attrdict import AttrDict

from .tnp import TNP


class TNPD(TNP):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=True
    ):
        super(TNPD, self).__init__(
            dim_x,
            dim_y,
            d_model,
            emb_depth,
            dim_feedforward,
            nhead,
            dropout,
            num_layers,
            bound_std
        )

        self.predictor = nn.Sequential(
            nn.Linear(d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, dim_y*2)
        )

    def forward(self, batch, reduce_ll=True):
        pred_tar = self.predict(batch.xc, batch.yc, batch.x)
        outs = AttrDict()
        ll = pred_tar.log_prob(batch.y).sum(-1)
        if self.training:
            outs.loss = -ll.mean()
        else:
            num_ctx = batch.xc.shape[-2]
            outs.mse = torch.mean((pred_tar.mean - batch.y) ** 2) 

            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
                # outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1).mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
                # outs.tar_ll = pred_tar.log_prob(batch.yt).sum(-1)
        # outs.loss = - (outs.tar_ll)

        return outs

    def predict(self, xc, yc, xt):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]), device='cuda')

        z_target = self.encode(batch, autoreg=False)
        out = self.predictor(z_target)

        mean, std = torch.chunk(out, 2, dim=-1)
        if self.bound_std:
            std = 0.1 + 0.9 * F.softplus(std)
        else:
            std = torch.exp(std)

        return Normal(mean, std)

    def calculate_crps(self, y_true, means, stds):
        y_true = y_true.unsqueeze(-1)
        means = means.unsqueeze(-1)
        stds = stds.unsqueeze(-1)

        z = (y_true - means) / stds
        cdf_z = 0.5 * (1 + torch.erf(z / torch.sqrt(torch.tensor(2.0, device=z.device))))
        pdf_z = torch.exp(-0.5 * z**2) / torch.sqrt(torch.tensor(2 * torch.pi, device=z.device))

        crps = stds * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi, device=z.device)))
        return crps.mean(dim=-1)
    
    def crps(self, batch, num_samples=None):
        outs = AttrDict()
        
        y = batch.y
        py = self.predict(batch.xc, batch.yc, batch.x)
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 
        
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        lower_bounds_ctx = ctx_means - z_score * ctx_stds
        upper_bounds_ctx = ctx_means + z_score * ctx_stds
        lower_bounds_tar = tar_means - z_score * tar_stds
        upper_bounds_tar = tar_means + z_score * tar_stds

        outs.ctx_ci = ((y_ctx >= lower_bounds_ctx) & (y_ctx <= upper_bounds_ctx)).float().mean()
        outs.tar_ci = ((y_tar >= lower_bounds_tar) & (y_tar <= upper_bounds_tar)).float().mean()

        ctx_crps = self.calculate_crps(y_ctx, ctx_means, ctx_stds)
        tar_crps = self.calculate_crps(y_tar, tar_means, tar_stds)
        
        outs.ctx_crps = ctx_crps.mean()
        outs.tar_crps = tar_crps.mean()
        
        return outs