import json
from utils.utils import read_json, write_jsonl
from check_hidden_state import *
from llms import alive_client


def wash():
    params = {
        "source": f"/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test.jsonl",
        "type": "qa",
        "ra": "none",
        "model_path": "/home/xygxzs28/llms/Llama3-8B-Instruct",
        "batch_size": 8,
        "task": "nq",
        "max_new_tokens": 526,
        "hidden_states": 1,
        "hidden_idx_mode": "first",
        "need_layers": "mid",
        "gpu_device": f"0",
        "weight_path": "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_open_crash/result/mid_layer/res/seed_0/weights/26.pth",
        "hidden_prob_output_dir": f"/home/xygxzs28/knowledge_boundary/NQ-retrieval/wash/",
    }
    run(params)


def wash4eval(index):
    params = {
        "source": f"/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_{index}.jsonl",
        "type": "qa",
        "ra": "none",
        "model_path": "/home/xygxzs28/llms/Llama3-8B-Instruct",
        "batch_size": 8,
        "task": "nq",
        "max_new_tokens": 526,
        "hidden_states": 1,
        "hidden_idx_mode": "first",
        "need_layers": "mid",
        "gpu_device": f"{index}",
        "weight_path": "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_open_crash/result/mid_layer/res/seed_0/weights/26.pth",
        "hidden_prob_output_dir": f"/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/hidden_prob_eval_{index}/",
    }
    run(params)


def split_rerank_data():
    source = "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert.jsonl"
    data = read_json(source)
    data_list_0 = []
    data_list_1 = []
    data_list_2 = []
    data_list_3 = []
    data_list_4 = []
    data_list_5 = []
    data_list_6 = []
    data_list_7 = []
    for i in range(len(data)):
        if i % 8 == 0:
            data_list_0.append(data[i])
        elif i % 8 == 1:
            data_list_1.append(data[i])
        elif i % 8 == 2:
            data_list_2.append(data[i])
        elif i % 8 == 3:
            data_list_3.append(data[i])
        elif i % 8 == 4:
            data_list_4.append(data[i])
        elif i % 8 == 5:
            data_list_5.append(data[i])
        elif i % 8 == 6:
            data_list_6.append(data[i])
        elif i % 8 == 7:
            data_list_7.append(data[i])
        else:
            raise ValueError

    write_jsonl(data_list_0, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_0.jsonl")
    write_jsonl(data_list_1, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_1.jsonl")
    write_jsonl(data_list_2, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_2.jsonl")
    write_jsonl(data_list_3, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_3.jsonl")
    write_jsonl(data_list_4, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_4.jsonl")
    write_jsonl(data_list_5, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_5.jsonl")
    write_jsonl(data_list_6, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_6.jsonl")
    write_jsonl(data_list_7, "/home/xygxzs28/knowledge_boundary/NQ-retrieval/nq_rag_test_convert_7.jsonl")


def merge_rerank_data():
    data_list_0 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_0.jsonl")
    data_list_1 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_1.jsonl")
    data_list_2 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_2.jsonl")
    data_list_3 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_3.jsonl")
    data_list_4 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_4.jsonl")
    data_list_5 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_5.jsonl")
    data_list_6 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_6.jsonl")
    data_list_7 = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_7.jsonl")
    result = []
    for i in range(len(data_list_0)):
        result.append(data_list_0[i])
        if i < len(data_list_1):
            result.append(data_list_1[i])
        if i < len(data_list_2):
            result.append(data_list_2[i])
        if i < len(data_list_3):
            result.append(data_list_3[i])
        if i < len(data_list_4):
            result.append(data_list_4[i])
        if i < len(data_list_5):
            result.append(data_list_5[i])
        if i < len(data_list_6):
            result.append(data_list_6[i])
        if i < len(data_list_7):
            result.append(data_list_7[i])
    write_jsonl(result, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_no_5.jsonl")


def convert_file(meta_data_path, output_path):
    meta_data = read_json(meta_data_path)
    result = []
    question_id = 0
    for meta in meta_data:
        context_list = meta['context']
        result.append({
            "question": meta['question'],
            "question_id": question_id,
            "context": ""
        })
        for context in context_list:
            question = generate_rag_prompt(question=meta["question"], context=context)
            result.append({
                "question": question,
                "question_id": question_id,
                "context": context
            })
        question_id += 1
    write_jsonl(result, output_path)


def generate_rerank_preference_dataset():

    def _rank_filter(_meta):
        _pos = _meta["pos"]
        _neg = _meta["neg"]
        _sort_pos = sorted(_pos, key=lambda x: x["diff_prob"], reverse=True)[:5]
        _sort_neg = sorted(_neg, key=lambda x: x["diff_prob"], reverse=True)[:5]
        _meta["pos"] = [j["context"] for j in _sort_pos]
        _meta["neg"] = [j["context"] for j in _sort_neg]
        _meta["prompt"] = "Given a question, retrieve Wikipedia passages that answer the question."
        return _meta

    prob_data = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/hidden_prob_eval/hidden_prob_by_mlp.jsonl")
    meta_data = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta.jsonl")
    question_id = -1
    prob_tag = []
    tmp_dic = {}
    result = []
    for i in range(len(prob_data)):
        _meta = meta_data[i]
        if _meta["question_id"] != question_id:
            if question_id != -1:
                result.append(_rank_filter(tmp_dic))
            question_id = _meta["question_id"]
            prob_tag = prob_data[i][0]
            tmp_dic = {"query": _meta["question"], "pos": [], "neg": []}
        else:
            _neg_prob, _pso_prob = prob_data[i][0]
            if _neg_prob > prob_tag[0]:
                tmp_dic["neg"].append({"context": _meta["context"], "diff_prob": float(_neg_prob) - float(prob_tag[0])})
            elif _pso_prob > prob_tag[1]:
                tmp_dic["pos"].append({"context": _meta["context"], "diff_prob": float(_pso_prob) - float(prob_tag[1])})
            else:
                continue
    write_jsonl(result, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval.jsonl")


def filter_zero_data():
    meta_data = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval.jsonl")
    result = []
    count = 0
    for _meta_data in meta_data:
        if not _meta_data["pos"] or not _meta_data["neg"]:
            count += 1
        else:
            result.append(_meta_data)
    print(f"Filter {count} data from {len(meta_data)}")
    write_jsonl(result, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_filter.jsonl")


def filter_file_with_question_id(filter_id_list):
    filter_id = filter_id_list

    meta_data = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_no_5.jsonl")
    hidden_state_data = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/hidden_prob_eval/hidden_state_by_llm.jsonl")
    prob_data = read_json("/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/hidden_prob_eval/hidden_prob_by_mlp.jsonl")

    assert len(meta_data) == len(hidden_state_data) == len(prob_data)

    filter_meta_data = []
    filter_hidden_state_data = []
    filter_prob_data = []

    for i in range(len(hidden_state_data)):
        if meta_data[i]["question_id"] not in filter_id:
            filter_meta_data.append(meta_data[i])
            filter_hidden_state_data.append(hidden_state_data[i])
            filter_prob_data.append(prob_data[i])

    write_jsonl(filter_meta_data, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_meta_no_5.jsonl")
    write_jsonl(filter_hidden_state_data, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/hidden_prob_eval/hidden_state_by_llm.jsonl")
    write_jsonl(filter_prob_data, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/hidden_prob_eval/hidden_prob_by_mlp.jsonl")


def generate_similar_context(query: str, label: str, existing_contexts: list, count: int) -> list:
    """
    调用大模型生成相似上下文
    :param query: 原始查询文本
    :param label: 生成类型（'pos' 或 'neg'）
    :param existing_contexts: 已存在的上下文（避免重复）
    :param count: 需要生成的数量
    :return: 生成的新上下文列表
    """
    # 构建明确的提示词
    prompt = f"""请生成 {count} 条与以下查询文本语义相似的{ '正面' if label=='pos' else '负面'}上下文：
                查询：“{query}”
                现有上下文：{existing_contexts}
                
                要求：
                1. 生成内容需与查询主题强相关
                2. 避免与现有上下文完全重复
                3. 正面上下文应当可以促进对问题query的回答，负面上下文应当无法回答问题query。
                4. 每条长度控制在 50 words左右
                5. 应当以列表形式返回
                6. 返回格式形如["xxx", "xxx", "xxx"]
                """

    # 调用大模型 API（示例使用 OpenAI GPT-4）
    try:
        response = alive_client.chat(prompt)

        # 解析生成结果
        generated = json.loads(response)
        assert type(generated) is list and len(generated) == count
        return generated[:count]
    except Exception as e:
        print(f"{e}")
        return []


def complete_dataset(dataset_path: str):
    dataset = read_json(dataset_path)
    length = len(dataset)
    for index, item in enumerate(dataset):
        print(f"Will process {index + 1} / {length}")
        # 补齐正面上下文
        if len(item['pos']) < 5:
            need_count = 5 - len(item['pos'])
            new_pos = generate_similar_context(
                query=item['query'],
                label='pos',
                existing_contexts=item['pos'],
                count=need_count
            )
            item['pos'].extend(new_pos)

        # 补齐负面上下文
        if len(item['neg']) < 5:
            need_count = 5 - len(item['neg'])
            new_neg = generate_similar_context(
                query=item['query'],
                label='neg',
                existing_contexts=item['neg'],  # 避免与任何现有上下文重复
                count=need_count
            )
            item['neg'].extend(new_neg)

    write_jsonl(dataset, "/home/xygxzs28/knowledge_boundary/LKBP/data/nq_rerank/nq_rerank_eval_completed.jsonl")
