from typing import Tuple

import torch
from torch import Tensor


class Diffusion:

    def __init__(
        self, batch_size: int, target: str, diffusion_timesteps: int, beta_scheduler_type: str, device: str
    ) -> None:
        self.device = device
        self.batch_size = batch_size

        self.target = target
        self.diffusion_timesteps = diffusion_timesteps
        self.beta_scheduler_type = beta_scheduler_type

        self.betas: Tensor = self.__beta_schedule(self.diffusion_timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)

        # Forward Diffusion
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

        # Predict Start from Noise
        self.sqrt_recip_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod)
        self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / self.alphas_cumprod - 1.)

        # Backward Diffusion
        self.alphas_cumprod_prev = torch.nn.functional.pad(self.alphas_cumprod[:-1], (1, 0), value=1.)
        self.posterior_mean_coef1 = self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.posterior_mean_coef2 = (1. - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (
                    1. - self.alphas_cumprod)
        self.posterior_variance = self.betas * (1. - self.alphas_cumprod_prev) / (1. - self.alphas_cumprod)
        self.model_log_variance = torch.log(torch.clamp(self.posterior_variance, min=1e-20))

    def forward_diffusion(self, x_0: Tensor, t: Tensor) -> Tuple[Tensor, Tensor]:
        noise = torch.randn_like(x_0)
        # x_0.shape = noise.shape = [batch_size, seq_len, n_features]

        sqrt_alphas_cumprod_t = self.extract(self.sqrt_alphas_cumprod, t)
        sqrt_one_minus_alphas_cumprod_t = self.extract(self.sqrt_one_minus_alphas_cumprod, t)

        x_t = sqrt_alphas_cumprod_t * x_0 + sqrt_one_minus_alphas_cumprod_t * noise
        # x_t.shape = [batch_size, seq_len, n_features]

        return x_t, noise

    def predict_start_from_noise(self, x_t: Tensor, noise_pred: Tensor, t: Tensor) -> Tensor:
        # x_t.shape = noise_pred.shape = [batch_size, seq_len, n_features]

        sqrt_recip_alphas_cumprod_t = self.extract(self.sqrt_recip_alphas_cumprod, t)
        sqrt_recipm1_alphas_cumprod_t = self.extract(self.sqrt_recipm1_alphas_cumprod, t)

        x_0_reconstructed = sqrt_recip_alphas_cumprod_t * x_t - sqrt_recipm1_alphas_cumprod_t * noise_pred
        # x_0_reconstructed.shape = [batch_size, seq_len, n_features]

        return x_0_reconstructed

    def backward_diffusion(self, x_t: Tensor, pred: Tensor, t: Tensor, timestep: int) -> Tensor:
        # x_t.shape = pred.shape = [batch_size, seq_len, n_features]

        x_0_pred = self.predict_start_from_noise(x_t, pred, t) if self.target == 'noise' else pred
        posterior_mean_coef1_t = self.extract(self.posterior_mean_coef1, t)
        posterior_mean_coef2_t = self.extract(self.posterior_mean_coef2, t)

        model_mean = posterior_mean_coef1_t * x_0_pred + posterior_mean_coef2_t * x_t

        if timestep == 0:
            x_tm1 = model_mean
        else:
            model_log_variance_t = self.extract(self.model_log_variance, t)
            noise = torch.randn_like(x_0_pred)
            x_tm1 = model_mean + torch.exp(.5 * model_log_variance_t) * noise
        # x_tm1.shape = [1, seq_len, n_features]

        return x_tm1

    @staticmethod
    def extract(tensor: Tensor, t: Tensor) -> Tensor:
        # tensor.shape = [diffusion_timesteps]
        # t.shape = [batch_size]
        return tensor.gather(-1, t).reshape(t.shape[0], *((1,) * 2))

    def __beta_schedule(self, diffusion_timesteps) -> Tensor:
        if self.beta_scheduler_type == 'Linear':
            return self.__linear_beta_schedule(diffusion_timesteps)
        elif self.beta_scheduler_type == 'Quadratic':
            return self.__quadratic_beta_schedule(diffusion_timesteps)
        elif self.beta_scheduler_type == 'Cosine':
            return self.__cosine_beta_schedule(diffusion_timesteps)
        else:
            return self.__sigmoid_beta_schedule(diffusion_timesteps)

    def __linear_beta_schedule(self, diffusion_timesteps, start: float = 0.0001, end: float = 0.02) -> Tensor:
        return torch.linspace(start, end, diffusion_timesteps, device=self.device)

    def __quadratic_beta_schedule(self, diffusion_timesteps, start: float = 0.0001, end: float = 0.02) -> Tensor:
        return torch.linspace(start ** 0.5, end ** 0.5, diffusion_timesteps, device=self.device) ** 2

    def __cosine_beta_schedule(self, diffusion_timesteps, s: float = 0.008) -> Tensor:
        steps = diffusion_timesteps + 1
        x = torch.linspace(0, diffusion_timesteps, steps, device=self.device)
        alphas_cumprod = torch.cos(((x / diffusion_timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        return torch.clip(betas, 0.0001, 0.9999)

    def __sigmoid_beta_schedule(self, diffusion_timesteps, start: float = 0.0001, end: float = 0.02) -> Tensor:
        betas = torch.linspace(-6, 6, diffusion_timesteps, device=self.device)
        return torch.sigmoid(betas) * (end - start) + start

    def __repr__(self) -> str:
        return (f"{self.__class__.__name__}("
                f"beta_scheduler_type={self.beta_scheduler_type},"
                f"diffusion_timesteps={self.diffusion_timesteps}"
                f")")
