import torch
import numpy as np
import torch.nn as nn
from attrdict import AttrDict

from .modules import CrossAttnEncoder2, Decoder, PoolingEncoder, SelfAttn, build_mlp
import time
from torch.distributions.normal import Normal

import numpy as np
import uncertainty_toolbox as uct

class MAB(nn.Module):
    def __init__(self, dim_out=128, num_heads=8):
        super().__init__()
        self.dim_out = dim_out
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_out, dim_out)
        self.fc_k = nn.Linear(dim_out, dim_out)
        self.fc_v = nn.Linear(dim_out, dim_out)
        self.fc_out = nn.Linear(dim_out, dim_out)
        self.fc_real_out = nn.Linear(dim_out, dim_out)
        self.ln1 = nn.LayerNorm(dim_out)
        self.ln2 = nn.LayerNorm(dim_out)

    def scatter(self, x):
        return torch.cat(torch.split(x, x.shape[-1] // self.num_heads, dim=-1), dim=0)

    def gather(self, x):
        return torch.cat(torch.split(x, x.shape[0] // self.num_heads, dim=0), dim=-1)

    def attend(self, q, k, v, mask=None):
        
        q_, k_, v_ = self.scatter(q), self.scatter(k), self.scatter(v)
        A_logits = q_ @ k_.swapaxes(-2, -1) / np.sqrt(self.dim_out)
        A = torch.nn.functional.softmax(A_logits, dim=-1)
        
        return self.gather(A @ v_)

    def __call__(self, q, v, mask=None):
        q, k, v = self.fc_q(q), self.fc_k(v), self.fc_v(v)
        out = self.ln1(q + self.attend(q, k, v, mask))
        out = self.ln2(out + nn.functional.relu(self.fc_out(out)))
        out = self.fc_real_out(out)
        return out
    
class ISAB(nn.Module):
    def __init__(self, dim_out=128, num_heads=8):
        super().__init__()
        self.mab0 = MAB(dim_out=dim_out, num_heads=num_heads)
        self.mab1 = MAB(dim_out=dim_out, num_heads=num_heads)

    def __call__(self, context, generate_sample, mask_context=None, mask_generate=None):

        h = self.mab0(context, generate_sample, mask_generate)
        return self.mab1(generate_sample, h, mask_context)
    
class SetGenerate(nn.Module):
    def __init__(self, dim_out=128, dim_hidden=128, num_heads=8):
        super().__init__()
        self.isab = ISAB(dim_out=dim_out, num_heads=num_heads)

    def __call__(self, r, generate_num, mask=None):
        generator = torch.Generator().manual_seed(int(time.time()))
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        generate_initial = torch.randn((*r_front_shape, generate_num, r_tail_shape), generator=generator).to(r.device)
        generate_mask = torch.ones((r.shape[0], generate_num), dtype=torch.bool)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out
    
def autoregressive(**kwargs):
    return SetGenerate(**kwargs)

class MPANP_UNCERTAINTY(nn.Module):
    def __init__(self,
            dim_x=1,
            dim_y=1,
            dim_hid=128,
            enc_v_depth=4,
            enc_qk_depth=2,
            enc_pre_depth=4,
            enc_post_depth=2,
            dec_depth=3):

        super().__init__()

        self.net_qk = build_mlp(dim_x, dim_hid, dim_hid, enc_qk_depth)
        self.net_v_pre = build_mlp(dim_x+dim_y, dim_hid, dim_hid, enc_v_depth-2) 
        self.self_attn = SelfAttn(dim_hid, dim_hid)
        self.net_v_post = build_mlp(dim_hid, dim_hid, dim_hid, enc_post_depth)
                
        self.enc1 = CrossAttnEncoder2(
                dim_x=dim_x,
                dim_y=dim_y,
                dim_hid=dim_hid)
        
        self.auto_regressive = autoregressive(dim_out=dim_hid*2)

        self.dec = Decoder(
            dim_x=dim_x,
            dim_y=dim_y,
            dim_enc=dim_hid,
            dim_hid=dim_hid,
            depth=dec_depth)
            
    def _auto_regressive(self, r_i_k, r_i_v, mask_kv, num_samples, num_generates):
        r_i_kv = torch.cat([r_i_k, r_i_v], -1) # [batch, context, r_dim * 2]
        r_i_gen = self.auto_regressive(r_i_kv, num_generates, mask_kv)
        
        r_i_k_gen, r_i_v_gen = torch.split(r_i_gen, r_i_gen.shape[-1] // 2, dim=-1)
        return r_i_k_gen, r_i_v_gen
        
    def predict(self, xc, yc, xt, num_samples=None, num_generates=40, mask=None):
        
        r_i_q = self.net_qk(xt)  # [batch_size, target, dim_hid]
        r_i_k_base = self.net_qk(xc) # [batch_size, context, dim_hid]
        
        r_i_v_base = self.net_v_pre(torch.cat([xc, yc], -1))
        r_i_v_base = self.self_attn(r_i_v_base) # [batch_size, context, dim_hid]
        #r_i_v_base = r_i_v_base.mean(-2)
        #r_i_v_base = self.net_v_post(r_i_v_base)
        
        r_i_k_gen, r_i_v_gen = \
            self._auto_regressive(r_i_k_base, r_i_v_base, mask, num_samples, num_generates)
        
        r_i_k = torch.cat([r_i_k_base, r_i_k_gen], -2) # [batch_size, context + num_generates, dim_hid]
        r_i_v = torch.cat([r_i_v_base, r_i_v_gen], -2)  # [batch_size, context + num_generates, dim_hid]
        
        r_ctx_gen = self.enc1(r_i_q, r_i_k, r_i_v) # [batch_size, target, dim_hid]

        return self.dec(r_ctx_gen, xt)
        
    def forward(self, batch, num_samples=None, reduce_ll=True):
        outs = AttrDict()
        py = self.predict(batch.xc, batch.yc, batch.x)
        ll = py.log_prob(batch.y).sum(-1)

        if self.training:
            outs.loss = -ll.mean()
        else:
            num_ctx = batch.xc.shape[-2]
            outs.mse = torch.mean((py.mean - batch.y) ** 2) 

            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx]
                outs.tar_ll = ll[...,num_ctx:]

        return outs

    def calculate_crps(self, y_true, means, stds):
        y_true = y_true.squeeze(-1)
        means = means.squeeze(-1)
        stds = stds.squeeze(-1)

        z = (y_true - means) / stds
        #cdf_z = 0.5 * (1 + torch.stack([torch.erf(z[i, ...]/torch.sqrt(torch.tensor(2.0, device=z.device))) for i in range(50)]).mean(dim=0))

        cdf_z = 0.5 * (1 + torch.erf(z/torch.sqrt(torch.tensor(2.0, device=z.device))).mean(dim=0))

        pdf_z = (torch.exp(-0.5 * z**2) / torch.sqrt(torch.tensor(2 * torch.pi, device=z.device))).mean(dim=0)

        crps = stds * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi, device=z.device)))
        return crps.mean(dim=-1)
    
    def crps(self, batch, num_samples=None):
        outs = AttrDict()

        if num_samples is None:
            y = batch.y.unsqueeze(-1)
        else:
            y = torch.stack([batch.y]*num_samples)
            
        py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
        
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 
        print(means.shape)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]
                
        ctx_crps = self.calculate_crps(y_ctx, ctx_means, ctx_stds)
        tar_crps = self.calculate_crps(y_tar, tar_means, tar_stds)
        
        means = means.mean(dim=0)
        stds = torch.sqrt((stds**2).mean(dim=0) + (py.loc**2).mean(dim=0) - (py.loc.mean(dim=0)**2))
        
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        lower_bounds_ctx = ctx_means - z_score * ctx_stds
        upper_bounds_ctx = ctx_means + z_score * ctx_stds
        lower_bounds_tar = tar_means - z_score * tar_stds
        upper_bounds_tar = tar_means + z_score * tar_stds

        outs.ctx_ci = ((y_ctx >= lower_bounds_ctx) & (y_ctx <= upper_bounds_ctx)).float().mean()
        outs.tar_ci = ((y_tar >= lower_bounds_tar) & (y_tar <= upper_bounds_tar)).float().mean()
        
        outs.ctx_crps = ctx_crps.mean()
        outs.tar_crps = tar_crps.mean()
        
        return outs

    def uncertainty(self, batch, num_samples=None):
        outs = AttrDict()

        if num_samples is None:
            y = batch.y.unsqueeze(-1)
        else:
            y = torch.stack([batch.y]*num_samples)

            
        py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
        
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 
        
        # means = means.mean(dim=0)
        # stds = torch.sqrt((stds**2).mean(dim=0) + (py.loc**2).mean(dim=0) - (py.loc.mean(dim=0)**2))
       
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        ctx_metrics = uct.metrics.get_all_metrics(ctx_means.cpu().numpy()[0,:,0], ctx_stds.cpu().numpy()[0,:,0], y_ctx.cpu().numpy()[0,0,:,0])
        tar_metrics = uct.metrics.get_all_metrics(tar_means.cpu().numpy()[0,:,0], tar_stds.cpu().numpy()[0,:,0], y_tar.cpu().numpy()[0,0,:,0])

        outs.ctx_mae = ctx_metrics['accuracy']['mae']
        outs.ctx_rmse = ctx_metrics['accuracy']['rmse']
        outs.ctx_mdae = ctx_metrics['accuracy']['mdae']
        outs.ctx_marpd = ctx_metrics['accuracy']['marpd']
        outs.ctx_r2 = ctx_metrics['accuracy']['r2']
        outs.ctx_corr = ctx_metrics['accuracy']['corr']
        outs.ctx_rms_cal = ctx_metrics['avg_calibration']['rms_cal']
        outs.ctx_ma_cal = ctx_metrics['avg_calibration']['ma_cal']
        outs.ctx_miscal_area = ctx_metrics['avg_calibration']['miscal_area']
        
        outs.tar_mae = tar_metrics['accuracy']['mae']
        outs.tar_rmse = tar_metrics['accuracy']['rmse']
        outs.tar_mdae = tar_metrics['accuracy']['mdae']
        outs.tar_marpd = tar_metrics['accuracy']['marpd']
        outs.tar_r2 = tar_metrics['accuracy']['r2']
        outs.tar_corr = tar_metrics['accuracy']['corr']
        outs.tar_rms_cal = tar_metrics['avg_calibration']['rms_cal']
        outs.tar_ma_cal = tar_metrics['avg_calibration']['ma_cal']
        outs.tar_miscal_area = tar_metrics['avg_calibration']['miscal_area']


        return outs