import torch
from ..diffusion.mdlm import MDLM, sample_uncond, sample_cond_prefix, eval_lambada
from core.sampling.ancestral import sample_categorical
from transformers import AutoModelForMaskedLM
from loguru import logger
from models.loading_utils import get_backbone
import torch.nn.functional as F
import numpy as np
from math import sqrt, log
import re
import time
from pathlib import Path
import os
import sys
import copy

@torch.jit.script
def tv_dist(log_p: torch.Tensor, log_q: torch.Tensor):
    p = log_p.exp()
    q = log_q.exp()
    diff = (p - q).abs()
    loss = diff.sum(-1).mean()
    loss = loss / 2
    return loss


@torch.jit.script
def correct_jsd(log_p: torch.Tensor, log_q: torch.Tensor):
    log_avg = torch.logaddexp(log_p, log_q) - log(2)

    kl_p = log_p.exp() * (log_p - log_avg)
    kl_q = log_q.exp() * (log_q - log_avg)

    kl_p = kl_p.sum(-1).mean()
    kl_q = kl_q.sum(-1).mean()

    return (kl_p + kl_q) / 2


@torch.jit.script
def log_cosh_logspace(log_p: torch.Tensor, log_q: torch.Tensor):
    diff = log_p - log_q
    return diff.cosh().log().sum(-1).mean()


@torch.jit.script
def log_cosh_probs(log_p: torch.Tensor, log_q: torch.Tensor):
    diff = log_p.exp() - log_q.exp()
    return diff.cosh().log().sum(-1).mean()


class DistillMDLMDoubleEveryK(MDLM):
    def __init__(self, config, tokenizer):
        MDLM.__init__(self, config, tokenizer)
        self.teacher = None
        self.num_distill_steps = self.config.parameterization.num_distill_steps
        self.tot_num_sampl_steps = self.config.parameterization.orig_num_sampling_steps
        self.min_num_sampl_steps = self.config.parameterization.min_num_sampling_steps
        self.distill_mode = self.config.parameterization.distill_mode
        self.start_from_hf = self.config.parameterization.start_from_hf
        self.reset_optimizer_on_growth = (
            config.parameterization.reset_optimizer_on_growth
        )
                
        self.use_ema_on_growth = config.parameterization.use_ema_on_growth

        self.sampling_eps_tensor = torch.tensor(self.sampling_eps)
        self.sampling_mode = self.config.parameterization.sampling_mode
        assert self.sampling_mode in ("ancestral", "analytic")

        self.grow_dt_every = config.parameterization.grow_dt_every

        self.dt = (1 - self.sampling_eps) / self.tot_num_sampl_steps
        self.loss_precision = self.config.parameterization.loss_precision

        mode = self.distill_mode
        self._loss_fn = None  # fn to compare preds & targets
        if mode == "mse":
            self._loss_fn = self._mse
        elif mode == "tvd":
            self._loss_fn = self._tvd
        elif mode == "kl-fwd":
            self._loss_fn = self._fwd_kl
        elif mode == "kl-bwd":
            self._loss_fn = self._bwd_kl
        elif mode == "js":
            self._loss_fn = self._js_div
        elif mode == "fwd-bwd":
            self._loss_fn = self._fwd_bwd_avg
        elif mode == "logcosh_logspace":
            self._loss_fn = log_cosh_logspace
        elif mode == "logcosh_probspace":
            self._loss_fn = log_cosh_probs
        else:
            raise ValueError(mode)

        logger.info(f"Distillation loss: {mode}")

        self.prepare_teacher_and_student()

    def prepare_teacher_and_student(self, verbose=True):
        """
        If start from hf checkpoint:
            - Load the hf arch in student + teacher
        Else:
            - Init teacher as a copy of student

        if start checkpoint is not kuleshov-group/mdlm-owt -> load from disk

        """
        if verbose:
            logger.info("Loading teacher checkpoint...")
        ckpt_path = self.config.parameterization.checkpoint_path

        if self.start_from_hf:
            assert self.config.data_preprocess.legacy_start_end_bos
            self.backbone = AutoModelForMaskedLM.from_pretrained(
                "kuleshov-group/mdlm-owt", trust_remote_code=True
            )

            # Hack so that teacher doesn't get registered as child
            self.teacher = [
                AutoModelForMaskedLM.from_pretrained(
                    "kuleshov-group/mdlm-owt", trust_remote_code=True
                ).eval()
            ]

            if self.config.compile:
                self.teacher[0] = torch.compile(self.teacher[0])

            # Update the forward method to use the legacy logits processing (less numerically stable)
            self.forward = self._forward_legacy
            self.forward_teacher = self._forward_teacher_legacy
        else:
            self.teacher = [get_backbone(self.config, self.vocab_size)]
            # trained in original MDLM codebase
            self.forward = self._forward_legacy
            self.forward_teacher = self._forward_teacher_legacy
            # Use regular forward/forward_teacher

        if ckpt_path != "kuleshov-group/mdlm-owt":
            logger.info(f"Loading checkpoint in teacher from `{ckpt_path}`.")
            ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]

            if not self.config.compile:
                ckpt = {k.replace("_orig_mod.", ""): v for k, v in ckpt.items()}

            # Note: if loading a distilled version of the original mdlm model (loaded from hf), the
            #   checkpoint has keys starting with backbone.backbone. However, we need to have only one
            #   `backbone.` prefix. In case we are using our models, there should be no `backbone.` prefix
            #   at all, but this is left for later
            ckpt = {k.replace("backbone.", ""): v for k, v in ckpt.items()}

            if self.start_from_hf:
                ckpt = {"backbone." + k: v for k, v in ckpt.items()}

            self.teacher[0].load_state_dict(ckpt)
            self.backbone.load_state_dict(ckpt)

        self.teacher[0].eval()
        self.teacher[0].requires_grad_(False)
        # Reset EMA to use weights from checkpoint
        self.init_ema()

        if verbose:
            logger.info("Teacher checkpoint loaded.")

    def _forward_legacy(self, xt, cond):
        if not self.time_conditioning:
            cond = torch.zeros_like(cond)

        with torch.amp.autocast("cuda", dtype=torch.float32):
            logits = self.backbone(xt, cond)
        logits = self._subs_parameterization(logits, xt)
        return logits

    def _forward_teacher_legacy(self, xt, cond):
        if not self.time_conditioning:
            cond = torch.zeros_like(cond)

        with torch.amp.autocast("cuda", dtype=torch.float32):
            logits = self.teacher[0](xt, cond)

        logits = self._subs_parameterization(logits, xt)
        return logits

    def _subs_parameterization(self, logits, xt):
        # log prob at the mask index = - infinity
        logits[:, :, self.mask_index] += self.neg_infinity

        # Normalize the logits such that x.exp() is
        # a probability distribution over vocab_size.
        logits = logits - torch.logsumexp(logits, dim=-1, keepdim=True)

        # Apply updates directly in the logits matrix.
        # For the logits of the unmasked tokens, set all values
        # to -infinity except for the indices corresponding to
        # the unmasked tokens.
        unmasked_indices = xt != self.mask_index
        logits[unmasked_indices] = self.neg_infinity
        logits[unmasked_indices, xt[unmasked_indices]] = 0
        return logits

    def to(self, device):
        MDLM.to(self, device=device)
        self.teacher[0].to(device=device)

    @torch.no_grad
    def _teacher_logprobs_on_mask(self, xt, t_start):
        """
        Collect teacher predictions for ALL mask tokens
        """
        dt = self.dt

        space = torch.linspace(
            1, 0, self.num_distill_steps, device=t_start.device
        ).double()[:, None]
        t_start = t_start[None, :].double()
        t_end = t_start - dt * self.num_distill_steps
        # Evenly-spaced interpolation between t_start and t_end
        ts = t_start * space + (1 - t_start) * t_end
        # Ensure we don't feed the model values smaller than sampling_eps
        ts = torch.maximum(ts, self.sampling_eps_tensor)

        teacher_predictions = torch.zeros(
            (*xt.shape, self.vocab_size), device=xt.device
        )
        unmasked_tokens = torch.zeros(xt.shape, device=xt.device)
        curr_x = xt

        for idx in range(len(ts)):
            t = ts[idx].float()
            # TODO: add analytic sampler
            if self.sampling_mode == "ancestral":
                log_p_x0, q_xs = self._compute_ddpm_update(
                    curr_x, t, dt, forward=self.forward_teacher
                )
                update = sample_categorical(q_xs)
                new_batch = self._ddpm_sample_update(curr_x, update)

            elif self.sampling_mode == "analytic":
                log_p_x0, new_batch = self._analytic_update(
                    curr_x,
                    t,
                    dt,
                    forward=self.forward_teacher,
                )
            else:
                raise ValueError(self.sampling_mode)

            updated = curr_x != new_batch
            # Extract predictions for denoised tokens
            teacher_predictions[updated] = log_p_x0[updated]
            unmasked_tokens += updated
            curr_x = new_batch

        # Put predictions from model on last step for remaining MASK tokens
        last_preds_update_mask = (curr_x == self.mask_index) * torch.logical_not(
            unmasked_tokens
        )
        last_preds_update_mask = last_preds_update_mask[..., None].to(bool)
        teacher_predictions = torch.where(
            last_preds_update_mask, log_p_x0, teacher_predictions
        )
        return teacher_predictions

    def loss(self, x, t=None, attention_mask=None):
        if attention_mask is not None:
            assert (
                (attention_mask.to(int) == 1).all().item()
            ), "attention mask not supported"

        x0 = x
        if t is None:
            t = self._sample_t(x0.shape[0], x0.device)

        sigma, move_chance, dsigma = self._t_to_sigma(t)
        xt = self.q_xt(x0, move_chance)
        sigma = sigma.squeeze(-1)  # Original shape [bs, 1]
        # Loss on all masked tokens
        teacher_preds = self._teacher_logprobs_on_mask(xt, t)
        student_preds = self.forward(xt, sigma)
        is_mask = xt == self.mask_index

        target = teacher_preds[is_mask]
        preds = student_preds[is_mask]

        if self.loss_precision == "64":
            target = target.to(torch.float64)
            preds = preds.to(torch.float64)
        elif self.loss_precision == "32":
            target = target.to(torch.float32)
            preds = preds.to(torch.float32)

        loss = self._loss_fn(preds, target)
        return loss

    def _mse(self, preds, target):
        return F.mse_loss(preds, target)

    def _tvd(self, preds, target):
        return (preds - target).abs().sum(-1).mean()

    def _fwd_kl(self, preds, target):
        return F.kl_div(preds, target, log_target=True, reduction="batchmean")

    def _bwd_kl(self, preds, target):
        return F.kl_div(target, preds, log_target=True, reduction="batchmean")

    def _js_div(self, preds, target):
        return correct_jsd(preds, target)

    def _fwd_bwd_avg(self, preds, target):
        fwd_loss = F.kl_div(preds, target, log_target=True, reduction="batchmean")
        bwd_loss = F.kl_div(target, preds, log_target=True, reduction="batchmean")
        return (fwd_loss + bwd_loss) / 2

    def training_step(self, batch):
        if self.ema is not None:
            assert (
                not self._using_ema_weights
            ), "SHOULD NOT USE EMA WEIGHTS DURING TRAINING!!!"
        x = batch["input_ids"]
        attention_mask = batch.get("attention_mask", None)

        step = self.trainer.global_step
        if step > 0 and step % self.grow_dt_every == 0:
            curr_round = step // self.grow_dt_every
            self.dt = (
                (1 - self.sampling_eps) / self.tot_num_sampl_steps * (self.num_distill_steps**curr_round)
            )
            effective_num_steps = round(1 / self.dt)
            if effective_num_steps < self.min_num_sampl_steps:
                logger.info(
                    f"Reached below the minimal effective number of sampling steps, stopping..."
                )
                sys.exit()
            else:
                logger.info(
                    f"Step {step}: Doubling `dt`! New effective number of steps: {effective_num_steps}."
                )
            self._student_to_teacher()
            if self.reset_optimizer_on_growth:
                logger.info("Resetting optimizers...")
                self.trainer.strategy.setup_optimizers(self.trainer)

        loss = self.loss(x, attention_mask=attention_mask)
        self.log(
            name="train/loss",
            value=loss,
            on_step=True,
            on_epoch=False,
            sync_dist=True,
        )
        return loss

    def _student_to_teacher(
        self,
    ):
        start = time.perf_counter()
        if self.use_ema_on_growth:
            # Use EMA as teacher and student, and reset EMA for next round
            self.store_ema()
            student_ckpt = copy.deepcopy(self.backbone.state_dict())
            self.restore_ema()
            self.backbone.load_state_dict(student_ckpt)
            self.init_ema()
        else:
            student_ckpt = self.backbone.state_dict()

        self.teacher[0].load_state_dict(student_ckpt)
        end = time.perf_counter()

        logger.info(f"Swapped student into teacher in {end - start:.2f} seconds.")
        save_path = (
            Path(os.getcwd())
            / "student_checkpoints"
            / f"{self.trainer.global_step}.ckpt"
        )
        self.trainer.save_checkpoint(save_path)
