import numpy as np
import torch as th
import torch.nn as nn
from torchdiffeq import odeint
from functools import partial
from tqdm import tqdm

class sde:
    """SDE solver class"""
    def __init__(
        self,
        drift,
        diffusion,
        *,
        t0,
        t1,
        num_steps,
        sampler_type,
    ):
        assert t0 < t1, "SDE sampler has to be in forward time"

        self.num_timesteps = num_steps
        self.t = th.linspace(t0, t1, num_steps)
        self.dt = self.t[1] - self.t[0]
        self.drift = drift
        self.diffusion = diffusion
        self.sampler_type = sampler_type

    def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
        w_cur = th.randn(x.size()).to(x)
        t = th.ones(x.size(0)).to(x) * t
        dw = w_cur * th.sqrt(self.dt)
        drift = self.drift(x, t, model, **model_kwargs)
        diffusion = self.diffusion(x, t)
        mean_x = x + drift * self.dt
        x = mean_x + th.sqrt(2 * diffusion) * dw
        return x, mean_x

    def __X_ODE_step(self, x, mean_x, t, model, **model_kwargs):
        t = th.ones(x.size(0)).to(x) * t
        drift = self.drift(x, t, model, **model_kwargs)
        # Deterministic update without random noise
        mean_x = x + drift * self.dt
        x = mean_x  # DEBUG: a = {'x': x.detach(), 'x_next': mean_x.detach(), 'u': drift.detach(), 't': t.detach(), 'y': model_kwargs['y'].detach()}
        return x, mean_x

    def __Heun_step(self, x, _, t, model, **model_kwargs):
        w_cur = th.randn(x.size()).to(x)
        dw = w_cur * th.sqrt(self.dt)
        t_cur = th.ones(x.size(0)).to(x) * t
        diffusion = self.diffusion(x, t_cur)
        xhat = x + th.sqrt(2 * diffusion) * dw
        K1 = self.drift(xhat, t_cur, model, **model_kwargs)
        xp = xhat + self.dt * K1
        K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
        return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step

    def __forward_fn(self):
        """TODO: generalize here by adding all private functions ending with steps to it"""
        sampler_dict = {
            "Euler": self.__Euler_Maruyama_step,
            "Heun": self.__Heun_step,
            "ODE": self.__X_ODE_step,
        }

        try:
            sampler = sampler_dict[self.sampler_type]
        except:
            raise NotImplementedError("Smapler type not implemented.")

        return sampler

    def sample(self, init, model, **model_kwargs):
        """forward loop of sde"""
        x = init
        mean_x = init
        samples = []
        sampler = self.__forward_fn()
        for ti in self.t[:-1]:
            with th.no_grad():
                x, mean_x = sampler(x, mean_x, ti.to(x.dtype), model, **model_kwargs)
                samples.append(x)

        return samples

class ode:
    """ODE solver class"""
    def __init__(
        self,
        drift,
        *,
        t0,
        t1,
        sampler_type,
        num_steps,
        atol,
        rtol,
    ):
        assert t0 < t1, "ODE sampler has to be in forward time"

        self.drift = drift
        self.t = th.linspace(t0, t1, num_steps)
        self.atol = atol
        self.rtol = rtol
        self.sampler_type = sampler_type

    def sample(self, x, model, **model_kwargs):

        device = x[0].device if isinstance(x, tuple) else x.device
        def _fn(t, x):
            t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
            model_output = self.drift(x, t, model, **model_kwargs)
            return model_output

        t = self.t.to(device)
        atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
        rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
        samples = odeint(
            _fn,
            x,
            t,
            method=self.sampler_type,
            atol=atol,
            rtol=rtol
        )
        return samples

import os
DISTRL_RL_FP32_STD = os.environ.get("DISTRL_RL_FP32_STD", None)
DISTRL_DEBUG_ALIGN_SFT_RL = os.environ.get('DISTRL_DEBUG_ALIGN_SFT_RL', None)

from .log_prob import GaussianDistribution
class rl_sde(sde):
    def get_std_t(self, t, dt=None):
        """Compute the standard deviation at time t.

        Args:
            t (torch.Tensor): Time point

        Returns:
            torch.Tensor: Standard deviation at time t
        """
        dt = self.dt[0] if dt is None else dt
        return th.sqrt(2 * self.diffusion(th.ones(1, 1).to(t), t)) * th.sqrt(dt)

    def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
        assert os.environ.get('DISTRL_RL', None) is not None, "RL is not enabled"
        t = th.ones(x.size(0)).to(x) * t
        # if 0 in t:
        #     print(f"t: {t}")
        drift = self.drift(x, t, model, **model_kwargs)
        mean_x = x + drift * self.dt

        DISTRL_STATUS_FILL_POOL = os.environ.get('DISTRL_STATUS_FILL_POOL', None) == "1"
        if DISTRL_RL_FP32_STD and not DISTRL_STATUS_FILL_POOL:
            with th.autocast(device_type='cuda', dtype=th.bfloat16):
                std_x = self.get_std_t(t.float(), self.dt)
                log_std_x = th.log(std_x)
                dist = GaussianDistribution(mean_x, std_x, log_std_x)
                x = dist.sample(x.size())
                log_prob = dist.log_prob(x)
        else:
            std_x = self.get_std_t(t, self.dt)
            log_std_x = th.log(std_x)
            dist = GaussianDistribution(mean_x, std_x, log_std_x)
            if DISTRL_DEBUG_ALIGN_SFT_RL:
                dist.dt = self.dt
                dist.drift = drift
                dist.x_cur = x
            x = dist.sample(x.size())
            log_prob = dist.log_prob(x)

        return x, mean_x, log_prob, dist

    def Euler_Maruyama_step(self, x, t, model, **model_kwargs):
        x, mean_x, log_prob, dist = self.__Euler_Maruyama_step(x, None, t, model, **model_kwargs)
        return x, mean_x, log_prob, dist

    def Euler_Maruyama_step_for_sft_with_sde_drift(self, x, t, model, **model_kwargs):
        t = th.ones(x.size(0)).to(x) * t
        drift = self.drift(x, t, model, **model_kwargs)
        return x + drift * self.dt

    def __forward_fn(self):
        """TODO: generalize here by adding all private functions ending with steps to it"""
        sampler_dict = {
            "Euler": self.__Euler_Maruyama_step,
        }

        try:
            sampler = sampler_dict[self.sampler_type]
        except:
            raise NotImplementedError("Smapler type not implemented.")

        return sampler

    def sample(self, init, model, **model_kwargs):
        """forward loop of sde"""
        x = init
        mean_x = init
        samples = []
        log_prob_list = []
        sampler = self.__forward_fn()
        for ti in self.t[:-1]:
            with th.no_grad():
                x, mean_x, log_prob, dist = sampler(x, mean_x, ti.to(x.dtype), model, **model_kwargs)
                samples.append(x)
                log_prob_list.append(log_prob)

        return samples, log_prob_list
