import torch
import torch.nn as nn

import copy
from ..utils.misc import stack, logmeanexp
from torch.distributions import kl_divergence
from .modules import build_mlp
from .modules import TTPoolingEncoder_Dim
from .attention import SelfAttn
import math

import ipdb 
import torch.nn.functional as F
from torch.distributions import kl_divergence

from torch.distributions.normal import Normal
from attrdict import AttrDict
from ..utils.misc import stack, logmeanexp
from typing import Callable

def grad_U(U: Callable, x: torch.Tensor) -> torch.Tensor:
    """
    Autograd-based gradient of U with respect to x.
    We use PyTorch's automatic differentiation for generality.
    """
    x = x.clone().requires_grad_(True)
    potential = U(x).sum()  # sum over batch if needed
    grad, = torch.autograd.grad(potential, x)
    return grad

# -----------------------------
# 2. Define the leapfrog integrator
#    which simulates Hamiltonian dynamics
# -----------------------------
def leapfrog(
    U: Callable,
    x: torch.Tensor,
    p: torch.Tensor,
    step_size: float,
    n_steps: int):
    """
    Leapfrog integrator for Hamiltonian dynamics.
    Returns the proposed new position x and momentum p.
    """
    # Half-step update of momentum
    p = p - 0.5 * step_size * grad_U(U, x)
    
    for i in range(n_steps):
        # Full-step update of position
        x = x + step_size * p
        # Update momentum except after the last step
        if i < n_steps - 1:
            p = p - step_size * grad_U(U,x)

    # Final half-step update of momentum
    p = p - 0.5 * step_size * grad_U(U,x)
    return x, p

# -----------------------------
# 3. HMC step: sample momentum, 
#    do leapfrog, accept/reject
# -----------------------------
def hmc_step(
    U: Callable,
    x_current: torch.Tensor,
    step_size: float,
    n_steps: int):
    """
    Performs one HMC update step in parallel for a batch of states.
    x_current has shape [B, D].
    Returns:
      x_next: shape [B, D]
      accept_mask: shape [B], indicating which chains were accepted.
    """
    # Sample random momentum from N(0, I) for each chain
    p_current = torch.randn_like(x_current)  # shape [B, D]

    # Current Hamiltonian (U + K)
    U_current = U(x_current)                        # shape [B]
    K_current = 0.5 * (p_current ** 2).sum(dim=-1)  # shape [B]

    # Propose new state with leapfrog
    x_proposed, p_proposed = leapfrog(U, x_current, p_current,
                                      step_size, n_steps)

    # Proposed Hamiltonian
    U_proposed = U(x_proposed)                      # shape [B]
    K_proposed = 0.5 * (p_proposed ** 2).sum(dim=-1)# shape [B]

    # Metropolis acceptance probability = exp(H_current - H_proposed)
    # log_accept_ratio has shape [B]
    log_accept_ratio = (U_current + K_current) - (U_proposed + K_proposed)
    accept_prob = torch.exp(log_accept_ratio)  # shape [B]

    # Decide which samples to accept
    accept_mask = (torch.rand_like(accept_prob) < accept_prob)  # [B]

    # Where accept_mask is True, use x_proposed; otherwise keep x_current
    # We broadcast accept_mask over the D dimension
    x_next = torch.where(
        accept_mask.unsqueeze(-1),  # shape [B, 1]
        x_proposed,
        x_current
    )

    return x_next, accept_mask
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe_1 = torch.zeros(max_seq_len, d_model)
        pe_1[:, 0::2] = torch.sin(position * div_term)
        pe_1[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe_1', pe_1)

        pe_2 = torch.zeros(max_seq_len, d_model)

        pe_2[:, 0::2] = torch.cos(position * div_term)
        pe_2[:, 1::2] = torch.sin(position * div_term)
        self.register_buffer('pe_2', pe_2)

    def forward(self, x, y_dim):
        pe = torch.cat([self.pe_1[:x.size(-2)-y_dim, :], self.pe_2[:y_dim, :]])
        x = x + pe
        return x

class DimensionAggregator(nn.Module):
    def __init__(self, dim_hid, dim_out, max_seq_len=101):
        super(DimensionAggregator, self).__init__()
        self.dim_hid = dim_hid
        self.dim_out = dim_out
        self.positional_encoding = PositionalEncoding(self.dim_hid, max_seq_len)
        self.linear = nn.Linear(1, self.dim_hid)
        self.selfattention = SelfAttn(self.dim_hid, self.dim_out)

    def forward(self, data_xy, y_dim):
        data_xy_unsqueeze = data_xy.unsqueeze(-1) 
        # [B, num_data, dim_x+dim_y, 1]
        data_xy_linear = self.linear(data_xy_unsqueeze) 
        # [B, num_data, dim_x+dim_y, dim_hid]
        data_xy_positional = self.positional_encoding(data_xy_linear, y_dim) 
        # [B, num_data, dim_x+dim_y, dim_hid]
        
        data_xy_selfattn = self.selfattention(data_xy_positional.reshape(-1,\
            data_xy_positional.shape[-2], data_xy_positional.shape[-1])) 
        # [B * num_data, dim_x+dim_y, dim_out]

        data_xy_selfattn = data_xy_selfattn.reshape(data_xy_positional.shape[0],\
            data_xy_positional.shape[1], data_xy_positional.shape[2],
            data_xy_selfattn.shape[-1]) 
        # [B, num_data, dim_x+dim_y, dim_out]
        
        data_x_selfattn, data_y_selfattn = \
            data_xy_selfattn.split([data_xy_positional.shape[2]-y_dim,y_dim],
                                   dim=-2) 
        # [B, num_data, dim_x, dim_out], [B, num_data, dim_y, dim_out]
        
        data_x_agg = data_x_selfattn.mean(dim=-2, keepdim=True)
        data_x_expanded = data_x_agg.expand(-1, -1, y_dim, -1)

        data_xy_combined = torch.cat([data_x_expanded, data_y_selfattn], dim=-1)
        return data_xy_combined
    
class DTANP_Y_base(nn.Module):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std
    ):
        super(DTANP_Y_base, self).__init__()
        self.dim_agg = DimensionAggregator(int(d_model/2), int(d_model/2))

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder1 = nn.TransformerEncoder(encoder_layer, num_layers)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder2 = nn.TransformerEncoder(encoder_layer, 2)

        self.bound_std = bound_std
        self.lenc = TTPoolingEncoder_Dim(
                dim_x=int(d_model/2),
                dim_y=int(d_model/2),
                dim_hid=d_model,
                dim_lat=dim_feedforward,
                self_attn=True,
                pre_depth=4,
                post_depth=2)

    def construct_input(self, batch, autoreg=False):
        x_y_ctx = torch.cat((batch.xc, batch.yc), dim=-1)
        x_0_tar = torch.cat((batch.xt, torch.zeros_like(batch.yt)), dim=-1)

        inp = torch.cat((x_y_ctx, x_0_tar), dim=1)

        return inp

    def create_mask(self, batch, y_dim, autoreg=False):
        num_ctx = batch.xc.shape[1]
        num_tar = batch.xt.shape[1]
        num_all = num_ctx + num_tar
        mask = torch.zeros(y_dim * num_all, y_dim * num_all, device='cuda').fill_(float('-inf'))
        mask[:, :y_dim * num_ctx] = 0.0

        return mask, num_tar 
    
    def encode(self, *args, **kwargs):
        return self.encode_v1(*args, **kwargs)

    def encode_v1(self, batch, z=None, num_samples=None, autoreg=False):
        print("num_samples", num_samples)
        y_dim = batch.yt.shape[-1]
        inp = self.construct_input(batch, autoreg)
        mask, num_tar = self.create_mask(batch, y_dim, autoreg)
        embeddings = self.dim_agg(inp, y_dim) 
            
        embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])
        out = stack(self.encoder1(embeddings, mask), num_samples)

        assert z is None, "Eval Only"
        context_embeddings = embeddings[:, :batch.xc.shape[1]*y_dim]
        pz = self.encoder2(context_embeddings) 
        pz = self.lenc(pz)
        z = pz.rsample() if num_samples is None \
                else pz.rsample([num_samples])

        def target(z):
            _batch = copy.deepcopy(batch)
            _batch.xt = batch.xc
            _batch.yt = batch.yc
            y_dim = _batch.yt.shape[-1]
            inp = self.construct_input(_batch, autoreg)
            mask, num_tar = self.create_mask(_batch, y_dim, autoreg)
            embeddings = self.dim_agg(inp, y_dim) 
                
            embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])
            out = stack(self.encoder1(embeddings, mask), num_samples)
            z = stack(z, inp.shape[-2], -2)
            z = z.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z.shape[-1])
            out = torch.cat([out, z], dim=-1)
            out = out.view(*out.shape[:2], -1, y_dim, out.shape[-1])

            z_target = out[:, :, -num_tar:,:]            
            out = self.predictor(z_target)
            mean, std = torch.chunk(out, 2, dim=-1)
            mean, std = mean.reshape((*mean.shape[:-2],-1)), std.reshape((*std.shape[:-2],-1))
            if self.bound_std:
                std = 0.1 + 0.9 * F.softplus(std)
            else:
                std = torch.exp(std)

            py = Normal(mean, std)
            assert len(_batch.yc.shape) == 3, _batch.y.shape
            assert len(z_target.shape) == 5, z_target.shape
            yc = stack(_batch.yc, z_target.shape[0], 0)
            logp = py.log_prob(yc).sum((-1,-2))
            logprior = Normal(loc=0, scale=1).log_prob(z_target).sum((-1,-2,-3))

            return -logp -logprior

        var = z
        iters = 5
        accept= 0.
        for i in range(iters):
            var, accept_mask = hmc_step(target, var, 0.1, 2)
            accept += accept_mask.float().mean().item()
        assert accept > 0.5, accept
        z = var
        
        z = stack(z, inp.shape[-2], -2)
        z = z.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z.shape[-1])
        out = torch.cat([out, z], dim=-1)
        out = out.view(*out.shape[:2], -1, y_dim, out.shape[-1])

        return out[:, :, -num_tar:,:]

    def encode_v0(self, batch, z=None, num_samples=None, autoreg=False):
        y_dim = batch.yt.shape[-1]
        inp = self.construct_input(batch, autoreg)
        mask, num_tar = self.create_mask(batch, y_dim, autoreg)
        embeddings = self.dim_agg(inp, y_dim) 
            
        embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])
        out = stack(self.encoder1(embeddings, mask), num_samples)

        if z is None:
            context_embeddings = embeddings[:, :batch.xc.shape[1]*y_dim]
            pz = self.encoder2(context_embeddings) 
            pz = self.lenc(pz)
            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])
        
        z = stack(z, inp.shape[-2], -2)
        z = z.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z.shape[-1])
        out = torch.cat([out, z], dim=-1)
        out = out.view(*out.shape[:2], -1, y_dim, out.shape[-1])

        return out[:, :, -num_tar:,:]

    def lencode(self, batch, autoreg=False):
        
        inp = self.construct_input(batch, autoreg)
        embeddings = self.dim_agg(inp, batch.y.shape[-1])
        
        num_context = batch.xc.shape[1]
        num_total = batch.x.shape[1]
        
        embeddings = torch.mean(embeddings, dim=-2)
        
        # context_embeddings = embeddings[:, :num_context].reshape(-1, num_context * batch.y.shape[-1], embeddings.shape[-1])
        # total_embeddings = embeddings[:, :num_total].reshape(-1, num_total * batch.y.shape[-1], embeddings.shape[-1])
        
        context_embeddings = embeddings[:, :num_context]
        total_embeddings = embeddings[:, :num_total]
        
        pz = self.lenc(context_embeddings)
        qz = self.lenc(total_embeddings)
        return pz, qz

class DTANP_GIBBS_HS(DTANP_Y_base):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=True
    ):
        super(DTANP_GIBBS_HS, self).__init__(
            dim_x,
            dim_y,
            d_model,
            emb_depth,
            dim_feedforward,
            nhead,
            dropout,
            num_layers,
            bound_std
        )

        self.predictor = nn.Sequential(
            nn.Linear(dim_feedforward+d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, 2)
        )
    
    def forward(self, batch, num_samples, reduce_ll=True, iters=10, avg=False, **kwargs):
        outs = AttrDict()
        if self.training:
            pz, qz = self.lencode(batch)

            z = qz.rsample() if num_samples is None else \
                    qz.rsample([num_samples])

            py = self.predict(batch.xc, batch.yc, batch.x, z=z, num_samples=num_samples)

            if num_samples > 1 :
                recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                log_qz = qz.log_prob(z).sum(-1)
                log_pz = pz.log_prob(z).sum(-1)

                log_w = recon.sum(-1) + log_pz - log_qz
                outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
            else:
                outs.recon = py.log_prob(batch.y).sum(-1).mean()
                outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]
        else:
            py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
            if num_samples is None:
                ll = py.log_prob(batch.y).sum(-1)
            else:
                y = torch.stack([batch.y]*num_samples)
                if reduce_ll:
                    ll = logmeanexp(py.log_prob(y).sum(-1))
                else:
                    ll = py.log_prob(y).sum(-1)
            num_ctx = batch.xc.shape[-2]
            outs.mse = torch.mean((py.mean - batch.y) ** 2) 

            if reduce_ll:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
            else:
                outs.ctx_ll = ll[...,:num_ctx].mean()
                outs.tar_ll = ll[...,num_ctx:].mean()
        return outs

    def predict(self, *args, **kwargs):
        return self.predict_v0(*args, **kwargs)

    def predict_v0(self, xc, yc, xt, z=None, num_samples=None):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]), device='cuda')

        z_target = self.encode(batch, z=z, num_samples=num_samples, autoreg=False)

        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        mean, std = mean.reshape((*mean.shape[:-2],-1)), std.reshape((*std.shape[:-2],-1))
        if self.bound_std:
            std = 0.1 + 0.9 * F.softplus(std)
        else:
            std = torch.exp(std)

        return Normal(mean, std)

    def calculate_crps(self, y_true, means, stds):
        y_true = y_true.squeeze(-1)
        means = means.squeeze(-1)
        stds = stds.squeeze(-1)

        z = (y_true - means) / stds
        #cdf_z = 0.5 * (1 + torch.stack([torch.erf(z[i, ...]/torch.sqrt(torch.tensor(2.0, device=z.device))) for i in range(50)]).mean(dim=0))

        cdf_z = 0.5 * (1 + torch.erf(z/torch.sqrt(torch.tensor(2.0, device=z.device))).mean(dim=0))

        pdf_z = (torch.exp(-0.5 * z**2) / torch.sqrt(torch.tensor(2 * torch.pi, device=z.device))).mean(dim=0)

        crps = stds * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi, device=z.device)))
        return crps.mean(dim=-1)
    
    def crps(self, batch, num_samples=None):
        outs = AttrDict()

        if num_samples is None:
            y = batch.y.unsqueeze(-1)
        else:
            y = torch.stack([batch.y]*num_samples)
            
        py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
        
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 

        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]
                
        ctx_crps = self.calculate_crps(y_ctx, ctx_means, ctx_stds)
        tar_crps = self.calculate_crps(y_tar, tar_means, tar_stds)
        
        means = means.mean(dim=0)
        stds = torch.sqrt((stds**2).mean(dim=0) + (py.loc**2).mean(dim=0) - (py.loc.mean(dim=0)**2))
        
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        lower_bounds_ctx = ctx_means - z_score * ctx_stds
        upper_bounds_ctx = ctx_means + z_score * ctx_stds
        lower_bounds_tar = tar_means - z_score * tar_stds
        upper_bounds_tar = tar_means + z_score * tar_stds

        outs.ctx_ci = ((y_ctx >= lower_bounds_ctx) & (y_ctx <= upper_bounds_ctx)).float().mean()
        outs.tar_ci = ((y_tar >= lower_bounds_tar) & (y_tar <= upper_bounds_tar)).float().mean()
        
        outs.ctx_crps = ctx_crps.mean()
        outs.tar_crps = tar_crps.mean()
        
        return outs