import os
import datasets
import torch
import pandas as pd
import re
import tqdm
import numpy as np
import torch.nn.functional as F
import wandb
import torch.nn as nn
from torch.nn import MSELoss
from typing import List, Optional
from torch.utils.data import DataLoader, Dataset
from transformers import Trainer, AutoModelForCausalLM
from transformers.trainer_utils import EvalLoopOutput, has_length
from transformers.trainer_pt_utils import IterableDatasetShard
from transformers.deepspeed import deepspeed_init
from transformers.trainer import logger
from sklearn.metrics import accuracy_score
from utils import check_answer_correctness
from constants import INDENTIFIER2NAME
from operator import itemgetter

def safe_str(obj):
    try: return str(obj)
    except UnicodeEncodeError:
        return obj.encode('ascii', 'ignore').decode('ascii')
    return ""

def extract_answer(generation, position):
    if position == 'right':
        word_list = generation.split('<')
        model_answer = word_list[0]
        match = re.search(r'type(\d+)_', word_list[-1])
        if match:
            # 提取到的数字
            type = match.group(1)
        else:
            print("type not found")
        
    else:
        word_list = generation.split('>')
        model_answer = word_list[-1]
        match = re.search(r'type(\d+)_', word_list[0])
        if match:
            # 提取到的数字
            type = match.group(1)
        else:
            print("type not found")

    return model_answer, type

def get_model_device(model):
    return next(model.parameters()).device


class MyTrainer(Trainer):

    def __init__(self, custom_args, ref_model=None, **kwargs):
        super(MyTrainer, self).__init__(**kwargs)
        self.custom_args = custom_args
        self.weight = custom_args.weight
        if ref_model is not None:
            print('use ref model')
            self.ref_model = ref_model
        self.count = 0
        self.lm_loss_sum = torch.tensor(0.0).to('cuda')
        self.dist_loss_sum = torch.tensor(0.0).to('cuda')
        self.total_loss_sum = torch.tensor(0.0).to('cuda')

    def compute_loss(self, model, inputs):
        """
        SFT + 置信度头：
        - 语言建模 loss：标准 CE（不改）
        - 置信度 loss：对答案片段(hidden states)做平均池化，置信度头输出 logits，
                        用 BCEWithLogitsLoss 拟合 targets∈[0,1]
        """
        input_ids      = inputs["input_ids"]
        attention_mask = inputs["attention_mask"]
        labels         = inputs["labels"]
        targets        = inputs.get("targets", None)  # [B] or [B,1]

        # 1) 常规前向 + CE
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_hidden_states=True,  # 必须为 True 才能取 hidden states
        )
        lm_loss = outputs.loss

        # 2) 从“答案片段”(labels!=-100)做平均池化，得到句级表示
        hs = outputs.hidden_states[-1]                         # [B, T, H]
        ans_mask = (labels != -100).float()                    # [B, T]
        tok_cnt  = ans_mask.sum(dim=1, keepdim=True).clamp_min(1.0)  # [B,1]
        pooled   = (hs * ans_mask.unsqueeze(-1)).sum(dim=1) / tok_cnt  # [B, H]

        # 3) 置信度 head 输出 logits → BCEWithLogitsLoss
        assert hasattr(model, "conf_head"), "model.conf_head not found. Did you attach it in train.py?"
        conf_out = model.conf_head(pooled).squeeze(-1)         # [B], 可能是 logits 或 prob

        # 兼容处理：如果 head 返回的似乎是概率(0~1)，转成 logits 再喂 BCE
        if conf_out.detach().min().item() >= 0.0 and conf_out.detach().max().item() <= 1.0:
            conf_out = conf_out.clamp(1e-6, 1-1e-6)
            conf_logits = torch.log(conf_out) - torch.log1p(-conf_out)
        else:
            conf_logits = conf_out  # 已经是 logits

        dist_loss = torch.tensor(0.0, device=lm_loss.device)
        if targets is not None:
            targets = targets.squeeze(-1) if targets.ndim > 1 else targets
            targets = targets.to(conf_logits.dtype).clamp(0.0, 1.0)
            bce = nn.BCEWithLogitsLoss()
            dist_loss = bce(conf_logits, targets)

            # （可选）轻量排序约束，提升 AUC（按需打开）
            # K = min(64, conf_logits.size(0))
            # if K >= 2:
            #     idx = torch.randint(0, conf_logits.size(0), (K, 2), device=conf_logits.device)
            #     i, j = idx[:, 0], idx[:, 1]
            #     pair = (targets[i] > targets[j]).float()
            #     margin = 0.02
            #     rank_loss = ((margin - (torch.sigmoid(conf_logits[i]) - torch.sigmoid(conf_logits[j]))) * pair).clamp_min(0).mean()
            #     dist_loss = dist_loss + 0.3 * rank_loss

        total_loss = lm_loss + self.weight * dist_loss

        # 日志
        _step = self.state.global_step + 1
        if self.state.global_step != 0 and _step % self.args.logging_steps == 0:
            try:
                wandb.log({
                    "lm_loss": float(lm_loss.detach().cpu()),
                    "dist_loss": float(dist_loss.detach().cpu()),
                    "total_loss": float(total_loss.detach().cpu()),
                    "conf_head/logit_mean": float(conf_logits.detach().mean().cpu()),
                    "conf_head/prob_mean": float(torch.sigmoid(conf_logits.detach()).mean().cpu()),
                }, step=_step)
            except Exception:
                pass

        if torch.isnan(total_loss):
            raise ValueError("Loss is NaN")
        return total_loss




    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation [`~torch.utils.data.DataLoader`].

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (`torch.utils.data.Dataset`, *optional*):
                If provided, will override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted
                by the `model.forward()` method are automatically removed. It must implement `__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset

        if isinstance(eval_dataset, datasets.Dataset):
            eval_dataset = self._remove_unused_columns(eval_dataset, description="evaluation")

        dataloader_params = {
            "batch_size": self.args.eval_batch_size,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
        }

        if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last

        return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))
    
    def evaluation_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> EvalLoopOutput:
        """
        Prediction/evaluation loop, shared by `Trainer.evaluate()` and `Trainer.predict()`.
        """
        args = self.args
        prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only

        # if eval is called w/o train, handle model prep here
        if self.is_deepspeed_enabled and self.deepspeed is None:
            _, _ = deepspeed_init(self, num_training_steps=0, inference=True)

        model = self._wrap_model(self.model, training=False, dataloader=dataloader)

        if len(self.accelerator._models) == 0 and model is self.model:
            model = (
                self.accelerator.prepare(model)
                if self.is_deepspeed_enabled
                else self.accelerator.prepare_model(model, evaluation_mode=True)
            )

            if self.is_fsdp_enabled:
                self.model = model

            if model is not self.model:
                self.model_wrapped = model

            if self.is_deepspeed_enabled:
                self.deepspeed = self.model_wrapped

        if not self.is_in_train:
            if args.fp16_full_eval:
                model = model.to(dtype=torch.float16, device=args.device)
            elif args.bf16_full_eval:
                model = model.to(dtype=torch.bfloat16, device=args.device)

        batch_size = self.args.eval_batch_size

        logger.info(f"***** Running {description} *****")
        if has_length(dataloader):
            logger.info(f"  Num examples = {self.num_examples(dataloader)}")
        else:
            logger.info("  Num examples: Unknown")
        logger.info(f"  Batch size = {batch_size}")

        model.eval()
        self.callback_handler.eval_dataloader = dataloader
        eval_dataset = getattr(dataloader, "dataset", None)
        if args.past_index >= 0:
            self._past = None

        all_preds, all_inputs, all_gt = [], [], []
        all_predtypes, all_correctnesses = [], []
        observed_num_examples = 0
        invalid_count = 0
        diff_count = 0

        if self.custom_args.exp_name == 'cls_baseline':
            for step, batch in tqdm.tqdm(enumerate(dataloader)):
                observed_batch_size = len(batch)
                if observed_batch_size is not None:
                    observed_num_examples += observed_batch_size
                    if batch_size is None:
                        batch_size = observed_batch_size

                encoding = self.tokenizer(batch["x"], padding=True, return_tensors='pt')
                encoding = self._prepare_inputs(encoding)
                max_length = max(self.tokenizer.model_max_length, encoding['input_ids'].size(1) + 512)
                generation_outputs = self.model.generate(**encoding, max_length=max_length, output_scores=True,
                                                        return_dict_in_generate=True)
                generated_ids = generation_outputs['sequences']

                try:
                    generated_texts = self.tokenizer.batch_decode(
                        generated_ids[:, encoding['input_ids'].size(1):],
                        skip_special_tokens=True
                    )
                    for i, generation in enumerate(generated_texts):
                        inp = batch['x']
                        gt_answer = batch['gt_answer'][0][i]
                        model_answer, typ = extract_answer(generation, 'right')
                        if model_answer.strip() == '':
                            invalid_count += 1
                            continue
                        all_inputs.append(inp)
                        all_preds.append(model_answer)
                        all_predtypes.append(int(typ))
                        all_gt.append(gt_answer)
                except Exception:
                    print("error in decoding")
                    print(generated_ids)
                    continue

            correctnesses = check_answer_correctness(all_gt, all_preds)
            all_results = pd.DataFrame({
                'all_preds': all_preds,
                'all_inputs': all_inputs,
                'all_predtypes': all_predtypes,
                'all_correctnesses': correctnesses
            })
            acc = all_results.groupby('all_predtypes')['all_correctnesses'].mean().reset_index()
            count = all_results.groupby('all_predtypes').size().reset_index(name='Size')
            acc.rename(columns={'all_correctnesses': 'acc'}, inplace=True)
            acc = acc.merge(count, on='all_predtypes')
            model_name = INDENTIFIER2NAME[self.custom_args.base_model_name_or_path]
            save_dir = f'./results/{model_name}/'
            os.makedirs(save_dir, exist_ok=True)
            acc.to_csv(os.path.join(save_dir, 'cls_baseline.csv'), index=False)
            all_results.to_csv(os.path.join(save_dir, 'cls_baseline_all.csv'), index=False)
            print(f'invalid predictions: {invalid_count}')
            print(f'different types: {diff_count}')
            return EvalLoopOutput(predictions=None, label_ids=None, metrics={}, num_samples=observed_num_examples)

        # —— 非 baseline：confidence head + 生成答案（基于生成片段的 TF 表示） ——
        for step, batch in tqdm.tqdm(enumerate(dataloader)):
            observed_batch_size = len(batch)
            if observed_batch_size is not None:
                observed_num_examples += observed_batch_size
                if batch_size is None:
                    batch_size = observed_batch_size

            # 编码（prompt）
            encoding = self.tokenizer(batch["x"], padding=True, return_tensors='pt')
            encoding = self._prepare_inputs(encoding)  # 放到正确设备

            # 先生成答案序列 y_pred
            max_length = max(self.tokenizer.model_max_length, encoding['input_ids'].size(1) + 512)
            gen_out = self.model.generate(**encoding, max_length=max_length, return_dict_in_generate=True)
            generated_ids = gen_out['sequences']
            # 只取生成部分（去掉 prompt）
            y_pred = generated_ids[:, encoding['input_ids'].size(1):]        # [B, Ty]
            generated_texts = self.tokenizer.batch_decode(
                y_pred, skip_special_tokens=True
            )

            # 对生成序列做一次 TF 前向，拿生成片段 hidden states
            # TF 输入：prompt + y_pred[:-1]，预测 y_pred 的下一个 token
            tf_input_ids = torch.cat([encoding['input_ids'], y_pred[:, :-1]], dim=1) if y_pred.size(1) > 0 else encoding['input_ids']
            tf_attn_mask = torch.cat([encoding['attention_mask'], torch.ones_like(y_pred[:, :-1])], dim=1) if y_pred.size(1) > 0 else encoding['attention_mask']

            with torch.no_grad():
                out2 = self.model(input_ids=tf_input_ids,
                                attention_mask=tf_attn_mask,
                                output_hidden_states=True)
                if y_pred.size(1) > 0:
                    # 仅取“生成片段”的 hidden states：[B, Ty, H]
                    hs_gen = out2.hidden_states[-1][:, -y_pred.size(1):, :]
                    # 生成片段平均池化 → [B, H]
                    pooled_gen = hs_gen.mean(dim=1)
                else:
                    # 极端情况：模型没生成新 token，就退回用 prompt 末 token（很少发生）
                    hs_full = out2.hidden_states[-1]
                    lengths = tf_attn_mask.sum(dim=1) - 1
                    B = hs_full.size(0)
                    pooled_gen = hs_full[torch.arange(B, device=hs_full.device), lengths]

                # 过置信度 head
                assert hasattr(self.model, "conf_head"), "confidence head missing in eval"
                conf_out = self.model.conf_head(pooled_gen).squeeze(-1)   # 可能是 logits 或 prob
                # 兼容：若 head 返回 logits，则做 sigmoid；若已是 [0,1] 概率就直接用
                if conf_out.detach().min().item() < 0.0 or conf_out.detach().max().item() > 1.0:
                    pred_conf = torch.sigmoid(conf_out)                   # [B]
                else:
                    pred_conf = conf_out.clamp(0.0, 1.0)                  # [B]

            # 收集结果
            try:
                for i, generation in enumerate(generated_texts):
                    if generation.strip() == '':
                        invalid_count += 1
                        continue
                    all_inputs.append(batch['x'][i])
                    all_preds.append(generation)
                    all_predtypes.append(float(pred_conf[i].detach().cpu().item()))
                    all_gt.append(batch['gt_answer'][i])
            except Exception as e:
                print(f'confidences: {pred_conf} \nmodel answers: {generated_texts}', e)
                invalid_count += 1


        # 打分与落盘
        correctnesses = check_answer_correctness(all_gt, all_preds)
        all_results = pd.DataFrame({
            'all_preds': all_preds,
            'all_inputs': all_inputs,
            'all_conf': all_predtypes,            # 温度得到的置信度
            'all_correctnesses': correctnesses
        })

        model_name = INDENTIFIER2NAME[self.custom_args.model_name_or_path]
        save_dir = f'./results/{model_name}/'
        os.makedirs(save_dir, exist_ok=True)
        save_path_all = os.path.join(save_dir, self.custom_args.exp_name + '.csv')
        all_results.to_csv(save_path_all, index=False)

        print(f'invalid predictions: {invalid_count}')
        print(f'The results are saved in {save_path_all}')

        # Trainer 需要返回 EvalLoopOutput；这里不需要 metrics，就返回空
        return EvalLoopOutput(predictions=None, label_ids=None, metrics={}, num_samples=observed_num_examples)


        
        