import os
import json
from collections import defaultdict

ROOT_DIR = "results/prompt_tuning/deepseek-chat"
OUTPUT_LATEX_PATH = "tables/prompt_tuning_table.txt"

K_VALUES = ["k=2", "k=3", "k=4", "k=5"]

ENCODING_COLUMNS = ["many→one", "one→one"] + [f"{k}" for k in K_VALUES]

def get_prompt_folders():
    return sorted([
        d for d in os.listdir(ROOT_DIR)
        if os.path.isdir(os.path.join(ROOT_DIR, d))
    ])

def collect_accuracies(prompt_folders):
    acc_dict = defaultdict(lambda: defaultdict(dict))

    for prompt in prompt_folders:
        prompt_path = os.path.join(ROOT_DIR, prompt)

        # many_to_one and one_to_one
        for encoding in ["many_to_one", "one_to_one"]:
            encoding_path = os.path.join(prompt_path, encoding)
            if not os.path.isdir(encoding_path):
                continue

            for task in os.listdir(encoding_path):
                task_path = os.path.join(encoding_path, task)
                if not os.path.isdir(task_path):
                    continue

                for file in os.listdir(task_path):
                    if file.endswith(".json"):
                        with open(os.path.join(task_path, file)) as f:
                            data = json.load(f)
                            acc = data.get("accuracy", None)
                            if acc is not None:
                                acc_dict[task][prompt][encoding] = acc

        # one_to_many/k=...
        encoding_path = os.path.join(prompt_path, "one_to_many")
        if not os.path.isdir(encoding_path):
            continue

        for k in K_VALUES:
            k_path = os.path.join(encoding_path, k)
            if not os.path.isdir(k_path):
                continue

            for task in os.listdir(k_path):
                task_path = os.path.join(k_path, task)
                if not os.path.isdir(task_path):
                    continue

                for file in os.listdir(task_path):
                    if file.endswith(".json"):
                        with open(os.path.join(task_path, file)) as f:
                            data = json.load(f)
                            acc = data.get("accuracy", None)
                            if acc is not None:
                                acc_dict[task][prompt][k] = acc
    return acc_dict

def write_latex_table(acc_dict, prompt_folders, output_path):
    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    with open(output_path, "w") as f:
        f.write("\\documentclass{article}\n")
        f.write("\\usepackage{booktabs, multirow, xcolor}\n")
        f.write("\\input{02-Paper/01-packages}\n")
        f.write("\\begin{document}\n\n")
        f.write("\\begin{table}[htbp]\n\\centering\n\\small\n")
        f.write("\\begin{tabular}{ll" + "r" * len(ENCODING_COLUMNS) + "}\n")
        f.write("\\toprule\n")
        f.write("\\multirow{3}{*}{\\textbf{Language}} & \\multirow{3}{*}{\\textbf{Prompting Strategy}} & \\multicolumn{6}{c}{\\textbf{Encoding Strategy}} \\\\\n")
        f.write("\\cmidrule(lr){3-8}\n")
        f.write(" & & \\multirow{2}{*}{many $\\to$ one} & \\multirow{2}{*}{one $\\to$ one} & \\multicolumn{4}{c}{one $\\to$ many} \\\\\n")
        f.write("\\cmidrule(lr){5-8}\n")
        f.write(" & & & & 2 & 3 & 4 & 5 \\\\\n")
        f.write("\\midrule\n")

        for j, task in enumerate(sorted(acc_dict.keys())):
            for i, prompt in enumerate(prompt_folders):
                row = []
                if i == 0:
                    row.append(f"\\multirow{{{len(prompt_folders)}}}{{*}}{{{task}}}")
                else:
                    row.append("")

                row.append(prompt)  # Use raw prompt folder name or map if needed

                def get_val(enc_key):
                    val = acc_dict[task].get(prompt, {}).get(enc_key, None)
                    return f"{val:.2f}" if isinstance(val, float) else "--"

                row.append(get_val("many_to_one"))
                row.append(get_val("one_to_one"))
                row.extend(get_val(k) for k in K_VALUES)

                f.write(" & ".join(row) + " \\\\\n")
            f.write("\\midrule\n" if j < len(acc_dict.keys()) - 1 else "")

        f.write("\\bottomrule\n\\end{tabular}\n")
        f.write("\\caption{Accuracy per task, grouped by prompting strategy and encoding strategy.}\n")
        f.write("\\label{tab:prompt_tuning}\n\\end{table}\n\n\\end{document}\n")

    print(f"LaTeX table saved to {output_path}")

if __name__ == "__main__":
    prompts = get_prompt_folders()
    acc_dict = collect_accuracies(prompts)
    write_latex_table(acc_dict, prompts, OUTPUT_LATEX_PATH)
