import re
import torch
import tools
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import os
import argparse
import torch


SAMPLE_NUM = None
JUDGED_MODEL = None
MODEL_PATH = None


def replace_last_answer(text: str, answer, method) -> str:
    if method == "gt":
        pattern = re.compile(r"<answer>.*?</answer>", re.DOTALL)
        matches = list(pattern.finditer(text))

        if not matches:
            return text

        last_match = matches[-1]
        start, end = last_match.span()

        res = text[:start] + f"<answer> {answer} </answer>" + text[end:]
        res = res.strip()
        return res
    elif method == "answer":
        return text
    else:
        raise ValueError(f"Unknown method: {method}. Use 'gt' or 'answer'.")


def process_part(data, partid, partnum):
    if partnum == 1:
        return data
    n = len(data)
    if partnum == 2:
        mid = n // 2
        if partid == 1:
            return data[:mid]
        elif partid == 2:
            return data[mid:]
        else:
            raise ValueError("partid must be 1 or 2")
    if partnum == 4:
        mid = n // 4
        if partid == 1:
            return data[:mid]
        elif partid == 2:
            return data[mid:2*mid]
        elif partid == 3:
            return data[2*mid:3*mid]
        elif partid == 4:
            return data[3*mid:]
    if partnum == 8:
        mid = n // 8
        if 1 <= partid <= 7:
            return data[(partid-1)*mid:partid*mid]
        elif partid == 8:
            return data[7*mid:]
        else:
            raise ValueError("partid must be 1~8")
    raise ValueError("partnum must be 1, 2, 4, or 8")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sample_num", type=int, default=20,
                        help="number of samples to process")
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--subset", type=str, required=True)
    parser.add_argument("--traindataset", type=str, required=True)
    parser.add_argument("--method", type=str, required=True)

    args = parser.parse_args()
    SAMPLE_NUM = args.sample_num
    MODEL = args.model
    MODEL_PATH = f"{tools.machine_pather()}/models/{MODEL}"
    method = args.method

    part_id = int(os.environ.get("RANK"))+1
    os.environ["CUDA_VISIBLE_DEVICES"] = os.environ.get("RANK")
    part_num = 8
    subset = args.subset

    INPUT_FILE = f"{tools.machine_pather()}/works/DPO/judge/input/{args.traindataset}/{subset}/{MODEL}/judge_input.jsonl"

    raw_inputs = tools.read_jsonl(
        INPUT_FILE)
    prompts = []
    user_msg = []
    assistant_msg = []
    print("start to process the inputs")
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_PATH, trust_remote_code=True)
    for index, each_json in tqdm(enumerate(raw_inputs)):
        assert each_json['rank_id'] == f"{index//SAMPLE_NUM}_{index % SAMPLE_NUM}_0"
        user_msg.append([{"role": "user",
                         "content": each_json['assembly_question']}])
        assistant_msg.append({"role": "assistant",
                              "content": replace_last_answer(each_json['assembly_reasoning'], each_json['assembly_answer'], method)})
    user_msg = tokenizer.apply_chat_template(
        user_msg, tokenize=False, add_generation_prompt=True, enable_thinking=True)
    prompts = [user_msg_i+assistant_msg_i["content"]
               for user_msg_i, assistant_msg_i in zip(user_msg, assistant_msg)]

    OUTPUT_DIR = f"{tools.machine_pather()}/works/DPO/judge/output/{args.traindataset}/{subset}/{MODEL}/logp@{method}/"

    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print("output to >>>", OUTPUT_DIR)

    tools.check_existence(
        os.path.join(OUTPUT_DIR, f"score_part{part_id}of{part_num}.jsonl"))
    open(os.path.join(OUTPUT_DIR,
         f"score_part{part_id}of{part_num}.jsonl"), 'w').close()

    model = AutoModelForCausalLM.from_pretrained(
        MODEL_PATH, torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()

    scores = []
    with torch.inference_mode():
        for msg in tqdm(process_part(prompts, part_id, part_num)):
            enc = tokenizer(msg, return_tensors="pt").to(model.device)
            input_ids = enc["input_ids"]

            outputs = model(**enc)
            logits = outputs.logits

            logp_list = []
            for i in range(input_ids.size(1)):
                if i == 0:
                    logp_list.append(0)
                else:
                    logp = torch.log_softmax(
                        logits[0, i-1], dim=-1)[input_ids[0, i]].item()
                    logp_list.append(logp)

            pre_match = list(re.finditer(
                r"<answer>(.*?)</answer>", msg, re.DOTALL))
            if not pre_match:
                answer_start = -1
                answer_end = -1
            else:
                last_match = pre_match[-1]
                answer_text = last_match.group(1)

                prefix_text = msg[:last_match.start(1)]
                prefix_ids = tokenizer(
                    prefix_text, return_tensors="pt").input_ids[0]

                answer_ids = tokenizer(
                    answer_text, return_tensors="pt").input_ids[0]

                answer_start = len(prefix_ids)
                answer_end = answer_start + len(answer_ids)

            tools.write_jsonl(
                {
                    "answer_start": answer_start,
                    "answer_end": answer_end,
                    "logp_list": logp_list,
                    "input": msg
                },
                os.path.join(
                    OUTPUT_DIR, f"score_part{part_id}of{part_num}.jsonl")
            )

    open(
        os.path.join(OUTPUT_DIR, f"success_part{part_id}of{part_num}.tag"), "w").close()
    open(
        os.path.join(OUTPUT_DIR, MODEL_PATH.split("/")[-1]+".tag"), "w").close()
