import copy

from tqdm import tqdm
import torch
import time
import warnings
import matplotlib.pyplot as plt
import numpy as np
import wandb
from types import SimpleNamespace
from crc.baselines.contrastive_crl.src.evaluation import LossCollection, get_R2_values

import math

# Adaptive bandwidth design

def _mean_var_scale(X: torch.Tensor) -> float:
    """Return mean variance across dimensions (trace(Cov)/d) for setting scale."""
    if X.numel() == 0: return 1.0
    return X.var(dim=0, unbiased=False).mean().clamp_min(1e-12).item()

def _normal_ref_h2(n: int, d: int, s2: float) -> float:
    """Multivariate normal reference (Scott/Silverman): h^2 = c_d^2 * s2 * n^{-2/(d+4)}"""
    if n <= 1: return max(s2, 1e-6)
    c = (4.0 / (d + 2.0)) ** (1.0 / (d + 4.0))
    return (c ** 2) * s2 * (n ** (-2.0 / (d + 4.0)))

def _grid_search_h2(objective, h2_base: float, device, r: float = 4.0, K: int = 25):
    """
    Grid search on log scale [h2_base/r, h2_base*r], objective returns tensor scalar.
    """
    h2_base = max(h2_base, 1e-8)
    ts = torch.linspace(-math.log(r), math.log(r), K, device=device)
    cand = h2_base * torch.exp(ts)                   # [K]
    losses = []
    for h2 in cand:
        loss = objective(h2)                         # tensor()
        losses.append(loss)
    losses = torch.stack(losses)                     # [K]
    idx = int(torch.argmin(losses).item())
    return float(cand[idx].item())


# Precomputed closure: given squared distance matrix D2, return h2-dependent log-kde evaluation
def _precomp_log_kde_from_D2(D2: torch.Tensor, d: int, eps: float = 1e-12):
    """
    D2: [B,M], pairwise squared distance matrix U–X
    Returns function f(h2) -> [B], equivalent to per-sample log_kde(U, X, h2)
    """
    B, M = D2.shape
    def f(h2: float) -> torch.Tensor:
        h2_t = torch.as_tensor(h2, device=D2.device, dtype=D2.dtype)
        Q = -0.5 * D2 / (h2_t + eps)                       # [B,M]
        logsum = torch.logsumexp(Q, dim=1) - math.log(M)   # [B]
        log_norm = -0.5 * d * math.log(2 * math.pi) - 0.5 * d * torch.log(h2_t)
        return log_norm + logsum                           # [B]
    return f

def _precomp_log_kde_loo_from_D2(D2: torch.Tensor, d: int, eps: float = 1e-12):
    """
    D2: [N,N], pairwise squared distance matrix X–X
    Returns function f(h2) -> [N], equivalent to per-sample log_kde_loo(X, h2)
    If N<2, degenerates to self-contained KDE (consistent with original implementation)
    """
    N = D2.shape[0]
    eye = torch.eye(N, device=D2.device, dtype=torch.bool)

    # Pre-build a "non-LOO" evaluator for N<2 fallback
    f_self = _precomp_log_kde_from_D2(D2, d, eps)

    def f(h2: float) -> torch.Tensor:
        if N < 2:
            return f_self(h2)                              # Consistent with original fallback
        h2_t = torch.as_tensor(h2, device=D2.device, dtype=D2.dtype)
        Q = -0.5 * D2 / (h2_t + eps)                       # [N,N]
        Q = Q.masked_fill(eye, -float('inf'))
        m = Q.max(dim=1, keepdim=True).values
        logsum = m.squeeze(1) + torch.log(torch.exp(Q - m).sum(dim=1) / (N - 1) + eps)  # [N]
        log_norm = -0.5 * d * math.log(2 * math.pi) - 0.5 * d * torch.log(h2_t)
        return log_norm + logsum                           # [N]
    return f

# Rewritten bandwidth estimation: reuse precomputed distance matrix
def estimate_h2_shared(U_batch: torch.Tensor, X: torch.Tensor,
                       ema_prev: float | None = None, beta: float = 0.0,
                       r: float = 4.0, K: int = 25) -> float:
    """
    Estimate h^2 such that hv=hq=hu=h^2: minimize -E_u[ log KDE(u ; centers=X, h^2 I) ].
    Now reuses precomputed U–X distances to avoid recomputing distances for K candidates.
    """
    M, d = U_batch.shape
    N = X.shape[0]
    s2U = _mean_var_scale(U_batch)
    s2X = _mean_var_scale(X)
    alpha = M / max(M + N, 1)
    s2_pool = alpha * s2U + (1 - alpha) * s2X
    h2_0 = _normal_ref_h2(max(M, 2), d, s2_pool)

    # Precompute U–X squared distance matrix
    D2_UX = _pairwise_sq_dist(U_batch, X).detach()
    logkde = _precomp_log_kde_from_D2(D2_UX, d)

    def _obj(h2: float) -> torch.Tensor:
        return -logkde(h2).mean()     # Consistent with "minimize negative log density mean" objective

    h2_hat = _grid_search_h2(_obj, h2_0, device=U_batch.device, r=r, K=K)

    # EMA smoothing (optional)
    if ema_prev is not None and beta > 0:
        h2_hat = (1 - beta) * ema_prev + beta * h2_hat

    return max(float(h2_hat), 1e-8)

def estimate_g2_particles(X: torch.Tensor,
                          ema_prev: float | None = None, beta: float = 0.0,
                          r: float = 4.0, K: int = 25) -> float:
    """
    Estimate g^2: maximize LOO log-likelihood of particle cloud (here minimize its negative).
    Reuses precomputed X–X distances to avoid recomputing distances for K candidates.
    """
    N, d = X.shape
    s2X = _mean_var_scale(X)
    g2_0 = _normal_ref_h2(max(N, 2), d, s2X)

    # Precompute X–X squared distance matrix
    D2_XX = _pairwise_sq_dist(X, X).detach()
    logkde_loo = _precomp_log_kde_loo_from_D2(D2_XX, d)

    def _obj(h2: float) -> torch.Tensor:
        return -logkde_loo(h2).mean()  # Negative mean of LOO log-likelihood

    g2_hat = _grid_search_h2(_obj, g2_0, device=X.device, r=r, K=K)

    if ema_prev is not None and beta > 0:
        g2_hat = (1 - beta) * ema_prev + beta * g2_hat

    return max(float(g2_hat), 1e-8)

def estimate_h2_v_loo(U_batch, ema_prev=None, beta=0.0, r=4.0, K=25):
    M, d = U_batch.shape
    s2U  = _mean_var_scale(U_batch)
    h2_0 = _normal_ref_h2(max(M,2), d, s2U)
    def _obj(h2):         # Observation side LOO (stable)
        return -log_kde_loo(U_batch, h2).mean()
    h2_hat = _grid_search_h2(_obj, h2_0, device=U_batch.device, r=r, K=K)
    if ema_prev is not None and beta>0:
        h2_hat = (1-beta)*ema_prev + beta*h2_hat
    return max(h2_hat, 1e-8)

def estimate_h2_q_direct(U_batch: torch.Tensor, X: torch.Tensor, g2_hat: float,
                         ema_prev: float | None = None, beta: float = 0.0,
                         r: float = 4.0, K: int = 25) -> float:
    """
    Given g^2=hrho, directly select h_q^2:
      h_q^2 = argmin_{h_q^2>0} -E_u log q̂_ρ(u; h_q^2 + g^2).
    """
    M, d = U_batch.shape
    N     = X.shape[0]
    # Reference initialization: compute "effective bandwidth" NR initial value, then subtract g^2 for h_q^2 initial value
    s2U = _mean_var_scale(U_batch)
    s2X = _mean_var_scale(X)
    alpha  = M / max(M + N, 1)
    s2pool = alpha * s2U + (1 - alpha) * s2X
    h2_eff0 = _normal_ref_h2(max(M, 2), d, s2pool)  # Effective bandwidth initial value
    h2_q0   = max(h2_eff0 - g2_hat, 1e-8)           # Direct h_q^2 initial value

    # Objective: -E_u log q̂_ρ(u; h_q^2 + g^2)
    def _obj(h2_q):
        return -log_kde(U_batch, X, h2_q + g2_hat).mean()

    h2_q_hat = _grid_search_h2(_obj, h2_q0, device=U_batch.device, r=r, K=K)

    # EMA smoothing (optional)
    if (ema_prev is not None) and (beta > 0):
        h2_q_hat = (1 - beta) * ema_prev + beta * h2_q_hat

    # Protection: not less than 1e-8, and prevent h_q^2 from being too close to 0 or "consuming" g^2 (optional proportional clipping)
    h2_q_hat = float(max(h2_q_hat, 1e-8))
    return h2_q_hat

# KDE helpers for denoising

def _pairwise_sq_dist(u, v):
    diff = u[:, None, :] - v[None, :, :]
    return (diff ** 2).sum(dim=-1)

def log_kde(u, V, h2, eps=1e-12):
    """
    log \hat{p}(u) with Gaussian KDE of bandwidth h^2 I.
    u:[B,d], V:[M,d]
    """
    B, d = u.shape
    Q = -0.5 * _pairwise_sq_dist(u, V) / (h2 + eps)        # [B,M]
    logsum = torch.logsumexp(Q, dim=1) - math.log(Q.shape[1])
    h2_t = torch.as_tensor(h2, device=u.device, dtype=u.dtype)
    log_norm = -0.5 * d * math.log(2 * math.pi) - 0.5 * d * torch.log(h2_t)
    return log_norm + logsum                                # [B]

def log_kde_loo(V, h2, eps=1e-12):
    M, d = V.shape
    if M < 2: # Return self-contained KDE (slightly biased but stable)
        return log_kde(V, V, h2)
        # Or directly return 0 (equivalent to ignoring this entropy term; often more stable)
        # return torch.zeros((M,), device=V.device, dtype=V.dtype)
    Q = -0.5 * _pairwise_sq_dist(V, V) / (h2 + eps)
    Q.fill_diagonal_(-float('inf'))
    m = Q.max(dim=1, keepdim=True).values
    logsum = m.squeeze(1) + torch.log(torch.exp(Q - m).sum(dim=1) / (M - 1) + eps)
    h2_t = torch.as_tensor(h2, device=V.device, dtype=V.dtype)
    log_norm = -0.5 * d * math.log(2 * math.pi) - 0.5 * d * torch.log(h2_t)
    return log_norm + logsum

def gaussian_weights(u, Z, h2, eps=1e-12):
    Q = -0.5 * _pairwise_sq_dist(u, Z) / (h2 + eps)
    return torch.softmax(Q, dim=1)  # rows sum to 1

# _scaled_nw is scale transformation for smooth_field, restoring convolution computation
def _scaled_nw(u, Z, fieldZ, h2, eps=1e-12):
    Q = -0.5 * _pairwise_sq_dist(u, Z) / (h2 + eps)
    # More numerically stable: subtract maximum per row before exp to avoid overflow
    Q = Q - Q.max(dim=1, keepdim=True).values
    W = torch.exp(Q)
    mass = W.sum(dim=1, keepdim=True).clamp_min(eps)
    return (W / mass) @ fieldZ * (mass / Z.size(0))


def smooth_field(u, Z, fieldZ, h2, eps=1e-12):
    """
    Outer smoothing: adaptive switching between NW (stable) and equivalent convolution (scaled-NW, dimension-preserving)
    Criteria: coefficient of variation CV of kernel mass and effective neighbor count k_eff
    u:[B,d], Z:[M,d], fieldZ:[M,d] -> [B,d]
    """
    # Compute unnormalized kernel weights (numerically stable: subtract maximum per row)
    Q = -0.5 * _pairwise_sq_dist(u, Z) / (h2 + eps)
    Q = Q - Q.max(dim=1, keepdim=True).values
    W = torch.exp(Q)                                   # [B,M]

    # Kernel mass and criteria
    mass = W.sum(dim=1, keepdim=True).clamp_min(eps)   # [B,1]
    cv = (mass.std(unbiased=False) / (mass.mean() + eps)).item()            # Kernel mass CV
    keff = ((mass.squeeze(1)**2) / (W.pow(2).sum(dim=1) + eps)).mean().item()  # Average k_eff

    # Empirical threshold: CV<=0.20 and k_eff>=48 considered "locally approximately uniform", can switch to convolution dimension-preserving
    use_conv_like = (cv <= 0.20) and (keff >= 48)

    if use_conv_like:
        # scaled-NW: (W/mass)@F * (mass/M) —— numerically equivalent to convolution (W@F)/M, more stable
        return (W / mass) @ fieldZ * (mass / Z.size(0))
    else:
        # Early stage or non-uniform distribution: NW (row normalization), more balanced and stable
        return (W / mass) @ fieldZ

def score_kde(u, V, h2, eps=1e-12):
    """
    Inner KDE score: s(u)=∇_u log KDE(u; V, h^2 I)
    u:[B,d], V:[M,d] -> [B,d]
    """
    W = gaussian_weights(u, V, h2, eps)
    mu = W @ V
    return (mu - u) / (h2 + eps)

def estimate_sigma_diag_kernel(X, h2, eps=1e-12):
    """
    Use only kernel weighting to estimate diagonal Σ_i in each particle ξ_i's neighborhood:
        Σ_i = diag(C_i),
    where C_i is the local covariance with Gaussian kernel K_{h^2} of bandwidth h^2.
    X:  [N, d]  particle positions (same latent basis as updates)
    h2: float   bandwidth (recommended to use scalar form of outer smoothing H_u)
    Returns: [N, d], diagonal Σ_i vector for each particle
    """
    # Pairwise kernel weights
    diff2 = (X[:, None, :] - X[None, :, :]).pow(2).sum(-1)      # [N,N]
    Q = -0.5 * diff2 / (h2 + eps)
    # Numerically stable: subtract maximum per row before exp
    Q = Q - Q.max(dim=1, keepdim=True).values
    W = torch.exp(Q)                                            # [N,N]
    mass = W.sum(dim=1, keepdim=True).clamp_min(eps)            # [N,1]

    # Kernel-weighted mean and per-dimension variance: diagonal of C_i
    mu  = (W @ X) / mass                                        # [N,d]
    EX2 = (W @ (X * X)) / mass                                  # [N,d]
    var = (EX2 - mu * mu).clamp_min(0.0) + eps                  # [N,d]

    return var  # Used as per-dimension scaling factor (v = - grad_psi * var)

@torch.no_grad()
def _scale_floor(x: torch.Tensor, frac: float = 0.05) -> float:
        # Use "mean variance per dimension" as scale, give bandwidth a lower bound: h^2_min = frac * trace(Cov)/d
        if x.numel() == 0: return 1e-8
        return float((x.var(dim=0, unbiased=False).mean() * frac).clamp_min(1e-8).item())

class SimpleParticleDenoiser:
    def __init__(self, n_particles=512,
                 hq=0.5, hv=0.5, hrho=0.5, hu=None,
                 lam_rho=1.0, inner_steps=1, step_size=1e-2,
                 bw_ema_beta=0.2, auto_bandwidth=True):
        self.np = n_particles
        self.hq, self.hv, self.hrho = hq, hv, hrho
        self.lam_rho = lam_rho
        self.inner_steps = inner_steps
        self.step_size = step_size
        self.particles = {}
        # Adaptive bandwidth
        self.auto_bandwidth = auto_bandwidth
        self._h2_hist = {}   # label -> float
        self._g2_hist = {}
        self._bw_beta = bw_ema_beta

    def _maybe_init(self, label, u_batch):
        if label not in self.particles:
            M, d = u_batch.shape
            idx = torch.randint(0, M, (min(self.np, M),), device=u_batch.device)
            P = u_batch[idx] + 0.01 * torch.randn_like(u_batch[idx])
            if P.shape[0] < self.np:
                extra_idx = torch.randint(0, P.shape[0], (self.np - P.shape[0],), device=u_batch.device)
                P = torch.cat([P, P[extra_idx]], dim=0)
            self.particles[label] = P
    
    @torch.no_grad()
    def _autotune_bandwidths(self, label, u_batch):
        if not self.auto_bandwidth: 
            return
        X = self.particles[label]

        g2_prev = self._g2_hist.get(label, None)
        g2_hat  = estimate_g2_particles(X, ema_prev=g2_prev, beta=self._bw_beta)

        hq_prev = getattr(self, "_hq_hist", {}).get(label, None) if hasattr(self, "_hq_hist") else None
        h2_q_hat = estimate_h2_q_direct(u_batch, X, g2_hat, ema_prev=hq_prev, beta=self._bw_beta)

        hv_prev = getattr(self, "_hv_hist", {}).get(label, None) if hasattr(self, "_hv_hist") else None
        h2_v_hat = estimate_h2_v_loo(u_batch, ema_prev=hv_prev, beta=self._bw_beta)

        # Set lower bounds based on data scale to ensure effective neighbor count is not too small
        floor_u = _scale_floor(u_batch, frac=0.05)   # 5% of mean variance per dimension
        floor_x = _scale_floor(X,       frac=0.05)

        h2_v_hat = max(h2_v_hat, floor_u)
        # q's outer convolution/scoring should not be too small; ρ's kernel can be slightly smaller but also give a lower bound
        h2_q_hat = max(h2_q_hat, 0.5 * floor_x)
        g2_hat   = max(g2_hat,   0.5 * floor_x)

        self.hrho = g2_hat
        self.hq   = h2_q_hat
        self.hv   = h2_v_hat

        self._g2_hist[label] = g2_hat
        if not hasattr(self, "_hq_hist"): self._hq_hist = {}
        if not hasattr(self, "_hv_hist"): self._hv_hist = {}
        self._hq_hist[label] = h2_q_hat
        self._hv_hist[label] = h2_v_hat

    @torch.no_grad()
    def _inner_update(self, label, u_batch):
        self._maybe_init(label, u_batch)
        # Adaptive bandwidth
        self._autotune_bandwidths(label, u_batch)

        X = self.particles[label]
        for _ in range(self.inner_steps):
            s_nu_U  = score_kde(u_batch, u_batch, self.hv)
            s_rho_X = score_kde(X,       X,       self.hrho)
            s_q_X   = score_kde(X,       X,       self.hq+self.hrho)

            tilde_s_q  = _scaled_nw(X, X,       s_q_X,  self.hq)
            tilde_s_nu = _scaled_nw(X, u_batch, s_nu_U, self.hq)

            grad_psi = (tilde_s_q - tilde_s_nu) + self.lam_rho * s_rho_X

            # Mobility (diagonal) —— if I is needed, can change to v = -grad_psi
            Sigma_diag = estimate_sigma_diag_kernel(X, h2=self.hq)
            v = -grad_psi * Sigma_diag

            X = X + self.step_size * v

        self.particles[label] = X
    
    def denoise_loss(self, label, u_batch):
        """ Estimate KL(q_ρ || ν0) + λ Ent(ρ), and perform one no_grad particle inner loop update """
        self._inner_update(label, u_batch.detach())
        X = self.particles[label]                                    # [N,d]

        # Key consistency correction: effective bandwidth of q_ρ = hq + hrho
        h2_eff = float(self.hq + self.hrho)
        # Sample q_ρ
        Uq = X + torch.sqrt(torch.tensor(self.hq, device=X.device, dtype=X.dtype)) * torch.randn_like(X)
        # KL estimation
        log_q  = log_kde(Uq, X,       h2_eff)
        log_nu = log_kde(Uq, u_batch, self.hv)
        kl = (log_q - log_nu).mean()
        # Ent(ρ) estimation (LOO)
        log_rho_loo = log_kde_loo(X, self.hrho)
        ent = - log_rho_loo.mean()
        return kl + self.lam_rho * ent

# Above is denoising code

def get_NOTEARS_loss(A):
    d = A.size(0)
    return torch.trace(torch.matrix_exp(A * A)) - d   # A*A is Hadamard square

def train_model(model, device, dl_train, dl_val, training_kwargs, z_gt=None, x_val=None, verbose=False):
    # device = training_kwargs.get("device", 'cpu')
    model = model.to(device)
    best_model = copy.deepcopy(model)

    mse = torch.nn.MSELoss()
    # mse = torch.nn.HuberLoss(delta=1., reduction='sum')
    huber_loss = torch.nn.HuberLoss(delta=1., reduction='sum')
    ce = torch.nn.CrossEntropyLoss()
    loss_tracker = LossCollection()
    val_loss = np.inf

    epochs = training_kwargs.get("epochs", 10)
    mu = training_kwargs.get('mu', 0.0)
    eta = training_kwargs.get('eta', 0.0)
    kappa = training_kwargs.get('kappa', 0.0)
    lr_nonparametric = training_kwargs.get('lr_nonparametric', .1)
    lr_parametric = training_kwargs.get('lr_parametric', lr_nonparametric)
    contrastive = False if training_kwargs.get("type") in ['vae', 'vae_vanilla', 'vae_vanilla2', 'vae_contrastive'] else True
    optimizer_name = training_kwargs.get("optimizer", "sgd").lower()

    # Denoising parameters
    val_every = int(training_kwargs.get('val_every', 10))
    if val_every < 1:
        val_every = 1

    use_denoise   = training_kwargs.get('denoise', True)          # Enable/disable denoising term, default on
    tau3_target = training_kwargs.get('tau3', 1.0)
    tau3_warmup_epochs = int(training_kwargs.get('tau3_warmup_epochs', 10))
    tau3_gate_eps = float(training_kwargs.get('tau3_gate_eps', 0.1))  # New: gating threshold
    denoise_cfg = dict(
        n_particles = training_kwargs.get('denoise_particles', 512),
        hq          = training_kwargs.get('kde_hq',   0.5),   # Inner: q bandwidth
        hv          = training_kwargs.get('kde_hv',   0.5),   # Inner: ν bandwidth
        hrho        = training_kwargs.get('kde_hrho', 0.5),   # Inner: ρ bandwidth
        hu          = training_kwargs.get('kde_hu',   None),  # Outer convolution bandwidth; None→default use hq
        lam_rho     = training_kwargs.get('denoise_lambda', 0.1),
        inner_steps = training_kwargs.get('denoise_steps', 1),
        step_size   = training_kwargs.get('denoise_step', 1e-2),
    )
    denoiser = SimpleParticleDenoiser(**denoise_cfg) if use_denoise else None

    non_parametric_params = []
    if hasattr(model, "embedding"):
        non_parametric_params += list(model.embedding.parameters())
    if hasattr(model, "encoder"):
        non_parametric_params += list(model.encoder.parameters())
    if hasattr(model, "decoder"):
        non_parametric_params += list(model.decoder.parameters())

    if optimizer_name == 'sgd':
        optim = torch.optim.SGD([
                {'params': model.parametric_part.parameters(), 'lr': lr_parametric},
                {'params': non_parametric_params, 'lr': lr_nonparametric}
            ], weight_decay=training_kwargs.get('weight_decay', 0.0))
    elif optimizer_name == 'adam':
        optim = torch.optim.Adam([
                {'params': model.parametric_part.parameters(), 'lr': lr_parametric},
                {'params': non_parametric_params, 'lr': lr_nonparametric}
            ], weight_decay=training_kwargs.get('weight_decay', 0.0))
    else:
        raise NotImplementedError("Only Adam and SGD supported at the moment")
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, epochs)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optim, factor=0.3, patience=3)

    train_loss_history = []
    val_loss_history = []
    r2_history = []

    for i in tqdm(range(epochs)):
    # for i in range(epochs):
        model.train()

        # Set current τ₃ at the beginning of each epoch
        tau3_now = tau3_target * min(1.0, (i + 1) / max(1, tau3_warmup_epochs))
        for step, data in enumerate(dl_train):
            x_obs, x_int, t_int = data[0], data[1], data[2]
            x_obs, x_int, t_int = x_obs.to(device), x_int.to(device), t_int.to(device)

            # Training: contrastive branch (one forward pass + denoising)

            if contrastive:
                logits_obs, emb_obs = model(x_obs, t_int, True)
                logits_int, emb_int = model(x_int, t_int, True)

                method_specific_loss = kappa * torch.sum(torch.mean(emb_obs, dim=0) ** 2)

                # Gating: when tau3_now is too small, completely skip denoiser (no loss, no inner loop)
                denoise_loss = torch.tensor(0.0, device=device)
                if use_denoise and (tau3_now > tau3_gate_eps):
                    terms = [denoiser.denoise_loss(label=0, u_batch=emb_obs)]
                    labels_int = torch.argmax(t_int, dim=1).long() if t_int.ndim > 1 else t_int.view(-1).long()
                    for j in torch.unique(labels_int):
                        idx = (labels_int == j)
                        if idx.any():
                            terms.append(denoiser.denoise_loss(label=int(j.item())+1, u_batch=emb_int[idx]))
                    denoise_loss = sum(terms) / len(terms)

            else:
                x_int_hat, mean_int, logvar_int, logits_int = model(x_int, t_int, True)
                x_obs_hat, mean_obs, logvar_obs, logits_obs = model(x_obs, t_int, True)

                rec_loss = mse(x_obs, x_obs_hat) / x_obs.size(0)
                # Learn only to reconstruct observational distribution
                kl_divergence = - 0.5 * torch.mean(1 + logvar_obs - mean_obs.pow(2) - logvar_obs.exp())
                if not model.match_observation_dist_only:
                    rec_loss += mse(x_int, x_int_hat) / x_int.size(0)
                    kl_divergence += - 0.5 * torch.mean(1 + logvar_int - mean_int.pow(2) - logvar_int.exp())
                method_specific_loss = rec_loss + kl_divergence

            classifier_loss = ce(logits_obs, torch.zeros(x_obs.size(0), dtype=torch.long, device=device)) + \
                              ce(logits_int, torch.ones(x_int.size(0), dtype=torch.long, device=device))
            accuracy = (torch.sum(torch.argmax(logits_obs, dim=1) == 0) +
                        torch.sum(torch.argmax(logits_int, dim=1) == 1)) / (2 * x_int.size(0))
            reg_loss = eta * torch.sum(
                torch.abs(model.parametric_part.A)) + mu * get_NOTEARS_loss(
                model.parametric_part.A)

            loss = method_specific_loss + classifier_loss + reg_loss \
            + (tau3_now * denoise_loss if (contrastive and use_denoise) else 0.0)

            log_dict = {
                'loss': loss.item(),
                'method_specific_loss': method_specific_loss.item(),
                'classifier_loss': classifier_loss.item(),
                'reg_loss': reg_loss.item()
            }
            if use_denoise:
                log_dict['denoise_loss'] = denoise_loss.item()
            wandb.log(log_dict)

            optim.zero_grad()
            loss.backward()
            optim.step()
            loss_tracker.add_loss(
                {'method_loss': method_specific_loss.item(), 'CE-loss': classifier_loss.item(),
                 'A-reg loss': reg_loss.item(), 'accuracy': accuracy.item()}, x_obs.size(0))
        if verbose:
            print("Finished epoch {}, printing test and validation loss".format(i + 1))
            loss_tracker.print_mean_loss()
        if getattr(model, 'vanilla', False):
            train_loss_history.append(loss_tracker.get_mean_loss()['method_loss'])
        else:
            train_loss_history.append(loss_tracker.get_mean_loss()['CE-loss'])
        loss_tracker.reset()

        # Validation block: eval + no_grad, consistent coefficients
        should_validate = (i % val_every == 0) or (i == epochs - 1)
        if should_validate:
            model.eval()
            with torch.no_grad():
                for step, data in enumerate(dl_val):
                    x_obs, x_int, t_int = (d.to(device) for d in data[:3])
                    if contrastive:
                        logits_obs, emb_obs = model(x_obs, t_int, True)
                        logits_int, _       = model(x_int, t_int, True)
                        method_specific_loss = kappa * torch.sum(torch.mean(emb_obs, dim=0) ** 2)
                    else:
                        x_int_hat, mean_int, _, logits_int = model(x_int, t_int, False)
                        x_obs_hat, mean_obs, logvar_obs, logits_obs = model(x_obs, t_int, False)
                        # Use chosen reconstruction loss (Huber or MSE) - choose one of the following two lines
                        rec_loss = huber_loss(x_obs, x_obs_hat) / x_obs.size(0)
                        if not model.match_observation_dist_only:
                            rec_loss += huber_loss(x_int, x_int_hat) / x_int.size(0)
                        kl = -0.5 * torch.mean(1 + logvar_obs - mean_obs.pow(2) - logvar_obs.exp())
                        method_specific_loss = rec_loss + kl

                    ce_loss_step = ce(logits_obs, torch.zeros(x_obs.size(0), dtype=torch.long, device=device)) + \
                                ce(logits_int, torch.ones(x_int.size(0), dtype=torch.long, device=device))
                    acc_obs = (torch.argmax(logits_obs, dim=1) == 0).float().mean()
                    acc_int = (torch.argmax(logits_int, dim=1) == 1).float().mean()
                    accuracy = 0.5 * (acc_obs + acc_int)

                    loss_tracker.add_loss({'method_loss': method_specific_loss.item(),
                                        'CE-loss': ce_loss_step.item(),
                                        'accuracy': accuracy.item()},
                                        x_obs.size(0))
            ce_loss = loss_tracker.get_mean_loss()['CE-loss']
            if getattr(model, 'vanilla', False):
                ce_loss = loss_tracker.get_mean_loss()['method_loss']
            if ce_loss < val_loss:
                best_model = copy.deepcopy(model)
                val_loss = ce_loss
            val_loss_history.append(ce_loss)
            scheduler.step(ce_loss)
            loss_tracker.reset()
            model.train()

        if z_gt is not None:
            z_gt_tensor = torch.tensor(z_gt, device=device, dtype=torch.float)
            z_pred = model.get_z(torch.tensor(x_val, device=device, dtype=torch.float))
            r2_history.append(torch.mean(get_R2_values(z_gt_tensor, z_pred)).item())
        else:
            r2_history.append(0)
    return best_model, model, val_loss, [train_loss_history, val_loss_history, r2_history]
