import re
import json
from datetime import datetime

import os
import orjson
import torch
import time
import re
import string
from pathlib import Path


BACKBONE_MODELS = [
    "models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B",
    "models--Qwen--Qwen3-8B",
    "models--meta-llama--Llama-3.1-8B-Instruct",
]
BASELINES = {
    "models--meta-llama--Llama-3.1-8B-Instruct": ["models--openbmb--RLPR-Llama3.1-8B-Inst"],
    "models--Qwen--Qwen3-8B": ["models--deepseek-ai--DeepSeek-R1-0528-Qwen3-8B"],
    "models--deepseek-ai--DeepSeek-R1-Distill-Llama-8B": [],
}
GRPO_METHODS = ["RLVR", "GTreasoning_reward", "0.5RLVR-0.5GTreasoning_reward"]
METHODS = []
for pick_method in ["random", "0.7*judger_score+0.2*relevance@all_score+0.1*logp@reasoning_N@answer_score", "1.0*judger_score+1.0*relevance@all_score+1.0*logp@reasoning_N@answer_score"]:
    for result_type in ["any", "T+F", "T+T", "F+F"]:
        METHODS.append(f"{pick_method}|{result_type}")
for pick_method in ["judger_score", "relevance@all_score", "logp@reasoning_N@answer_score"]:
    for result_type in ["any"]:
        METHODS.append(f"{pick_method}|{result_type}")

METHODS.sort()


def machine_pather() -> str:
    current_file = str(Path(__file__).resolve())
    return current_file.split('/works')[0]


def normalize_answer(text: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(s):
        return re.sub(r'\b(a|an|the)\b', ' ', s)

    def white_space_fix(s):
        return ' '.join(s.split())

    def remove_punc(s):
        return ''.join(ch for ch in s if ch not in set(string.punctuation))

    def lower(s):
        return s.lower()
    return white_space_fix(remove_articles(remove_punc(lower(text))))


def compute_em(prediction: str, gold: str):
    """Compute Exact Match (EM) between two strings."""
    return normalize_answer(prediction) == normalize_answer(gold)


def compute_f1(prediction: str, gold: str) -> float:
    pred_tokens = normalize_answer(prediction).split()
    gold_tokens = normalize_answer(gold).split()
    common = set(pred_tokens) & set(gold_tokens)
    num_same = sum(min(pred_tokens.count(w), gold_tokens.count(w))
                   for w in common)
    if len(pred_tokens) == 0 or len(gold_tokens) == 0:
        return float(pred_tokens == gold_tokens)
    if num_same == 0:
        return 0.0
    precision = num_same / len(pred_tokens)
    recall = num_same / len(gold_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1


def get_machine_code():
    return "test"


def write_string_with_time(content, filename=None):
    now = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
    line = f"[{now}] {content}\n"
    with open(filename, 'a', encoding='utf-8') as f:
        f.write(line)


def print_with_time(content, print_file_obj=None):
    print(f"[{time.strftime("%Y-%m-%d %H:%M:%S",
          time.localtime())}]   ", end="", file=print_file_obj)
    print(content, file=print_file_obj)


def read_jsonl(filepath):

    jsons = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = orjson.loads(line)
            jsons.append(obj)
    return jsons


def write_jsonl(one_json, filepath):

    with open(filepath, 'a', encoding='utf-8') as f:
        f.write(json.dumps(one_json, ensure_ascii=False) + '\n')


def check_existence(filepath):

    if os.path.exists(filepath):
        raise ValueError(f"{filepath} already exists, please remove it first")


def extract_tag(string, tag, check=True):
    start_tag = f"<{tag}>"
    end_tag = f"</{tag}>"
    if check:
        assert end_tag in string, f"{end_tag} not found"
        assert string.count(start_tag) <= 1 and string.count(
            end_tag) <= 1, f"{tag} should only appear once in {string}"
    if start_tag in string:
        res = string.split(start_tag)[-1].split(end_tag)[0]
    else:
        return ""
    return res.strip()


def find_last_boxed_content(s):
    positions = []
    i = 0
    while i < len(s):
        if s.startswith(r'\boxed{', i):
            start = i + len(r'\boxed{')
            stack = 1
            j = start
            while j < len(s):
                if s[j] == '{':
                    stack += 1
                elif s[j] == '}':
                    stack -= 1
                    if stack == 0:
                        positions.append((start, j))
                        break
                j += 1
            i = j
        else:
            i += 1
    if positions:
        start, end = positions[-1]
        return s[start:end]
    else:
        return None


def remove_html_tags(text):

    clean = re.sub(r'<[^>]+>', '', text)
    return clean


def get_model_from_dirs(models_dir):
    models = []
    if models_dir is None:
        return models
    for model_dir in models_dir:
        if os.path.exists(model_dir):
            for each_model in os.listdir(model_dir):
                if each_model.startswith("checkpoint"):
                    models.append(f"{model_dir}/{each_model}")
    models.sort()
    return models


def get_cuda0_memory_gb():
    if not torch.cuda.is_available():
        raise RuntimeError("CUDA is not available.")
    device = torch.device('cuda:0')
    total_bytes = torch.cuda.get_device_properties(device).total_memory
    total_gb = total_bytes / (1024 ** 3)
    return float(total_gb)


def get_models():
    models_dir = []
    model_list = []
    model_list += [each+"-SWIFT" for each in BACKBONE_MODELS]
    model_list += [each+"-gcore" for each in BACKBONE_MODELS]

    for each in BASELINES.values():

        model_list.extend(each)
    train_list = ["DPO", "GRPO"]
    subset_list = ["train_ANS", "train_NOANS", "combine"]
    for model in model_list:
        for train in train_list:
            for subset in subset_list:
                for method in METHODS+GRPO_METHODS:
                    for traindataset in ["RAG", ""]:
                        models_dir.append(
                            f"{model}-{train}-{traindataset}{subset}-{method}")
    models_dir = [
        f"{machine_pather()}/works/DPO/src/trainedmodels/" + each for each in models_dir]
    models = get_model_from_dirs(
        models_dir
    )
    models.sort(key=str.lower)

    models = [f"{machine_pather()}/models/" +
              each for each in BACKBONE_MODELS]+models
    for each in BASELINES.values():
        each_model = [f"{machine_pather()}/models/"+m for m in each]
        models.extend(each_model)

    return models


def get_unfinished_tasks(machine_code,
                         output_root=f"{machine_pather()}/works/DPO/output",
                         verbose=False
                         ):
    models = get_models()
    datasets = os.listdir(
        f"{machine_pather()}/works/DPO/data")
    input_files = ["test_NOANS.jsonl"]
    datasets.sort(key=str.lower)

    unfinished_tasks = []
    for model in models:
        for dataset in datasets:
            for input_file in input_files:
                model_short_name = model.split("/")[-1]
                if "checkpoint" in model:
                    model_short_name = model.split(
                        "/")[-2] + "/" + model.split("/")[-1]
                output_dir = os.path.join(
                    output_root, dataset, input_file.split(
                        '.')[0], model_short_name
                )
                if os.path.exists(output_dir):
                    if verbose:
                        print(
                            f"Skip {dataset} {input_file} {model_short_name} because output already exists")
                    continue
                unfinished_tasks.append(
                    (model, dataset, input_file,  model_short_name, output_dir))
    return unfinished_tasks


# 示例用法
if __name__ == "__main__":
    print(get_models())
