"""Elucidating the Design Space of Diffusion-Based Generative Models"""
import torch
import wandb
from tqdm.auto import trange
import pickle
import numpy as np

from ddlm.modeling.diffusion import (
    DiffusionTransformer, DiffusionOutput,
)
from ddlm.sampler.early_exit import LogStrategy, Strategy, NoStrategy


def append_zero(x):
    return torch.cat([x, x.new_zeros([1])])


def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(
            f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
        )
    return x[(...,) + (None,) * dims_to_append]


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_sigmas_karras(
    n, sigma_min, sigma_max, rho=8.0, timedelta=0.0, device="cpu", tw_model=None
):
    """Constructs the noise schedule of Karras et al. (2022)."""
    ramp = torch.linspace(0, 1, n, device=device) - timedelta
    ramp = torch.masked_fill(ramp, ramp < 0, 0)
    if tw_model is not None:
        ramp = tw_model(ramp)[0]
        ramp = 1 - torch.flip(ramp, (0,))  # as ramp is reversed
        rho = 1.0

    min_inv_rho = sigma_min ** (1 / rho)
    max_inv_rho = sigma_max ** (1 / rho)
    sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
    return append_zero(sigmas).to(device)


@torch.no_grad()
def sample_euler(
    model: DiffusionTransformer,
    sigmas,
    input_ids,
    conditioning_mask,
    continuation_number,
    batch_index: int,
    disable=None,
    s_churn=0.0,
    s_tmin=0.0,
    s_tmax=float("inf"),
    s_noise=1.0,
    self_conditioning: bool = True,
    initial_noise_scale: float = 1.0,
    simplified_inputs: bool = False,
    renormalization: bool = False,
    interpolate: bool = False,
    strategy: Strategy = NoStrategy,
):
    """Implements Algorithm 2 (Euler steps) from Karras et al. (2022)."""
    
    s_in = torch.ones(input_ids.size(0), device=input_ids.device)
    embeddings = model.get_embeddings(input_ids)
    c_x = embeddings * conditioning_mask.unsqueeze(-1)
    if simplified_inputs:
        n_x = (
            c_x
            + torch.randn_like(embeddings)
            * (~conditioning_mask).unsqueeze(-1)
            * initial_noise_scale
        )
    else:
        n_x = (
            torch.randn_like(embeddings)
            * (~conditioning_mask).unsqueeze(-1)
            * initial_noise_scale
        )
    state = strategy.get_initial_state(n_x)
    if interpolate:
        start = n_x[-2].clone()
        end = n_x[1].clone()
        batch_size = n_x.size(0)
        for i in range(batch_size):
            lambda_ = i / batch_size
            n_x[i] = lambda_ * start + (1 - lambda_) * end

    denoised = None
    observed_steps = 0
    final_logits = None

    log = {
        "patience": [],
        "entropy": [],
        "encoded": [],
        "denoised": [],
        "kl": [],
    }

    prev_denoised = None

    for i in trange(len(sigmas) - 1, disable=disable):
        gamma = (
            min(s_churn / (len(sigmas) - 1), 2**0.5 - 1)
            if s_tmin <= sigmas[i] <= s_tmax
            else 0.0
        )
        eps = torch.randn_like(n_x) * s_noise * (~conditioning_mask).unsqueeze(-1)
        sigma_hat = sigmas[i] * (gamma + 1)
        if gamma > 0:
            additional_noise_var = sigma_hat**2 - sigmas[i] ** 2
            n_x = n_x + eps * additional_noise_var**0.5
            scale = 1 / (1 + additional_noise_var) ** 0.5
            n_x = scale * n_x
        observed_steps += (~state["exit_mask"]).sum().item()
        if simplified_inputs:
            outputs = model(
                timestamps=(sigma_hat * s_in)[~state["exit_mask"]],
                conditioning_mask=conditioning_mask,
                noisy_hidden_states=n_x[~state["exit_mask"]],
                self_conditioning_hidden_states=denoised[~state["exit_mask"]] if self_conditioning else None,
                output_denoised=True,
            )
        else:
            outputs = model(
                timestamps=(sigma_hat * s_in)[~state["exit_mask"]],
                conditioning_mask=conditioning_mask[~state["exit_mask"]],
                conditioning_hidden_states=c_x[~state["exit_mask"]],
                noisy_hidden_states=n_x[~state["exit_mask"]],
                self_conditioning_hidden_states=denoised if self_conditioning else None,
                output_denoised=True,
            )
        if i == 0:
            final_logits = outputs.logits.clone()
        final_logits[~state["exit_mask"]] = outputs.logits
        prev_mask = state["exit_mask"].clone()
        state = strategy.new_step(outputs=outputs, state=state)
        if (~state["exit_mask"]).sum().item() == 0:
            break
        mask_for_outputs = ~state["exit_mask"][~prev_mask]

        denoised = outputs.denoised[mask_for_outputs]
        denoised = denoised * (~conditioning_mask)[~state["exit_mask"]].unsqueeze(-1)

        d = to_d(n_x[~state["exit_mask"]], sigma_hat, denoised)
        if simplified_inputs:
            d = d * (~conditioning_mask).unsqueeze(
                -1
            )  # to prevent changing of the conditional tokens in n_x
        dt = sigmas[i + 1] - sigma_hat
        # Euler method
        n_x[~state["exit_mask"]] = n_x[~state["exit_mask"]] + d * dt

        if renormalization:
            n_x = torch.nn.functional.normalize(n_x, dim=-1) * (
                ~conditioning_mask
            ).unsqueeze(-1)

        if prev_denoised is None:
            prev_denoised = outputs.denoised

        if type(strategy)==LogStrategy:
            curr = outputs.denoised.cpu().numpy()
            prev = prev_denoised.cpu().numpy()
            diff = np.sum([curr, -prev], axis=0)
            denoised_diff = np.linalg.norm(diff, axis=(1, 2))
            log["patience"].append(state["patience"].cpu().tolist())
            log["entropy"].append(state["entropy"].cpu().tolist())
            log["denoised"].append(denoised_diff)
            log["kl"].append(state["kl"])
            log["encoded"].append(outputs.logits.argmax(-1).cpu().tolist())
            prev_denoised = outputs.denoised

    metrics = dict()
    metrics["observed_steps"] = observed_steps

    if type(strategy) == LogStrategy:
        with open(f"log_{continuation_number}_{batch_index}.pickle", 'wb') as f:
            pickle.dump(log, f)
            wandb.save(f"log_{continuation_number}_{batch_index}.pickle")
        return log["encoded"], metrics
    
    encoded = final_logits.argmax(-1) 

    return [encoded], metrics