import argparse
import torch
import tools
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import os


SAMPLE_NUM = None
MODEL = None
MODEL_PATH = f"{tools.machine_pather()}/models/models--Qwen--Qwen3-Reranker-8B"


def format_instruction(instruction, query, doc):
    if instruction is None:
        raise ValueError("Instruction is None")
    output = "<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}".format(
        instruction=instruction, query=query, doc=doc)
    return output


relevance_prompt = """
You are given a Query and a Document. The Query contains a user question and related information, and the Document contains a reasoning process that leads to a final answer. Your task is to retrieve the Document with reasoning that:
- is relevant to the information in the Query.
- directly helps to answer the user question.
- is logically coherent and leads to the final answer.
""".strip()


def process_inputs(pairs, model):
    inputs = tokenizer(
        pairs, padding=False, truncation='longest_first',
        return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
    )
    for i, ele in enumerate(inputs['input_ids']):
        inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
    inputs = tokenizer.pad(inputs, padding=True,
                           return_tensors="pt", max_length=max_length)
    for key in inputs:
        inputs[key] = inputs[key].to(model.device)
    return inputs


@torch.no_grad()
def compute_logits(inputs, model, **kwargs):
    batch_scores = model(**inputs).logits[:, -1, :]
    true_vector = batch_scores[:, token_true_id]
    false_vector = batch_scores[:, token_false_id]
    batch_scores = torch.stack([false_vector, true_vector], dim=1)
    batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
    scores = batch_scores[:, 1].exp().tolist()
    return scores


tokenizer = AutoTokenizer.from_pretrained(
    MODEL_PATH, padding_side='left')

token_false_id = tokenizer.convert_tokens_to_ids("no")
token_true_id = tokenizer.convert_tokens_to_ids("yes")
max_length = 8192*2

prefix = "<|im_start|>system\nJudge whether the Document meets the requirements based on the Query and the Instruct provided. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
prefix_tokens = tokenizer.encode(prefix, add_special_tokens=False)
suffix_tokens = tokenizer.encode(suffix, add_special_tokens=False)


def process_part(data, partid, partnum):
    if partnum == 1:
        return data
    n = len(data)
    if partnum == 2:
        mid = n // 2
        if partid == 1:
            return data[:mid]
        elif partid == 2:
            return data[mid:]
        else:
            raise ValueError("partid must be 1 or 2")
    if partnum == 4:
        mid = n // 4
        if partid == 1:
            return data[:mid]
        elif partid == 2:
            return data[mid:2*mid]
        elif partid == 3:
            return data[2*mid:3*mid]
        elif partid == 4:
            return data[3*mid:]
    if partnum == 8:
        mid = n // 8
        if 1 <= partid <= 7:
            return data[(partid-1)*mid:partid*mid]
        elif partid == 8:
            return data[7*mid:]
        else:
            raise ValueError("partid must be 1~8")
    raise ValueError("partnum must be 1, 2, 4, or 8")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sample_num", type=int, default=20,
                        help="number of samples to process")
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--subset", type=str, required=True)
    parser.add_argument("--traindataset", type=str, required=True)

    args = parser.parse_args()
    SAMPLE_NUM = args.sample_num
    MODEL = args.model

    part_id = int(os.environ.get("RANK"))+1
    os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("RANK")
    part_num = 8
    subset = args.subset

    INPUT_FILE = f"{tools.machine_pather()}/works/DPO/judge/input/{args.traindataset}/{subset}/{MODEL}/judge_input.jsonl"

    raw_inputs = tools.read_jsonl(
        INPUT_FILE)
    print("start to process the inputs")
    queris = []
    docs = []
    for index, each_json in tqdm(enumerate(raw_inputs)):
        assert each_json['rank_id'] == f"{index//SAMPLE_NUM}_{index % SAMPLE_NUM}_0"

        queris.append(each_json['assembly_question'])
        docs.append(each_json['assembly_reasoning'])

    pairs = [format_instruction(relevance_prompt, query, doc)
             for query, doc in zip(queris, docs)]

    OUTPUT_DIR = f"{tools.machine_pather()}/works/DPO/judge/output/{args.traindataset}/{subset}/{MODEL}/relevance_score"

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print("output to >>>", OUTPUT_DIR)

    tools.check_existence(
        os.path.join(OUTPUT_DIR, f"score_part{part_id}of{part_num}.jsonl"))
    open(os.path.join(OUTPUT_DIR,
         f"score_part{part_id}of{part_num}.jsonl"), 'w').close()
    relevance_model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()

    scores = []
    with torch.inference_mode():
        for pair in tqdm(process_part(pairs, part_id, part_num)):
            input_pair = process_inputs([pair], relevance_model)
            score_pair = compute_logits(input_pair, relevance_model)[0]

            scores.append(score_pair)
            tools.write_jsonl({"score_name": f'relevance_score', "score": score_pair, "input": pair},
                              os.path.join(OUTPUT_DIR, f"score_part{part_id}of{part_num}.jsonl"))
    open(
        os.path.join(OUTPUT_DIR, f"success_part{part_id}of{part_num}.tag"), "w").close()
    open(
        os.path.join(OUTPUT_DIR, MODEL_PATH.split("/")[-1]+".tag"), "w").close()
