from typing import Dict, List, Optional, Any
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from dataclasses import dataclass

from transformers import PreTrainedTokenizerBase
from tqdm import tqdm
import json


# ====== 前缀 state 提取器：复用 BuddyModel 的第一层输出的 key_states ======
@torch.no_grad()
def extract_policy_state(buddy_lm, input_ids, attention_mask):
    """
    返回 state 张量 [B, T', H]，等价于你在 BuddyModel.forward 里拿到并 reshape 的 key_states。
    仅做最小复制，不改模型结构。
    """
    model = buddy_lm.model.model  # BuddyModel (子模块)
    # model.train()  # 冻结主干

    inputs_embeds = model.embed_tokens(input_ids)
    device = inputs_embeds.device

    past_key_values = None
    output_attentions = False
    use_cache = False

    past_seen_tokens = 0
    cache_position = torch.arange(past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=device)
    position_ids = cache_position.unsqueeze(0)

    causal_mask = model._update_causal_mask(
        attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
    )
    hidden_states = inputs_embeds
    position_embeddings = model.rotary_emb(hidden_states, position_ids)

    # 只过第一层
    layer0 = model.layers[0]
    layer_outputs = model.layer_forward(
        decoder_layer=layer0,
        hidden_states=hidden_states,
        attention_mask=causal_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
    )
    (key_states, value_states) = layer_outputs[-1]  # (B, heads, T, head_dim)
    B = key_states.size(0)
    H = hidden_states.size(-1)

    # 拼回 [B, T, H]
    state = key_states.transpose(1, 2).reshape(B, -1, H)
    return state  # [B, T', H]


def compute_ce(logits, labels, B, g):
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    ce_token = F.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        reduction='none'
    ).view(B * g, -1)  # [B*g, L-1]

    mask = (shift_labels != 0).view(B * g, -1)  # 0 是 pad_token_id
    ce_sample = (ce_token * mask).sum(-1) / (mask.sum(-1) + 1e-8)
    return ce_sample


@dataclass
class CausalLMDataCollator:
    tokenizer: PreTrainedTokenizerBase
    label_pad_token_id: int = 0
    pad_to_multiple_of: Optional[int] = 8  # Tensor Core 友好；不需要就设为 None
    device: str = "cuda"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        features = [
            {"input_ids": f["input_ids"], "attention_mask": f["attention_mask"], "labels": f["labels"]}
            for f in features
        ]
        labels = [f.pop("labels") for f in features]  # 先取出 labels
        batch = self.tokenizer.pad(
            features,
            padding=True,
            return_tensors="pt",
            pad_to_multiple_of=self.pad_to_multiple_of,
        )
        max_len = batch["input_ids"].shape[1]
        # pad labels 到同长，用 -100
        padded_labels = []
        for l in labels:
            if isinstance(l, torch.Tensor):
                l = l.tolist()
            if len(l) < max_len:
                l = l + [self.label_pad_token_id] * (max_len - len(l))
            else:
                l = l[:max_len]
            padded_labels.append(l)
        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)

        return batch.to(self.device)


# =========================
# 4) GRPO 配置
# =========================
@dataclass
class GRPOConfig:
    group_size: int = 4  # 每样本组内采样数 g
    entropy_coef: float = 0.01  # 熵正则系数
    cost_coef: float = 0.3  # 计算成本系数（越大越省算）
    perf_scale: float = 1.0  # 性能项系数，r = -perf_scale*CE - cost_coef*cost
    lr: float = 1e-4
    max_grad_norm: float = 1.0
    device: str = "cuda"
    log_every: int = 20
    # === 新增 PPO/GRPO 特定参数 ===
    grpo_epochs: int = 4
    clip_eps: float = 0.2

    # policy_temperature: float = 2.0
    policy_temperature: float = 1.0
    epsilon_mixture: float = 0.05


# =========================
# 5) GRPO 训练器（只训练 predictor）
# =========================
class GRPOBudgetTrainer:
    def __init__(self, buddy_lm, tokenizer, cfg: GRPOConfig):
        self.model = buddy_lm.to(cfg.device)
        self.budget_predictor = buddy_lm.model.model.budget_predictor
        self.tok = tokenizer
        self.cfg = cfg

        self.training_details = []

        # 冻结主干，仅训练 budget_predictor
        for n, p in self.model.named_parameters():
            if "predictor" in n:
                p.requires_grad = True
            else:
                p.requires_grad = False

        self.opt = torch.optim.AdamW(self.budget_predictor.parameters(), lr=cfg.lr)

    def step_batch(self, batch: Dict[str, torch.Tensor], step_idx: int):
        g = self.cfg.group_size
        device = self.cfg.device

        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        B = input_ids.size(0)
        L = self.model.model.config.num_hidden_layers

        # =======================================================================
        # A. Rollout / 数据收集阶段 (不计算梯度)
        # =======================================================================
        with torch.no_grad():
            state = extract_policy_state(self.model, input_ids, attention_mask)
            k_idx_old, logp_old, _ = self.budget_predictor.sample_k(
                state, sample_shape=torch.Size([g]),
                temperature=self.cfg.policy_temperature,
                eps=self.cfg.epsilon_mixture
            )

            k = k_idx_old + 1
            budgets = (k + 2).float() / float(L)

            flat_B = B * g

            def repeat_on_group(x):
                return x.unsqueeze(1).expand(B, g, *x.shape[1:]).reshape(flat_B, *x.shape[1:])

            flat_input_ids = repeat_on_group(input_ids)
            flat_attention = repeat_on_group(attention_mask)
            flat_labels = repeat_on_group(labels)
            flat_budgets_list = budgets.reshape(-1).tolist()

            out = self.model(
                input_ids=flat_input_ids,
                attention_mask=flat_attention,
                # labels=flat_labels,
                budgets=flat_budgets_list,
                use_cache=False,
                return_dict=True,
            )

            # ce = out.loss

            ce = compute_ce(out.logits, flat_labels, B, g)
            flat_budgets_t = budgets.reshape(-1).to(device)

            print("ce=", ce.shape, ce, "\n budgets=", flat_budgets_t.shape, flat_budgets_list)

            reward = -self.cfg.perf_scale * ce - self.cfg.cost_coef * flat_budgets_t
            reward = reward.view(B, g)

            # # ===== 全局标准化 =====
            # flat_r = reward.reshape(-1)
            # r_mean, r_std = flat_r.mean(), flat_r.std()
            # reward = (reward - r_mean) / (r_std + 1e-8)
            # # =====================

            group_mean = reward.mean(dim=1, keepdim=True)
            group_std = reward.std(dim=1, unbiased=False, keepdim=True)
            group_std = torch.clamp(group_std, min=1e-3)
            advantages = (reward - group_mean) / (group_std + 1e-6)

        # =======================================================================
        # B. Optimization / 优化阶段 (进行多次梯度更新)
        # =======================================================================

        flat_state = state.unsqueeze(1).expand(B, g, *state.shape[1:]).reshape(flat_B, *state.shape[1:])
        flat_k_idx_old = k_idx_old.reshape(flat_B)
        flat_logp_old = logp_old.reshape(flat_B)
        flat_advantages = advantages.reshape(flat_B)

        # 记录最后一个 epoch 的 loss 用于返回
        final_loss, final_policy_loss, final_entropy_new = 0.0, 0.0, torch.tensor(0.0)

        for _ in range(self.cfg.grpo_epochs):
            logp_new, entropy_new = self.budget_predictor.evaluate_actions(
                flat_state, flat_k_idx_old,
                temperature=self.cfg.policy_temperature,
                eps=self.cfg.epsilon_mixture
            )
            ratio = torch.exp(logp_new - flat_logp_old.detach())

            surr1 = ratio * flat_advantages
            surr2 = torch.clamp(ratio, 1.0 - self.cfg.clip_eps, 1.0 + self.cfg.clip_eps) * flat_advantages

            policy_loss = -torch.min(surr1, surr2).mean()
            entropy_loss = -self.cfg.entropy_coef * entropy_new.mean()
            loss = policy_loss + entropy_loss

            self.opt.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.budget_predictor.parameters(), self.cfg.max_grad_norm)
            self.opt.step()

            final_loss = loss.item()
            final_policy_loss = policy_loss.item()
            final_entropy_new = entropy_new

        metrics = {
            "loss": final_loss,
            "policy": final_policy_loss,
            "entropy": float(final_entropy_new.mean().item()),
            "CE": float(ce.mean().item()),
            "budget": float(budgets.float().mean().item()),
            "reward": float(reward.mean().item()),
        }
        return metrics

    def train(self, dataloader: DataLoader, num_epochs: int):

        batches_per_epoch = len(dataloader)
        steps_per_epoch = batches_per_epoch
        total_steps = steps_per_epoch * num_epochs

        pbar = tqdm(total=total_steps, desc="Training", dynamic_ncols=True)
        global_step = 0

        # 用于平滑显示
        ema = None

        def update_ema(ema, cur, m=0.95):
            if ema is None:
                return cur.copy()
            return {k: m * ema[k] + (1 - m) * cur[k] for k in ema}

        self.training_details = []

        for epoch in range(num_epochs):
            for batch in dataloader:
                metrics = self.step_batch(batch, global_step)
                self.training_details.append(metrics)

                ema = update_ema(ema, metrics)
                global_step += 1
                pbar.set_postfix({
                    "loss": f"{ema['loss']:.3f}",
                    "CE": f"{ema['CE']:.3f}",
                    "budget": f"{ema['budget']:.3f}",
                    "reward": f"{ema['reward']:.3f}",
                })
                pbar.update(1)

                if global_step > 10000:
                    break

        pbar.close()

    def save(self, output_dir, save_details=True):
        from safetensors.torch import save_file
        # save router weights
        os.makedirs(output_dir, exist_ok=True)
        save_file(self.budget_predictor.state_dict(), output_dir + "predictor_weights.safetensors")

        if save_details:
            with open(output_dir + "/training_details.json", "w") as file:
                file.write(json.dumps(self.training_details))
