import time
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
from experiments.utils_metrics import compute_log_normal_pdf, get_mse
from model.components import Z_to_mu_ReLU, Z_to_std_ReLU
import math
import utils
from model.ivp_solvers.gnn import GNN
import torch.nn.functional as F
import matplotlib.pyplot as plt
import os
import matplotlib.ticker as ticker
import numpy as np
from torch.linalg import svd 
import seaborn as sns
import torch
import torch.nn as nn
from torch.distributions.normal import Normal
from torch.distributions import kl_divergence
from experiments.utils_metrics import compute_log_normal_pdf
from model.components import Z_to_mu_ReLU, Z_to_std_ReLU
import time
import utils
import matplotlib.pyplot as plt
import os
import numpy as np

class IVPVAE(nn.Module):
    def __init__(
            self,
            args,
            embedding_nn, 
            embedding_nn_gnn,
            reconst_mapper, 
            ivp_solver1): 

        super(IVPVAE, self).__init__()

        self.args = args
        self.time_start = 0 
        self.latent_dim = args.latent_dim

        self.register_buffer('obsrv_std', torch.tensor([args.obsrv_std])) 
        self.register_buffer('mu', torch.tensor([args.prior_mu])) 
        self.register_buffer('std', torch.tensor([args.prior_std])) 
        
        self.embedding_nn = embedding_nn
        self.embedding_nn_gnn = embedding_nn_gnn
        self.ivp_solver1 = ivp_solver1
        self.reconst_mapper = reconst_mapper
        self.z2mu_mapper = Z_to_mu_ReLU(self.latent_dim) 
        self.z2std_mapper = Z_to_std_ReLU(self.latent_dim) 
        self.attention2 = SelfAttentionAdj(1)

        self.cur_delta = 0.0 

    def forward(self, batch, k_iwae=1):  
        results = dict.fromkeys(['likelihood', 'mse', 'forward_time', 'loss'])

        times_in = batch['times_in']
        times_initial = torch.zeros_like(times_in)
        data_in = batch['data_in']
        mask_in = batch['mask_in']
        T = times_in.shape[1]
        
        if self.args.extrap_full == True:
            times_out = batch['times_out']
            data_out = batch['data_out']
            mask_out = batch['mask_out']
        else:
            times_out = batch['times_in']
            data_out = batch['data_in']
            mask_out = batch['mask_in']

        utils.check_mask(data_in, mask_in)
        self.time_start = time.time()
        data_embeded = self.embedding_nn(data_in, mask_in)
        data_embeded_h = self.embedding_nn_gnn(data_in.unsqueeze(-1), mask_in.unsqueeze(-1))
        self.attention2_weights = self.attention2(data_embeded.unsqueeze(-1))
        self.adj = self.attention2_weights
        t_exist = times_in.gt(torch.zeros_like(times_in))
        lat_exist = t_exist.unsqueeze(-1).repeat(1, 1, self.latent_dim)

        if self.args.ivp_solver == 'gnn' or self.args.ivp_solver == 'invergnn' or self.args.ivp_solver == 'graph':
            latent = self.ivp_solver1(data_embeded.unsqueeze(-2), data_embeded_h.unsqueeze(-3), times_in.unsqueeze(-1), times_initial.unsqueeze(-1), self.adj).squeeze()
        
        lat_mu = torch.sum(latent * lat_exist, dim=-2, keepdim=True) / lat_exist.sum(dim=-2, keepdim=True)
        lat_variance = torch.sum((latent - lat_mu)**2 * lat_exist, dim=-2, keepdim=True) / lat_exist.sum(dim=-2, keepdim=True)

        z0_mean = self.z2mu_mapper(latent)
        z0_std = self.z2std_mapper(latent) + 1e-8
        z0_mean = z0_mean * lat_exist
        z0_mean = torch.nan_to_num(z0_mean)
        z0_std = torch.nan_to_num(z0_std, nan=1e-8)

        t_loss_start = time.time()
        fp_distr = Normal(z0_mean, z0_std)
        kldiv_z0_all = kl_divergence(fp_distr, torch.distributions.Normal(self.mu, self.std))
        kldiv_z0 = torch.sum(kldiv_z0_all * lat_exist, (1, 2)) / lat_exist.sum((1, 2))
        t_loss_end = time.time()
        self.time_start += t_loss_end - t_loss_start

        z0_mean_iwae = z0_mean.repeat(k_iwae, 1, 1, 1)
        z0_std_iwae = z0_std.repeat(k_iwae, 1, 1, 1)
        initial_state = utils.sample_standard_gaussian(z0_mean_iwae, z0_std_iwae)

        if self.args.combine_methods == "average":
            initial_state = torch.sum(initial_state * lat_exist, dim=-2, keepdim=True) / lat_exist.sum(dim=-2, keepdim=True)
        elif self.args.combine_methods == "kl_weighted":
            kl_r = kldiv_z0_all
            kl_w = kl_r / torch.sum(kl_r * lat_exist, dim=-2, keepdim=True)
            kl_w = (kl_w * lat_exist).repeat(k_iwae, 1, 1, 1)
            initial_state = torch.sum(initial_state * kl_w, dim=-2, keepdim=True)
        else:
            raise NotImplementedError

        if self.args.ivp_solver == 'gnn' or self.args.ivp_solver == 'invergnn' or self.args.ivp_solver == 'graph':
            initial_state_h = initial_state.view(initial_state.shape[0],initial_state.shape[1],initial_state.shape[2],self.latent_dim,1)
            kl_r = kldiv_z0_all
            kl_w = kl_r / torch.sum(kl_r * lat_exist, dim=-2, keepdim=True)
            kl_w = (kl_w * lat_exist)
            kl_w = kl_w.unsqueeze(-1)
            self.adj = torch.sum(self.adj * kl_w, dim=1, keepdim=True).squeeze(1)

            sol_z = self.ivp_solver1(initial_state, initial_state_h, times_initial.unsqueeze(0),times_out.unsqueeze(0), self.adj)

            zn = sol_z[:, :, -1:, :] 
            zn_h = zn.view(zn.shape[0], zn.shape[1], zn.shape[2], self.latent_dim, 1)    
            sol_z_back = self.ivp_solver1(zn, zn_h,times_out.unsqueeze(0), times_initial.unsqueeze(0), self.adj)
            sol_z_back = sol_z_back.flip(2)    
            
            pred_x_back = self.reconst_mapper(sol_z_back)

            t_star_idx = int(T/2)
            initial_state_shifted = sol_z[:, :, t_star_idx:t_star_idx+1, :]
            initial_state_h_shifted = initial_state_shifted.view(initial_state_shifted.shape[0],initial_state_shifted.shape[1],initial_state_shifted.shape[2],self.latent_dim,1)
            times_abs = times_out[:, t_star_idx:].unsqueeze(0)
            times_abs_initial = times_out[:, t_star_idx:t_star_idx+1].unsqueeze(0)
            times_abs_initial = times_abs_initial.repeat(1, 1, times_abs.shape[-1])
            sol_z_abs = self.ivp_solver1(initial_state_shifted, initial_state_h_shifted, times_abs_initial, times_abs, self.adj)
            
            times_past = times_out[:, :t_star_idx+1] 
            times_start_back_exp = times_out[:, t_star_idx:t_star_idx+1].unsqueeze(0).repeat(1, 1, times_past.shape[1])
            sol_z_past = self.ivp_solver1(initial_state_shifted, initial_state_h_shifted, times_start_back_exp, times_past.unsqueeze(0), self.adj)
            pred_x_past = self.reconst_mapper(sol_z_past)

        data_out = data_out.repeat(k_iwae, 1, 1, 1)
        mask_out = mask_out.repeat(k_iwae, 1, 1, 1)

        pred_x = self.reconst_mapper(sol_z)
        pred_x_abs = self.reconst_mapper(sol_z_abs)
        rec_likelihood = compute_log_normal_pdf(data_out, mask_out, pred_x, self.args)
        
        t_loss_start = time.time()
        ll_z = compute_log_normal_pdf(data_embeded, lat_exist, sol_z, self.args)
        loss_ll_z = -torch.logsumexp(ll_z, 0).mean(dim=0)
        
        loss = -torch.logsumexp(rec_likelihood - self.args.kl_coef * kldiv_z0, 0)
        loss = torch.mean(loss, dim=0)
        
        t_star_idx = int(T/2)
        L_param = getattr(self.args, 'L', 1.0)
        

        max_loss = repulsion_loss(sol_z, sol_z_abs, model=self, sol_z=sol_z, sol_z_abs=sol_z_abs, times_out=times_out, t_star_idx=t_star_idx, L_param=L_param)


        loss_rtg = torch.sum((pred_x_back - data_out)**2 * mask_out) / torch.clamp(torch.sum(mask_out), min=1.0)


        results["loss"] = loss + 0.1 * max_loss + 0.1 * loss_rtg

        pred_x_original_past = pred_x[:, :, :t_star_idx+1, :]
        results['likelihood'] = torch.mean(rec_likelihood).detach()
        results['kldiv_z0'] = torch.mean(kldiv_z0).detach()
        results['loss_ll_z'] = loss_ll_z.detach()
        results["lat_variance"] = torch.mean(lat_variance).detach()

        results['loss_rtg'] = loss_rtg.detach()

        t_loss_end = time.time()
        self.time_start += t_loss_end - t_loss_start
        forward_info = {'initial_state': initial_state, 'sol_z': sol_z, 'pred_x': pred_x}

        return results, forward_info
    

    def run_validation(self, batch):
        return self.forward(batch, k_iwae=self.args.k_iwae)


class SelfAttentionAdj(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
    def forward(self, x):

        q = self.q_proj(x) 
        k = self.k_proj(x) 
        attn_logits = torch.matmul(q, k.transpose(-1, -2)) / (q.shape[-1] ** 0.5)
        attn_weights = self.softmax(attn_logits)
        return attn_weights



def compute_delta(sol_z, sol_z_abs, times_out, t_star_idx,
                  W, A_local, L_enc, kappa, m, L_param,
                  eps_fallback=1e-8):
    try:
        if isinstance(W, torch.Tensor):
            s_vals = torch.linalg.svdvals(W)
            m_val = float(s_vals.min().cpu())
            L_g_val = float(s_vals.max().cpu())
        else:
            s_vals = np.linalg.svd(W, compute_uv=False) 
            m_val = float(s_vals.min())
            L_g_val = float(s_vals.max())

        if m_val <= 0:
            m_val = eps_fallback
        if L_g_val <= 0:
            L_g_val = 1.0
    except Exception:
        m_val = eps_fallback
        L_g_val = 1.0

    m = float(m_val)
    L_g = float(L_g_val)

    try:
        if isinstance(A_local, torch.Tensor):
            kappa_val = float(torch.linalg.svdvals(A_local).min().cpu())
        else:
            kappa_val = float(np.linalg.svd(A_local)[1].min())
        if kappa_val < 0:
            kappa_val = 0.0
    except Exception:
        kappa_val = 0.0

    kappa = float(kappa_val)
    L_enc = float(L_enc)
    try:
        diff = (sol_z.detach() - sol_z_abs.detach()).reshape(-1, sol_z.shape[-1])
        Delta_in_val = float(torch.mean(torch.norm(diff, dim=-1)).cpu())
        if Delta_in_val <= 0:
            Delta_in_val = eps_fallback
    except Exception:
        Delta_in_val = eps_fallback

    Delta_in = float(Delta_in_val)

    try:
        num_steps = float(times_out.shape[1] - t_star_idx)
        
        eta = kappa * m * L_enc * Delta_in
        
        phi = 0.5
        self_term_bound = (1.0 + phi * L_g) * Delta_in
        
        step_margin = eta - self_term_bound
        
        delta = num_steps * step_margin

    except Exception:
        delta = float(eps_fallback)

    if not (isinstance(delta, float) and delta > 0):
        delta = float(eps_fallback)

    return delta

def repulsion_loss(pred_x, pred_x_abs, model=None, sol_z=None, sol_z_abs=None,
                   times_out=None, t_star_idx=None, L_param=None,
                   fallback_margin=1e-8):

    device = pred_x.device

    with torch.no_grad():
        px = pred_x
        pxa = pred_x_abs

    try:

        anchor = px.mean(dim=2)[0]
        negative = pxa.mean(dim=2)[0]
    except Exception:
        anchor = px.mean(dim=(0,1,2))
        negative = pxa.mean(dim=(0,1,2))

    try:
        sq_dist = ((anchor - negative) ** 2).sum(dim=1)
    except Exception:
        sq_dist = torch.sum((anchor - negative) ** 2).unsqueeze(0)

    delta_val = None
    try:
        if model is not None and sol_z is not None and sol_z_abs is not None:

            if hasattr(model, "reconst_mapper") and hasattr(model.reconst_mapper, "layers"):
                W = model.reconst_mapper.layers[-1].weight
            else:

                W = torch.eye(sol_z.shape[-1], device=device)

            if hasattr(model, "adj"):
                A_local = model.adj.mean(dim=0)  
            else:
                A_local = torch.eye(W.shape[0], device=device)

            try:

                svd_W = torch.linalg.svd(W, compute_uv=False)
                m = svd_W[-1] + 1e-6
            except Exception:
                m = torch.tensor(1.0, device=device)

            try:
                svd_A = torch.linalg.svd(A_local, compute_uv=False)
                kappa = svd_A[-1] + 1e-6 
            except Exception:
                kappa = torch.tensor(1.0, device=device)

          
            L_enc = torch.tensor(1.0, device=device)
        

            if L_param is None:
                L_param = getattr(model.args, 'L', 1.0)

            delta_val = compute_delta(
                sol_z, sol_z_abs, times_out, t_star_idx,
                W=W, A_local=A_local, 
                L_enc=L_enc, kappa=kappa, m=m, 
                L_param=L_param
            )

            margin = torch.tensor(delta_val, device=device)

        else:
            margin = torch.tensor(fallback_margin, device=device)
    except Exception as e:

        margin = torch.tensor(fallback_margin, device=device)


    delta = torch.full_like(sq_dist, fill_value=margin)
    
    loss_per_t = F.relu(delta - sq_dist)

    loss = loss_per_t.mean()

    return loss
