import argparse
from dataclasses import asdict
import json
import os

import torch
from DatasetInit import CustomDataset
from DatasetJudge import DatasetJudge
from get_hidden_state import get_hidden_state, InputQA
from get_known import get_isKnown
from mlp_trainer import TrainResult, train
from utils import jsonlload, load_model, seed_everything, get_proj, split_data


MODEL_PATH = {
    # "Qwen2.5-7B-Instruct": "/path/to/Qwen2.5-7B-Instruct",
    # "Llama-3.1-8B": "/path/to/Meta-Llama-3.1-8B",
    # "BLEURT-20": "/path/to/BLEURT-20",
    "Qwen2.5-7B-Instruct": "Qwen/Qwen2.5-7B-Instruct",
    "Llama-3.1-8B": "meta-llama/Llama-3.1-8B",
    "BLEURT-20": "lucadiliello/BLEURT-20",
}

DATASET_PATH = {
    "truthful_qa": "truthfulqa/truthful_qa",
    # "truthful_qa": "/path/to/truthful_qa",
    "tydiqa": "google-research-datasets/tydiqa",
    "trivia_qa": "mandarjoshi/trivia_qa",
    "nq_open": "google-research-datasets/nq_open",
}

PROMPT_TEMPLATE = {
    "truthful_qa": "Q: {question}\nA:",
    # "truthful_qa": "Answer the question concisely.\nQ: {question}\nA:",
    "tydiqa": "Concisely answer the following question based on the information in the given passage: \nPassage:\n{context}\nQ: {question}\nA:",
    "trivia_qa": "Q: {question}\nA:",
    "nq_open": "Q: {question}\nA:",
}

parser = argparse.ArgumentParser()
parser.add_argument("--model_name", type=str, default="Qwen2.5-7B-Instruct")
parser.add_argument("--dataset_name", type=str, default="truthful_qa")
parser.add_argument("--train_ratio", type=float, default=0.75)
parser.add_argument("--train_batch_size", type=int, default=50)
parser.add_argument("--train_epochs", type=int, default=50)
parser.add_argument("--train_step_log", type=int, default=10)
parser.add_argument("--epoch_log", type=int, default=None)

args = parser.parse_args()

model_quantity = args.model_name
ds_name = args.dataset_name
train_ratio = args.train_ratio
train_batch_size = args.train_batch_size
train_epochs = args.train_epochs
train_step_log = args.train_step_log
epoch_log = args.epoch_log if args.epoch_log is not None else train_epochs

proj_dim_list = [32, 64, 128, 192, 256, 512, 1024]
random_seed = 42
hidden_dim = 1024


def mkdir(path) -> None:
    if not os.path.exists(path):
        # 创建文件夹
        os.makedirs(path)
        print(f"makedir: {path}")


if __name__ == "__main__":
    dataset_dir = "./dataset"
    hidden_states_dir = f"./hidden_state/{model_quantity}"
    svd_dir = "./svd"
    train_dir = f"./mlp_result/{model_quantity}"
    # auroc_img_save_folder = f"./mlp_result/{model_quantity}/auroc_img"
    data_device = "cuda:1"
    model_device = "cuda:2"

    known_save_path = f"./{dataset_dir}/[known]{ds_name}.jsonl"
    unknown_save_path = f"./{dataset_dir}/[unknown]{ds_name}.jsonl"
    data_type_list = [
        "known",
        "unknown",
    ]

    mkdir(dataset_dir)
    mkdir(hidden_states_dir)
    mkdir(svd_dir)
    mkdir(train_dir)

    print(f"Loading model: {model_quantity}...")
    model, tokenizer = load_model(
        MODEL_PATH[model_quantity],
        # torch_dtype="auto",
        # device_map="auto",
        device_map=model_device,
        attn_implementation="eager",
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model.eval()
    device = next(model.parameters()).device
    layers = model.model.config.num_hidden_layers

    with torch.no_grad():
        print("Loading projection matrix...")
        svd_path = f"./{svd_dir}/[{model_quantity}]unembedding_svd.pt"
        if not os.path.exists(svd_path):
            w = model.lm_head.weight.data
            svd_res = torch.svd(w)
            V = svd_res.V.to(device)
            torch.save(svd_res, svd_path)
        else:
            _, _, V = torch.load(svd_path, map_location=device, weights_only=False)

        print("=" * 20)

        if os.path.exists(known_save_path) and os.path.exists(unknown_save_path):
            print("known and unknown file exists, skip...")
            known_list = jsonlload(known_save_path)
            unknown_list = jsonlload(unknown_save_path)
        else:
            print(f"Loading dataset: {DATASET_PATH[ds_name]}...")
            ds = CustomDataset(
                name=ds_name,
                path=DATASET_PATH[ds_name],
                prompt_template=PROMPT_TEMPLATE[ds_name],
            )

            print("Loading DatasetJudge...")
            judge = DatasetJudge(
                bleurt_model_path=MODEL_PATH["BLEURT-20"],
                sen_sim_threshold=0.5,
                rouge_threshold=0.7,
            )

            with open(known_save_path, "w", encoding="utf-8") as f:
                f.write("")
            with open(unknown_save_path, "w", encoding="utf-8") as f:
                f.write("")

            print("Start get_isKnown...")
            known_list, unknown_list = get_isKnown(
                judge=judge,
                ds=ds,
                model=model,
                tokenizer=tokenizer,
                known_save_path=known_save_path,
                unknown_save_path=unknown_save_path,
                check_num=len(ds) // 20,
                max_new_tokens=64,
                temperature=0.5,
                top_k=5,
                top_p=0.99,
                num_beams=10,
                num_return_sequences=10,
            )
            known_list = [asdict(item) for item in known_list]
            unknown_list = [asdict(item) for item in unknown_list]
        print("=" * 20)

        print("Start getting hidden states...")
        qa_list = {"known": known_list, "unknown": unknown_list}
        all_data = {}
        have_context = "{context}" in PROMPT_TEMPLATE[ds_name]
        for data_type in data_type_list:
            print(f"Getting [{data_type}] data...")
            hidden_states_save_path = f"./{hidden_states_dir}/[{data_type}]{ds_name}.pt"
            if os.path.exists(hidden_states_save_path):
                print("Loading hidden states...")
                res = torch.load(
                    hidden_states_save_path,
                    map_location=data_device,
                    weights_only=True,
                )
                all_emb = res["all_emb"]
                all_hallucination_flag = res["all_hallucination_flag"]
                print(f"all_flag len: {len(all_hallucination_flag)}")
            else:
                all_qa = []
                all_hallucination_flag = []
                for item in qa_list[data_type]:
                    prompt = item["prompt"]
                    prompt_token = []

                    if have_context:
                        prompt_token = (
                            tokenizer(prompt, return_tensors="pt").input_ids[0].tolist()
                        )
                        prompt = ""

                    for beam in item["result"]:
                        answer_token = (
                            tokenizer(prompt + beam["gen_text"], return_tensors="pt")
                            .input_ids[0]
                            .tolist()
                        )
                        qa_text = InputQA(prompt_token, answer_token)
                        all_qa.append(qa_text)
                        all_hallucination_flag.append(
                            not beam["score"]["is_correct"]
                        )

                all_emb = get_hidden_state(
                    model=model,
                    tokenizer=tokenizer,
                    all_qa=all_qa,
                    key_list=[f"model.layers.{layers - 1}"],
                    save_device=data_device,
                    need_prompt=not have_context,
                )

                # save
                print("save...")
                res = {
                    "all_emb": all_emb,
                    "all_hallucination_flag": all_hallucination_flag,
                }
                torch.save(
                    res,
                    hidden_states_save_path,
                )
            all_data[data_type] = res
    print("=" * 20)

    data_type = "all"
    all_auroc = {}
    key = f"model.layers.{layers - 1}"
    norm = None

    for proj_dim in proj_dim_list:
        seed_everything(seed=random_seed)
        with torch.no_grad():
            print("projecting...")
            known_sentences = [
                get_proj(V=V[:, -proj_dim:], emb=emb.to(device), norm=norm).clone()
                for emb in all_data["known"]["all_emb"][key]
            ]


            known_flags = torch.tensor(all_data["known"]["all_hallucination_flag"]).to(
                device
            )
            known_data = list(zip(known_sentences, known_flags))

            unknown_sentences = [
                get_proj(V=V[:, -proj_dim:], emb=emb.to(device), norm=norm).clone()
                for emb in all_data["unknown"]["all_emb"][key]
            ]

            unknown_flags = torch.tensor(
                all_data["unknown"]["all_hallucination_flag"]
            ).to(device)
            unknown_data = list(zip(unknown_sentences, unknown_flags))

            print("splitting data...")
            train_data, valid_data = split_data(
                known_data,
                train_ratio=train_ratio,
                shuffle=True,
            )
            valid_data.extend(unknown_data)

        print("training MLP...")
        save_path_prefix = f"{train_dir}/[{ds_name}_{data_type}][train_{int(train_ratio * 100):03d}][{proj_dim}_{hidden_dim}]"
        mkdir(save_path_prefix)

        input_dim = train_data[0][0].shape[1]
        hidden_dim = hidden_dim

        cur_img_save_folder = f"{save_path_prefix}/auroc_img"
        mkdir(cur_img_save_folder)  

        res: TrainResult = train(
            train_data=train_data,
            valid_data=valid_data,
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            batch_size=train_batch_size,
            step_log=train_step_log,
            epochs=train_epochs,
            epoch_log=epoch_log,
            max_grad_norm=1.0,
            random_seed=random_seed,
            lr=1e-4,
            weight_decay=3e-4,
            device=device,
            auroc_img_save_folder=cur_img_save_folder,
        )

        ## save
        torch.save(
            res.model,
            f"{save_path_prefix}/model.pt",
        )
        res.model = None
        res.args_dict.pop("device")
        res_dict = asdict(res)
        res_dict.pop("model")

        all_auroc[proj_dim] = {
            "train_auroc": res.train_auroc,
            "valid_auroc": res.valid_auroc,
        }
        res_save_path = f"{save_path_prefix}/train_result.json"
        with open(res_save_path, "w") as f:
            json.dump(res_dict, f, indent=4, ensure_ascii=False)
            print(f"save to {res_save_path}")

        print("=" * 20)

        del known_sentences, known_flags, known_data
        del unknown_sentences, unknown_flags, unknown_data
        del train_data, valid_data
        del res
        torch.cuda.empty_cache()

    for proj_dim, res in all_auroc.items():
        print(
            f"proj_dim: {proj_dim:04d} \t Train AUROC: {res['train_auroc']:.2%} \t Valid AUROC: {res['valid_auroc']:.2%}"
        )
        print("-" * 20)

    print("done.")
