# Disclaimer: This code is strongly influenced by
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_ddim.py

import math
from typing import Union

import numpy as np
import torch
from tqdm import tqdm

from utils import match_shape, clip


def cosine_beta_schedule(timesteps, beta_start=0.0, beta_end=0.999, s=0.008):
    steps = timesteps + 1
    x = torch.linspace(0, timesteps, steps, dtype=torch.float32)
    alphas_cumprod = torch.cos(
        ((x / timesteps) + s) / (1 + s) * math.pi * 0.5)**2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
    return torch.clip(betas, beta_start, beta_end)


class DDIMScheduler:

    def __init__(self,
                 num_train_timesteps=500,
                 beta_start=0.0001,
                 beta_end=0.02,
                 beta_schedule="cosine",
                 clip_sample=True,
                 set_alpha_to_one=True):

        if beta_schedule == "linear":
            self.betas = np.linspace(beta_start,
                                     beta_end,
                                     num_train_timesteps,
                                     dtype=np.float32)
            self.betas = torch.tensor(self.betas)
        elif beta_schedule == "cosine":
            self.betas = cosine_beta_schedule(num_train_timesteps,
                                              beta_start=beta_start,
                                              beta_end=beta_end)
        else:
            raise NotImplementedError(
                f"{beta_schedule} does is not implemented for {self.__class__}")

        self.num_train_timesteps = num_train_timesteps
        self.clip_sample = clip_sample
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = np.cumprod(self.alphas, axis=0)

        self.final_alpha_cumprod = np.array(
            1.0) if set_alpha_to_one else self.alphas_cumprod[0]

        self.num_inference_steps = None
        self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy()

    def _get_variance(self, timestep, prev_timestep):
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[
            prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t
        beta_prod_t_prev = 1 - alpha_prod_t_prev

        variance = (beta_prod_t_prev /
                    beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)

        return variance

    def set_timesteps(self, num_inference_steps, offset=0):
        self.num_inference_steps = num_inference_steps
        self.timesteps = np.arange(0, self.num_train_timesteps,
                                   self.num_train_timesteps // num_inference_steps)[::-1].copy()
        self.timesteps += offset

    def step(
        self,
        model_output: Union[torch.FloatTensor, np.ndarray],
        timestep: int,
        sample: Union[torch.FloatTensor, np.ndarray],
        eta: float = 1.0,
        use_clipped_model_output: bool = True,
        generator=None,
    ):
        # 1. get previous step value (=t-1)
        prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps

        # 2. compute alphas, betas
        alpha_prod_t = self.alphas_cumprod[timestep]
        alpha_prod_t_prev = self.alphas_cumprod[
            prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
        beta_prod_t = 1 - alpha_prod_t

        # 3. compute predicted original sample from predicted noise also called
        # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_original_sample = (sample - beta_prod_t**
                                (0.5) * model_output) / alpha_prod_t**(0.5)

        # 4. Clip "predicted x_0"
        if self.clip_sample:
            pred_original_sample = clip(pred_original_sample, -1, 1)

        # 5. compute variance: "sigma_t(η)" -> see formula (16)
        # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1)
        variance = self._get_variance(timestep, prev_timestep)
        std_dev_t = eta * variance**(0.5)

        if use_clipped_model_output:
            # the model_output is always re-derived from the clipped x_0 in Glide
            model_output = (sample - alpha_prod_t**
                            (0.5) * pred_original_sample) / beta_prod_t**(0.5)
            #print("after",model_output)

        # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        pred_sample_direction = (1 - alpha_prod_t_prev -
                                 std_dev_t**2)**(0.5) * model_output
        #print("direction", pred_sample_direction)

        # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
        prev_sample = alpha_prod_t_prev**(
            0.5) * pred_original_sample + pred_sample_direction
        #print("v",alpha_prod_t_prev**(0.5))
        if eta > 0:
            device = model_output.device if torch.is_tensor(
                model_output) else "cpu"
            noise = torch.randn(model_output.shape,
                                generator=generator).to(device)
            variance = self._get_variance(timestep,
                                          prev_timestep)**(0.5) * eta * noise

            if not torch.is_tensor(model_output):
                variance = variance.numpy()

            prev_sample = prev_sample + variance

        return prev_sample

    def add_noise(self, original_samples, noise, timesteps):
        timesteps = timesteps.cpu()
        sqrt_alpha_prod = self.alphas_cumprod[timesteps]**0.5
        sqrt_alpha_prod = match_shape(sqrt_alpha_prod, original_samples)
        sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps])**0.5
        sqrt_one_minus_alpha_prod = match_shape(sqrt_one_minus_alpha_prod,
                                                original_samples)
        noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
        return noisy_samples

    @torch.no_grad()
    def generate(self,
                 model,
                 drug_cond,
                 prot_cond,
                 pk,
                 dti,
                 drug_dki = None,
                 target_dki = None,
                 generator=None,
                 eta=1.0,
                 use_clipped_model_output=False,
                 num_inference_steps=500,
                 device=None):
        if device is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"

        # batch_size = cond.shape[0]
        # ys = torch.randn(
        #     (batch_size, model.out_dim),
        #     generator=generator,
        # )
        # ys = ys.to(device)

        self.set_timesteps(num_inference_steps)

        for t in self.timesteps:
            # 1. predict noise model_output
            # model_output = model(cond, ys, t)
            pk_noise_pred, dti_noise_pred = model(
                timesteps = t,
                drug_cond = drug_cond,
                prot_cond = prot_cond,
                pk = pk,
                dti = dti,
                drug_dki = drug_dki,
                target_dki = target_dki
            )

            # 2. predict previous mean of ys x_t-1 and add variance depending on eta
            # do x_t -> x_t-1
            pk = self.step(
                pk_noise_pred,
                t,
                pk,
                eta,
                use_clipped_model_output=use_clipped_model_output
            )
            if dti is not None:
                dti = self.step(
                    dti_noise_pred,
                    t,
                    dti,
                    eta,
                    use_clipped_model_output=use_clipped_model_output
                )
            

        return pk, dti

    def __len__(self):
        return self.num_train_timesteps