from src.diffusion.sampling.predictors import NonePredictor, EulerMaruyamaPredictor
from src.diffusion.sampling.correctors import NoneCorrector
from tqdm import tqdm
import torch
import numpy as np
from src.utils.diffusion_utils import (
    from_flattened_numpy,
    to_flattened_numpy,
    get_score_fn,
)
from scipy import integrate
from src.constants import EPS_SDE


def get_sampler(config, sde, denoiser):
    """Create a sampling function.
    Args:
      config: A `ml_collections.ConfigDict` object that contains all configuration information.
      sde: A `sde_lib.SDE` object that represents the forward SDE.

      eps: A `float` number. The reverse-time SDE is only integrated to `eps` for numerical stability.
    Returns:
      A function that takes random states and a replicated training state and outputs samples with the
        trailing dimensions matching `shape`.
    """

    p_steps = config.model.p_steps

    sampler_name = config.sampler_name

    # Probability flow ODE sampling with black-box ODE solvers
    if sampler_name.lower() == "ode_sampler":
        sampler = ODE_Sampler(
            sde=sde,
            denoiser=denoiser,
            sample_shape=(config.model.sampler_b_size, config.model.data_dim),
            denoise=False,
            eps=EPS_SDE,
        )

    # Predictor-Corrector sampling. Predictor-only and Corrector-only samplers are special cases.
    elif sampler_name.lower() == "pc_sampler":

        predictor_name = "euler_maruyama"
        corrector_name = None

        sampler = PC_Sampler(
            sde=sde,
            denoiser=denoiser,
            predictor_name=predictor_name,
            corrector_name=corrector_name,
            sample_shape=(config.model.sampler_b_size, config.model.data_dim),
            snr=0,
            p_steps=p_steps,
            c_steps=0,
            large_step=config.model.sampler_large_step,
            small_step=config.model.sampler_small_step,
            probability_flow=False,
            continuous=True,
            denoise=False,
            eps=EPS_SDE,
        )
    else:
        raise ValueError(f"Sampler name {sampler_name} unknown.")

    return sampler


class ODE_Sampler:
    def __init__(
        self,
        denoiser,
        sde,
        sample_shape,
        denoise=False,
        rtol=1e-5,
        atol=1e-5,
        method="RK45",
        eps=1e-3,
        device="cpu",
    ):
        """Probability flow ODE sampler with the black-box ODE solver.
        Args:
          sde: An `sde_lib.SDE` object that represents the forward SDE.
          denoise: If `True`, add one-step denoising to final samples.
          rtol: A `float` number. The relative tolerance level of the ODE solver.
          atol: A `float` number. The absolute tolerance level of the ODE solver.
          method: A `str`. The algorithm used for the black-box ODE solver.
            See the documentation of `scipy.integrate.solve_ivp`.
          eps: A `float` number. The reverse-time SDE/ODE will be integrated to `eps` for numerical stability.

        """
        self.sde = sde
        self.denoiser = denoiser
        self.sample_shape = sample_shape
        self.rtol = rtol
        self.atol = atol
        self.method = method
        self.eps = eps
        self.device = device
        self.denoiser = denoiser

    def run_sampler(self, eta=0.0, z=None):
        """The probability flow ODE sampler with black-box ODE solver.
        Args:
          model: A score model.
          z: If present, generate samples from latent code `z`.
        Returns:
          samples, number of function evaluations.
        """

        def drift_term(x, t):
            """Get the drift term of the reverse-time ODE."""
            # score_fn = get_score_fn(sde, model, conditional=False, train=False, continuous=True)
            # rsde = sde.reverse(score_fn, probability_flow=True)
            return self.sde.get_reverse_sde_coefficients(x, t, probability_flow=True)[0]

        def ode_func(t, x):
            x = (
                from_flattened_numpy(x, self.sample_shape)
                .to(self.device)
                .type(torch.float32)
            )
            vec_t = torch.ones(self.sample_shape[0], device=self.device) * t
            drift = drift_term(x, vec_t)
            return to_flattened_numpy(drift)

        with torch.no_grad():
            # Initial sample
            if z is None:
                # If not represent, sample the latent code from the prior distibution of the SDE.
                x = self.sde.prior_sampling(self.sample_shape).to(self.device)
            else:
                x = z

            # Get score function and initialize the reverse-time ODE.
            score_fn = get_score_fn(self.denoiser, if_training=False)
            self.sde.init_score_fn(score_fn)

            # Black-box ODE solver for the probability flow ODE
            # We shoud make sure we can use custom time steps in this scipy solver
            # Answer: it is possible, just pass t_eval (must be within t_span)
            # t_eval: array_like or None, optional
            # for us, t_span is (sde.T, eps)

            solution = integrate.solve_ivp(
                ode_func,
                (self.sde.T, self.eps),
                to_flattened_numpy(x),
                rtol=self.rtol,
                atol=self.atol,
                method=self.method,
            )

            # nfev (int): Number of evaluations of the right-hand side.
            nfe = solution.nfev
            x = torch.tensor(
                solution.y[:, -1], device=self.device, dtype=torch.float32
            ).reshape(self.sample_shape)

            # Denoising is equivalent to running one predictor step without adding noise
            # if self.denoise:
            #   x = self.denoise_update_fn(model, x)

            return x, nfe

    def run_sampler_forward(self, x, p_steps=1000):
        x_evo = [x]
        for t in torch.linspace(0, self.sde.T, p_steps + 1, device=self.device):
            drift, diffusion = self.sde.get_sde_coefficients(x, t)
            x = (
                x
                + drift * (1 / p_steps)
                + diffusion * torch.randn_like(x) / np.sqrt(p_steps)
            )
            x_evo.append(x.detach().clone())

        return torch.stack(x_evo)


class PC_Sampler:
    def __init__(
        self,
        denoiser,
        sde,
        predictor_name,
        corrector_name,
        sample_shape,
        snr,
        p_steps,
        large_step=2,
        small_step=0,
        c_steps=1,
        probability_flow=False,
        continuous=False,
        denoise=True,
        eps=1e-3,
        device="cpu",
    ):
        """Predictor-Corrector (PC) sampler.
        Args:
          smodel: core predictor model from which samples we want to get
          sde: An `sde_lib.SDE` object representing the forward SDE.

          predictor: A subclass of `sampling.Predictor` representing the predictor algorithm.
          corrector: A subclass of `sampling.Corrector` representing the corrector algorithm.
          inverse_scaler: The inverse data normalizer. -> not used anymore
          snr: A `float` number. The signal-to-noise ratio for configuring correctors.
          n_steps: An integer. The number of corrector steps per predictor update.
          probability_flow: If `True`, solve the reverse-time probability flow ODE when running the predictor.
          continuous: `True` indicates that the score model was continuously trained.
          denoise: If `True`, add one-step denoising to the final samples.
          eps: A `float` number. The reverse-time SDE and ODE are integrated to `epsilon` to avoid numerical issues.

        """
        self.large_step = large_step
        self.small_step = small_step
        self.sde = sde
        self.sample_shape = sample_shape
        self.snr = snr
        self.p_steps = p_steps
        self.c_steps = c_steps
        self.probability_flow = probability_flow
        self.continuous = continuous
        self.denoise = denoise
        self.eps = eps
        self.device = device
        self.denoiser = denoiser

        # predictor = get_predictor(predictor_name)
        # corrector = get_corrector(corrector_name)

        score_fn = get_score_fn(denoiser, if_training=False)

        # Create predictor & corrector
        if predictor_name is None:
            # Corrector-only sampler
            self.predictor = NonePredictor(sde, None, probability_flow)
        else:
            self.predictor = EulerMaruyamaPredictor(sde, probability_flow)

        assert corrector_name is None, "Corrector not supported yet"

        self.corrector = NoneCorrector(sde, score_fn, snr, c_steps)

    def run_sampler(self, eta, show_score_size=False, show_evolution=False):
        """The PC sampler function.
        Args:
          eta (float): The hack - noise scale.
          show_evolution (bool): If `True`, return the evolution of the samples.
        """
        if show_evolution:
            evolution = []
        if show_score_size:
            s_sizes = []

        with torch.no_grad():
            # Initial sample
            x = (
                self.sde.prior_sampling(self.sample_shape)
                .to(self.device)
                .type(torch.float32)
            )
            if show_evolution:
                evolution.append(x.cpu())

            timesteps = torch.linspace(self.sde.T - 0.000001, self.eps, self.p_steps+1, device=self.device)

            # dts = torch.linspace(
            #     self.large_step, self.small_step, self.p_steps + 1, device=self.device
            # )  # relative sizes of timesteps
            # #dts = (dts * (self.sde.T - self.sde.eps)) / torch.sum(dts)  # normalize
            # #timesteps = torch.cumsum(dts)[::-1]
            # dts = (dts * (self.sde.T - self.eps)) / torch.sum(dts)
            # timesteps = torch.cumsum(dts, dim=0)[:-1]

            
            # Get score function and initialize the reverse-time SDE.
            score_fn = get_score_fn(self.denoiser, if_training=False)

            # self.sde.init_reverse_sde(score_fn, probability_flow=False)
            self.sde.init_score_fn(score_fn)
            self.predictor.set_discretisation(discretisation=timesteps)

            for i in tqdm(range(self.p_steps)):
                t = timesteps[i]
                vec_t = torch.ones(self.sample_shape[0], device=self.device) * t
                x, x_mean = self.corrector.update_fn(
                    x, vec_t
                )  # should not be in the different order, first predictor then corrector?
                x, x_mean = self.predictor.update_fn(x, vec_t, eta)

                if show_evolution:
                    evolution.append(x.cpu())
                if show_score_size:
                    print(s_sizes.append(torch.norm(score_fn(x, vec_t))))

            samples = x_mean if self.denoise else x

            if show_evolution:
                samples = torch.stack(evolution)
            if show_score_size:
                s_sizes = torch.stack(s_sizes)
            else:
                s_sizes = None

            return samples, s_sizes

    def run_sampler_forward(self, x):
        x_evo = [x]
        for t in torch.linspace(0, self.sde.T, self.p_steps + 1, device=self.device):
            drift, diffusion = self.sde.get_sde_coefficients(x, t)
            x = (
                x
                + drift * (1 / self.p_steps)
                + diffusion * torch.randn_like(x) / np.sqrt(self.p_steps)
            )
            x_evo.append(x.detach().clone())

        return torch.stack(x_evo)

    def ode_loglik_computation(self, x_evo, timesteps=None):
        """Compute prob flow ODE loglik.
        Args:
          x_evo (torch.Tensor): The evolution of the samples with size [num_times, num_samples, dim]. First dimensions is time in direction T_max -> T_min.
           timesteps (torch.Tensor): The timesteps used for integration."""

        x_T = x_evo[0]
        p_T = self.sde.prior_logp(x_T)
        p_sum = p_T

        timesteps = torch.linspace(
            1e-5, self.eps, self.p_steps + 1, device="cpu"
        )  # REMEMBER TO SAVE THIS AT THE TIME OF SAMPLING!!!!!!!!!!
        self.predictor.set_discretisation(discretisation=timesteps)

        div_t = []
        # compute the grads
        with torch.enable_grad():
            for j in range(1, len(x_evo)):
                x_t = torch.nn.parameter.Parameter(
                    x_evo[j].to(self.device), requires_grad=True
                )
                t = timesteps[
                    -j - 1
                ]  # time is reversed. Also subtract 1 because we don't want to compute the gradient at T_max

                vec_t = torch.ones(self.sample_shape[0], device=self.device) * t
                dt = torch.tensor(self.predictor.find_dt(vec_t[0])).type_as(vec_t)
                ode_drift = self.sde.get_reverse_sde_coefficients(
                    x_t, vec_t, probability_flow=True
                )[0]
                ode_drift_sum = torch.sum(ode_drift * dt)
                ode_drift_sum.backward()
                p_t = x_t.grad
                div_t.append(torch.sum(p_t.detach(), dim=-1))

        div_t = torch.stack(div_t)

        for sample_num in range(len(p_T)):
            ans = integrate.trapezoid(
                div_t[:, sample_num].cpu(), x=torch.flip(timesteps[1:], dims=(0,))
            )
            p_sum[sample_num] -= ans

        return p_sum
