import torch
from typing import Optional


class Transport:
    def __init__(self, sigma_d, T_max, T_min, enhance_target=False, w_gt=1.0, w_cond=0.0, w_start=0.0, w_end=1.0):
        self.sigma_d = sigma_d
        self.T_max = T_max
        self.T_min = T_min
        self.enhance_target = enhance_target
        self.w_gt = w_gt
        self.w_cond = w_cond
        self.w_start = w_start
        self.w_end = w_end

    def sample_t(self, batch_size, dtype, device):
        pass

    def c_noise(self, t: torch.Tensor):
        pass

    def interpolant(self, t: torch.Tensor):
        pass

    def target(self, x_t: torch.Tensor, v_t: torch.Tensor, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor,
               r: torch.Tensor, dF_dv_dt: torch.Tensor, F_t_cond: torch.Tensor, F_t_uncond: torch.Tensor):
        pass

    def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor):
        pass


class OT_FM(Transport):
    def __init__(self, P_mean=0.0, P_std=1.0, sigma_d=1.0, T_max=1.0, T_min=0.0, enhance_target=False, w_gt=1.0,
                 w_cond=0.0, w_start=0.0, w_end=1.0):
        '''
        Flow-matching with linear path formulation from the paper:
        "SiT: Exploring Flow and Diffusion-based Generative Models with Scalable Interpolant Transformers"
        '''
        self.P_mean = P_mean
        self.P_std = P_std
        super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end)

    def interpolant(self, t: torch.Tensor):
        alpha_t = 1 - t
        sigma_t = t
        d_alpha_t = -1
        d_sigma_t = 1
        return alpha_t, sigma_t, d_alpha_t, d_sigma_t

    def sample_t(self, batch_size, dtype, device):
        rnd_normal = torch.randn((batch_size,), dtype=dtype, device=device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        t = sigma / (1 + sigma)  # [0, 1]
        return t

    def c_noise(self, t: torch.Tensor):
        return t

    def target(
            self,
            x_t: torch.Tensor,
            v_t: torch.Tensor,
            x: torch.Tensor,
            z: torch.Tensor,
            t: torch.Tensor,
            r: torch.Tensor,
            dF_dv_dt: torch.Tensor,
            F_t_cond: Optional[torch.Tensor] = 0.0,
            F_t_uncond: Optional[torch.Tensor] = 0.0,
            enhance_target=False,
    ):
        if enhance_target:
            w_gt = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_gt, 1.0)
            w_cond = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_cond, 0.0)
            v_t = w_gt * v_t + w_cond * F_t_cond + (1 - w_gt - w_cond) * F_t_uncond
        F_target = v_t - (t - r) * dF_dv_dt
        return F_target

    def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
        x_r = x_t - (t - r) * F
        if s_ratio > 0.0:
            z = x_t + (1 - t) * F
            epsilon = torch.randn_like(z)
            dt = t - r
            x_r = x_r - s_ratio * z * dt + torch.sqrt(s_ratio * 2 * t * dt) * epsilon
        return x_r


class TrigFlow(Transport):
    def __init__(self, P_mean=-1.0, P_std=1.6, sigma_d=0.5, T_max=1.57, T_min=0.0, enhance_target=False, w_gt=1.0,
                 w_cond=0.0, w_start=0.0, w_end=1.0):
        '''
        TrigFlow formulation from the paper:
        "SIMPLIFYING, STABILIZING & SCALING CONTINUOUS-TIME CONSISTENCY MODELS"
        '''
        self.P_mean = P_mean
        self.P_std = P_std
        super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end)

    def interpolant(self, t: torch.Tensor):
        alpha_t = torch.cos(t)
        sigma_t = torch.sin(t)
        d_alpha_t = -torch.sin(t)
        d_sigma_t = torch.cos(t)
        return alpha_t, sigma_t, d_alpha_t, d_sigma_t

    def sample_t(self, batch_size, dtype, device):
        rnd_normal = torch.randn((batch_size,), dtype=dtype, device=device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        t = torch.atan(sigma)  # [0, pi/2]
        return t

    def c_noise(self, t: torch.Tensor):
        return t

    def target(
            self,
            x_t: torch.Tensor,
            v_t: torch.Tensor,
            x: torch.Tensor,
            z: torch.Tensor,
            t: torch.Tensor,
            r: torch.Tensor,
            dF_dv_dt: torch.Tensor,
            F_t_cond: Optional[torch.Tensor] = 0.0,
            F_t_uncond: Optional[torch.Tensor] = 0.0,
            enhance_target=False,
    ):
        if enhance_target:
            w_gt = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_gt, 1.0)
            w_cond = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_cond, 0.0)
            v_t = w_gt * v_t + w_cond * F_t_cond + (1 - w_gt - w_cond) * F_t_uncond
        F_target = v_t - torch.tan(t - r) * (x_t + dF_dv_dt)
        return F_target

    def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
        x_r = torch.cos(t - r) * x_t - torch.sin(t - r) * F
        return x_r


class EDM(Transport):
    def __init__(self, P_mean=-1.2, P_std=1.2, sigma_d=0.5, T_max=80.0, T_min=0.01, enhance_target=False, w_gt=1.0,
                 w_cond=0.0, w_start=0.0, w_end=1.0):
        '''
        EDM formulation from the paper:
        "Elucidating the Design Space of Diffusion-Based Generative Models"
        '''
        self.P_mean = P_mean
        self.P_std = P_std
        super().__init__(sigma_d, T_max, T_min, enhance_target, w_gt, w_cond, w_start, w_end)

    def interpolant(self, t: torch.Tensor):
        '''
        The d_alpha_t and d_sigma_t are easy to obtain:
        # from sympy import *
        # from scipy.stats import *
        # t, sigma_d = symbols('t sigma_d')
        # alpha_t = sigma_d * ((t**2 + sigma_d**2) ** (-0.5))
        # sigma_t = t * ((t**2 + sigma_d**2) ** (-0.5))
        # d_alpha_t = diff(alpha_t, t)
        # d_sigma_t = diff(sigma_t, t)
        # print(d_alpha_t)
        # print(d_sigma_t)
        '''
        sigma_d = self.sigma_d
        alpha_t = 1 / (t ** 2 + sigma_d ** 2).sqrt()
        sigma_t = t / (t ** 2 + sigma_d ** 2).sqrt()
        d_alpha_t = -t / ((sigma_d ** 2 + t ** 2) ** 1.5)
        d_sigma_t = (sigma_d ** 2) / ((sigma_d ** 2 + t ** 2) ** 1.5)
        return alpha_t, sigma_t, d_alpha_t, d_sigma_t

    def sample_t(self, batch_size, dtype, device):
        rnd_normal = torch.randn((batch_size,), dtype=dtype, device=device)
        sigma = (rnd_normal * self.P_std + self.P_mean).exp()
        t = sigma  # t > 0
        return t

    def c_noise(self, t: torch.Tensor):
        return torch.log(t) / 4

    def target(
            self,
            x_t: torch.Tensor,
            v_t: torch.Tensor,
            x: torch.Tensor,
            z: torch.Tensor,
            t: torch.Tensor,
            r: torch.Tensor,
            dF_dv_dt: torch.Tensor,
            F_t_cond: Optional[torch.Tensor] = 0.0,
            F_t_uncond: Optional[torch.Tensor] = 0.0,
            enhance_target=False,
    ):
        sigma_d = self.sigma_d
        alpha_hat_t = t / (sigma_d * (t ** 2 + sigma_d ** 2).sqrt())
        sigma_hat_t = - sigma_d / (t ** 2 + sigma_d ** 2).sqrt()
        d_alpha_hat_t = -t ** 2 / (sigma_d * (sigma_d ** 2 + t ** 2) ** (3 / 2)) + 1 / (
                    sigma_d * (sigma_d ** 2 + t ** 2).sqrt())
        d_sigma_hat_t = sigma_d * t / ((sigma_d ** 2 + t ** 2) ** (3 / 2))
        diffusion_target = alpha_hat_t * x + sigma_hat_t * z
        Bt_dv_dBt = (t - r) * (sigma_d ** 2 + t ** 2) * (sigma_d ** 3 + t ** 2) / (
                2 * t * (r - t) * (sigma_d ** 2 + t ** 2) - t * (r - t) * (sigma_d ** 3 + t ** 2) + (
                    sigma_d ** 2 + t ** 2) * (sigma_d ** 3 + t ** 2)
        )
        if enhance_target:
            w_gt = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_gt, 1.0)
            w_cond = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_cond, 0.0)
            diffusion_target = w_gt * diffusion_target + w_cond * F_t_cond + (1 - w_gt - w_cond) * F_t_uncond
        F_target = diffusion_target + Bt_dv_dBt * (d_alpha_hat_t * x + d_sigma_hat_t * z - dF_dv_dt)
        return F_target

    def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
        sigma_d = self.sigma_d
        ratio = (t ** 2 + sigma_d ** 2).sqrt() / (r ** 2 + sigma_d ** 2).sqrt() / (sigma_d ** 3 + t ** 2)
        A_t = (sigma_d ** 3 + t * r) * ratio
        B_t = (sigma_d ** 2) * (t - r) * ratio
        x_r = A_t * x_t + B_t * F
        return x_r


class VP_SDE(Transport):
    def __init__(self, beta_min=0.1, beta_d=19.9, epsilon_t=1e-5, T=1000, sigma_d=1.0, enhance_target=False, w_gt=1.0,
                 w_cond=0.0, w_start=0.0, w_end=1.0):
        '''
        Variance preserving (VP) formulation from the paper:
        "Score-Based Generative Modeling through Stochastic Differential Equations".
        '''
        self.beta_min = beta_min
        self.beta_d = beta_d
        self.epsilon_t = epsilon_t
        self.T = T
        super().__init__(sigma_d, 1.0, epsilon_t, enhance_target, w_gt, w_cond, w_start, w_end)

    def interpolant(self, t: torch.Tensor):
        '''
        The d_alpha_t and d_sigma_t are easy to obtain:
        # from sympy import *
        # from scipy.stats import *
        # t, beta_d, beta_min = symbols('t beta_d beta_min')
        # sigma = sqrt(exp(0.5 * beta_d * (t ** 2) + beta_min * t) - 1)
        # d_sigma_d_t = diff(sigma, t)
        # print(d_sigma_d_t)
        # sigma = symbols('sigma')
        # alpha_t = (sigma**2 + 1) ** (-0.5)
        # sigma_t = sigma * (sigma**2 + 1) ** (-0.5)
        # d_alpha_d_sigma = diff(alpha_t, sigma)
        # print(d_alpha_d_sigma)
        # d_sigma_d_sigma = diff(sigma_t, sigma)
        # print(d_sigma_d_sigma)
        '''
        beta_t = self.beta(t)
        alpha_t = 1 / torch.sqrt(beta_t ** 2 + 1)
        sigma_t = beta_t / torch.sqrt(beta_t ** 2 + 1)
        d_alpha_t = -0.5 * (self.beta_d * t + self.beta_min) / (beta_t ** 2 + 1).sqrt()
        d_sigma_t = 0.5 * (self.beta_d * t + self.beta_min) / (beta_t * (beta_t ** 2 + 1).sqrt())
        return alpha_t, sigma_t, d_alpha_t, d_sigma_t

    def beta(self, t: torch.Tensor):
        return torch.sqrt((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1)

    def sample_t(self, batch_size, dtype, device):
        rnd_uniform = torch.rand((batch_size,), dtype=dtype, device=device)
        t = 1 + rnd_uniform * (self.epsilon_t - 1)  # [epsilon_t, 1]
        return t

    def c_noise(self, t: torch.Tensor):
        return (self.T - 1) * t

    def target(
            self,
            x_t: torch.Tensor,
            v_t: torch.Tensor,
            x: torch.Tensor,
            z: torch.Tensor,
            t: torch.Tensor,
            r: torch.Tensor,
            dF_dv_dt: torch.Tensor,
            F_t_cond: Optional[torch.Tensor] = 0.0,
            F_t_uncond: Optional[torch.Tensor] = 0.0,
            enhance_target=False,
    ):
        if enhance_target:
            w_gt = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_gt, 1.0)
            w_cond = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_cond, 0.0)
            z = w_gt * z + w_cond * F_t_cond + (1 - w_gt - w_cond) * F_t_uncond
        beta_t = self.beta(t)
        beta_r = self.beta(r)
        d_beta_t = (self.beta_d * t + self.beta_min) * (beta_t ** 2 + 1) / (2 * beta_t)
        F_target = z - dF_dv_dt * (beta_t - beta_r) / d_beta_t
        return F_target

    def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
        beta_t = self.beta(t)
        beta_r = self.beta(r)
        A_t = (beta_t ** 2 + 1).sqrt() / (beta_r ** 2 + 1).sqrt()
        B_t = (beta_r - beta_t) / (beta_r ** 2 + 1).sqrt()
        x_r = A_t * x_t + B_t * F
        return x_r


class VE_SDE(Transport):
    def __init__(self, sigma_min=0.02, sigma_max=100, sigma_d=1.0, enhance_target=False, w_gt=1.0, w_cond=0.0,
                 w_start=0.0, w_end=1.0):
        '''
        Variance exploding (VE) formulation from the paper:
        "Score-Based Generative Modeling through Stochastic Differential Equations".
        '''
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        super().__init__(sigma_d, sigma_max, sigma_min, enhance_target, w_gt, w_cond, w_start, w_end)

    def interpolant(self, t: torch.Tensor):
        alpha_t = 1
        sigma_t = t
        d_alpha_t = 0
        d_sigma_t = 1
        return alpha_t, sigma_t, d_alpha_t, d_sigma_t

    def sample_t(self, batch_size, dtype, device):
        rnd_uniform = torch.rand((batch_size,), dtype=dtype, device=device)
        t = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform)  # [sigma_min, sigma_max]
        return t

    def c_noise(self, t: torch.Tensor):
        return torch.log(0.5 * t)

    def target(
            self,
            x_t: torch.Tensor,
            v_t: torch.Tensor,
            x: torch.Tensor,
            z: torch.Tensor,
            t: torch.Tensor,
            r: torch.Tensor,
            dF_dv_dt: torch.Tensor,
            F_t_cond: Optional[torch.Tensor] = 0.0,
            F_t_uncond: Optional[torch.Tensor] = 0.0,
            enhance_target=False,
    ):
        if enhance_target:
            w_gt = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_gt, 1.0)
            w_cond = torch.where((t >= self.w_start) & (t <= self.w_end), self.w_cond, 0.0)
            z = w_gt * z + w_cond * (-F_t_cond) + (1 - w_gt - w_cond) * (-F_t_uncond)
        F_target = (r - t) * dF_dv_dt - z
        return F_target

    def from_x_t_to_x_r(self, x_t: torch.Tensor, t: torch.Tensor, r: torch.Tensor, F: torch.Tensor, s_ratio=0.0):
        x_r = x_t + (t - r) * F
        return x_r
