import torch

from .absorbing import DiffusionCore
from functools import partial
from core.sampling import AncestralSampler, AnalyticSampler
from tqdm import trange
from data.utils import params2key
from pathlib import Path
import os

from loguru import logger
import os

import torch

from pathlib import Path
from loguru import logger
from pathlib import Path
from transformers import AutoTokenizer

from data import dataloader
from lightning.fabric import Fabric
from utils import str_to_dtype
from tqdm import trange
import numpy as np
from einops import rearrange
import lightning as L
from itertools import product
import pandas as pd


@torch.jit.script
def post_process_model_output(logits: torch.Tensor, xt: torch.Tensor, mask_index: int, neg_infty: float) -> torch.Tensor:
    unmasked_indices = xt != mask_index
    logits[unmasked_indices] = neg_infty
    logits[unmasked_indices, xt[unmasked_indices]] = 0
    logits = torch.log_softmax(logits, dim=-1)

    # TODO: try putting the next three lines before the log_softmax
    # This should allow using log_softmax + be faster. But try with a trained
    # Model, to check its not messing up the performance
    #logits = torch.where(unmasked_indices[..., None], neg_infty, logits)
    #logits[unmasked_indices, xt[unmasked_indices]] = 0
    return logits

class MDLM(DiffusionCore, AncestralSampler, AnalyticSampler):
    def __init__(self, config, tokenizer):
        DiffusionCore.__init__(self, config, tokenizer)
        AncestralSampler.__init__(self, config)
        AnalyticSampler.__init__(self, config)
        self.validate_config()
        self.log_loss_buckets = self.config.parameterization.log_loss_buckets

        self.neg_infinity = -1000000.0
        self._post_process_outputs = partial(
            post_process_model_output, mask_index=self.mask_index, neg_infty=-1000000.0
        )

    def forward(self, xt, cond):
        if not self.time_conditioning:
            cond = torch.zeros_like(cond)
        with torch.amp.autocast("cuda",dtype=torch.float32):
            logits = self.backbone(xt, cond)
        logits = self._post_process_outputs(logits, xt)
        return logits

    def validate_config(self):
        assert self.T == 0, "Only continuous mode implemented"

    def diffusion_elbo(self, x0, t=None):
        if t is None:
            t = self._sample_t(x0.shape[0], x0.device)

        sigma, move_chance, dsigma = self._t_to_sigma(t)

        xt = self.q_xt(x0, move_chance)

        sigma = sigma.squeeze(-1)  # Original shape [bs, 1]
        logits = self.forward(xt, sigma)

        log_p_theta = torch.gather(input=logits, dim=-1, index=x0[:, :, None]).squeeze(
            -1
        )

        # TODO: not sure if it's good?
        if self.change_of_variables or self.importance_sampling:
            raise ValueError
            return log_p_theta * torch.log1p(-torch.exp(-self.noise.sigma_min))

        elbo = -log_p_theta * (dsigma / torch.expm1(sigma))[:, None]
        if self.trainer.training or self.trainer.validating:
            mode = "train" if self.trainer.training else "valid"
            # Log loss without scaling
            to_log = (-log_p_theta).mean(-1)
            key = "raw_loss"
            self._log_buckets(to_log, t, mode=mode, key=key)
            # Log elbo with scaling
            to_log = elbo.mean(-1)
            key = "scaled_loss"
            self._log_buckets(to_log, t, mode=mode, key=key)
        return elbo

    def loss(self, x, t=None, attention_mask=None):
        elbo = self.diffusion_elbo(x, t)
        if attention_mask is not None:
            elbo = elbo * attention_mask
            loss = elbo.sum() / attention_mask.sum()
        else:
            loss = elbo.mean()
        return loss

    @torch.no_grad()
    def sample(
        self,
        n_samples=8,
        num_steps=256,
        seq_len=1024,
        sampler="ancestral",
        cache_preds=False,
        verbose=False,
        add_bos=False,
        add_eos=False,
        project_fn=lambda x: x,
    ):
        assert not cache_preds, "Not implemented"
        if cache_preds:
            assert (
                not self.config.time_conditioning
            ), "Cannot use caching with time-conditional network"

        assert sampler in ("ancestral", "analytic")
        if seq_len is None:
            seq_len = self.config.model.length

        batch = self._sample_prior(n_samples, seq_len)
        batch = project_fn(batch)

        if add_bos:
            batch[:, 0] = self.tokenizer.bos_token_id

        if add_eos:
            batch[:, -1] = self.tokenizer.eos_token_id

        # +1 because we use the last value for denoising
        ts = torch.linspace(1.0, self.sampling_eps, steps=num_steps + 1)
        dt = (1 - self.sampling_eps) / num_steps

        for i in trange(num_steps, desc="sampling...", disable=not verbose):
            t = ts[i] * torch.ones(n_samples, 1, device=self.device)
            if sampler == "ancestral":
                _, new_batch = self._ddpm_update(batch, t, dt)
            elif sampler == "analytic":
                _, new_batch = self._analytic_update(batch, t, dt)
            new_batch = project_fn(new_batch)
            # If no caching or an update was made, remove cache
            # if not cache_preds or not torch.allclose(new_batch, batch):
            #    cache = None
            batch = new_batch

        # Denoise
        if (batch == self.mask_index).any():
            t = ts[-1] * torch.ones(n_samples, 1, device=self.device)
            _, batch = self._ddpm_update(
                batch, t, dt, denoise=True, mask_idx=self.mask_index
            )
            batch = project_fn(batch)

        return batch


def sample_uncond(module):
    logger.info("Starting unconditional sampling.")
    config = module.config
    sampling_cfg = config.parameterization.sampling
    uncond_cfg = sampling_cfg.uncond

    metadata = dict(
        num_samples=uncond_cfg.num_samples,
        from_ema=uncond_cfg.from_ema,
        num_steps=uncond_cfg.num_steps,
        seq_len=uncond_cfg.seq_len,
        sampler=uncond_cfg.sampler,
        add_bos=uncond_cfg.add_bos,
        add_eos=uncond_cfg.add_eos,
        checkpoint_name=config.checkpointing.resume_ckpt_path,
    )

    save_fname = params2key(**metadata) + ".npz"
    save_path = Path(os.getcwd()) / "samples" / "uncond" / save_fname
    assert not save_path.exists(), save_fname

    fabric = Fabric(
        accelerator=config.trainer.accelerator,
        precision=config.trainer.precision,
        num_nodes=config.trainer.num_nodes,
        devices=config.trainer.devices,
    )
    fabric.launch()
    L.seed_everything(100 + fabric.global_rank)
    # Note: the next line creates a bug when calling functions from the module
    # pl_module = fabric.setup(module)
    pl_module = module
    fabric.to_device(pl_module)

    bs = uncond_cfg.batch_size
    num_steps = uncond_cfg.num_steps
    seq_len = uncond_cfg.seq_len
    target_num_samples = uncond_cfg.num_samples
    tot_num_device = config.trainer.num_nodes * config.trainer.devices
    assert target_num_samples % (tot_num_device * bs) == 0
    n_sampling_rounds = target_num_samples // (tot_num_device * bs)

    if uncond_cfg.from_ema:
        pl_module.store_ema()

    all_samples = []
    for _ in trange(
        n_sampling_rounds,
        desc=f"Sampling with n_steps={num_steps}, seq_len={seq_len}",
        disable=fabric.global_rank > 0,
    ):
        with fabric.autocast():
            out = pl_module.sample(
                n_samples=bs,
                num_steps=num_steps,
                seq_len=seq_len,
                sampler=uncond_cfg.sampler,
                add_bos=uncond_cfg.add_bos,
                add_eos=uncond_cfg.add_eos,
                cache_preds=uncond_cfg.cache_preds,
            )
        out = fabric.all_gather(data=out)
        if fabric.global_rank == 0:
            if out.ndim == 3:  # ndim == 2 when running on one device
                out = rearrange(out, "dev bs l -> (dev bs) l")
            all_samples.append(out.cpu())
        del out

    # Join and save to disk
    if fabric.global_rank == 0:
        all_samples = torch.cat(all_samples, dim=0).numpy()
        all_samples = all_samples[:target_num_samples]

        save_path.parent.mkdir(exist_ok=True, parents=True)
        np.savez(save_path, samples=all_samples, metadata=metadata)
        logger.info(f"Saved {len(all_samples)} samples in {save_path}")

    # Restore orig model weights
    if uncond_cfg.from_ema:
        pl_module.restore_ema()


def sample_cond_prefix(module):
    logger.info("Starting conditional sampling (cond on prefix).")
    config = module.config
    sampling_cfg = config.parameterization.sampling
    cond_cfg = sampling_cfg.cond_prefix

    metadata = dict(
        checkpoint_name=config.checkpointing.resume_ckpt_path,
        num_samples=cond_cfg.num_samples,
        from_ema=cond_cfg.from_ema,
        dataset=cond_cfg.dataset,
        seq_len=cond_cfg.seq_len,
        prefix_len=cond_cfg.prefix_len,
        num_cont_per_prefix=cond_cfg.num_cont_per_prefix,
        min_seq_len=cond_cfg.min_seq_len,
        num_steps=cond_cfg.num_steps,
        sampler=cond_cfg.sampler,
        add_bos=cond_cfg.add_bos,
        add_eos=cond_cfg.add_eos,
    )

    save_fname = params2key(**metadata) + ".npz"
    save_path = Path(os.getcwd()) / "samples" / "cond" / save_fname
    assert not save_path.exists(), save_fname
    # Extract args from cfg
    bs = cond_cfg.batch_size
    prefix_len = cond_cfg.prefix_len
    num_steps = cond_cfg.num_steps
    seq_len = cond_cfg.seq_len
    target_num_samples = cond_cfg.num_samples
    tot_num_device = config.trainer.num_nodes * config.trainer.devices
    assert target_num_samples % (tot_num_device * bs) == 0
    # Load prefix dataset
    tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.name)

    fabric = Fabric(
        accelerator=config.trainer.accelerator,
        precision=config.trainer.precision,
        num_nodes=config.trainer.num_nodes,
        devices=config.trainer.devices,
    )
    fabric.launch()
    L.seed_everything(200 + fabric.global_rank)

    if fabric.global_rank > 0:
        fabric.barrier()  # Make sure that only the first device does the preprocessing

    dataset = dataloader.get_dataset(
        cond_cfg.dataset,
        tokenizer,
        mode="valid",
        cache_dir=config.data_preprocess.data_cache,
        num_proc=config.trainer.devices * config.loader.num_workers,
        min_seq_len=cond_cfg.min_seq_len,
        seq_len=seq_len,
        group_text=False,
        remove_text=True,
        add_bos=cond_cfg.add_bos,
        add_eos=cond_cfg.add_eos,
    )

    if fabric.global_rank == 0:
        fabric.barrier()  # Make sure the data was preprocessed on one device before starting

    assert len(dataset) >= target_num_samples
    dataset = dataset.select(range(cond_cfg.num_samples))

    pl_module = module
    fabric.to_device(pl_module)

    if cond_cfg.from_ema:
        pl_module.store_ema()

    all_samples = []
    start = fabric.global_rank * bs
    stop = target_num_samples
    end = fabric.world_size * bs
    for idx in trange(
        start,
        stop,
        end,
        desc=f"Sampling with n_steps={num_steps}, seq_len={seq_len}",
        disable=fabric.global_rank > 0,
    ):
        docs = dataset[idx : idx + bs]["input_ids"]
        prefixes = docs[:, :prefix_len]

        def project_fn(batch):
            batch[:, :prefix_len] = prefixes
            return batch

        # Generate potentially multiple continuations per prefix (typically 5)
        for _ in range(cond_cfg.num_cont_per_prefix):
            with fabric.autocast():
                out = pl_module.sample(
                    n_samples=bs,
                    num_steps=num_steps,
                    seq_len=seq_len,
                    sampler=cond_cfg.sampler,
                    add_bos=cond_cfg.add_bos,
                    add_eos=cond_cfg.add_eos,
                    cache_preds=cond_cfg.cache_preds,
                    project_fn=project_fn,
                )
            out = fabric.all_gather(data=out)
            if fabric.global_rank == 0:
                # unstack after all_gather
                if out.ndim == 3:
                    out = rearrange(out, "dev bs l -> (dev bs) l")
                all_samples.append(out.cpu())
            del out

    # Join and save to disk
    if fabric.global_rank == 0:
        all_samples = torch.cat(all_samples, dim=0).numpy()
        all_samples = all_samples[:target_num_samples * cond_cfg.num_cont_per_prefix]

        save_path.parent.mkdir(exist_ok=True, parents=True)
        references = dataset[:target_num_samples]["input_ids"].numpy()
        np.savez(
            save_path, samples=all_samples, references=references, metadata=metadata
        )
        logger.info(f"Saved samples in {save_path}")

    if cond_cfg.from_ema:
        pl_module.restore_ema()


def _eval_suffix_nll_generators(module: MDLM, config, prefix: torch.Tensor, suffix):
    N = len(suffix)
    device = module.device
    batch_size = config.eval.lambada_openai.batch_size
    num_samples = config.eval.lambada_openai.num_samples
    add_eos = config.eval.lambada_openai.add_eos
    assert num_samples % batch_size == 0

    all_t = module._sample_t(num_samples, device=module.device)
    full_sentence = torch.cat([prefix, suffix], dim=-1, ).repeat(batch_size, 1).to(module.device)

    for idx in range(0, num_samples, batch_size):
        curr_t = all_t[idx: idx + batch_size]
        sigma, move_chance, dsigma = module._t_to_sigma(curr_t)
        sigma = sigma.squeeze(-1)

        xt = module.q_xt(full_sentence, move_chance)
        xt[:, :len(prefix)] = full_sentence[:, :len(prefix)]
        if add_eos:
            xt[:, -1] = full_sentence[:, -1]

        y = full_sentence
        scale = (dsigma / torch.expm1(sigma))[:, None]

        yield xt.to(device), y.to(device), scale.to(device), sigma.to(device), curr_t.to(device)


@torch.no_grad
def eval_suffix_nll(config, module: MDLM, prefix, suffix, sigma):
    """
    1. Generate all ways to mask the suffix.
    2. Evaluate the loss over all possible maskings
    3. Average over all possible masking
    """

    all_losses = []
    for xt, y, scale, sigma, t in _eval_suffix_nll_generators(module, config, prefix, suffix):
        preds = module(xt, sigma).log_softmax(-1)

        loss = - torch.gather(preds, dim=-1, index=y[..., None])[..., 0]
        is_masked = xt == module.mask_index
        loss = torch.where(is_masked.to(bool), loss, 0.0) * scale

        loss = loss.sum(-1)
        loss = loss.mean()
        all_losses.append(float(loss))

    return float(np.mean(all_losses))


@torch.no_grad
def eval_lambada(module: MDLM):
    logger.info("Starting eval acc/ppl on openai lambada")
    config = module.config
    lambada_cfg = config.eval.lambada_openai

    if config.eval.lambada_openai.from_ema:
        module.store_ema()

    tokenizer = module.tokenizer

    dataset = dataloader.get_dataset(
        "EleutherAI/lambada_openai",
        tokenizer,
        mode="test",
        cache_dir=config.data_preprocess.data_cache,
        num_proc=config.trainer.devices * config.loader.num_workers,
        group_text=False,
        remove_text=False,
        add_bos=lambada_cfg.add_bos,
        add_eos=lambada_cfg.add_eos,
    )

    tot_num_device = config.trainer.num_nodes * config.trainer.devices
    assert tot_num_device == 1, "Code only works with one device"

    pl_module = module
    pl_module = pl_module.cuda()
    t = torch.tensor([pl_module.sampling_eps], device="cuda")
    sigma = pl_module._t_to_sigma(t)[0][0]

    all_losses = []
    all_last_correct = []
    add_eos = lambada_cfg.add_eos

    for idx in trange(
        len(dataset),
        desc="Evaluating lambada..."
    ):
        prefix = dataset[idx]["prefix_ids"]
        suffix = dataset[idx]["suffix_ids"]
        suffix_mask = suffix.clone()
        if add_eos:
            suffix_mask[:-1] = pl_module.mask_index
        else:
            suffix_mask[:] = pl_module.mask_index

        input_ids = torch.cat([prefix, suffix_mask]).cuda().reshape(1, -1)
        preds = pl_module(input_ids, sigma)

        assert pl_module.mask_index == preds.shape[-1] - 1
        greedy_tokens = preds[0, :, :-1].argmax(-1)
        suff_len = len(suffix)

        if add_eos:
            correct = greedy_tokens[-suff_len:-1].cpu() == suffix[:-1]
            correct = correct.all().item()

            loss = eval_suffix_nll(config, pl_module, prefix, suffix, sigma)

            all_losses.append(loss)
            all_last_correct.append(correct)

        else:
            raise NotImplementedError

    acc = np.mean(all_last_correct)
    avg_loss = np.mean(all_losses)

    from run_eval import CURR_DATETIME_STR
    csv_save_path = Path(os.getcwd()) / "csv" / CURR_DATETIME_STR / "lambada.csv"
    header = [
        "num_samples",
        "from_ema",
        "add_bos",
        "add_eos",
        "checkpoint_path",
        "acc",
        "ppl",
    ]

    row = [
        lambada_cfg.num_samples,
        lambada_cfg.from_ema,
        lambada_cfg.add_bos,
        lambada_cfg.add_eos,
        config.checkpointing.resume_ckpt_path,
        float(acc),
        float(np.exp(avg_loss)),
    ]

    df = pd.DataFrame([row], columns=header)
    csv_save_path.parent.mkdir(parents=True, exist_ok=True)
    df.to_csv(csv_save_path)
    logger.info(f"Lambada results: \n{df}\n{'=' * 50}")

    if config.eval.lambada_openai.from_ema:
        module.restore_ema()

