import torch
import torch.nn as nn
from demo.helper import (
    at_least_ndim, to_tensor,
    SUPPORTED_NOISE_SCHEDULES, SUPPORTED_SAMPLING_STEP_SCHEDULE)

SUPPORTED_SOLVERS = [
    "ddpm", "ddim",
    "ode_dpmsolver_1", "ode_dpmsolver++_1", "ode_dpmsolver++_2M",
    "sde_dpmsolver_1", "sde_dpmsolver++_1", "sde_dpmsolver++_2M",]


def epstheta_to_xtheta(x, alpha, sigma, eps_theta):
    """
    x_theta = (x - sigma * eps_theta) / alpha
    """
    return (x - sigma * eps_theta) / alpha


def xtheta_to_epstheta(x, alpha, sigma, x_theta):
    """
    eps_theta = (x - alpha * x_theta) / sigma
    """
    return (x - alpha * x_theta) / sigma


class DiffusionSDE(nn.Module):
    def __init__(
            self, nn_diffusion1, nn_diffusion2, nn_condition, 
            fix_mask=None, loss_weight=None, sample_steps=None, epsilon=1e-3,
            noise_schedule="cosine", x_max=None, x_min=None, predict_noise=True
    ):
        super(DiffusionSDE, self).__init__()
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # model 
        self.Diffusion1 = nn_diffusion1.to(self.device)
        self.Diffusion2 = nn_diffusion2.to(self.device)
        self.Diffusion3 = nn_diffusion2.to(self.device)
        self.Condition = nn_condition.to(self.device)

        # params
        self.sample_steps = sample_steps
        self.fix_mask = to_tensor(fix_mask, self.device)[None, ] if fix_mask is not None else 0.
        self.loss_weight = to_tensor(loss_weight, self.device)[None, ] if loss_weight is not None else 1.

        self.predict_noise = predict_noise
        self.epsilon = epsilon
        self.x_max = x_max.to(self.device) if isinstance(x_max, torch.Tensor) else x_max
        self.x_min = x_min.to(self.device) if isinstance(x_min, torch.Tensor) else x_min


        # ==================== Continuous Time-step Range ====================
        if noise_schedule == "cosine":
            self.t_diffusion = [epsilon, 0.9946]
        else:
            self.t_diffusion = [epsilon, 1.]


        # ===================== Noise Schedule ======================
        if isinstance(noise_schedule, str):
            if noise_schedule in SUPPORTED_NOISE_SCHEDULES.keys():
                self.noise_schedule_funcs = SUPPORTED_NOISE_SCHEDULES[noise_schedule]
            else:
                raise ValueError(f"Noise schedule {noise_schedule} is not supported.")
        elif isinstance(noise_schedule, dict):
            self.noise_schedule_funcs = noise_schedule
        else:
            raise ValueError("noise_schedule must be a callable or a string")


    # ---------------------------------------------------------------------------
    # Load weight
    
    def load_weights(self, path=None, strict=True):
        try:
            checkpoint = torch.load(path, map_location=self.device, weights_only=True)
            self.load_state_dict(checkpoint, strict=strict)
            print(f"Loaded Diffusion model weights from {path}")
        except Exception as e:
            raise ValueError(f"Failed to load diffusion model: {e}")

    # ---------------------------------------------------------------------------
    # Training

    @property
    def clip_pred(self):
        return (self.x_max is not None) or (self.x_min is not None)
    
    # ==================== Training: Score Matching ======================

    def add_noise(self, x0, t=None, eps=None):
        t = (torch.rand((x0.shape[0],), device=self.device) *
             (self.t_diffusion[1] - self.t_diffusion[0]) + self.t_diffusion[0]) if t is None else t

        eps = torch.randn_like(x0) if eps is None else eps

        alpha, sigma = self.noise_schedule_funcs["forward"](t)
        alpha = at_least_ndim(alpha, x0.dim())
        sigma = at_least_ndim(sigma, x0.dim())

        xt = alpha * x0 + sigma * eps
        xt = (1. - self.fix_mask) * xt + self.fix_mask * x0

        return xt, t, eps
    

    def compute_loss(self, x0, condition):
        xt, t, eps = self.add_noise(x0)
        cond_agent, cond_adv = self.Condition(condition)
        pred = self.Diffusion2(xt, t, cond_agent) + 0.5 * (self.Diffusion1(xt, t, condition['state']) - self.Diffusion3(xt, t, cond_adv))
        
        loss = (pred - eps) ** 2
        return (loss * self.loss_weight * (1 - self.fix_mask)).mean()
    

    def forward(self, x0, condition=None, use_mask=False):
        # x0: dict
        if not condition:
            raise ValueError(f"Condition can't be None")
        
        loss = self.compute_loss(x0, condition)
        return loss

    # ==================== Sampling: Solving SDE/ODE ======================
    @torch.no_grad()
    def classifier_free_guidance(self, xt, t, condition=None, w_cfg=None, d_step=None):
        # meta policy
        if w_cfg == 'meta':
            pred = self.Diffusion(xt, t, condition)

        # imitation policy
        elif w_cfg == 'demo':
            eps_agent = self.Diffusion(xt, t, condition['agent_state'])
            eps_adv = self.Diffusion(xt, t, condition['adv_state'])
            eps = {'eps_agent': eps_agent, 'eps_adv': eps_adv}
            
            w = self.Weight(noise=xt, pred_eps=eps, d_step=d_step)
            pred = w * eps_agent + (1 - w) * eps_adv
            
        else:
            raise ValueError(f"Weight type should be 'meta' or 'demo', but got {w_cfg}.")

        return pred


    def clip_prediction(self, pred, xt, alpha, sigma):
        """
        Clip the prediction at each sampling step to stablize the generation.
        (xt - alpha * x_max) / sigma <= eps <= (xt - alpha * x_min) / sigma
                               x_min <= x0  <= x_max
        """
        if self.predict_noise:
            if self.clip_pred:
                upper_bound = (xt - alpha * self.x_min) / sigma if self.x_min is not None else None
                lower_bound = (xt - alpha * self.x_max) / sigma if self.x_max is not None else None
                pred = pred.clip(lower_bound, upper_bound)
        else:
            if self.clip_pred:
                pred = pred.clip(self.x_min, self.x_max)

        return pred
    

    def sample(self, prior, solver="ddpm", n_samples=1, 
               sample_steps=5, sample_step_schedule="uniform_continuous", 
               temperature = 1.0,
               condition=None, w_cfg=None,
               diffusion_x_sampling_steps=0,
               warm_start_reference=None,
               warm_start_forward_level=0.3,
    ):
        assert solver in SUPPORTED_SOLVERS, f"Solver {solver} is not supported."

        if sample_steps != self.sample_steps:
            sample_steps = self.sample_steps

        # ===================== Initialization =====================
        prior = prior.to(self.device)
        if isinstance(warm_start_reference, torch.Tensor) and warm_start_forward_level > 0.:
            warm_start_forward_level = self.epsilon + warm_start_forward_level * (1. - self.epsilon)
            fwd_alpha, fwd_sigma = self.noise_schedule_funcs["forward"](
                torch.ones((1,), device=self.device) * warm_start_forward_level, **(self.noise_schedule_params or {}))
            xt = warm_start_reference * fwd_alpha + fwd_sigma * torch.randn_like(warm_start_reference)
        else:
            xt = torch.randn_like(prior) * temperature
        xt = xt * (1. - self.fix_mask) + prior * self.fix_mask

        # ===================== Sampling Schedule ====================
        if isinstance(warm_start_reference, torch.Tensor) and warm_start_forward_level > 0.:
            t_diffusion = [self.t_diffusion[0], warm_start_forward_level]
        else:
            t_diffusion = self.t_diffusion
        if isinstance(sample_step_schedule, str):
            if sample_step_schedule in SUPPORTED_SAMPLING_STEP_SCHEDULE.keys():
                sample_step_schedule = SUPPORTED_SAMPLING_STEP_SCHEDULE[sample_step_schedule](
                    t_diffusion, sample_steps)
            else:
                raise ValueError(f"Sampling step schedule {sample_step_schedule} is not supported.")
        elif callable(sample_step_schedule):
            sample_step_schedule = sample_step_schedule(t_diffusion, sample_steps)
        else:
            raise ValueError("sample_step_schedule must be a callable or a string")


        alphas, sigmas = self.noise_schedule_funcs["forward"](sample_step_schedule)
        logSNRs = torch.log(alphas / sigmas)
        hs = torch.zeros_like(logSNRs)
        hs[1:] = logSNRs[:-1] - logSNRs[1:]  # hs[0] is not correctly calculated, but it will not be used.
        stds = torch.zeros((sample_steps + 1,), device=self.device)
        stds[1:] = sigmas[:-1] / sigmas[1:] * (1 - (alphas[1:] / alphas[:-1]) ** 2).sqrt()
        buffer = []


        # ===================== Denoising Loop ========================
        loop_steps = [1] * diffusion_x_sampling_steps + list(range(1, sample_steps + 1))
        for i in reversed(loop_steps):
            t = torch.full((n_samples,), sample_step_schedule[i], dtype=torch.float32, device=self.device)
            pred = self.classifier_free_guidance(xt, t, condition, w_cfg)
            pred = self.clip_prediction(pred, xt, alphas[i], sigmas[i])

            # transform to eps_theta
            eps_theta = pred if self.predict_noise else xtheta_to_epstheta(xt, alphas[i], sigmas[i], pred)
            x_theta = pred if not self.predict_noise else epstheta_to_xtheta(xt, alphas[i], sigmas[i], pred)

            # one-step update
            if solver == "ddpm":
                xt = (
                        (alphas[i - 1] / alphas[i]) * (xt - sigmas[i] * eps_theta) +
                        (sigmas[i - 1] ** 2 - stds[i] ** 2 + 1e-8).sqrt() * eps_theta)
                if i > 1:
                    xt += (stds[i] * torch.randn_like(xt))

            elif solver == "ddim":
                xt = (alphas[i - 1] * ((xt - sigmas[i] * eps_theta) / alphas[i]) + sigmas[i - 1] * eps_theta)

            elif solver == "ode_dpmsolver_1":
                xt = (alphas[i - 1] / alphas[i]) * xt - sigmas[i - 1] * torch.expm1(hs[i]) * eps_theta

            elif solver == "ode_dpmsolver++_1":
                xt = (sigmas[i - 1] / sigmas[i]) * xt - alphas[i - 1] * torch.expm1(-hs[i]) * x_theta

            elif solver == "ode_dpmsolver++_2M":
                buffer.append(x_theta)
                if i < sample_steps:
                    r = hs[i + 1] / hs[i]
                    D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
                    xt = (sigmas[i - 1] / sigmas[i]) * xt - alphas[i - 1] * torch.expm1(-hs[i]) * D
                else:
                    xt = (sigmas[i - 1] / sigmas[i]) * xt - alphas[i - 1] * torch.expm1(-hs[i]) * x_theta

            elif solver == "sde_dpmsolver_1":
                xt = ((alphas[i - 1] / alphas[i]) * xt -
                      2 * sigmas[i - 1] * torch.expm1(hs[i]) * eps_theta +
                      sigmas[i - 1] * torch.expm1(2 * hs[i]).sqrt() * torch.randn_like(xt))

            elif solver == "sde_dpmsolver++_1":
                xt = ((sigmas[i - 1] / sigmas[i]) * (-hs[i]).exp() * xt -
                      alphas[i - 1] * torch.expm1(-2 * hs[i]) * x_theta +
                      sigmas[i - 1] * (-torch.expm1(-2 * hs[i])).sqrt() * torch.randn_like(xt))

            elif solver == "sde_dpmsolver++_2M":
                buffer.append(x_theta)
                if i < sample_steps:
                    r = hs[i + 1] / hs[i]
                    D = (1 + 0.5 / r) * buffer[-1] - 0.5 / r * buffer[-2]
                    xt = ((sigmas[i - 1] / sigmas[i]) * (-hs[i]).exp() * xt -
                          alphas[i - 1] * torch.expm1(-2 * hs[i]) * D +
                          sigmas[i - 1] * (-torch.expm1(-2 * hs[i])).sqrt() * torch.randn_like(xt))
                else:
                    xt = ((sigmas[i - 1] / sigmas[i]) * (-hs[i]).exp() * xt -
                          alphas[i - 1] * torch.expm1(-2 * hs[i]) * x_theta +
                          sigmas[i - 1] * (-torch.expm1(-2 * hs[i])).sqrt() * torch.randn_like(xt))

            # fix the known portion, and preserve the sampling history
            xt = xt * (1. - self.fix_mask) + prior * self.fix_mask

        # ================= Post-processing =================
        if self.clip_pred:
            xt = xt.clip(self.x_min, self.x_max)

        return xt