import torch
import torch.nn as nn

import copy
from ..utils.misc import stack, logmeanexp
from torch.distributions import kl_divergence
from .modules import build_mlp
from .modules import TTPoolingEncoder_Dim
from .attention import SelfAttn
import math
import time
import numpy as np

import ipdb 
from torch.distributions.relaxed_categorical import RelaxedOneHotCategorical
import torch.nn.functional as F
from torch.distributions import kl_divergence

from torch.distributions.normal import Normal
from attrdict import AttrDict
from ..utils.misc import stack, logmeanexp, log_w_weighted_sum_exp

class MAB(nn.Module):
    def __init__(self, dim_out=128, num_heads=8):
        super().__init__()
        self.dim_out = dim_out
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_out, dim_out)
        self.fc_k = nn.Linear(dim_out, dim_out)
        self.fc_v = nn.Linear(dim_out, dim_out)
        self.fc_out = nn.Linear(dim_out, dim_out)
        self.fc_real_out = nn.Linear(dim_out, dim_out)
        self.ln1 = nn.LayerNorm(dim_out)
        self.ln2 = nn.LayerNorm(dim_out)

    def scatter(self, x):
        return torch.cat(torch.split(x, x.shape[-1] // self.num_heads, dim=-1), dim=0)

    def gather(self, x):
        return torch.cat(torch.split(x, x.shape[0] // self.num_heads, dim=0), dim=-1)

    def attend(self, q, k, v, mask=None):
        
        q_, k_, v_ = self.scatter(q), self.scatter(k), self.scatter(v)
        A_logits = q_ @ k_.swapaxes(-2, -1) / np.sqrt(self.dim_out)
        A = torch.nn.functional.softmax(A_logits, dim=-1)
        
        return self.gather(A @ v_)

    def __call__(self, q, v, mask=None):
        q, k, v = self.fc_q(q), self.fc_k(v), self.fc_v(v)
        out = self.ln1(q + self.attend(q, k, v, mask))
        out = self.ln2(out + nn.functional.relu(self.fc_out(out)))
        out = self.fc_real_out(out)
        return out
    
class ISAB(nn.Module):
    def __init__(self, dim_out=128, num_heads=8):
        super().__init__()
        self.mab0 = MAB(dim_out=dim_out, num_heads=num_heads)
        self.mab1 = MAB(dim_out=dim_out, num_heads=num_heads)

    def __call__(self, context, generate_sample, mask_context=None, mask_generate=None):

        h = self.mab0(context, generate_sample, mask_generate)
        return self.mab1(generate_sample, h, mask_context)
    

class SetGenerate(nn.Module):
    def __init__(self, dim_out=128, dim_hidden=128, num_heads=8):
        super().__init__()
        self.isab = ISAB(dim_out=dim_out, num_heads=num_heads)

    def __call__(self, r, generate_num, mask=None):
        generator = torch.Generator().manual_seed(int(time.time()))
        r_front_shape, r_tail_shape = r.shape[:-2], r.shape[-1]
        generate_initial = torch.randn((*r_front_shape, generate_num, r_tail_shape), generator=generator).to(r.device)
        generate_mask = torch.ones((r.shape[0], generate_num), dtype=torch.bool)
        out = self.isab(r, generate_initial, mask, generate_mask)
        return out

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        super(PositionalEncoding, self).__init__()
        position = torch.arange(0, max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2)*(-math.log(10000.0)/d_model))
        pe_1 = torch.zeros(max_seq_len, d_model)
        pe_1[:, 0::2] = torch.sin(position * div_term)
        pe_1[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe_1', pe_1)

        pe_2 = torch.zeros(max_seq_len, d_model)

        pe_2[:, 0::2] = torch.cos(position * div_term)
        pe_2[:, 1::2] = torch.sin(position * div_term)
        self.register_buffer('pe_2', pe_2)

    def forward(self, x, y_dim):
        pe = torch.cat([self.pe_1[:x.size(-2)-y_dim, :], self.pe_2[:y_dim, :]])
        x = x + pe
        return x

class DimensionAggregator(nn.Module):
    def __init__(self, dim_hid, dim_out, max_seq_len=101):
        super(DimensionAggregator, self).__init__()
        self.dim_hid = dim_hid
        self.dim_out = dim_out
        self.positional_encoding = PositionalEncoding(self.dim_hid, max_seq_len)
        self.linear = nn.Linear(1, self.dim_hid)
        self.selfattention = SelfAttn(self.dim_hid, self.dim_out)

    def forward(self, data_xy, y_dim):
        data_xy_unsqueeze = data_xy.unsqueeze(-1) 
        # [B, num_data, dim_x+dim_y, 1]
        data_xy_linear = self.linear(data_xy_unsqueeze) 
        # [B, num_data, dim_x+dim_y, dim_hid]
        data_xy_positional = self.positional_encoding(data_xy_linear, y_dim) 
        # [B, num_data, dim_x+dim_y, dim_hid]
        
        data_xy_selfattn = self.selfattention(data_xy_positional.reshape(-1,\
            data_xy_positional.shape[-2], data_xy_positional.shape[-1])) 
        # [B * num_data, dim_x+dim_y, dim_out]

        data_xy_selfattn = data_xy_selfattn.reshape(data_xy_positional.shape[0],\
            data_xy_positional.shape[1], data_xy_positional.shape[2],
            data_xy_selfattn.shape[-1]) 
        # [B, num_data, dim_x+dim_y, dim_out]
        
        data_x_selfattn, data_y_selfattn = \
            data_xy_selfattn.split([data_xy_positional.shape[2]-y_dim,y_dim],
                                   dim=-2) 
        # [B, num_data, dim_x, dim_out], [B, num_data, dim_y, dim_out]
        
        data_x_agg = data_x_selfattn.mean(dim=-2, keepdim=True)
        data_x_expanded = data_x_agg.expand(-1, -1, y_dim, -1)

        data_xy_combined = torch.cat([data_x_expanded, data_y_selfattn], dim=-1)
        return data_xy_combined
    
class Gradeint_estimator(nn.Module):
    def __init__(self, d_model):
        super(Gradeint_estimator,self).__init__()
        self.predictor = nn.Sequential(
            nn.Linear(int(3*d_model), 3*d_model),
            nn.ReLU(),
            nn.Linear(3*d_model, 3*d_model),
            nn.ReLU(),
            nn.Linear(3*d_model, 2*d_model)
        )

    
    def forward(self, psi_data, t):

        # t = torch.tensor(t)
        
        # pe = continuous_positional_encoding(t)

        # psi_data += pe.cuda()

        return self.predictor(psi_data)

class DTANP_Y_base(nn.Module):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std
    ):
        super(DTANP_Y_base, self).__init__()
        self.dim_agg = DimensionAggregator(int(d_model/2), int(d_model/2))

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder1 = nn.TransformerEncoder(encoder_layer, num_layers)

        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, dim_feedforward, dropout, batch_first=True)
        self.encoder2 = nn.TransformerEncoder(encoder_layer, 2)

        self.bound_std = bound_std
        self.lenc = TTPoolingEncoder_Dim(
                dim_x=int(d_model/2),
                dim_y=int(d_model/2),
                dim_hid=d_model,
                dim_lat=dim_feedforward,
                self_attn=True,
                pre_depth=4,
                post_depth=2)
        self.gradient_estimator = Gradeint_estimator(d_model)
        self.auto_regressive = SetGenerate(dim_out=d_model)

    def compute_kl_loss(self, log_w_pre, log_w_last):

        log_w_normalizer = torch.logsumexp(log_w_last, dim=0)
        w_tilde = torch.exp(log_w_last - log_w_normalizer)
        loss = -torch.mean(torch.mean(w_tilde*log_w_pre, dim=0))

        return loss
    
    def generate_pseudo_context(self, split_theta, num_pseudo_context=12):
        generated_pseudo_context = self.auto_regressive(split_theta, num_pseudo_context)
        return generated_pseudo_context
    
    def compute_gradient_log_pi(self, T, t, z, xc, yc, theta, pz, sigma=1., num_samples=50):
        z_=z.clone().detach().requires_grad_(True)

        py = self.predict(xc, yc, xc, z=z_, num_samples=num_samples)

        loss = py.log_prob(yc).sum(dim=-2).squeeze(dim=-1).sum()

        loss.backward()
        mean = torch.tensor([ 0.0035,  0.0329, -0.0217,  0.0685,  0.0085,  0.0802,  0.0094, -0.0053,                       
         0.0154,  0.0173,  0.0005, -0.0114,  0.0192, -0.0110, -0.0035, -0.0059,                       
         0.0022,  0.0025, -0.0092, -0.0198, -0.0027, -0.0241,  0.0131, -0.0127,                       
        -0.0102,  0.0126,  0.0166,  0.0050,  0.0295, -0.0017,  0.0237,  0.0085,                       
         0.0142, -0.0185,  0.0042, -0.0075, -0.0658, -0.0104,  0.0168, -0.0154,                       
         0.0285, -0.0804, -0.0086, -0.0448,  0.0079, -0.0033, -0.0253,  0.0258,                       
        -0.0269, -0.0404, -0.0144,  0.0116, -0.0096,  0.0048, -0.0148,  0.0047,                       
        -0.0087, -0.0548, -0.0114, -0.0086, -0.0177,  0.0636,  0.0108,  0.0046,                       
         0.0662, -0.0095, -0.0090, -0.0103,  0.0073, -0.0098, -0.0126,  0.0266,                       
        -0.0137, -0.0051,  0.0372,  0.0004, -0.0083, -0.0031,  0.0235,  0.0250,                       
         0.0023,  0.0105, -0.0083, -0.0135,  0.0251,  0.0396, -0.0095,  0.0310,                       
        -0.0178, -0.0228,  0.0218, -0.0258, -0.0213, -0.0021,  0.0142,  0.0090,                       
         0.0312,  0.0388,  0.0042, -0.0212,  0.0134, -0.0185,  0.0193,  0.0225,                       
         0.0094,  0.0041, -0.0236,  0.0229, -0.0221, -0.0253, -0.0075,  0.0377,                       
         0.0067,  0.0101, -0.0183,  0.0275, -0.0107, -0.0019,  0.0203,  0.0423,                       
        -0.0245,  0.0094,  0.0055, -0.0072, -0.0098, -0.0138, -0.0429,  0.0542]).unsqueeze(0).expand(xc.shape[0],-1).cuda()
        mean = torch.zeros_like(mean).cuda()
        # std = torch.tensor([0.0157, 0.0305, 0.0372, 0.0807, 0.0179, 0.1770, 0.0645, 0.0328, 0.0138,                       
        # 0.0140, 0.0696, 0.0304, 0.0716, 0.0174, 0.0208, 0.0094, 0.0253, 0.0485,                       
        # 0.0087, 0.0183, 0.0328, 0.0059, 0.0088, 0.0172, 0.0090, 0.0953, 0.0582,                       
        # 0.0083, 0.0180, 0.0153, 0.0221, 0.0170, 0.0241, 0.0778, 0.0210, 0.2163,                       
        # 0.0808, 0.0073, 0.0237, 0.0133, 0.0689, 0.0755, 0.0097, 0.1630, 0.0161,                       
        # 0.0206, 0.0731, 0.0617, 0.0163, 0.0311, 0.0170, 0.0152, 0.1277, 0.0075,                       
        # 0.0102, 0.0828, 0.2070, 0.0539, 0.0123, 0.0143, 0.0750, 0.0463, 0.0086,                       
        # 0.0097, 0.1265, 0.0472, 0.0176, 0.0472, 0.0227, 0.0074, 0.0833, 0.0411,                       
        # 0.0436, 0.0832, 0.0638, 0.0104, 0.0193, 0.0094, 0.0150, 0.0443, 0.0090,                       
        # 0.0155, 0.0577, 0.0337, 0.1094, 0.0405, 0.0130, 0.0579, 0.0268, 0.0214,                       
        # 0.0251, 0.0423, 0.0172, 0.0119, 0.0067, 0.0372, 0.0474, 0.0311, 0.0123,                       
        # 0.0388, 0.0366, 0.0459, 0.0161, 0.0097, 0.0083, 0.0235, 0.0159, 0.0595,                       
        # 0.2393, 0.0430, 0.0371, 0.0405, 0.0060, 0.0248, 0.0253, 0.0488, 0.0191,                       
        # 0.0161, 0.0243, 0.0661, 1.1723, 0.0086, 0.0296, 0.0407, 0.0130, 0.1904,                       
        # 0.0613, 0.2425], device='cuda:0').unsqueeze(0).expand(16,-1)
        # gradient = z_.grad * t/T - (1-t/T)*(z_-pz.loc)/pz.scale**2 - t/T*z_/sigma**2
        # gradient = z_.grad * t/T - (1-t/T)*(z_-pz.loc) - t/T*z_/sigma**2

        grad = self.gradient_estimator(torch.cat([z,stack(theta.mean(dim=-2),z.shape[0])],-1), t/T)
        # grad = self.gradient_estimator(torch.cat([z,stack(self.lenc.attentive_update(theta).mean(dim=-2),z.shape[0])],-1), t/T)
        # gradient = z_.grad * t/T - (z_-mean)/sigma**2*t/T - (1-t/T)*(z_-pz.loc)/sigma**2
        gradient = -(1-t/T)*grad + (z_.grad) * t/T - (z-mean)/sigma**2*t/T - (1-t/T)*(z-pz.loc)/sigma**2
        
        return gradient

    def stable_ess(self, log_weights):

        log_w_max = torch.max(log_weights, dim=0, keepdim=True)[0]

        normalized_weights = torch.exp(log_weights - log_w_max)
        sum_weights = torch.sum(normalized_weights, dim=0)
        sum_weights_squared = torch.sum(normalized_weights ** 2, dim=0)
        ess = (sum_weights ** 2) / sum_weights_squared
        return ess

    def compute_log_pi(self, T, t, z, pz, pz_prime, sigma, xc, yc, num_samples):
        py = self.predict(xc, yc, xc, z=z, num_samples=num_samples)
        mean = torch.tensor([ 0.0035,  0.0329, -0.0217,  0.0685,  0.0085,  0.0802,  0.0094, -0.0053,                       
         0.0154,  0.0173,  0.0005, -0.0114,  0.0192, -0.0110, -0.0035, -0.0059,                       
         0.0022,  0.0025, -0.0092, -0.0198, -0.0027, -0.0241,  0.0131, -0.0127,                       
        -0.0102,  0.0126,  0.0166,  0.0050,  0.0295, -0.0017,  0.0237,  0.0085,                       
         0.0142, -0.0185,  0.0042, -0.0075, -0.0658, -0.0104,  0.0168, -0.0154,                       
         0.0285, -0.0804, -0.0086, -0.0448,  0.0079, -0.0033, -0.0253,  0.0258,                       
        -0.0269, -0.0404, -0.0144,  0.0116, -0.0096,  0.0048, -0.0148,  0.0047,                       
        -0.0087, -0.0548, -0.0114, -0.0086, -0.0177,  0.0636,  0.0108,  0.0046,                       
         0.0662, -0.0095, -0.0090, -0.0103,  0.0073, -0.0098, -0.0126,  0.0266,                       
        -0.0137, -0.0051,  0.0372,  0.0004, -0.0083, -0.0031,  0.0235,  0.0250,                       
         0.0023,  0.0105, -0.0083, -0.0135,  0.0251,  0.0396, -0.0095,  0.0310,                       
        -0.0178, -0.0228,  0.0218, -0.0258, -0.0213, -0.0021,  0.0142,  0.0090,                       
         0.0312,  0.0388,  0.0042, -0.0212,  0.0134, -0.0185,  0.0193,  0.0225,                       
         0.0094,  0.0041, -0.0236,  0.0229, -0.0221, -0.0253, -0.0075,  0.0377,                       
         0.0067,  0.0101, -0.0183,  0.0275, -0.0107, -0.0019,  0.0203,  0.0423,                       
        -0.0245,  0.0094,  0.0055, -0.0072, -0.0098, -0.0138, -0.0429,  0.0542]).unsqueeze(0).expand(xc.shape[0],-1).cuda()
        mean = torch.zeros_like(mean).cuda()
        # std = torch.tensor([0.0157, 0.0305, 0.0372, 0.0807, 0.0179, 0.1770, 0.0645, 0.0328, 0.0138,                       
        # 0.0140, 0.0696, 0.0304, 0.0716, 0.0174, 0.0208, 0.0094, 0.0253, 0.0485,                       
        # 0.0087, 0.0183, 0.0328, 0.0059, 0.0088, 0.0172, 0.0090, 0.0953, 0.0582,                       
        # 0.0083, 0.0180, 0.0153, 0.0221, 0.0170, 0.0241, 0.0778, 0.0210, 0.2163,                       
        # 0.0808, 0.0073, 0.0237, 0.0133, 0.0689, 0.0755, 0.0097, 0.1630, 0.0161,                       
        # 0.0206, 0.0731, 0.0617, 0.0163, 0.0311, 0.0170, 0.0152, 0.1277, 0.0075,                       
        # 0.0102, 0.0828, 0.2070, 0.0539, 0.0123, 0.0143, 0.0750, 0.0463, 0.0086,                       
        # 0.0097, 0.1265, 0.0472, 0.0176, 0.0472, 0.0227, 0.0074, 0.0833, 0.0411,                       
        # 0.0436, 0.0832, 0.0638, 0.0104, 0.0193, 0.0094, 0.0150, 0.0443, 0.0090,                       
        # 0.0155, 0.0577, 0.0337, 0.1094, 0.0405, 0.0130, 0.0579, 0.0268, 0.0214,                       
        # 0.0251, 0.0423, 0.0172, 0.0119, 0.0067, 0.0372, 0.0474, 0.0311, 0.0123,                       
        # 0.0388, 0.0366, 0.0459, 0.0161, 0.0097, 0.0083, 0.0235, 0.0159, 0.0595,                       
        # 0.2393, 0.0430, 0.0371, 0.0405, 0.0060, 0.0248, 0.0253, 0.0488, 0.0191,                       
        # 0.0161, 0.0243, 0.0661, 1.1723, 0.0086, 0.0296, 0.0407, 0.0130, 0.1904,                       
        # 0.0613, 0.2425], device='cuda:0').unsqueeze(0).expand(16,-1)
        # prior = Normal(mean,std)
        prior = Normal(mean,sigma*torch.ones([xc.shape[0],128]).cuda())
        if t != 0:
            pz.scale = sigma*torch.ones([xc.shape[0],128]).cuda()
            pz_prime.scale = sigma*torch.ones([xc.shape[0],128]).cuda()
        # pz.scale = sigma*torch.ones([xc.shape[0],128]).cuda()
        if t==0:
            log_pi = t/T*py.log_prob(yc).sum(dim=-2).squeeze(dim=-1) + t/T*prior.log_prob(z).sum(dim=-1)+(T-t)/T*pz.log_prob(z).sum(dim=-1)        # print(py.log_prob(yc).sum(dim=-2).squeeze(dim=-1)[:,0])

        else:
            log_pi = t/T*py.log_prob(yc).sum(dim=-2).squeeze(dim=-1) + t/T*prior.log_prob(z).sum(dim=-1)+(T-t)/T*pz.log_prob(z).sum(dim=-1) +(T-t)/T*pz_prime.log_prob(z).sum(dim=-1)
        # print(prior.log_prob(z).sum(dim=-1)[:,0])
        # print(pz.log_prob(z).sum(dim=-1)[:,0])
        # print(log_pi[:,0])
        return log_pi

    def test_time_scaling(self, xc, yc, pz=None, z=None, w=None, theta=None, T=10, cut_T=None, num_samples=50):
        log_w = torch.log(w).cuda()
        log_w_list = []
        for t in range(cut_T):
            forward_normal = self.forward_transition(torch.cat([z,theta], -1))

            new_z = forward_normal.rsample()
            backward_normal = self.backward_transition(torch.cat([new_z,theta], -1))
            update_log_pi   = self.compute_log_pi(T, t+1, new_z, pz, xc, yc, num_samples)
            previous_log_pi = self.compute_log_pi(T, t, z, pz, xc, yc, num_samples)
            
            log_backward = backward_normal.log_prob(z).sum(dim=-1).cuda()
            log_forward  = forward_normal.log_prob(new_z).sum(dim=-1).cuda()
            updated_new_log_w = log_w + update_log_pi + log_backward - previous_log_pi - log_forward

            z = new_z
            log_w = updated_new_log_w - torch.max(updated_new_log_w, dim=0,keepdim=True)[0]
            log_w_list.append(log_w)
            # shifted_w = torch.exp(updated_new_log_w - torch.max(updated_new_log_w, dim=0,keepdim=True)[0])
            # w = shifted_w shifted_w.sum(dim=0, keepdim=True)
            # z = new_z
            # log_w = torch.log(w).cuda()
        log_w_diff = torch.max(log_w[:,0])-torch.min(log_w[:,0])
        update_log_pi_diff = torch.max(update_log_pi[:,0])-torch.min(update_log_pi[:,0])
        previous_log_pi_diff = torch.max(previous_log_pi[:,0])-torch.min(previous_log_pi[:,0])
        log_backward_diff = torch.max(log_backward[:,0])-torch.min(log_backward[:,0])
        log_forward_diff = torch.max(log_forward[:,0])-torch.min(log_forward[:,0])
        return z, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff

    def test_time_scaling_resampling(self, xc, yc, pz, sigma=1., z=None, w=None, theta=None, T=10, cut_T=None, num_samples=50, threshold_rate=.5):
        log_w = torch.log(w).cuda()
        log_w_list = []
        z_list = []
        ess_threshold = num_samples*threshold_rate
        # ess_threshold = 0.
        # learning_rate = torch.tensor(0.003)
        learning_rate = torch.tensor(0.01)
        noise_scale = 1.
        z_list.append(z)
        log_w_list.append(log_w)

        # split_theta = self.lenc.split_theta_attentive(xc, yc)
        generated_pseudo_context = self.generate_pseudo_context(theta)
        pseudo_context_embeddings = self.encoder2(generated_pseudo_context)
        pz_prime = self.lenc(pseudo_context_embeddings)
        context_pseudo_context = torch.cat([theta, generated_pseudo_context], dim=-2)
        # context_pseudo_context_embeddings = self.encoder2(context_pseudo_context)
        # updated_pz = self.lenc(context_pseudo_context_embeddings)


        # pz_prime = self.lenc.aggregate_theta(generated_pseudo_context)
        # updated_pz = self.aggregate_context_pseudo_context(split_theta, generated_pseudo_context)

        # z=z.clone().detach().requires_grad_(True)
        for t in range(cut_T):
            forward_gradient = self.compute_gradient_log_pi(T, t+1, z, xc, yc, context_pseudo_context, pz, sigma=sigma, num_samples=num_samples)
            # forward_normal = Normal(z+(cut_T-t)/cut_T*learning_rate*forward_gradient, 0.001*torch.sqrt(torch.tensor(2.))*torch.sqrt(learning_rate)*torch.ones_like(z).cuda())
            # forward_normal = Normal(z+(cut_T-t)/cut_T*learning_rate*forward_gradient, torch.sqrt(torch.tensor(2.))*torch.sqrt((cut_T-t)/cut_T*learning_rate)*torch.ones_like(z).cuda())
            forward_normal = Normal(z+learning_rate*forward_gradient, noise_scale*torch.sqrt(torch.tensor(2.))*torch.sqrt(learning_rate)*torch.ones_like(z).cuda())
            # forward_normal = self.forward_transition(torch.cat([z,theta], -1))

            new_z = forward_normal.rsample()

            backward_gradient = self.compute_gradient_log_pi(T, t, new_z, xc, yc, context_pseudo_context, pz, sigma=sigma, num_samples=num_samples)
            # backward_normal = Normal(new_z+(cut_T-t)/cut_T*learning_rate*backward_gradient, 0.001*torch.sqrt(torch.tensor(2.))*torch.sqrt(learning_rate)*torch.ones_like(new_z).cuda())
            # backward_normal = Normal(new_z+(cut_T-t)/cut_T*learning_rate*backward_gradient, torch.sqrt(torch.tensor(2.))*torch.sqrt((cut_T-t)/cut_T*learning_rate)*torch.ones_like(new_z).cuda())
            backward_normal = Normal(new_z-learning_rate*backward_gradient, noise_scale*torch.sqrt(torch.tensor(2.))*torch.sqrt(learning_rate)*torch.ones_like(new_z).cuda())
            # backward_normal = self.backward_transition(torch.cat([new_z,theta], -1))
            update_log_pi   = self.compute_log_pi(T, t+1, new_z, pz, pz_prime, sigma, xc, yc, num_samples)
            previous_log_pi = self.compute_log_pi(T, t, z, pz, pz_prime, sigma, xc, yc, num_samples)
            
            log_backward = backward_normal.log_prob(z).sum(dim=-1).cuda()
            log_forward  = forward_normal.log_prob(new_z).sum(dim=-1).cuda()
            updated_new_log_w = log_w + update_log_pi + log_backward - previous_log_pi - log_forward


            z = new_z
            # z_list.append(z)
            log_w = updated_new_log_w - torch.max(updated_new_log_w, dim=0,keepdim=True)[0]

            # log_w_list.append(log_w)
            ess_now = self.stable_ess(log_w)
            # print('ESS before resampling')
            # print(ess_now)
            # print(ess_now)
            # print(ess_now)
            # print(torch.where(ess_now<ess_threshold))
            # max가 되는 애를 찍어보자...


            resample_index = torch.where(ess_now<ess_threshold)[0]
            resample_mask = ess_now<ess_threshold

            # print(resample_index)
            if len(resample_index) == 0:
                # print('No resampling')
                pass
            else:
                # print('Hello')
                w = torch.exp(log_w - torch.max(log_w, dim=0,keepdim=True)[0])
                prob = w / w.sum(dim=0, keepdim=True)
                # resampling with hard gumbel softmax
                # make num_samples number of sampled_indices using hard gumbel softmax to ensure backpropagation

                sampled_indices = RelaxedOneHotCategorical(1, prob.T).rsample(sample_shape=(num_samples,)) # [num_samples(replace dimension), batch_size, num_samples(gumbel softmax dimension)]
                hard_sampled_indices = torch.argmax(sampled_indices, dim=-1)
                hard_sampled_indices = torch.nn.functional.one_hot(hard_sampled_indices, num_classes=prob.shape[0]).float() # [num_samples(replace dimension), batch_size, num_samples(gumbel softmax dimension)]
                st_sampled_indices = (hard_sampled_indices-sampled_indices).detach()+sampled_indices # [num_samples(replace dimension), batch_size, num_samples(gumbel softmax dimension)]
                
                resampled_z = z.clone() # [num_samples, batch_size, dim_lat]
                resampled_log_w = log_w.clone() # [num_samples, batch_size]
                # Here, we have to make resampled_z by multiplying st_sampled_indices and z where mask is True
                st_z = torch.einsum('ijk,kjd->ijd', st_sampled_indices, z) # [num_samples, batch_size, dim_lat]
                resampled_z[:,resample_mask,:] = st_z[:,resample_mask,:]
                resampled_log_w[:,resample_mask] = torch.zeros_like(log_w[:,resample_mask])
                # print(resampled_z[:,0,0])
                # print(st_z[:,0,0])
                # print(z[:,0,0])
                # print(resampled_z[:10,:3,:10])
                # print(resample_mask.shape)
                # print(resampled_log_w.shape)
                # print(st_sampled_indices.shape)
                # print(z.shape)
                # print(resampled_z.shape)
                z = resampled_z
                # print(z[:,0,0])
                log_w = resampled_log_w

                # sampled_indices = RelaxedOneHotCategorical(1, prob.T).rsample().T
                # resampled_z = z.clone()
                # resampled_log_w = log_w.clone()

            # print('ESS after resampling')
            # print(self.stable_ess(log_w))
            z_list.append(z)
            log_w_list.append(log_w)
        


        log_w_diff = torch.max(log_w[:,0])-torch.min(log_w[:,0])
        update_log_pi_diff = torch.max(update_log_pi[:,0])-torch.min(update_log_pi[:,0])
        previous_log_pi_diff = torch.max(previous_log_pi[:,0])-torch.min(previous_log_pi[:,0])
        log_backward_diff = torch.max(log_backward[:,0])-torch.min(log_backward[:,0])
        log_forward_diff = torch.max(log_forward[:,0])-torch.min(log_forward[:,0])
        return z, log_w, z_list, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff


    def construct_input(self, batch, autoreg=False):
        x_y_ctx = torch.cat((batch.xc, batch.yc), dim=-1)
        x_0_tar = torch.cat((batch.xt, torch.zeros_like(batch.yt)), dim=-1)

        inp = torch.cat((x_y_ctx, x_0_tar), dim=1)

        return inp

    def create_mask(self, batch, y_dim, autoreg=False):
        num_ctx = batch.xc.shape[1]
        num_tar = batch.xt.shape[1]
        num_all = num_ctx + num_tar
        mask = torch.zeros(y_dim * num_all, y_dim * num_all, device='cuda').fill_(float('-inf'))
        mask[:, :y_dim * num_ctx] = 0.0

        return mask, num_tar 
    
    def encode(self, batch, z=None, num_samples=None, autoreg=False):
        y_dim = batch.yt.shape[-1]
        inp = self.construct_input(batch, autoreg)
        mask, num_tar = self.create_mask(batch, y_dim, autoreg)
        embeddings = self.dim_agg(inp, y_dim) 
            
        embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])
        out = stack(self.encoder1(embeddings, mask), num_samples)

        if z is None:
            context_embeddings = embeddings[:, :batch.xc.shape[1]*y_dim]

            pz = self.encoder2(context_embeddings)

            pz = self.lenc(pz)

            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])

        z = stack(z, inp.shape[-2], -2)

        z = z.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z.shape[-1])

        out = torch.cat([out, z], dim=-1)
        out = out.view(*out.shape[:2], -1, y_dim, out.shape[-1])

        return out[:, :, -num_tar:,:]

    def tts_encode(self, batch, z=None, num_samples=None, T=10, cut_T=None, resampling=False, threshold_rate=.5):
        y_dim = batch.yt.shape[-1]
        inp = self.construct_input(batch, autoreg=False)
        mask, num_tar = self.create_mask(batch, y_dim, autoreg=False)
        embeddings = self.dim_agg(inp, y_dim) 
            
        embeddings = embeddings.view(embeddings.shape[0], -1, embeddings.shape[-1])
        out = stack(self.encoder1(embeddings, mask), num_samples)

        sigma=1.

        if z is None:
            context_embeddings_ = embeddings[:, :batch.xc.shape[1]*y_dim]

            context_embeddings = self.encoder2(context_embeddings_)

            pz = self.lenc(context_embeddings)

            z = pz.rsample() if num_samples is None \
                    else pz.rsample([num_samples])

        w = torch.ones([z.shape[0], z.shape[1]])/z.shape[0]
        if resampling == False:
            z, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.test_time_scaling(batch.xc, batch.yc, pz, z, w, context_embeddings_, num_samples=num_samples,T=T,cut_T=cut_T)
        else:
            z, log_w, z_list, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.test_time_scaling_resampling(batch.xc, batch.yc, pz, sigma, z, w, context_embeddings_, num_samples=num_samples,T=T,cut_T=cut_T,threshold_rate=threshold_rate)


        z = stack(z, inp.shape[-2], -2)

        z = z.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z.shape[-1])
        encoded_list = []
        for i in range(cut_T):
            z_now = stack(z_list[i], inp.shape[-2], -2)
            z_now = z_now.repeat(1, 1, y_dim, 1).view(num_samples, embeddings.shape[0], -1, z_now.shape[-1])
            encoded = torch.cat([out,z_now],-1)
            encoded_list.append(encoded.view(*encoded.shape[:2], -1, y_dim, encoded.shape[-1])[:,:,-num_tar:,:])


        return z, log_w, z_list, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff, encoded_list

    def lencode(self, batch, autoreg=False):
        
        inp = self.construct_input(batch, autoreg)
        embeddings = self.dim_agg(inp, batch.y.shape[-1])
        
        num_context = batch.xc.shape[1]
        num_total = batch.x.shape[1]
        
        embeddings = torch.mean(embeddings, dim=-2)
        
        # context_embeddings = embeddings[:, :num_context].reshape(-1, num_context * batch.y.shape[-1], embeddings.shape[-1])
        # total_embeddings = embeddings[:, :num_total].reshape(-1, num_total * batch.y.shape[-1], embeddings.shape[-1])
        
        context_embeddings = embeddings[:, :num_context]
        total_embeddings = embeddings[:, :num_total]
        
        pz = self.lenc(context_embeddings)
        qz = self.lenc(total_embeddings)
        return pz, qz

class DTANP_Y_TTS_LG_LV_DIFF_PSEUDO(DTANP_Y_base):
    def __init__(
        self,
        dim_x,
        dim_y,
        d_model,
        emb_depth,
        dim_feedforward,
        nhead,
        dropout,
        num_layers,
        bound_std=True
    ):
        super(DTANP_Y_TTS_LG_LV_DIFF_PSEUDO, self).__init__(
            dim_x,
            dim_y,
            d_model,
            emb_depth,
            dim_feedforward,
            nhead,
            dropout,
            num_layers,
            bound_std
        )

        self.predictor = nn.Sequential(
            nn.Linear(dim_feedforward+d_model, dim_feedforward),
            nn.ReLU(),
            nn.Linear(dim_feedforward, 2)
        )
    
    def forward(self, batch, sigma=1., num_samples=None, reduce_ll=True, test_time_scaling=False, kl=False, lv=False, marginal=False, T=10, ess_lambda=1., threshold_rate=.5):
        outs = AttrDict()
        if test_time_scaling == False:
            if self.training:
                pz, qz = self.lencode(batch)
                z = qz.rsample() if num_samples is None else \
                        qz.rsample([num_samples])
                py = self.predict(batch.xc, batch.yc, batch.x,
                        z=z, num_samples=num_samples)

                if num_samples > 1:
                    # K * B * N
                    recon = py.log_prob(stack(batch.y, num_samples)).sum(-1)
                    # K * B
                    log_qz = qz.log_prob(z).sum(-1)
                    log_pz = pz.log_prob(z).sum(-1)

                    # K * B
                    log_w = recon.sum(-1) + log_pz - log_qz

                    outs.loss = -logmeanexp(log_w).mean() / batch.x.shape[-2]
                else:
                    outs.recon = py.log_prob(batch.y).sum(-1).mean()
                    outs.kld = kl_divergence(qz, pz).sum(-1).mean()
                    outs.loss = -outs.recon + outs.kld / batch.x.shape[-2]

            else:
                py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
                if num_samples is None:
                    ll = py.log_prob(batch.y).sum(-1)
                else:
                    y = torch.stack([batch.y]*num_samples)
                    if reduce_ll:
                        ll = logmeanexp(py.log_prob(y).sum(-1))
                    else:
                        ll = py.log_prob(y).sum(-1)
                num_ctx = batch.xc.shape[-2]
                if reduce_ll:
                    outs.ctx_ll = ll[...,:num_ctx].mean()
                    outs.tar_ll = ll[...,num_ctx:].mean()
                else:
                    outs.ctx_ll = ll[...,:num_ctx]
                    outs.tar_ll = ll[...,num_ctx:]
            return outs
        else:
            if self.training:
                # T = torch.randint(low=1,high=T+1,size=(1,)).cuda()
                break_t = T
                # py_list, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.tts_predict(batch.xc, batch.yc, batch.x, num_samples=num_samples,T=T, cut_T=break_t, resampling=True, threshold_rate=threshold_rate)
                py_list, log_w, z_list, log_w_stack_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.tts_predict(batch.xc, batch.yc, batch.x, sigma=sigma, num_samples=num_samples,T=T, cut_T=break_t, resampling=True, threshold_rate=threshold_rate)

                if num_samples > 1:
                    # K * B * N
                    outs.loss = 0
                    outs.recon_loss = 0
                    # for i in range(break_t):
                    #     recon = py_list[i].log_prob(stack(batch.y, num_samples)).sum(-1)
                    #     # K * B
                    #     now_recon_loss = -log_w_weighted_sum_exp(recon, stack(log_w_list[i], batch.x.shape[-2],-1).cuda()).mean()
                    #     outs.loss += now_recon_loss
                    #     outs.recon_loss += now_recon_loss
                    # outs.recon_loss /= break_t
                    recon = py_list[-1].log_prob(stack(batch.y, num_samples)).sum(-1)
                    # # K * B
                    if marginal == True:
                        outs.loss += -log_w_weighted_sum_exp(recon, log_w_stack_list[-1]).mean()
                        outs.marginal = -log_w_weighted_sum_exp(recon, log_w_stack_list[-1]).mean()
                    if lv ==True:
                        ess_loss_list = []
                        for i in range(break_t):
                            ess_loss_list.append(self.compute_log_variance_loss(log_w_stack_list[i][:,:,0]))
                        outs.log_variance = sum(ess_loss_list)/break_t
                        outs.loss += ess_lambda*outs.log_variance
                        # outs.ess_loss = self.compute_ess_loss(log_w)
                        # outs.loss += ess_lambda*outs.ess_loss

                    if kl == True:
                        kl_loss_list = []
                        log_w = log_w.detach()

                        for i in range(1,break_t+1):
                            log_w = log_w_stack_list[i-1].detach()
                            kl_loss_list.append(self.compute_kl_loss(log_w_stack_list[i][:,:,0], log_w[:,:,0]))
                        outs.kl_loss = sum(kl_loss_list)/break_t
                        outs.loss += ess_lambda*outs.kl_loss.mean()
                    # outs.loss = -log_weighted_sum_exp(log_w, w) / batch.x.shape[-2]
                    outs.log_w_diff = log_w_diff
                    outs.update_log_pi_diff = update_log_pi_diff
                    outs.previous_log_pi_diff = previous_log_pi_diff
                    outs.log_backward_diff = log_backward_diff
                    outs.log_forward_diff = log_forward_diff
            else:
                # py_list, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.tts_predict(batch.xc, batch.yc, batch.x, num_samples=num_samples,T=T, resampling=True, threshold_rate=threshold_rate)
                py_list, log_w, z_list, log_w_stack_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff = self.tts_predict(batch.xc, batch.yc, batch.x, num_samples=num_samples,T=T, resampling=True, threshold_rate=threshold_rate, sigma=sigma)
                # print(self.stable_ess(log_w[...,0]))
                # exit()

                y = torch.stack([batch.y]*num_samples)
                ctx_ll_list = []
                tar_ll_list = []
                for t in range(T):
                    py = py_list[t]
                    if num_samples is None:
                        ll = py.log_prob(batch.y).sum(-1)
                    else:
                        if reduce_ll:
                            ll = log_w_weighted_sum_exp(py.log_prob(y).sum(-1), log_w_stack_list[t])
                        else:
                            ll = py.log_prob(y).sum(-1)
                    num_ctx = batch.xc.shape[-2]
                    if reduce_ll:
                        ctx_ll_list.append(ll[...,:num_ctx].mean().item())
                        tar_ll_list.append(ll[...,num_ctx:].mean().item())

                    else:
                        ctx_ll_list.append(ll[...,:num_ctx].item())
                        tar_ll_list.append(ll[...,num_ctx:].item())
                outs.ctx_ll = ctx_ll_list[-1]
                outs.tar_ll = tar_ll_list[-1]
            return outs    

    def predict(self, xc, yc, xt, z=None, num_samples=None):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]), device='cuda')
        z_target = self.encode(batch, z=z, num_samples=num_samples, autoreg=False)

        out = self.predictor(z_target)
        mean, std = torch.chunk(out, 2, dim=-1)
        mean, std = mean.reshape((*mean.shape[:-2],-1)), std.reshape((*std.shape[:-2],-1))
        if self.bound_std:
            std = 0.1 + 0.9 * F.softplus(std)
        else:
            std = torch.exp(std)

        return Normal(mean, std)

    def tts_predict(self, xc, yc, xt, sigma=1., z=None, num_samples=None, T=10, cut_T=None, resampling=False, threshold_rate=.5):
        batch = AttrDict()
        batch.xc = xc
        batch.yc = yc
        batch.xt = xt
        batch.yt = torch.zeros((xt.shape[0], xt.shape[1], yc.shape[2]), device='cuda')
        if cut_T == None:
            cut_T = T 
        z, log_w, z_list, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff, encoded_list = self.tts_encode(batch, z=z, num_samples=num_samples, T=T, cut_T=cut_T, threshold_rate=threshold_rate, resampling=resampling)

        py_list = []
        for i in range(cut_T):
            out = self.predictor(encoded_list[i])
            mean, std = torch.chunk(out, 2, dim=-1)
            mean, std = mean.reshape((*mean.shape[:-2],-1)), std.reshape((*std.shape[:-2],-1))
            if self.bound_std:
                std = 0.1 + 0.9 * F.softplus(std)
            else:
                std = torch.exp(std)

            py_list.append(Normal(mean, std))

        log_w_stack_list = []
        for i in range(cut_T+1):
            log_w_stack_list.append(stack(log_w_list[i], xt.shape[-2],-1).cuda())
        if z.get_device() == '-1':
            log_w = stack(log_w, xt.shape[-2],-1).cpu()
        else:
            log_w = stack(log_w, xt.shape[-2],-1).cuda()

        # return py_list, log_w, log_w_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff
        return py_list, log_w, z_list, log_w_stack_list, log_w_diff, update_log_pi_diff, previous_log_pi_diff, log_backward_diff, log_forward_diff

    def calculate_crps(self, y_true, means, stds):
        y_true = y_true.squeeze(-1)
        means = means.squeeze(-1)
        stds = stds.squeeze(-1)

        z = (y_true - means) / stds
        #cdf_z = 0.5 * (1 + torch.stack([torch.erf(z[i, ...]/torch.sqrt(torch.tensor(2.0, device=z.device))) for i in range(50)]).mean(dim=0))

        cdf_z = 0.5 * (1 + torch.erf(z/torch.sqrt(torch.tensor(2.0, device=z.device))).mean(dim=0))

        pdf_z = (torch.exp(-0.5 * z**2) / torch.sqrt(torch.tensor(2 * torch.pi, device=z.device))).mean(dim=0)

        crps = stds * (z * (2 * cdf_z - 1) + 2 * pdf_z - 1 / torch.sqrt(torch.tensor(torch.pi, device=z.device)))
        return crps.mean(dim=-1)
    
    def crps(self, batch, num_samples=None):
        outs = AttrDict()

        if num_samples is None:
            y = batch.y.unsqueeze(-1)
        else:
            y = torch.stack([batch.y]*num_samples)
            
        py = self.predict(batch.xc, batch.yc, batch.x, num_samples=num_samples)
        
        num_ctx = batch.xc.shape[-2]

        means = py.loc; stds = py.scale 

        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]
                
        ctx_crps = self.calculate_crps(y_ctx, ctx_means, ctx_stds)
        tar_crps = self.calculate_crps(y_tar, tar_means, tar_stds)
        
        means = means.mean(dim=0)
        stds = torch.sqrt((stds**2).mean(dim=0) + (py.loc**2).mean(dim=0) - (py.loc.mean(dim=0)**2))
        
        z_score = Normal(0, 1).icdf(torch.tensor([(1 + 0.68) / 2])).to(means.device)
        ctx_means, tar_means = means[..., :num_ctx, :], means[..., num_ctx:, :]
        ctx_stds, tar_stds = stds[..., :num_ctx, :], stds[..., num_ctx:, :]
        y_ctx, y_tar = y[..., :num_ctx, :], y[..., num_ctx:, :]

        lower_bounds_ctx = ctx_means - z_score * ctx_stds
        upper_bounds_ctx = ctx_means + z_score * ctx_stds
        lower_bounds_tar = tar_means - z_score * tar_stds
        upper_bounds_tar = tar_means + z_score * tar_stds

        outs.ctx_ci = ((y_ctx >= lower_bounds_ctx) & (y_ctx <= upper_bounds_ctx)).float().mean()
        outs.tar_ci = ((y_tar >= lower_bounds_tar) & (y_tar <= upper_bounds_tar)).float().mean()
        
        outs.ctx_crps = ctx_crps.mean()
        outs.tar_crps = tar_crps.mean()
        
        return outs