from tqdm import tqdm
from typing import Callable

import torch
import numpy as np
from torch.distributions import Categorical
from samplers import bridge_statistics

from algorithms import cfg_denoiser


def fk_corrector_sampler(
    initial_noise: torch.Tensor,
    denoiser_fn: Callable,
    ctx: torch.Tensor,
    # --- sampler param
    n_particles: int = 5,
    n_steps: int = 32,
    cfg_scale: float = 2.0,
    # ---
    mk_sigmas_fn: Callable[[int, float, float], torch.Tensor] = None,
    # --- base sampler
    sigma_min: float = 0.0,
    sigma_max: float = 1e4,
    rho: float = 7.0,
    **kwargs,
):
    device = initial_noise.device
    x_dim = [i for i in range(1, initial_noise.dim())]
    arange_len_ctx = torch.arange(len(ctx)) * n_particles

    # n_particles must be greater than the number of sample per ctx
    n_samples_per_ctx = initial_noise.shape[0] // len(ctx)
    if n_particles < n_samples_per_ctx:
        raise ValueError(
            "``n_particles`` must be ``>=`` than the number of samplers per ctx.\n"
            f"Got {n_particles=} and {n_samples_per_ctx=}"
        )

    # initial noise
    initial_noise = torch.randn((n_particles * len(ctx), *initial_noise.shape[1:]))

    # init sigmas
    sigmas = mk_sigmas_fn(
        n_steps=n_steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho
    ).to(device)
    sigmas_flipped = sigmas.flip(0)
    pbar = tqdm(zip(sigmas_flipped[:-1], sigmas_flipped[1:]))

    # sampling loop
    x_t = sigmas[-1] * initial_noise
    for sigma_t, sigma_t_prev in pbar:
        pred_x0 = denoiser_fn(
            x=x_t,
            sigma=sigma_t,
            ctx=ctx,
            use_cfg=True,
        )
        predx0_uncond, predx0_cond = pred_x0.chunk(2)

        # compute update
        cfg_pred_x0 = predx0_cond + (cfg_scale - 1) * (predx0_cond - predx0_uncond)
        mean, std = bridge_statistics(x_t, cfg_pred_x0, sigma_t, sigma_t_prev, eta=1.0)
        x_t = mean + std * torch.randn_like(mean)

        # compute log weights
        coef = (
            cfg_scale
            * (cfg_scale - 1)
            * (sigma_t**2 - sigma_t_prev**2)
            / (2 * sigma_t**2 * sigma_t_prev**2)
        )
        log_weights = coef * ((predx0_cond - predx0_uncond) ** 2).sum(dim=x_dim)
        log_weights = log_weights.reshape((len(ctx), n_particles))

        # resample
        resampler = Categorical(logits=log_weights)
        idx_resampling = resampler.sample((n_particles,))

        # last step
        if sigma_t_prev == sigmas_flipped[-1]:
            idx_resampling = resampler.sample((n_samples_per_ctx,))

        idx_resampling = idx_resampling.permute(1, 0).reshape(-1)

        # NOTE shift the idx accordingly
        # to account for that fact that indices are between 0 and n_particles-1
        # but across each ctx

        # last step
        if sigma_t_prev == sigmas_flipped[-1]:
            idx_resampling += arange_len_ctx.repeat_interleave(n_samples_per_ctx, dim=0)
        else:
            idx_resampling += arange_len_ctx.repeat_interleave(n_particles, dim=0)

        x_t = x_t[idx_resampling]

    return x_t


def particle_guidance(
    initial_noise: torch.Tensor,
    denoiser_fn: Callable,
    ctx: torch.Tensor,
    mk_sigmas_fn: Callable[[int, float, float], torch.Tensor],
    n_steps: int,
    cfg_scale: float,
    # --- algo params
    coeff: float = 30,
    power: float = 2,
    n_particles: int = 4,
    # ---
    sampler: Callable = None,
    **sampler_kwargs,
):
    device = initial_noise.device
    x_shape = initial_noise.shape[1:]

    cfg_denoiser_fn = cfg_denoiser(
        denoiser_fn,
        cfg_cond_fn=lambda sigma, scale: scale > 1.0,
        cfg_scale=cfg_scale,
    )

    def repulsive_denoiser(x, sigma, ctx):

        latents = x
        # --- this is a copy/past of
        # https://github.com/gcorso/particle-guidance/blob/ce7de745191168c10b2c125b78d94ad8de119488/stable_diffusion/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_particle_copy.py#L784-L809
        # to compute the repulsive term
        # NOTE the comments inside this block are not mine
        latents_vec = latents.view(len(latents), -1)
        # N x N x d
        diff = latents_vec.unsqueeze(1) - latents_vec.unsqueeze(0)
        # N x N x 1
        distance = torch.norm(diff, p=2, dim=-1, keepdim=True)
        num_images = latents_vec.shape[0]
        h_t = (
            distance.mean(dim=1, keepdim=True) * num_images / (num_images - 1)
        ) ** 2 / np.log(num_images)
        weights = torch.exp(-(distance**power / h_t))
        # set -inf to diag
        # weights[torch.eye(num_images).bool()] = -1e10
        # for stability
        # weights = torch.exp(weights - weights.max(dim=1, keepdim=True)[0])
        # weights = weights / weights.sum(dim=1, keepdim=True)
        grad_phi = 2 * weights * diff / h_t * 2 * sigma * coeff
        grad_phi = grad_phi.sum(dim=1)
        grad_phi = grad_phi.view_as(latents)
        # ---

        pred_x0 = cfg_denoiser_fn(x, sigma, ctx)

        # NOTE in the original implementation
        # expression of grad was developed while using eps-prediction convention
        # and the update reads ``rep_pred_x0 = eps_pred_x0 - grad_phi``
        # here we use the fact that ``eps-prediction = -sigma_t * score``
        # to convert the update to the denoiser
        rep_pred_x0 = pred_x0 + sigma * grad_phi

        return rep_pred_x0

    # augment initial noise with n_particles
    n_samples_per_ctx = initial_noise.shape[0] // len(ctx)
    initial_noise = torch.randn((len(ctx) * n_particles, *x_shape), device=device)

    # n_particles must be greater than the number of sample per ctx
    if n_particles < n_samples_per_ctx:
        raise ValueError(
            "``n_particles`` must be ``>=`` than the number of samplers per ctx.\n"
            f"Got {n_particles=} and {n_samples_per_ctx=}"
        )

    samples = sampler(
        initial_noise=initial_noise,
        denoiser_fn=repulsive_denoiser,
        ctx=ctx,
        mk_sigmas_fn=mk_sigmas_fn,
        n_steps=n_steps,
        **sampler_kwargs,
    )

    # keep ``n_samples_per_ctx``
    samples = samples.view(len(ctx), n_particles, *x_shape)
    samples = samples[:, :n_samples_per_ctx]
    samples = samples.reshape(-1, *x_shape)

    return samples
