from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
from GCR.src.llms.base_language_model import BaseLanguageModel
import os
import dotenv
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
from peft import PeftConfig
import numpy as np

dotenv.load_dotenv()

HF_TOKEN = os.getenv("HF_TOKEN")


class HfCausalModel(BaseLanguageModel):
    DTYPE = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}

    # 类级别的静态变量，所有实例共享
    _shared_tokenizer = None
    _shared_model = None
    _shared_assistant_model = None
    _shared_generation_cfg = None
    _model_loaded = False
    _current_model_path = None

    @staticmethod
    def add_args(parser):

        parser.add_argument(
            "--model_path", type=str, help="HUGGING FACE MODEL or model path",
            default=""
        )

        parser.add_argument("--maximun_token", type=int, help="max length", default=4096)
        parser.add_argument(
            "--max_new_tokens", type=int, help="max length", default=1024
        )
        parser.add_argument("--dtype", choices=["fp32", "fp16", "bf16"], default="bf16")
        parser.add_argument("--quant", choices=["none", "4bit", "8bit"], default="none")
        parser.add_argument(
            "--attn_implementation",
            default="flash_attention_2",
            choices=["eager", "sdpa", "flash_attention_2"],
            help="enable flash attention 2",
        )
        parser.add_argument(
            "--generation_mode",
            type=str,
            default="greedy",
            choices=["greedy", "beam", "sampling", "group-beam", "beam-early-stopping", "group-beam-early-stopping"],
        )
        parser.add_argument(
            "--k", type=int, default=3, help="number of paths to generate"
        )
        parser.add_argument("--chat_model", default='true', type=lambda x: (str(x).lower() == 'true'))
        parser.add_argument("--use_assistant_model", default='false', type=lambda x: (str(x).lower() == 'true'))
        parser.add_argument("--assistant_model_path", type=str, help="HUGGING FACE MODEL or model path", default=None)

    def __init__(self, args):
        self.args = args
        self.maximun_token = args.maximun_token

    def token_len(self, text):
        return len(self.tokenizer.tokenize(text))

    def prepare_for_inference(self):
        # 检查是否需要加载模型（第一次或模型路径改变）
        if (not self.__class__._model_loaded or
                self.__class__._current_model_path != self.args.model_path):

            print(f"Loading model: {self.args.model_path}")
            self._load_model()
            self.__class__._model_loaded = True
            self.__class__._current_model_path = self.args.model_path
        else:
            print(f"Reusing already loaded model: {self.args.model_path}")

        # 将共享的模型组件赋值给实例变量
        self.tokenizer = self.__class__._shared_tokenizer
        self.model = self.__class__._shared_model
        self.assistant_model = self.__class__._shared_assistant_model

        # 每次都需要重新配置generation_cfg，因为可能参数不同
        self._setup_generation_config()

    def _load_model(self):
        """实际加载模型的方法"""
        print("Loading tokenizer...")
        self.__class__._shared_tokenizer = AutoTokenizer.from_pretrained(
            self.args.model_path, token=HF_TOKEN, trust_remote_code=True, force_download=True, resume_download=False
        )

        print("Loading main model...")
        self.__class__._shared_model = AutoModelForCausalLM.from_pretrained(
            self.args.model_path,
            device_map="auto",
            token=HF_TOKEN,
            torch_dtype=self.DTYPE.get(self.args.dtype, None),
            load_in_8bit=self.args.quant == "8bit",
            load_in_4bit=self.args.quant == "4bit",
            trust_remote_code=True,
            attn_implementation=self.args.attn_implementation,
            force_download=True,
            resume_download=False
        )

        if self.args.use_assistant_model:
            print("Loading assistant model...")
            self.__class__._shared_assistant_model = AutoModelForCausalLM.from_pretrained(
                self.args.assistant_model_path,
                device_map="auto",
                token=HF_TOKEN,
                torch_dtype=self.DTYPE.get(self.args.dtype, None),
                load_in_8bit=self.args.quant == "8bit",
                load_in_4bit=self.args.quant == "4bit",
                trust_remote_code=True,
                attn_implementation=self.args.attn_implementation,
            )
        else:
            self.__class__._shared_assistant_model = None

    def _setup_generation_config(self):
        """配置生成参数"""
        self.maximun_token = self.tokenizer.model_max_length

        try:
            self.generation_cfg = GenerationConfig.from_pretrained(self.args.model_path)
        except:
            # Load from PeftModel
            sft_peft_config = PeftConfig.from_pretrained(self.args.model_path)
            self.generation_cfg = GenerationConfig.from_pretrained(sft_peft_config.base_model_name_or_path)

        self.generation_cfg.max_new_tokens = self.args.max_new_tokens
        self.generation_cfg.return_dict_in_generate = (True,)

        if self.args.generation_mode == "greedy":
            self.generation_cfg.do_sample = False
            self.generation_cfg.num_return_sequences = 1
        elif self.args.generation_mode == "sampling":
            self.generation_cfg.do_sample = True
            self.generation_cfg.num_return_sequences = self.args.k
        elif self.args.generation_mode == "beam":
            self.generation_cfg.do_sample = True
            self.generation_cfg.num_beams = self.args.k
            self.generation_cfg.num_return_sequences = self.args.k
        elif self.args.generation_mode == "beam-early-stopping":
            self.generation_cfg.do_sample = False
            self.generation_cfg.num_beams = self.args.k
            self.generation_cfg.num_return_sequences = self.args.k
            self.generation_cfg.early_stopping = True
        elif self.args.generation_mode == "group-beam":
            self.generation_cfg.do_sample = False
            self.generation_cfg.num_beams = self.args.k
            self.generation_cfg.num_return_sequences = self.args.k
            self.generation_cfg.num_beam_groups = self.args.k
            self.generation_cfg.diversity_penalty = 1.
        elif self.args.generation_mode == "group-beam-early-stopping":
            self.generation_cfg.do_sample = False
            self.generation_cfg.num_beams = self.args.k
            self.generation_cfg.num_return_sequences = self.args.k
            self.generation_cfg.num_beam_groups = self.args.k
            self.generation_cfg.early_stopping = True
            self.generation_cfg.diversity_penalty = 1.

    def prepare_model_prompt(self, query):
        if self.args.chat_model:
            chat_query = [
                {"role": "user", "content": query}
            ]
            return self.tokenizer.apply_chat_template(chat_query, tokenize=False, add_generation_prompt=True)
        else:
            return query

    # def get_loglikelihood(self,
    #                       prefix: str,
    #                       contents: list[str],
    #                       **kwargs):
    #
    #     contents_length = len(contents)
    #     assert contents_length <= self.maximun_token, (contents_length, self.maximun_token)
    #
    #     # critic=False
    #
    #     # prefix_tokens = self.tokenizer.encode(prefix,
    #     #                                                add_bos=True,
    #     #                                                add_eos=False).squeeze(0).to(self.model.device)
    #     #
    #     # contents_tokens = self.tokenizer.encode(contents,
    #     #                                                  add_bos=True,
    #     #                                                  add_eos=False).to(self.model.device)
    #
    #     contents_tokens = self.tokenizer(contents, return_tensors="pt").input_ids.to("cuda")
    #     prefix_tokens = self.tokenizer(prefix, return_tensors="pt").input_ids
    #
    #     # 确保每个提示tokens都以给定的前缀tokens开头
    #     for content_tokens in contents_tokens:
    #         assert torch.all(content_tokens[:len(prefix_tokens)] == prefix_tokens)
    #
    #     res = self.model.generate(
    #         contents_tokens,
    #         generation_config=self.generation_cfg,
    #         return_dict_in_generate=True,
    #         pad_token_id=self.tokenizer.eos_token_id,
    #         output_scores=True
    #     )
    #
    #     scores = res.scores
    #     if scores:
    #         contents_logits = torch.stack(scores, dim=1)
    #
    #     # 创建一个形状为 (批次大小) 的张量，用于累积对数概率，初始值为0
    #     acc_loglikelihood = torch.zeros(contents_length).to(self.model.device)
    #     token_counts = torch.zeros(contents_length).to(self.model.device)  # 新增：计数器
    #
    #     # 从前缀tokens的长度开始遍历到最大提示长度
    #     for i in range(len(prefix_tokens), contents_tokens.shape[1]):
    #         probs = torch.softmax(contents_logits[:, i - 1, :], dim=-1)  # 对第i-1位置的logits计算softmax以得到概率分布
    #         for j in range(contents_length):  # 对于每个批次中的样本
    #             if contents_tokens[j, i] != self.tokenizer.pad_token_id:  # 如果当前token不是填充符
    #                 acc_loglikelihood[j] += torch.log(
    #                     probs[j, contents_tokens[j, i]])  # 将该token的对数概率累加到acc_loglikelihood中
    #                 token_counts[j] += 1  # 计数有效token
    #     # 计算平均值（避免除以0）
    #     avg_loglikelihood = acc_loglikelihood / token_counts.clamp(min=1)
    #
    #     return acc_loglikelihood.cpu().numpy(), avg_loglikelihood.cpu().numpy()

    def get_loglikelihood(self,
                          prefix: str,
                          contents: list[str],
                          **kwargs):

        contents_length = len(contents)

        # Tokenize
        contents_tokens = self.tokenizer(
            contents,
            return_tensors="pt",
            padding=True  # 显式指定padding
        ).input_ids.to("cuda")

        prefix_tokens = self.tokenizer(prefix, return_tensors="pt").input_ids.to("cuda")
        prefix_length = prefix_tokens.shape[1]

        # 验证前缀匹配
        for content_tokens in contents_tokens:
            assert torch.all(content_tokens[:prefix_length] == prefix_tokens[0])

        # ===== 正确方式：使用forward而不是generate =====
        with torch.no_grad():
            outputs = self.model(contents_tokens)
            contents_logits = outputs.logits  # [batch_size, seq_len, vocab_size]

        # 初始化log likelihood累加器
        acc_loglikelihood = torch.zeros(contents_length).to(self.model.device)
        token_counts = torch.zeros(contents_length).to(self.model.device)  # 新增：计数器


        # 遍历每个位置（从prefix之后开始）
        max_length = contents_tokens.shape[1]  # batch中最长序列的长度
        #            ↑                  ↑
        #         batch维度        sequence维度（最长的那个）

        for i in range(prefix_length, max_length):
            # logits[i-1] 预测 token[i]
            probs = torch.softmax(contents_logits[:, i - 1, :], dim=-1)

            for j in range(contents_length):
                # 跳过padding tokens
                if contents_tokens[j, i] != self.tokenizer.pad_token_id:
                    # 获取真实token的概率
                    token_id = contents_tokens[j, i]
                    acc_loglikelihood[j] += torch.log(probs[j, token_id])
                    token_counts[j] += 1  # 计数有效token

        # 计算平均值（避免除以0）
        avg_loglikelihood = acc_loglikelihood / token_counts.clamp(min=1)

        return acc_loglikelihood.cpu().numpy(), avg_loglikelihood.cpu().numpy()

    def evaluate_text(self, prefix_text: str, pos_token_str, neg_token_str) -> float:
        """
        对单条文本进行打分评估。
        逻辑完全复刻 forward：不使用 Softmax，直接计算 Logit(Yes) - Logit(No)。

        Args:
            tokenizer: 用于编码文本
            prefix_text: 输入的 Prompt 文本 (例如 "...Answer: This action is")

        Returns:
            float: Reward 分数 (Logits 差值), 越大越好
        """
        pos_ids = self.tokenizer.encode(pos_token_str, add_special_tokens=False)
        neg_ids = self.tokenizer.encode(neg_token_str, add_special_tokens=False)

        # 提取 ID
        pos_token_id = pos_ids[0]
        neg_token_id = neg_ids[0]
        # 准备输入
        inputs = self.tokenizer(prefix_text, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.model.device)
        attention_mask = inputs["attention_mask"].to(self.model.device)

        # 切换到评估模式
        self.model.eval()

        with torch.no_grad():
            # ====================================================
            # 2. 完全复刻 Forward 中的模型调用逻辑
            # ====================================================
            outputs = self.model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                output_hidden_states=False,  # 显式关闭
                return_dict=True,
                use_cache=False  # 推理短文本通常不需要 cache
            )
            logits = outputs.logits  # Shape: [1, Seq_Len, Vocab_Size]

            # ====================================================
            # 3. 提取最后一个有效 Token 的 Logits
            # ====================================================
            # 逻辑说明：虽然 inputs['input_ids'] 最后一个通常就是最后一位
            # 但为了严谨对应 mask 逻辑，我们取 attention_mask 为 1 的最后一位
            last_token_idx = attention_mask.sum(dim=1) - 1  # Shape: [1]
            last_token_idx = last_token_idx[0].item()

            # 取出该位置的 logits向量
            final_token_logits = logits[0, last_token_idx, :]  # Shape: [Vocab_Size]

            # ====================================================
            # 4. 提取 Yes/No 分数并计算 Logits 差值
            # ====================================================
            score_helpful = final_token_logits[pos_token_id].float().cpu().numpy()
            score_unhelpful = final_token_logits[neg_token_id].float().cpu().numpy()

            # 核心计算：Logit(Pos) - Logit(Neg)
            # 范围: (-inf, +inf)
            # reward = np.exp(score_helpful) / (np.exp(score_helpful) + np.exp(score_unhelpful))
            reward = score_helpful - score_unhelpful

            return reward.item()

    @torch.inference_mode()
    def generate_sentence(self, llm_input, *args, **kwargs):
        inputs = self.tokenizer(llm_input, return_tensors="pt", add_special_tokens=False)
        input_ids = inputs.input_ids.to(self.model.device)
        attention_mask = inputs.attention_mask.to(self.model.device)
        try:
            res = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                generation_config=self.generation_cfg,
                return_dict_in_generate=True,
                pad_token_id=self.tokenizer.eos_token_id
            )
        except Exception as e:
            print(e)
            return None
        response = []
        if len(res.sequences) == 1:
            return self.tokenizer.decode(res.sequences[0][input_ids.shape[1]:], skip_special_tokens=True)
        else:
            for r in res.sequences:
                response.append(self.tokenizer.decode(r[input_ids.shape[1]:],
                                                      skip_special_tokens=True))
            return response