import torch
import torch.nn.functional as F
from transformers import Trainer
from transformers import DefaultDataCollator
import random
from tqdm import tqdm
import pickle
import torch.distributed as dist
from torch.distributions import Geometric

import os, json, sqlite3, threading
import numpy as np
from scipy.stats import rankdata


class dLLMTrainer(Trainer):
    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        """
        Absorbing state diffusion loss computation
        """
        labels, t, num_prompt_tokens = inputs.pop("labels"), inputs.pop("t"), inputs.pop("num_prompt_tokens")
        outputs = model(**inputs)
        logits = outputs.logits
        unscaled_loss = F.cross_entropy(
            logits.view(-1, logits.shape[-1]), labels.view(-1), reduction="none"
        ).view(logits.shape[0], -1)
        if (self.state.global_step + 1) % self.args.logging_steps == 0:
            self.log({"unscaled_loss": (unscaled_loss.sum() / (labels != -100).sum()).item()})
        loss = unscaled_loss / t
        loss = loss.sum() / (inputs["input_ids"].numel() - num_prompt_tokens)
        return loss if not return_outputs else (loss, outputs)


class dLLMSFTDataset(torch.utils.data.Dataset):
    """
    Similar to AR datasets, except in inference, we keep the timsteps fixed
    """

    def __init__(self, data, tokenizer, max_length, eval=False):
        super().__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.eval = eval
        if self.eval:
            self.t = torch.linspace(0, 1, len(self.data))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        out = self.data[idx]
        if self.eval:
            out["t"] = self.t[idx]
        return out


class dLLMDataCollator(DefaultDataCollator):
    """
    Adds the forward noising process to the batch.
    Modify forward_process to change the noise schedule
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.mask_token_id = kwargs["tokenizer"].mask_token_id
        self.tokenizer = kwargs["tokenizer"]
        if "max_length" in kwargs:
            self.max_length = kwargs["max_length"]
        if kwargs["tokenizer"].mask_token_id is None:
            assert (
                "mask_token_id" in kwargs
            ), "For dLLM models, pass a mask_token_id or set it equal to tokenizer.mask_token_id"
            self.mask_token_id = kwargs["mask_token_id"]

    def forward_process(self, batch, eps=1e-3):
        input_ids = batch["input_ids"]
        B, N = input_ids.shape
        if "t" not in batch:
            t = torch.rand((B,), device=input_ids.device)
        else:
            t = batch["t"]

        t = (1 - eps) * t + eps
        t = t[:, None].repeat(1, N)

        mask_indices = torch.rand((B, N), device=input_ids.device) < t
        noisy_batch = torch.where(mask_indices, self.mask_token_id, input_ids)
        return noisy_batch, t, mask_indices

    def __call__(self, batch):
        batch = super().__call__(batch)
        batch["labels"] = batch["input_ids"].clone()
        noisy_batch, batch["t"], mask_indices = self.forward_process(batch)
        batch["labels"][~mask_indices] = -100
        batch["num_prompt_tokens"] = 0
        if "prompt_lengths" in batch:
            prompt_lengths = batch.pop("prompt_lengths")
            prompt_length_indices = torch.arange(noisy_batch.shape[1]).unsqueeze(0)
            prompt_mask = prompt_length_indices < prompt_lengths
            noisy_batch[prompt_mask] = batch["input_ids"][prompt_mask].clone()
            batch["labels"][prompt_mask] = -100
            batch["num_prompt_tokens"] = prompt_mask.sum()
        batch["input_ids"] = noisy_batch.long()
        return batch


SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
Your reasoning here
</reasoning>
<answer>
...
</answer>
"""


def preprocess_dataset(data, tokenizer, max_length, test_split=0.01):
    preprocessed_data = []
    for i in tqdm(range(len(data)), desc="Preprocessing dataset"):
        question = SYSTEM_PROMPT + "\n\n" + data[i]["question"]
        trajectory = f"<reasoning>{data[i]['thinking_trajectories'][0]}</reasoning>\n<answer>{data[i]['attempt']}</answer>"
        prompt = [{"role": "user", "content": question}]
        response = [{"role": "assistant", "content": trajectory}]
        inputs = tokenizer.apply_chat_template(prompt + response, tokenize=False)
        prompt = tokenizer.apply_chat_template(prompt, tokenize=False) + "\n"
        tokenized_input = tokenizer(
            inputs, return_tensors="pt", truncation=True, max_length=max_length, padding="max_length"
        ).input_ids.squeeze(0)
        num_tokens = tokenized_input.shape[0]
        tokenized_prompt = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=max_length)
        preprocessed_data.append(
            {
                "input_ids": tokenized_input,
                "prompt_lengths": tokenized_prompt.attention_mask.sum(-1),
            }
        )

    random.shuffle(preprocessed_data)
    test_data = preprocessed_data[: int(len(preprocessed_data) * test_split)]
    train_data = preprocessed_data[int(len(preprocessed_data) * test_split) :]
    return train_data, test_data


class WeightedDLLMDataCollator(DefaultDataCollator):

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.tokenizer = kwargs["tokenizer"]
        self.mask_token_id = kwargs.get("mask_token_id", self.tokenizer.mask_token_id)
        assert self.mask_token_id is not None
        self.max_length = kwargs.get("max_length", None)

    def __call__(self, features):
        batch = super().__call__(features)
        batch["labels"] = batch["input_ids"].clone()

        if "t" not in batch:
            B = batch["input_ids"].shape[0]
            batch["t"] = torch.rand((B,), dtype=torch.float)

        batch["mask_token_id"] = torch.tensor(self.mask_token_id, dtype=torch.long)
        return batch


class WeightedDLLMTrainer(Trainer):
    def _build_y_mask_bool(self, inputs, prompt_lengths):
        device = inputs["input_ids"].device
        B, N = inputs["input_ids"].shape
     
        idxs = torch.arange(N, device=device).view(1, N).expand(B, N)
        y_mask_bool = idxs >= prompt_lengths
        return y_mask_bool

    @torch.no_grad()
    def _estimate_token_betas(self, model, clean_inputs, labels, y_mask_bool, mask_token_id):
        masked_inputs = {k: v.clone() if isinstance(v, torch.Tensor) else v
                         for k, v in clean_inputs.items()}
        masked_inputs["input_ids"] = masked_inputs["input_ids"].masked_fill(y_mask_bool, mask_token_id)

        outputs = model(**masked_inputs)
        logits = outputs.logits
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)

        entropy = -(probs * log_probs).sum(dim=-1) 
        betas = torch.where(y_mask_bool, torch.sqrt(entropy.clamp_min(0.0)), torch.zeros_like(entropy))
        return betas

    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        
        labels = inputs.pop("labels")
        base_t = inputs.pop("t")
        mask_token_id = int(inputs.pop("mask_token_id").item())
        prompt_lengths = inputs.pop("prompt_lengths", None)

        device = inputs["input_ids"].device
        B, N = inputs["input_ids"].shape
        eps = 1e-3
        
        y_mask_bool = self._build_y_mask_bool(inputs, prompt_lengths)

        with torch.no_grad():
            betas = self._estimate_token_betas(model, inputs, labels, y_mask_bool, mask_token_id)
            self._record_percentiles(inputs["input_ids"], betas)
            y_counts = y_mask_bool.sum(dim=1)
            beta_ref = (betas * y_mask_bool).sum(dim=1) / y_counts
            beta_ref = beta_ref.clamp_min(1e-8)

        base_t = base_t.to(device).clamp(min=eps, max=1 - eps)
        base_t_full = base_t.unsqueeze(1).expand(B, N)
        ratio = betas / beta_ref.unsqueeze(1)

        t_i_y = 1.0 - torch.pow(1.0 - base_t_full, ratio.clamp_min(0.0))
        t_i_y = ((1.0 - eps) * t_i_y + eps).clamp_min(0.005)
        t_i = torch.where(y_mask_bool, t_i_y, torch.zeros_like(t_i_y))
        
        mask_sample = torch.rand((B, N), device=device)
        mask_indices = (mask_sample < t_i) & y_mask_bool

        noisy_inputs = {k: v.clone() if isinstance(v, torch.Tensor) else v
                        for k, v in inputs.items()}
        noisy_inputs["input_ids"] = noisy_inputs["input_ids"].masked_fill(mask_indices, mask_token_id)

        effective_labels = labels.clone()
        effective_labels[~mask_indices] = -100

        outputs = model(**noisy_inputs)
        logits = outputs.logits 

        ce = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            effective_labels.view(-1),
            reduction="none"
        ).view(B, N)

        inv_t = torch.ones_like(t_i)
        inv_t[mask_indices] = 1.0 / t_i[mask_indices].clamp_min(eps)

        scaled = ce * inv_t
        valid_mask = (effective_labels != -100)
        denom = valid_mask.sum().clamp_min(1)

        loss = scaled[valid_mask].sum() / denom

        if (self.state.global_step + 1) % self.args.logging_steps == 0:
            unscaled_mean = (ce[valid_mask].sum() / denom).item()
            mean_t = t_i[mask_indices].mean().item() if mask_indices.any() else 0.0
            log_dict = {
                "unscaled_loss": unscaled_mean,
                "masked_tokens": int(denom.item()),
                "mean_t": mean_t,
            }
            self.log(log_dict)

        return (loss, outputs) if return_outputs else loss


class DreamDLLMTrainer(Trainer):
    def compute_loss(self, model, inputs, num_items_in_batch=None, return_outputs=False):
        labels = inputs.pop("labels")
        t = inputs.pop("t")
        num_prompt_tokens = inputs.pop("num_prompt_tokens")

        outputs = model(**inputs)
        logits = outputs.logits
        
        per_token_ce = F.cross_entropy(
            logits.view(-1, logits.shape[-1]),
            labels.view(-1),
            reduction="none"
        ).view_as(labels).float()
        
        input_ids = inputs["input_ids"]
        attn = inputs.get("attention_mask", torch.ones_like(input_ids))
        device = input_ids.device
        B, L = input_ids.shape

        mask_id = 126336
        is_masked = (input_ids == mask_id) & attn.bool()
        is_clean  = (~is_masked) & attn.bool()
        
        pos = torch.arange(L, device=device)
        dist = (pos[None, :] - pos[:, None]).abs().float() - 1.0
        dist = torch.clamp(dist, min=0.0)

        ps = getattr(self.args, "geo_ps", (0.3,))
        if not isinstance(ps, (list, tuple)):
            ps = (float(ps),)
        kernel = sum(((1.0 - float(p)) ** dist) * float(p) for p in ps) / float(len(ps))
        kernel = 0.5 * kernel
        
        clean_float = is_clean.float()
        weights_pos = torch.matmul(clean_float, kernel.t())
        weights_pos = weights_pos * attn.float()

        no_clean = (clean_float.sum(dim=-1, keepdim=True) == 0)
        weights_pos = torch.where(no_clean, torch.ones_like(weights_pos), weights_pos)

        weighted_ce = per_token_ce * weights_pos
        loss = weighted_ce / t
        loss = loss.sum() / (inputs["input_ids"].numel() - num_prompt_tokens)

        if (self.state.global_step + 1) % self.args.logging_steps == 0:
            avg_unscaled = (per_token_ce.sum() / (labels != -100).sum()).item()
            self.log({"unscaled_loss": avg_unscaled,
                      "mean_weight": weights_pos[is_masked].mean().item() if is_masked.any() else 0.0})

        return loss if not return_outputs else (loss, outputs)