import json
from dataclasses import asdict, dataclass
from typing import List

import torch
from tqdm import tqdm

from DatasetJudge import DatasetJudge, JudgeResult
from utils import load_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from DatasetInit import CustomDataset


@dataclass
class BeamSearchResult:
    beam_id: int
    gen_text: str
    # gen_token: List[int]
    score: JudgeResult


@dataclass
class DatasetJudgeResult:
    ind: int
    prompt: str
    # prompt_token: List[int]
    correct_answers: List[str]
    incorrect_answers: List[str]
    result: List[BeamSearchResult]

    def is_known(self) -> bool:
        for res in self.result:
            if res.score.is_correct:
                self.isKnown = True
                return True
        self.isKnown = False
        return False


def get_isKnown(
    judge: DatasetJudge,
    ds: CustomDataset,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    known_save_path: str,
    unknown_save_path: str,
    check_num: int = 100,
    **gen_args,
) -> List[DatasetJudgeResult]:
    known_list: List[DatasetJudgeResult] = []
    unknown_list: List[DatasetJudgeResult] = []

    storp_str_list = [".", "!", "?", ".\n", "!\n", "?\n", "\n", "\n\n"]
    eos_token_id_list = []
    for s in storp_str_list:
        eos_token_id_list.extend(tokenizer.encode(text=s))
        eos_token_id_list.extend(tokenizer.encode(text="Yes" + s)[1:])
    eos_token_id_list = list(set(eos_token_id_list))  # 去重

    for ind in tqdm(range(len(ds))):
        
        correct_answers = ds[ind].correct_answers
        incorrect_answers = ds[ind].incorrect_answers

        input_text = ds[ind].prompt
        model_inputs = tokenizer([input_text], return_tensors="pt").to(model.device)
        input_len = model_inputs.input_ids.size(1)

        generated_ans = model.generate(
            **model_inputs,
            **gen_args,
            do_sample=True,  # 启用采样
            eos_token_id=eos_token_id_list,
            early_stopping=True,
            return_dict_in_generate=True,
        )
        generated_ids = generated_ans["sequences"]

        beam_search_result: List[BeamSearchResult] = []
        for i, beam in enumerate(generated_ids):
            beam_list = beam[input_len:].tolist()
            end_ind = len(beam)
            for eos_token_id in eos_token_id_list:
                if eos_token_id in beam_list:
                    end_ind = min(
                        end_ind, input_len + beam_list.index(eos_token_id) + 1
                    )

            generated_text = tokenizer.decode(
                beam[input_len:end_ind], skip_special_tokens=True
            )
            
            if not generated_text.strip():
                continue

            beam_search_result.append(
                BeamSearchResult(
                    beam_id=i,
                    gen_text=generated_text,
                    score=judge.judge(
                        generated_text=generated_text,
                        correct_answer_list=correct_answers,
                        incorrect_answer_list=incorrect_answers,
                    ),
                )
            )

        cur_res = DatasetJudgeResult(
            ind=ind,
            prompt=input_text,
            correct_answers=correct_answers,
            incorrect_answers=incorrect_answers,
            result=beam_search_result,
        )
        if cur_res.is_known():
            known_list.append(cur_res)
            with open(known_save_path, "a", encoding="utf-8") as f:
                json_line = json.dumps(asdict(cur_res), ensure_ascii=False)
                f.write(json_line + "\n")
        else:
            unknown_list.append(cur_res)
            with open(unknown_save_path, "a", encoding="utf-8") as f:
                json_line = json.dumps(asdict(cur_res), ensure_ascii=False)
                f.write(json_line + "\n")
        if (ind + 1) % check_num == 0:
            print(f"known: {len(known_list)} \t unknown: {len(unknown_list)}")

    return known_list, unknown_list


def main(known_save_path: str, unknown_save_path: str):
    judge = DatasetJudge(
        rouge_threshold=0.7,
        sen_sim_threshold=0.5,
        correct_advantage=0.05,
        bleurt_model_path="/home/hust1/model/BLEURT-20",
    )

    ds = CustomDataset("truthful_qa")

    model_quantity = "Qwen2.5-7B-Instruct"
    model_name = f"/home/hust1/model/{model_quantity}"
    model, tokenizer = load_model(
        model_name,
        # torch_dtype="auto",
        device_map="auto",
        attn_implementation="eager",
    )

    with open(known_save_path, "w", encoding="utf-8") as f:
        f.write("")
    with open(unknown_save_path, "w", encoding="utf-8") as f:
        f.write("")

    get_isKnown(
        judge=judge,
        ds=ds,
        model=model,
        tokenizer=tokenizer,
        known_save_path=known_save_path,
        unknown_save_path=unknown_save_path,
        check_num=len(ds) // 20,
        max_new_tokens=64,
        temperature=0.5,
        top_k=20,
        top_p=0.99,
        num_beams=10,
        num_return_sequences=10,
    )
