import re
import os
import json
import argparse
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from tampering.rm.dataset.prompt import LABELING_PROMPT
from tampering.utils.openai import OpenAILLM
from copy import deepcopy


def parse_tags(text):
    pattern = r'<(.*?)>(.*?)</\1>'
    matches = re.findall(pattern, text, re.DOTALL)
    return {tag.strip(): content.strip() for tag, content in matches}


def process_item(data, idx, llm):
    try:
        base = dict(data)

        conversation = base["messages"]
        a, b, c, d = base["response_1"], base["response_2"], base["response_3"], base["response_4"]

        prompt = LABELING_PROMPT.format(
            conversation=conversation,
            assistant_a=a, assistant_b=b, assistant_c=c, assistant_d=d
        )
        messages = [
            {"role": "system", "content": ""},
            {"role": "user", "content": prompt}
        ]
        response = llm.chat_generate(
            messages=messages,
            max_tokens=1024,
            temperature=0,
            top_p=1
        )

        rd = parse_tags(response)
        choice_map = {"A": "response_1", "B": "response_2", "C": "response_3", "D": "response_4"}

        base.update({
            "thought": rd["thought"],
            "rank_1": choice_map[rd["ranking_1"]],
            "rank_2": choice_map[rd["ranking_2"]],
            "rank_3": choice_map[rd["ranking_3"]],
            "rank_4": choice_map[rd["ranking_4"]],
        })
        return idx, base, None
    except Exception as e:
        return idx, None, e


def main(args):
    llm = OpenAILLM(model_name="gpt-4.1")

    with open(args.source_path, "r") as f:
        dataset = json.load(f)

    results = [None] * len(dataset)
    error_indices = []

    with ThreadPoolExecutor(max_workers=args.max_worker) as executor:
        futures = {executor.submit(process_item, dataset[i], i, llm): i for i in range(len(dataset))}
        for future in tqdm(as_completed(futures), total=len(futures)):
            idx, updated, err = future.result()
            if err is not None:
                error_indices.append(idx)
                print(f"[Error] index {idx}: {err}")
            else:
                results[idx] = updated

    tmp_path = args.target_path + ".tmp"
    with open(tmp_path, "w", encoding="utf-8") as wf:
        json.dump(results, wf, ensure_ascii=False, indent=4)
    os.replace(tmp_path, args.target_path)

    print("Errorneus indices:", error_indices)


if __name__ == "__main__":
    TAMPERING_HOME = os.getenv("TAMPERING_HOME")

    default_dataset = "hhrlhf"
    default_source = f"{TAMPERING_HOME}/datasets/{default_dataset}/rm/train/{default_dataset}_RM_5120_sft.json"
    default_target = f"{TAMPERING_HOME}/datasets/{default_dataset}/rm/train/{default_dataset}_RM_5120_pref.json"

    parser = argparse.ArgumentParser(description="Process dataset with LLM ranking")
    parser.add_argument("--dataset_name", type=str, default=default_dataset,
                        help="Dataset name (default: ultrafeedback)")
    parser.add_argument("--source_path", type=str, default=default_source,
                        help="Path to input JSON file")
    parser.add_argument("--target_path", type=str, default=default_target,
                        help="Path to output JSON file")
    parser.add_argument("--max_worker", type=int, default=30,
                        help="Number of parallel workers")

    args = parser.parse_args()
    main(args)
