import logging
import torch
from tqdm import tqdm


__all__ = ['SDE_sampler_manifolds']


@torch.no_grad()
def SDE_sampler_manifolds(sde,
                          manifold,
                          init,
                          reverse,
                          score_net=None,
                          keep_quiet=False,
                          **kwargs):
    """
    timesteps: shape [N+1, bsz]
    Shape of x_hist: [sde.N, bsz, dim]
    """
    device = init.device
    shape = init.shape
    N = sde.N

    if reverse:
        rsde = sde.reverse(score_net)
        drift_diffusion_fn = rsde.sde
        timesteps = torch.linspace(sde.T, 0., N+1, device=device).reshape(-1, 1).repeat(1, shape[0])
    else:
        drift_diffusion_fn = sde.sde
        timesteps = torch.linspace(0., sde.T, N+1, device=device).reshape(-1, 1).repeat(1, shape[0])

    def update_fn_predictor(x, t, delta_t):

        delta_t = delta_t.reshape(-1, 1)
        z = torch.randn_like(x, device=device)
        drift, diffusion = drift_diffusion_fn(x, t)
        std = diffusion[:, None] * torch.sqrt(torch.abs(delta_t))

        if len(z.shape) == 2:
            tangent_vec = manifold.project_onto_tangent_space(drift * delta_t + std * z, base_point=x)
        else:
            tangent_vec = manifold.project_onto_tangent_space(drift * delta_t[:, None] + std[:, None] * z, base_point=x)

        x_prime, converged_traj_one_step = manifold.project_onto_manifold_with_base(tangent_vec, base_point=x, keep_quiet=keep_quiet)

        std_hist[i] = std.reshape(-1).clone()
        tangent_vec_hist[i] = tangent_vec.clone()
        return x_prime.detach(), converged_traj_one_step

    x_hist = torch.zeros(N+1, *init.shape).to(device)
    x = init
    x_hist[0] = x.clone()
    std_hist = torch.zeros(N, init.shape[0]).to(device)
    tangent_vec_hist = torch.zeros(N, *init.shape).to(device)

    converged_traj = torch.ones(init.shape[0], dtype=torch.bool).to(device)

    for i in tqdm(range(N), mininterval=2., disable=keep_quiet):
        vec_t = timesteps[i]
        dt = timesteps[i+1] - timesteps[i]
        x, converged_traj_one_step = update_fn_predictor(x, vec_t, dt)
        converged_traj = torch.logical_and(converged_traj, converged_traj_one_step)
        x_hist[i + 1] = x.clone()

    num_unconverged_traj = init.shape[0] - converged_traj.sum()
    if keep_quiet is False: logging.info(f'{num_unconverged_traj} of {init.shape[0]} trajectories are dropped.')
    other_dict = {"tangent_vec_hist": tangent_vec_hist, "converged_traj":converged_traj, "x_hist_all": x_hist}
    return x[converged_traj, :], x_hist[:, converged_traj, :], other_dict



