import torch
import math
from tqdm import tqdm
from utils import marginal_prob_std, diff_coeff


class FKSampler:
    """
    FK-steered reverse SDE sampler for a score-based model.
    """

    def __init__(
        self,
        model,
        reward_fn,
        sigma,
        n_particles=512,
        n_time_steps=400,
        lmbda=5.0,
        resample_every=10,
        device="cpu",
        data_dim=2,
    ):
        self.model = model.to(device)
        self.model.eval()
        self.reward_fn = reward_fn
        self.n_particles = n_particles
        self.n_time_steps = n_time_steps
        self.lmbda = lmbda
        self.resample_every = resample_every
        self.device = device
        self.sigma = sigma
        self.data_dim = data_dim

    @torch.no_grad()
    def sample(self, seed=None):
        if seed is not None:
            torch.manual_seed(seed)

        # Initialize particles from marginal at t=T
        t = torch.ones(self.n_particles, device=self.device)
        x = (
            torch.randn(self.n_particles, self.data_dim, device=self.device)
            * marginal_prob_std(t, self.sigma)[:, None]
        )

        t_range = torch.linspace(1, 1e-5, self.n_time_steps, device=self.device)
        dt = t_range[0] - t_range[1]

        # FK bookkeeping
        population_reward = torch.zeros(self.n_particles, device=self.device)
        product_of_weights = torch.ones(self.n_particles, device=self.device)

        for i, t_step in tqdm(enumerate(t_range)):
            t_feed = torch.full((x.size(0),), t_step, device=self.device)

            # ---- Reverse SDE step ----
            g = diff_coeff(t_feed, self.sigma)
            drift = (g**2)[:, None] * self.model(x, t_feed)
            mean = x + drift * dt
            x = mean + torch.sqrt(dt) * g[:, None] * torch.randn_like(x)

            # ---- FK resampling ----
            if (i % self.resample_every == 0) or (i == len(t_range) - 1):
                r = self.reward_fn(x)  # shape [n_particles]
                w = torch.exp(self.lmbda * (r - population_reward))
                w = torch.clamp(w, min=0.0, max=1e30)
                if w.sum() == 0:
                    w = torch.ones_like(w)
                probs = w / w.sum()
                idx = torch.multinomial(
                    probs, num_samples=self.n_particles, replacement=True
                )

                # Resample
                x = x[idx]
                population_reward = r[idx]
                product_of_weights = product_of_weights[idx] * w[idx]

        return None, x


class SourceTemperingSampler:
    """
    Parallel Tempering MCMC for a pretrained MeanFlow or Diffusion model.
    Non-adaptive pCN + Parallel Tempering.
    Supports:
        - 'meanflow': deterministic transport via forward Euler
        - 'diffusion': stochastic reverse SDE transport
    """

    def __init__(
        self,
        model,
        reward_fn,
        sigma,
        beta=5.0,
        n_chains=10,
        n_time_steps=20,
        device="cpu",
        data_dim=2,
    ):
        self.sigma = sigma  # The std of our initial distribution.
        self.model = model.to(device)
        self.model.eval()

        self.reward_fn = reward_fn
        self.device = device
        self.n_chains = n_chains

        self.betas = torch.linspace(0.0, beta, n_chains, device=device)
        self.thetas = torch.linspace(math.pi / 2, 0.05, n_chains, device=device)
        self.grid = torch.linspace(0.0, 1.0, n_time_steps + 1, device=device)
        self.data_dim = data_dim

    @torch.no_grad()
    def transport(self, z):
        """
        Transport latent z to data space x.
        - meanflow: deterministic transport
        - diffusion: reverse SDE using Euler-Maruyama
        z: [n_chains, batch, dim]
        """
        orig_shape = z.shape
        z_flat = z.view(-1, self.data_dim)

        n_steps = len(self.grid) - 1
        t_range = torch.linspace(1.0, 1e-5, n_steps, device=self.device)
        dt = t_range[0] - t_range[1]

        x = z_flat
        for t_step in t_range:
            t_feed = torch.full((x.size(0),), t_step, device=self.device)
            g = diff_coeff(t_feed, 4.0)  # diffusion coefficient,
            drift = (g**2)[:, None] * self.model(x, t_feed)
            mean = x + drift * dt / 2
            x = mean  # + torch.sqrt(dt) * g[:, None] * torch.randn_like(x)
        z_flat = x

        return z_flat.view(orig_shape)

    @torch.no_grad()
    def get_energy(self, z):
        """
        Pullback energy: E(z) = R(T(z))
        z: [n_chains, batch_size, dim]
        Returns: [n_chains, batch_size]
        """
        x = self.transport(z)
        x_flat = x.reshape(-1, x.shape[-1])
        rewards = self.reward_fn(x_flat)
        return rewards.reshape(z.shape[0], z.shape[1])

    @torch.no_grad()
    def propose_updates(self, z, energies):
        """
        Propose PCN updates for each chain and accept/reject.
        """
        theta = self.thetas.view(-1, 1, 1)
        beta_k = self.betas.view(-1, 1)
        xi = torch.randn_like(z)
        z_prop = torch.cos(theta) * z + torch.sin(theta) * xi * self.sigma
        E_prop = self.get_energy(z_prop)
        log_alpha = beta_k * (E_prop - energies)
        accept = torch.rand_like(log_alpha).log() < log_alpha
        z_new = torch.where(accept.unsqueeze(-1), z_prop, z)
        energies_new = torch.where(accept, E_prop, energies)
        return z_new, energies_new

    @torch.no_grad()
    def swap_between_chains(self, z, energies):
        for k in range(self.n_chains - 1):
            delta_beta = self.betas[k] - self.betas[k + 1]
            delta_E = energies[k + 1] - energies[k]
            log_alpha = delta_beta * delta_E
            accept = torch.rand(z.shape[1], device=self.device).log() < log_alpha
            if accept.any():
                z_k, E_k = z[k].clone(), energies[k].clone()
                z[k, accept] = z[k + 1, accept]
                energies[k, accept] = energies[k + 1, accept]
                z[k + 1, accept] = z_k[accept]
                energies[k + 1, accept] = E_k[accept]
        return z, energies

    @torch.no_grad()
    def sample(self, n_iterations, batch_size):
        """
        Run the sampler and return samples from the coldest chain.
        """
        z = self.sigma * torch.randn(
            self.n_chains, batch_size, self.data_dim, device=self.device
        )
        energies = self.get_energy(z)
        for _ in tqdm(range(n_iterations)):
            z, energies = self.propose_updates(z, energies)
            z, energies = self.swap_between_chains(z, energies)
        last_chain = z[-1]
        transported_samples = self.transport(last_chain)
        return last_chain, transported_samples
