import torch
import torch.nn as nn
from torch.distributions import kl_divergence
from attrdict import AttrDict
from torch.distributions import Normal
import numpy as np
import uncertainty_toolbox as uct

from ..utils.misc import stack, logmeanexp

from .modules import CrossAttnEncoder, PoolingEncoder, Decoder

class ANP_UNCERTAINTY(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            dim_lat=128,
            enc_v_depth=4,
            enc_qk_depth=2,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()

        self.denc = CrossAttnEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                v_depth=enc_v_depth,
                qk_depth=enc_qk_depth)

        self.lenc = PoolingEncoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid,
                dim_lat=dim_lat,
                self_attn=True,
                pre_depth=enc_pre_depth,
                post_depth=enc_post_depth)

        self.dec = Decoder(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_enc=dim_hid+dim_lat,
                dim_hid=dim_hid,
                depth=dec_depth)

    def predict(self, xc, yc, xt, z=None, num_samples=None):
        theta = stack(self.denc(xc, yc, xt), num_samples)
        if z is None:
            pz = self.lenc(xc, yc)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])
        z = stack(z, xt.shape[-2], -2)
        encoded = torch.cat([theta, z], -1)
        return self.dec(encoded, stack(xt, num_samples))

    def sample(self, xc, yc, xt, z=None, num_samples=None):
        pred_dist = self.predict(xc, yc, xt, z, num_samples)
        return pred_dist.loc

    def calculate_crps(self, y_true, means, stds):
        y_true = y_true.squeeze(-1)
        means = means.squeeze(-1)
        stds = stds.squeeze(-1)

        z = (y_true - means) / stds
        #cdf_z = 0.5 * (1 + torch.stack([torch.erf(z[i, ...]/torch.sqrt(torch.tensor(2.0, device=z.device))) for i in range(50)]).mean(dim=0))

        cdf_z = 0.5 * (1 + torch.erf(z/torch.sqrt(torch.tensor(2.0, device=z.device))).mean(dim=0))

        pdf_z = (torch.exp(-0.5 * z**2) / torch.sqrt(torch.tensor(2 * torch.pi, device=z.device))).mean(dim=0)

        crps = stds * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi, device=z.device)))
        return crps.mean(dim=-1)
    
    def crps(self, batch, num_samples=None):
        outs = AttrDict()

        if num_samples is None:
            y = batch.y.unsqueeze(-1)
        else:
            y = torch.stack([batch.y]*num_samples)
            
        py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
        
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 

        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]
                
        ctx_crps = self.calculate_crps(y_ctx, ctx_means, ctx_stds)
        tar_crps = self.calculate_crps(y_tar, tar_means, tar_stds)
        
        means = means.mean(dim=0)
        stds = torch.sqrt((stds**2).mean(dim=0) + (py.loc**2).mean(dim=0) - (py.loc.mean(dim=0)**2))
        
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        lower_bounds_ctx = ctx_means - z_score * ctx_stds
        upper_bounds_ctx = ctx_means + z_score * ctx_stds
        lower_bounds_tar = tar_means - z_score * tar_stds
        upper_bounds_tar = tar_means + z_score * tar_stds

        outs.ctx_ci = ((y_ctx >= lower_bounds_ctx) & (y_ctx <= upper_bounds_ctx)).float().mean()
        outs.tar_ci = ((y_tar >= lower_bounds_tar) & (y_tar <= upper_bounds_tar)).float().mean()
        
        outs.ctx_crps = ctx_crps.mean()
        outs.tar_crps = tar_crps.mean()
        
        return outs
            
    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        if self.training:
            pz = self.lenc(batch.xc, batch.yc)
            qz = self.lenc(batch.x, batch.y)
            z = qz.rsample() if num_samples is None else \
                    qz.rsample([num_samples])
            py = self.predict(batch.xc, batch.yc, batch.x,
                    z=z, num_samples=num_samples)

            if num_samples > 1:
                # K * B * N
                recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                # K * B
                log_qz = qz.log_prob(z).sum(-1)
                log_pz = pz.log_prob(z).sum(-1)

                # K * B
                log_w = recon.sum(-1) + log_pz - log_qz

                outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
            else:
                outs.recon = py.log_prob(batch.y).sum(-1).mean()
                outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]

        else:
            py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
            if num_samples is None:
                ll = py.log_prob(batch.y).sum(-1)
            else:
                y = torch.stack([batch.y]*num_samples)
                if reduce_ll:
                    ll = logmeanexp(py.log_prob(y).sum(-1))
                else:
                    ll = py.log_prob(y).sum(-1)
            num_ctx = batch.xc.shape[-2]

            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]

        return outs
    
    def uncertainty(self, batch, num_samples=None):
        outs = AttrDict()

        if num_samples is None:
            y = batch.y.unsqueeze(-1)
        else:
            y = torch.stack([batch.y]*num_samples)
            
        py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
        
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 
        
        means = means.mean(dim=0)
        stds = torch.sqrt((stds**2).mean(dim=0) + (py.loc**2).mean(dim=0) - (py.loc.mean(dim=0)**2))
        
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        ctx_metrics = uct.metrics.get_all_metrics(ctx_means.cpu().numpy()[0,:,0], ctx_stds.cpu().numpy()[0,:,0], y_ctx.cpu().numpy()[0,0,:,0])
        tar_metrics = uct.metrics.get_all_metrics(tar_means.cpu().numpy()[0,:,0], tar_stds.cpu().numpy()[0,:,0], y_tar.cpu().numpy()[0,0,:,0])

        outs.ctx_mae = ctx_metrics['accuracy']['mae']
        outs.ctx_rmse = ctx_metrics['accuracy']['rmse']
        outs.ctx_mdae = ctx_metrics['accuracy']['mdae']
        outs.ctx_marpd = ctx_metrics['accuracy']['marpd']
        outs.ctx_r2 = ctx_metrics['accuracy']['r2']
        outs.ctx_corr = ctx_metrics['accuracy']['corr']
        outs.ctx_rms_cal = ctx_metrics['avg_calibration']['rms_cal']
        outs.ctx_ma_cal = ctx_metrics['avg_calibration']['ma_cal']
        outs.ctx_miscal_area = ctx_metrics['avg_calibration']['miscal_area']
        
        outs.tar_mae = tar_metrics['accuracy']['mae']
        outs.tar_rmse = tar_metrics['accuracy']['rmse']
        outs.tar_mdae = tar_metrics['accuracy']['mdae']
        outs.tar_marpd = tar_metrics['accuracy']['marpd']
        outs.tar_r2 = tar_metrics['accuracy']['r2']
        outs.tar_corr = tar_metrics['accuracy']['corr']
        outs.tar_rms_cal = tar_metrics['avg_calibration']['rms_cal']
        outs.tar_ma_cal = tar_metrics['avg_calibration']['ma_cal']
        outs.tar_miscal_area = tar_metrics['avg_calibration']['miscal_area']


        return outs