import math  # Added import for math.pi

import numpy as np
import torch


def expand_t_like_x(t, x_cur):
    """Function to reshape time t to broadcastable dimension of x
    Args:
      t: [batch_dim,], time vector
      x: [batch_dim,...], data point
    """
    dims = [1] * (len(x_cur.size()) - 1)
    t = t.view(t.size(0), *dims)
    return t


def get_score_from_velocity(vt, xt, t, path_type="linear"):
    """Wrapper function: transfrom velocity prediction model to score
    Args:
        velocity: [batch_dim, ...] shaped tensor; velocity model output
        x: [batch_dim, ...] shaped tensor; x_t data point
        t: [batch_dim,] time tensor
    """
    t = expand_t_like_x(t, xt)
    if path_type == "linear":
        alpha_t, d_alpha_t = 1 - t, torch.ones_like(xt, device=xt.device) * -1
        sigma_t, d_sigma_t = t, torch.ones_like(xt, device=xt.device)
    elif path_type == "cosine":
        alpha_t = torch.cos(t * np.pi / 2)
        sigma_t = torch.sin(t * np.pi / 2)
        d_alpha_t = -np.pi / 2 * torch.sin(t * np.pi / 2)
        d_sigma_t = np.pi / 2 * torch.cos(t * np.pi / 2)
    else:
        raise NotImplementedError

    mean = xt
    reverse_alpha_ratio = alpha_t / d_alpha_t
    var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
    score = (reverse_alpha_ratio * vt - mean) / var

    return score


def compute_diffusion(t_cur):
    return 2 * t_cur


def euler_sampler(
    model,
    latents,
    y,
    chs=None,
    cts=None,
    num_steps=20,
    heun=False,
    cfg_scale=1.0,
    guidance_low=0.0,
    guidance_high=1.0,
    path_type="linear",  # not used, just for compatability
):
    # setup conditioning
    if cfg_scale > 1.0:
        y_null = torch.tensor([1139] * y.size(0), device=y.device)
        if chs is not None:
            ch_null = torch.tensor([6] * chs.size(0), device=chs.device)
        if cts is not None:
            ct_null = torch.tensor([4] * cts.size(0), device=cts.device)
    _dtype = latents.dtype
    t_steps = torch.linspace(1, 0, num_steps + 1, dtype=torch.float64)
    x_next = latents.to(torch.float64)
    device = x_next.device
    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):
            x_cur = x_next
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                model_input = torch.cat([x_cur] * 2, dim=0)
                y_cur = torch.cat([y, y_null], dim=0)
                if chs is not None:
                    ch_cur = torch.cat([chs, ch_null], dim=0)
                else:
                    ch_cur = None
                if cts is not None:
                    ct_cur = torch.cat([cts, ct_null], dim=0)
                else:
                    ct_cur = None
            else:
                model_input = x_cur
                y_cur = y
                if chs is not None:
                    ch_cur = chs
                else:
                    ch_cur = None
                if cts is not None:
                    ct_cur = cts
                else:
                    ct_cur = None
            kwargs = dict(y=y_cur, ch=ch_cur, ct=ct_cur)
            time_input = (
                torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64)
                * t_cur
            )
            d_cur = model(
                model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
            )[0].to(torch.float64)
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                d_cur_cond, d_cur_uncond = d_cur.chunk(2)
                d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)
            x_next = x_cur + (t_next - t_cur) * d_cur
            if heun and (i < num_steps - 1):
                if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                    model_input = torch.cat([x_next] * 2)
                    y_cur = torch.cat([y, y_null], dim=0)
                    if chs is not None:
                        ch_cur = torch.cat([chs, ch_null], dim=0)
                    else:
                        ch_cur = None
                    if cts is not None:
                        ct_cur = torch.cat([cts, ct_null], dim=0)
                    else:
                        ct_cur = None
                else:
                    model_input = x_next
                    y_cur = y
                    if chs is not None:
                        ch_cur = chs
                    else:
                        ch_cur = None
                    if cts is not None:
                        ct_cur = cts
                    else:
                        ct_cur = None
                kwargs = dict(y=y_cur, ch=ch_cur, ct=ct_cur)
                time_input = (
                    torch.ones(model_input.size(0)).to(
                        device=model_input.device, dtype=torch.float64
                    )
                    * t_next
                )
                d_prime = model(
                    model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
                )[0].to(torch.float64)
                if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                    d_prime_cond, d_prime_uncond = d_prime.chunk(2)
                    d_prime = d_prime_uncond + cfg_scale * (
                        d_prime_cond - d_prime_uncond
                    )
                x_next = x_cur + (t_next - t_cur) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next


def euler_maruyama_sampler(
    model,
    latents,
    y,
    chs=None,
    cts=None,
    num_steps=20,
    heun=False,  # not used, just for compatability
    cfg_scale=1.0,
    guidance_low=0.0,
    guidance_high=1.0,
    path_type="linear",
):
    # setup conditioning
    if cfg_scale > 1.0:
        y_null = torch.tensor([1139] * y.size(0), device=y.device)
        if chs is not None:
            chs_null = torch.tensor([6] * chs.size(0), device=chs.device)
        if cts is not None:
            cts_null = torch.tensor([4] * cts.size(0), device=cts.device)

    _dtype = latents.dtype

    t_steps = torch.linspace(1.0, 0.04, num_steps, dtype=torch.float64)
    t_steps = torch.cat([t_steps, torch.tensor([0.0], dtype=torch.float64)])
    x_next = latents.to(torch.float64)
    device = x_next.device

    with torch.no_grad():
        for i, (t_cur, t_next) in enumerate(zip(t_steps[:-2], t_steps[1:-1])):
            dt = t_next - t_cur
            x_cur = x_next
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                model_input = torch.cat([x_cur] * 2, dim=0)
                y_cur = torch.cat([y, y_null], dim=0)
                if chs is not None:
                    ch_cur = torch.cat([chs, chs_null], dim=0)
                else:
                    ch_cur = None
                if cts is not None:
                    ct_cur = torch.cat([cts, cts_null], dim=0)
                else:
                    ct_cur = None
            else:
                model_input = x_cur
                y_cur = y
                if chs is not None:
                    ch_cur = chs
                else:
                    ch_cur = None
                if cts is not None:
                    ct_cur = cts
                else:
                    ct_cur = None
            kwargs = dict(y=y_cur, ch=ch_cur, ct=ct_cur)
            time_input = (
                torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64)
                * t_cur
            )
            diffusion = compute_diffusion(t_cur)
            eps_i = torch.randn_like(x_cur).to(device)
            deps = eps_i * torch.sqrt(torch.abs(dt))

            # compute drift
            v_cur = model(
                model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs
            )[0].to(torch.float64)
            s_cur = get_score_from_velocity(
                v_cur, model_input, time_input, path_type=path_type
            )
            d_cur = v_cur - 0.5 * diffusion * s_cur
            if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
                d_cur_cond, d_cur_uncond = d_cur.chunk(2)
                d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)

            x_next = x_cur + d_cur * dt + torch.sqrt(diffusion) * deps

    # last step
    t_cur, t_next = t_steps[-2], t_steps[-1]
    dt = t_next - t_cur
    x_cur = x_next
    if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
        model_input = torch.cat([x_cur] * 2, dim=0)
        y_cur = torch.cat([y, y_null], dim=0)
        if chs is not None:
            ch_cur = torch.cat([chs, chs_null], dim=0)
        else:
            ch_cur = None
        if cts is not None:
            ct_cur = torch.cat([cts, cts_null], dim=0)
        else:
            ct_cur = None
    else:
        model_input = x_cur
        y_cur = y
        if chs is not None:
            ch_cur = chs
        else:
            ch_cur = None
        if cts is not None:
            ct_cur = cts
        else:
            ct_cur = None
    kwargs = dict(y=y_cur, ch=ch_cur, ct=ct_cur)
    time_input = (
        torch.ones(model_input.size(0)).to(device=device, dtype=torch.float64) * t_cur
    )

    # compute drift
    v_cur = model(model_input.to(dtype=_dtype), time_input.to(dtype=_dtype), **kwargs)[
        0
    ].to(torch.float64)
    s_cur = get_score_from_velocity(v_cur, model_input, time_input, path_type=path_type)
    diffusion = compute_diffusion(t_cur)
    d_cur = v_cur - 0.5 * diffusion * s_cur
    if cfg_scale > 1.0 and t_cur <= guidance_high and t_cur >= guidance_low:
        d_cur_cond, d_cur_uncond = d_cur.chunk(2)
        d_cur = d_cur_uncond + cfg_scale * (d_cur_cond - d_cur_uncond)

    mean_x = x_cur + dt * d_cur

    return mean_x.to(latents.dtype)
