import os
import sys
from typing import List

import fire
import torch
import random
import torch.nn as nn
import torch.nn.functional as F
import transformers
from datasets import load_dataset
from typing import List, Optional, Union
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother
from torch.nn import CrossEntropyLoss
from util import *


class ProbTrainerThreshold_Next_ToBE(Trainer):
    def __init__(
        self,
        *args,
        future_steps: int = 1,
        prob_threshold: float = 0.2,  # Threshold T
        g_beta: float = 0.1,  # Power β for g(x)
        stop_token_ids=None,  # Stop token IDs
        stop_token_texts=None,  # Stop token texts
        tokenizer=None,
        offset_weighting: str = "none",  # "none" | "linear" | "exp"
        alpha: float = 0.5,  # Decay strength
        combined_type: str = "time_left_multiplication",
        x=1,  # Future loss coefficient
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.future_steps = int(future_steps)
        self.prob_threshold = float(prob_threshold)
        self.g_beta = float(g_beta)
        self.offset_weighting = offset_weighting
        self.alpha = float(alpha)
        self.combined_type = combined_type
        self.x = x
        # Build stop token ID set
        stop_ids = set(stop_token_ids or [])
        if stop_token_texts and tokenizer is not None:
            for w in stop_token_texts:
                toks = tokenizer(w, add_special_tokens=False)["input_ids"]
                if isinstance(toks, list) and len(toks) == 1:
                    stop_ids.add(int(toks[0]))
        self._stop_token_ids_list = sorted(stop_ids)

    def get_base_model(self, model):
        """Get base model, handling distributed training scenarios"""
        if hasattr(model, "module"):
            # DistributedDataParallel case
            return model.module
        return model

    def _build_stop_mask(self, future_labels: torch.Tensor) -> torch.Tensor:
        """Return bool mask: True indicates stop token position"""
        if not self._stop_token_ids_list:
            return torch.zeros_like(future_labels, dtype=torch.bool)
        stop_ids = torch.tensor(
            self._stop_token_ids_list,
            dtype=torch.long,
            device=future_labels.device,
        )
        return (future_labels.unsqueeze(-1) == stop_ids).any(dim=-1)

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        # ===== LoRA forward (for training) =====
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask", None),
            output_hidden_states=True,
            return_dict=True,
        )
        logits = outputs.logits  # [B, T, V]
        labels = inputs["labels"]  # [B, T]

        # ===== CE Loss =====
        shift_logits = logits[..., :-1, :].contiguous()  # [B, T-1, V]
        shift_labels = labels[..., 1:].contiguous()  # [B, T-1]
        ce_loss = CrossEntropyLoss(ignore_index=-100)(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1).to(shift_logits.device),
        )
        valid_length = max(0, shift_labels.size(1) - self.future_steps)
        log_probs = torch.log_softmax(shift_logits, dim=-1).clamp(
            min=-50
        )  # [B, L, V]
        probs = torch.exp(log_probs)  # [B, L, V]
        expand_labels = expand_future_labels(
            shift_labels, future_steps=self.future_steps + 1
        )
        expand_logits = get_future_logits(
            probs, expand_labels, future_steps=self.future_steps + 1
        )
        renorm_logits2 = torch.sum(expand_logits[..., 1:], dim=-1).clamp(
            min=1e-12
        )  # [B, L]
        # ===== Base forward (for gate, no gradient backprop) =====
        with torch.no_grad():
            base_model = self.get_base_model(model)

            # Now safe to use disable_adapter
            with base_model.disable_adapter():
                # Get hidden states without adapter
                base_outputs = base_model(
                    input_ids=inputs["input_ids"],
                    attention_mask=inputs.get("attention_mask", None),
                    output_hidden_states=True,
                )
                base_logits = base_outputs.logits[
                    ..., :-1, :
                ].contiguous()  # [B, T, V]
                base_prob = get_future_logits(
                    torch.softmax(base_logits, dim=-1),
                    expand_labels,
                    future_steps=self.future_steps + 1,
                )  # [B, L, future_steps+1]

                # Set values in base_prob below prob_threshold to 0
                mask = (base_prob > self.prob_threshold).float()
                base_weight = (base_prob**self.g_beta) * mask

                # Avoid all-zero cases, add small value and then normalize
                # Filter and scale probabilities corresponding to expand_labels based on threshold
                # [b,l,future_steps+1]
            decay = Generate_Decay_Prob(
                model,
                expand_labels,
                10,
                self.alpha,
                combined_type="addition",
                type="acc",
            )  # [b,l,future_steps+1]
            # Don't use the first one for normalization to get a coefficient
            # decay: [B, L, future_steps+1]
            # base_weight: [B, L, future_steps+1]
            combined = torch.zeros_like(decay)
            combined[..., 0] = 0  # First one directly uses base_weight
            combined[..., 1:] = decay[..., 1:] * base_weight[..., 1:]

            # step 3: Normalize
            norm_factor = combined.sum(dim=-1, keepdim=True).clamp(min=1e-12)  # [B, L, 1]
            final_weight = combined / norm_factor

        base_logits = base_outputs.logits
        future_loss = 0.0

        for offset in range(1, self.future_steps + 1):
            future_labels = labels[
                :, 1 + offset : 1 + offset + valid_length
            ]  # [B, L]
            valid_positions = future_labels != -100

            safe_future_labels = future_labels.masked_fill(~valid_positions, 0)
            lp = (
                log_probs[:, :valid_length, :]
                .gather(-1, safe_future_labels.unsqueeze(-1))
                .squeeze(-1)
            )  # [B, L]
            # 1) base model probability p_base
            lp = lp - torch.log(
                renorm_logits2
            )  # [B, L], divide by remaining probability (log subtraction)
            step_loss = -(lp.clamp_min(-60.0)) * final_weight[..., offset]

            step_loss = (
                step_loss * valid_positions.float()
            ).sum() / valid_positions.float().sum().clamp(min=1)

            future_loss += step_loss
        total_loss = ce_loss + self.x * future_loss
        print(
            f"total_loss: {total_loss.item()}, ce_loss: {ce_loss.item()}, future_loss: {future_loss.item()}"
        )
        return (total_loss, outputs) if return_outputs else total_loss


class ProbTrainer_topk(Trainer):
    def __init__(
        self,
        *args,
        top_k=5,
        future_steps=1,
        alpha=0.2,
        combined_type="time_left_multiplication",
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.top_k = top_k  # Size of future prediction window to consider
        self.future_steps = future_steps  # Number of forward prediction steps
        self.alpha = alpha  # Future prediction loss weight
        self.combined_type = combined_type

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        # Regular forward computation
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
        )
        logits = outputs.logits  # [batch, seq_len, vocab]
        labels = inputs["labels"]  # [batch, seq_len]
        # Basic cross-entropy loss
        shift_logits = logits[..., :-1, :].contiguous()  # Predict next token
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = CrossEntropyLoss(ignore_index=-100)
        ce_loss = loss_fct(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1).to(shift_logits.device),
        )

        # Future prediction enhancement loss
        future_loss = 0.0
        valid_length = shift_labels.size(1) - self.future_steps
        probs = torch.softmax(shift_logits, dim=-1)
        expand_labels = expand_future_labels(
            shift_labels, future_steps=self.future_steps + 1
        )
        expand_logits = get_future_logits(
            probs, expand_labels, future_steps=self.future_steps + 1
        )
        renorm_logits2 = torch.sum(expand_logits[..., 1:], dim=-1)  # [B, L]
        with torch.no_grad():
            decay = Generate_Decay_Prob(
                model,
                expand_labels,
                10,
                self.alpha,
                combined_type=self.combined_type,
                type="acc",
            )

        for offset in range(1, self.future_steps + 1):
            # Get true labels after offset steps
            future_labels = labels[
                :, 1 + offset : 1 + offset + valid_length
            ]  # [B, L]

            # Calculate current prediction Top-K distribution
            topk_probs, topk_indices = torch.topk(
                probs[:, :valid_length, :], self.top_k, dim=-1
            )  # [B, L, K]

            future_mask = (
                future_labels.unsqueeze(-1) == topk_indices[..., 1:]
            ).float()  # [B, L, K]
            weighted_probs = (topk_probs[..., 1:] * future_mask).sum(
                dim=-1
            )  # [B, L]

            weighted_probs = weighted_probs / (renorm_logits2)  # [B, L]
            valid_positions = (future_labels != -100).float()
            valid_positions = valid_positions * (weighted_probs > 0).float()
            step_loss = -torch.log(weighted_probs + 1e-12) * valid_positions
            # beta = 1 - offset / (self.future_steps+1)
            future_loss += (
                decay[..., offset] * step_loss
            ).sum() / valid_positions.sum().clamp(min=1)
        total_loss = ce_loss + future_loss
        return (total_loss, outputs) if return_outputs else total_loss



class ProbTrainerThreshold1(Trainer):
    def __init__(
        self,
        *args,
        future_steps: int = 1,
        prob_threshold: float = 0.2,  # Threshold T: only backprop when p>=T
        g_beta: float = 0.1,  # β: gating power
        # Stop token configuration
        stop_token_ids=None,
        stop_token_texts=None,
        tokenizer=None,
        # Time/semantic decay wrapper
        offset_weighting: str = "none",  # "none" | "linear" | "exp"
        alpha: float = 0.5,  # Only passed to Generate_Decay_Prob or _offset_weight
        combined_type: str = "time_left_multiplication",
        # ⭐ Global coefficient (future loss weight)
        future_global_coeff: float = 1.0,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.future_steps = int(future_steps)
        self.prob_threshold = float(prob_threshold)
        self.g_beta = float(g_beta)
        self.offset_weighting = offset_weighting
        self.alpha = float(alpha)
        self.combined_type = combined_type
        self.future_global_coeff = float(future_global_coeff)

        # Build stop token ID set
        stop_ids = set(stop_token_ids or [])
        if stop_token_texts and tokenizer is not None:
            for w in stop_token_texts:
                toks = tokenizer(w, add_special_tokens=False)["input_ids"]
                if isinstance(toks, list) and len(toks) == 1:
                    stop_ids.add(int(toks[0]))
        self._stop_token_ids_list = sorted(stop_ids)

    def _offset_weight(self, offset: int) -> float:
        if self.offset_weighting == "none":
            return 1.0
        if self.offset_weighting == "linear":
            return max(
                1e-6, 1.0 - self.alpha * offset / (self.future_steps + 1)
            )
        if self.offset_weighting == "exp":
            return float(torch.exp(torch.tensor(-self.alpha * offset)))
        return 1.0

    @torch.no_grad()
    def _build_stop_mask(self, labels: torch.Tensor) -> torch.Tensor:
        """True indicates stop token position."""
        if not self._stop_token_ids_list:
            return torch.zeros_like(labels, dtype=torch.bool)
        stop_ids = torch.tensor(
            self._stop_token_ids_list, dtype=torch.long, device=labels.device
        )
        return (labels.unsqueeze(-1) == stop_ids).any(dim=-1)

    def compute_loss(
        self, model, inputs, return_outputs=False, num_items_in_batch=None
    ):
        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs.get("attention_mask", None),
        )
        logits = outputs.logits  # [B, T, V]
        labels = inputs["labels"]  # [B, T]

        # Standard NTP (next token)
        shift_logits = logits[..., :-1, :].contiguous()  # [B, T-1, V]
        shift_labels = labels[..., 1:].contiguous()  # [B, T-1]
        ce_loss = CrossEntropyLoss(ignore_index=-100)(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_labels.view(-1).to(shift_logits.device),
        )

        if self.future_steps <= 0:
            return (ce_loss, outputs) if return_outputs else ce_loss

        # Valid length that can cover future k steps
        valid_length = max(0, shift_labels.size(1) - self.future_steps)
        if valid_length == 0:
            return (ce_loss, outputs) if return_outputs else ce_loss

        # Only use the first valid_length segments
        log_probs = torch.log_softmax(shift_logits, dim=-1).clamp(
            min=-50
        )  # [B, L, V]
        probs = torch.exp(log_probs)  # [B, L, V]

        # Future window auxiliary structure (including offset=0 next-token)
        expand_labels = expand_future_labels(
            shift_labels, future_steps=self.future_steps + 1
        )
        # Get denominator within bag (relative probability normalization term, offset>=1)
        expand_logits = get_future_logits(
            probs, expand_labels, future_steps=self.future_steps + 1
        )
        renorm_logits2 = torch.sum(expand_logits[..., 1:], dim=-1).clamp(
            min=1e-12
        )  # [B, L]

        # Time/semantic decay (assuming returns [B,L,S+1], 0th dimension corresponds to next-token)
        with torch.no_grad():
            decay_full = Generate_Decay_Prob(
                model,
                expand_labels,
                10,
                self.alpha,
                combined_type=self.combined_type,
                type="acc",
            )  # [B, L, S+1]
        # Only need future window part (offset=1..S)
        d_stack_list = []  # [S,B,L]
        g_stack_list = []  # [S,B,L]
        step_base_list = []  # [S,B,L]

        T = self.prob_threshold
        S = self.future_steps
        eps = 1e-12

        for offset in range(1, S + 1):

            future_labels = labels[:, 1 + offset : 1 + offset + valid_length]
            valid_pos = future_labels != -100
            safe_future = future_labels.masked_fill(~valid_pos, 0)

            lp = (
                log_probs[:, :valid_length, :]
                .gather(-1, safe_future.unsqueeze(-1))
                .squeeze(-1)
            )  # [B,L]
            p = torch.exp(lp)
            lp_rel = lp - torch.log(renorm_logits2)  # [B,L]
            step_base = -(lp_rel.clamp_min(-60.0))  # [B,L]

            stop_mask = self._build_stop_mask(future_labels)
            gate = (p >= T) & (~stop_mask) & valid_pos  # [B,L]
            g_raw = (p**self.g_beta) * gate.float()  # [B,L]

            if isinstance(decay_full, torch.Tensor) and decay_full.dim() == 3:
                d_raw = decay_full[..., offset]  # [B,L]
            else:
                d_raw = torch.full_like(step_base, self._offset_weight(offset))
            step_base_list.append(step_base)
            g_stack_list.append(g_raw)
            d_stack_list.append(d_raw)
        # Stack to [S, B, L]
        step_base_stack = torch.stack(
            step_base_list, dim=0
        )  # Base loss for future offsets (relative probability -log)
        g_stack = torch.stack(
            g_stack_list, dim=0
        )  # Raw β (after gating, below threshold/stop tokens → 0)
        d_stack = torch.stack(
            d_stack_list, dim=0
        )  # Raw α (time/semantic weights)
        with torch.no_grad():
            # ===== Only normalize the product =====
            eps = 1e-12
            w_raw = g_stack * d_stack  # [S, B, L] multiply first
            w_sum = w_raw.sum(
                dim=0, keepdim=True
            )  # [1, B, L] sum over offset for each (B,L)
            w_stack = w_raw / (
                w_sum + eps
            )  # Final normalized weights, ∑_i w_stack[i,:,:] = 1
            # Only average positions where "at least one offset is valid" (w_sum>0)
            any_gate = w_sum.squeeze(0) > 0  # [B, L]

            denom = any_gate.float().sum().clamp(min=1.0)
            # Aggregate future loss
        future_loss = (w_stack * step_base_stack).sum() / denom

        # Total loss: NTP CE + global coefficient * future term (keeping your existing self.future_global_coeff)
        total_loss = ce_loss + self.future_global_coeff * future_loss
        print(
            f"total_loss: {total_loss.item()}, ce_loss: {ce_loss.item()}, future_loss: {future_loss.item()}"
        )
        return (total_loss, outputs) if return_outputs else total_loss
