import json
import os
from typing import Tuple

import numpy as np
import pandas as pd

# SoS + PPO
SOS_PPO_CKPTS = [
    "/home/user/train-countdown/tril/outputs/hppo/240820_073947/model_100",
    "/home/user/train-countdown/tril/outputs/hppo/240820_193933/model_150",
    "/home/user/train-countdown/tril/outputs/hppo/240821_150618/model_100",
]

# SoS + STaR + PPO
SOS_STAR_PPO_CKPTS = [
    "/home/user/train-countdown/tril/outputs/hppo/240818_172757/model_200",
    "/home/user/train-countdown/tril/outputs/hppo/240818_173641/model_200",
    "/home/user/train-countdown/tril/outputs/hppo/240820_073438/model_200",
]


def get_accuracy(ckpt: str, split: str = "test") -> Tuple[float, float]:
    # Load stats
    prefix = f"stats_final_{split}"
    files = sorted([f for f in os.listdir(ckpt) if f.startswith(prefix)])
    gen_ratings = []
    ref_ratings = []
    for file in files:
        file = os.path.join(ckpt, file)
        with open(file, "r") as f:
            stats = json.load(f)
            gen_ratings += stats["gen_ratings"]
            ref_ratings += stats["ref_ratings"]

    # Compute accuracy
    gen_accuracy = np.mean([r > 0 for r in gen_ratings])
    ref_accuracy = np.mean([r > 0 for r in ref_ratings])

    return gen_accuracy, ref_accuracy


def main():
    # SoS + PPO
    sos_ppo_val_acc = []
    sos_ppo_test_acc = []
    for ckpt in SOS_PPO_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        sos_ppo_val_acc.append(val_acc)
        sos_ppo_test_acc.append(test_acc)
    sos_ppo_val_acc_mean = np.mean(sos_ppo_val_acc)
    sos_ppo_val_acc_std = np.std(sos_ppo_val_acc)
    sos_ppo_test_acc_mean = np.mean(sos_ppo_test_acc)
    sos_ppo_test_acc_std = np.std(sos_ppo_test_acc)

    # SoS + STaR + PPO
    sos_star_ppo_val_acc = []
    sos_star_ppo_test_acc = []
    for ckpt in SOS_STAR_PPO_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        sos_star_ppo_val_acc.append(val_acc)
        sos_star_ppo_test_acc.append(test_acc)
    sos_star_ppo_val_acc_mean = np.mean(sos_star_ppo_val_acc)
    sos_star_ppo_val_acc_std = np.std(sos_star_ppo_val_acc)
    sos_star_ppo_test_acc_mean = np.mean(sos_star_ppo_test_acc)
    sos_star_ppo_test_acc_std = np.std(sos_star_ppo_test_acc)

    # Save data to CSV
    data = {
        "model": [
            "SoS+PPO",
            "SoS+STaR+PPO",
        ],
        "val_acc_mean": [
            sos_ppo_val_acc_mean,
            sos_star_ppo_val_acc_mean,
        ],
        "val_acc_std": [
            sos_ppo_val_acc_std,
            sos_star_ppo_val_acc_std,
        ],
        "test_acc_mean": [
            sos_ppo_test_acc_mean,
            sos_star_ppo_test_acc_mean,
        ],
        "test_acc_std": [
            sos_ppo_test_acc_std,
            sos_star_ppo_test_acc_std,
        ],
    }
    df = pd.DataFrame(data)
    df.to_csv("data/sos_rl_acc.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
