import torch
import torch.nn as nn
import numpy as np
from diffusers import DDPMScheduler
from tqdm import tqdm
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
import sample_sensitivity_mog

import sys
# APPEND PATH TO PROJECT CODE TO ENABLE IMPORTS
import utils.cfdm_ddpm_conversion as convert

def rowwise_corr_all(neural: torch.Tensor, exact: torch.Tensor) -> torch.Tensor:
    """
    Compute row-wise correlations between two (B, D) tensors.
    Returns a tensor of shape (B,).
    """
    # Center each row
    neural_centered = neural - neural.mean(dim=1, keepdim=True)
    exact_centered  = exact  - exact.mean(dim=1, keepdim=True)

    # Row-wise numerator = covariance
    numerator = (neural_centered * exact_centered).sum(dim=1)

    # Row-wise denominators = std deviations
    denom = torch.sqrt(
        (neural_centered**2).sum(dim=1) * (exact_centered**2).sum(dim=1)
    )

    # Row-wise correlation
    corr = numerator / (denom + 1e-9)

    # Handle potential NaNs if a row has zero variance
    return torch.nan_to_num(corr, nan=0.0)

# Sample from mixture of Gaussians
def sample_mog(mog, n_samples):
    n_components = mog.means.shape[0]
    component_ids = torch.randint(0, n_components, (n_samples,), device=mog.means.device)
    samples = mog.means[component_ids] + mog.sigma_0 * torch.randn(n_samples, mog.means.shape[1], device=mog.means.device)
    return samples

# Training step for score model
def train_step(x, t, model, optimizer, scheduler):
    optimizer.zero_grad()
    # t can be a batch of timesteps (shape: [batch_size])
    if not isinstance(t, torch.Tensor):
        t = torch.tensor([t] * x.shape[0]).cpu()
    elif t.dim() == 0:
        t = t.expand(x.shape[0]).cpu()
    else:
        t = t.cpu()
    alpha_bar = scheduler.alphas_cumprod[t].to(x.device)
    sigma = torch.sqrt(1 - alpha_bar).to(x.device)
    noise = torch.randn_like(x)
    noisy_x = torch.sqrt(alpha_bar).unsqueeze(1) * x + sigma.unsqueeze(1) * noise
    model_preds = model(noisy_x, t.to(x.device)).sample
    targets = noise
    loss = ((model_preds - targets) ** 2).mean()
    loss.backward()
    optimizer.step()
    return loss.item()

# Time-dependent score model using Fourier features
class InputMapping(nn.Module):
    """Fourier features mapping."""
    def __init__(self, d_in, n_freq, sigma=1, tdiv=2, incrementalMask=True, Tperiod=None, kill=False):
        super().__init__()
        Bmat = torch.randn(n_freq, d_in) * np.pi * sigma / np.sqrt(d_in)
        Bmat[:, 0] /= tdiv
        self.Tperiod = Tperiod
        if Tperiod is not None:
            Tcycles = (Bmat[:, 0] * Tperiod / (2 * np.pi)).round()
            K = Tcycles * (2 * np.pi) / Tperiod
            Bmat[:, 0] = K
        Bnorms = torch.norm(Bmat, p=2, dim=1)
        sortedBnorms, sortIndices = torch.sort(Bnorms)
        Bmat = Bmat[sortIndices, :]
        self.d_in = d_in
        self.n_freq = n_freq
        self.d_out = n_freq * 2 + d_in if Tperiod is None else n_freq * 2 + d_in - 1
        self.B = nn.Linear(d_in, self.d_out, bias=False)
        with torch.no_grad():
            self.B.weight = nn.Parameter(Bmat.to('cpu'), requires_grad=False)
            self.mask = nn.Parameter(torch.zeros(1, n_freq), requires_grad=False)
        self.incrementalMask = incrementalMask
        if not incrementalMask:
            self.mask = nn.Parameter(torch.ones(1, n_freq), requires_grad=False)
        if kill:
            self.mask = nn.Parameter(torch.zeros(1, n_freq), requires_grad=False)
    def step(self, progressPercent):
        if self.incrementalMask:
            float_filled = (progressPercent * self.n_freq) / 0.7
            int_filled = int(float_filled // 1)
            if int_filled >= self.n_freq:
                self.mask[0, :] = 1
            else:
                self.mask[0, 0:int_filled] = 1
    def forward(self, xi):
        dim = xi.shape[1] - 1
        y = self.B(xi)
        if self.Tperiod is None:
            return torch.cat([torch.sin(y) * self.mask, torch.cos(y) * self.mask, xi], dim=-1)
        else:
            return torch.cat([
                torch.sin(y) * self.mask, torch.cos(y) * self.mask, xi[:, 1 : dim + 1]
            ], dim=-1)

class TimeDependentScoreModel(nn.Module):
    def __init__(self, d_in, d_hid, d_out, n_freq=100, sigma=1, tdiv=2, Tperiod=None, incrementalMask=False, num_timesteps=1000):
        super().__init__()
        self.input_mapping = InputMapping(
            d_in=d_in + 1,
            n_freq=n_freq,
            sigma=sigma,
            tdiv=tdiv,
            incrementalMask=incrementalMask,
            Tperiod=Tperiod,
        )
        self.net = nn.Sequential(
            nn.Linear(self.input_mapping.d_out, d_hid),
            nn.SiLU(),
            nn.Linear(d_hid, d_hid),
            nn.SiLU(),
            nn.Linear(d_hid, d_out),
        )
        self.num_timesteps = num_timesteps
    def step(self, progressPercent):
        self.input_mapping.step(progressPercent)
    def forward(self, x, t):
        # t: shape [batch_size] or [batch_size, 1]
        if t.dim() == 1:
            t = t.unsqueeze(1)
            t = t.expand(x.shape[0], -1)
        xt = torch.cat([x, t], dim=1)
        h = self.input_mapping(xt)
        y = self.net(h)
        from types import SimpleNamespace
        return SimpleNamespace(sample=y)

def run_sample_sensitivity_analysis_correlations_neural(
    means,
    sigma_0,
    n_model_samples,
    ode_step_size,
    min_clamp,
    max_clamp,
    num_hutchinson_samples=100,
    n_train_steps=50000,
    batch_size=100000,
    d_hid=512,
    n_freq=128,
    device='cuda',
    correlation_eval_interval=1000
):
    """
    Trains a neural score model for a mixture of Gaussians and tracks the median correlation between neural and exact sample sensitivities over training.
    Returns:
        correlations: list of median correlations at each eval interval
        steps: list of training steps at which correlation was evaluated
        final_sens_neural: neural sample sensitivities at last eval
        final_sens_exact: exact sample sensitivities at last eval
        logpt, zt, z1: as before
    """
    D = means.shape[1]
    scheduler = DDPMScheduler(
        num_train_timesteps=int(1/ode_step_size)+1,
        beta_start=1e-4,
        beta_end=0.02,
        beta_schedule="linear"
    )
    z1 = torch.randn(n_model_samples, D, device=device)
    weights_0 = torch.tensor([0.5, 0.5], device=device)
    weights_1 = torch.tensor([0.0, 1.0], device=device)
    mog = sample_sensitivity_mog.MixtureOfGaussians(means=means, sigma_0=sigma_0, scheduler=scheduler, weights=weights_0)
    model = TimeDependentScoreModel(
        d_in=D,
        d_hid=d_hid,
        d_out=D,
        n_freq=n_freq,
        sigma=2.0,
        tdiv=2,
        incrementalMask=False,
        num_timesteps=scheduler.config.num_train_timesteps
    ).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    correlations = []
    l2_distances = []
    steps = []
    losses = []
    final_sens_neural = None
    final_sens_exact = None
    # Precompute exact sample sensitivities
    eps_eta0 = convert.EpsFromScore(mog)
    ddpm_ss_exact = sample_sensitivity_mog.DDPMSampleSensitivityMoG(
        model=eps_eta0,  # Use analytic score
        scheduler=scheduler,
        num_hutchinson_samples=num_hutchinson_samples,
        ode_step_size=ode_step_size,
        min_clamp=min_clamp,
        max_clamp=max_clamp
    )
    zt_gt, logpt_gt = ddpm_ss_exact.precompute_sample_path(z1)
    sample_sensitivity_exact = ddpm_ss_exact.sensitivity_given_sample_path(
        zt_gt, logpt_gt, means, sigma_0, weights_1
    )
    for step in tqdm(range(n_train_steps), desc="Training neural score model"):
        x = sample_mog(mog, batch_size).to(device)
        t = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size,)).cpu()
        loss = train_step(x, t, model, optimizer, scheduler)
        if (step + 1) % correlation_eval_interval == 0 or (step + 1) == n_train_steps:
            ddpm_ss_neural = sample_sensitivity_mog.DDPMSampleSensitivityMoG(
                model=model,
                scheduler=scheduler,
                num_hutchinson_samples=num_hutchinson_samples,
                ode_step_size=ode_step_size,
                min_clamp=min_clamp,
                max_clamp=max_clamp
            )
            zt, logpt = ddpm_ss_neural.precompute_sample_path(z1)
            sample_sensitivity_neural = ddpm_ss_neural.sensitivity_given_sample_path(
                zt, logpt, means, sigma_0, weights_1
            )
            # Compute correlation between neural and exact sensitivities
            corr_all = rowwise_corr_all(sample_sensitivity_neural.detach(), sample_sensitivity_exact.detach())
            median_corr = torch.median(corr_all).item()
            print(f"Step {step+1}, Loss: {loss:.4f}, Median correlation: {median_corr:.4f}")
            correlations.append(median_corr)
            # Also compute row-wise L2 distance between sensitivities
            l2_distance_all = torch.norm(sample_sensitivity_neural - sample_sensitivity_exact, dim=1).detach()
            rel_l2_error = l2_distance_all / (torch.norm(sample_sensitivity_exact, dim=1).detach() + 1e-9)
            median_rel_l2_distance = torch.median(rel_l2_error).item()
            print(f"Step {step+1}, Median relative L2 distance: {median_rel_l2_distance:.4f}")
            l2_distances.append(median_rel_l2_distance)
            losses.append(loss)
            steps.append(step + 1)
            final_sens_neural = sample_sensitivity_neural
            final_sens_exact = sample_sensitivity_exact
    return correlations, l2_distances, losses, steps, final_sens_neural, final_sens_exact, zt, logpt, zt_gt, logpt_gt, z1