import json
import os
from typing import Tuple

import numpy as np
import pandas as pd

# GSoS + PPO (token)
GSOS_PPO_TOK_CKPTS = [
    "/home/user/train-countdown/tril/outputs/ppo/240815_233503/model_200",
    "/home/user/train-countdown/tril/outputs/ppo/240817_092910/model_200",
    "/home/user/train-countdown/tril/outputs/ppo/240817_093017/model_200",
]

# GSoS + PPO (operation)
GSOS_PPO_OP_CKPTS = [
    "/home/user/train-countdown/tril/outputs/hppo/240815_233415/model_200",
    "/home/user/train-countdown/tril/outputs/hppo/240817_092910/model_200",
    "/home/user/train-countdown/tril/outputs/hppo/240817_092935/model_200",
]


def get_accuracy(ckpt: str, split: str = "test") -> Tuple[float, float]:
    # Load stats
    prefix = f"stats_final_{split}"
    files = sorted([f for f in os.listdir(ckpt) if f.startswith(prefix)])
    gen_ratings = []
    ref_ratings = []
    for file in files:
        file = os.path.join(ckpt, file)
        with open(file, "r") as f:
            stats = json.load(f)
            gen_ratings += stats["gen_ratings"]
            ref_ratings += stats["ref_ratings"]

    # Compute accuracy
    gen_accuracy = np.mean([r > 0 for r in gen_ratings])
    ref_accuracy = np.mean([r > 0 for r in ref_ratings])

    return gen_accuracy, ref_accuracy


def main():
    # GSoS + PPO (token)
    gsos_ppo_tok_val_acc = []
    gsos_ppo_tok_test_acc = []
    for ckpt in GSOS_PPO_TOK_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_ppo_tok_val_acc.append(val_acc)
        gsos_ppo_tok_test_acc.append(test_acc)
    gsos_ppo_tok_val_acc_mean = np.mean(gsos_ppo_tok_val_acc)
    gsos_ppo_tok_val_acc_std = np.std(gsos_ppo_tok_val_acc)
    gsos_ppo_tok_test_acc_mean = np.mean(gsos_ppo_tok_test_acc)
    gsos_ppo_tok_test_acc_std = np.std(gsos_ppo_tok_test_acc)

    # GSoS + PPO (operation)
    gsos_ppo_op_val_acc = []
    gsos_ppo_op_test_acc = []
    for ckpt in GSOS_PPO_OP_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_ppo_op_val_acc.append(val_acc)
        gsos_ppo_op_test_acc.append(test_acc)
    gsos_ppo_op_val_acc_mean = np.mean(gsos_ppo_op_val_acc)
    gsos_ppo_op_val_acc_std = np.std(gsos_ppo_op_val_acc)
    gsos_ppo_op_test_acc_mean = np.mean(gsos_ppo_op_test_acc)
    gsos_ppo_op_test_acc_std = np.std(gsos_ppo_op_test_acc)

    # Save data to CSV
    data = {
        "model": [
            "GSoS+PPO (token)",
            "GSoS+PPO (operation)",
        ],
        "val_acc_mean": [
            gsos_ppo_tok_val_acc_mean,
            gsos_ppo_op_val_acc_mean,
        ],
        "val_acc_std": [
            gsos_ppo_tok_val_acc_std,
            gsos_ppo_op_val_acc_std,
        ],
        "test_acc_mean": [
            gsos_ppo_tok_test_acc_mean,
            gsos_ppo_op_test_acc_mean,
        ],
        "test_acc_std": [
            gsos_ppo_tok_test_acc_std,
            gsos_ppo_op_test_acc_std,
        ],
    }
    df = pd.DataFrame(data)
    df.to_csv("data/gsos_rl_acc.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
