import json
import os
from typing import List

import numpy as np
import pandas as pd

# SoS
SOS_CKPT = "/home/user/train-countdown/stream-of-search/outputs/sft-gpt2/checkpoint-50000"

# SoS + STaR
SOS_STAR_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-s1-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-s2-gpt2/checkpoint-20000",
]

# SoS + PPO
SOS_PPO_CKPTS = [
    "/home/user/train-countdown/tril/outputs/hppo/240820_073947/model_100",
    "/home/user/train-countdown/tril/outputs/hppo/240820_193933/model_150",
    "/home/user/train-countdown/tril/outputs/hppo/240821_150618/model_100",
]

# GSoS
GSOS_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s1-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s2-gpt2/checkpoint-20000",
]

# GSoS + PPO
GSOS_PPO_CKPTS = [
    "/home/user/train-countdown/tril/outputs/hppo/240815_233415/model_200",
    "/home/user/train-countdown/tril/outputs/hppo/240817_092910/model_200",
    "/home/user/train-countdown/tril/outputs/hppo/240817_092935/model_200",
]


def get_accuracy(ckpt: str, split: str = "test") -> List[float]:
    # Load data
    prefix = f"stats_final_{split}"
    files = sorted([f for f in os.listdir(ckpt) if f.startswith(prefix)])
    gen_ratings = []
    ref_ratings = []
    for file in files:
        file = os.path.join(ckpt, file)
        with open(file, "r") as f:
            data = json.load(f)
            gen_ratings += data["gen_ratings"]
            ref_ratings += data["ref_ratings"]

    # Compute accuracy
    gen_accuracy = np.mean([r > 0 for r in gen_ratings])
    ref_accuracy = np.mean([r > 0 for r in ref_ratings])

    return gen_accuracy, ref_accuracy


def main():
    # SoS & Ref
    sos_val_acc, ref_val_acc = get_accuracy(SOS_CKPT, split="val")
    sos_test_acc, ref_test_acc = get_accuracy(SOS_CKPT, split="test")

    # SoS + STaR
    sos_star_val_acc = []
    sos_star_test_acc = []
    for ckpt in SOS_STAR_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        sos_star_val_acc.append(val_acc)
        sos_star_test_acc.append(test_acc)
    sos_star_val_acc_mean = np.mean(sos_star_val_acc)
    sos_star_val_acc_std = np.std(sos_star_val_acc)
    sos_star_test_acc_mean = np.mean(sos_star_test_acc)
    sos_star_test_acc_std = np.std(sos_star_test_acc)

    # SoS + PPO
    sos_ppo_val_acc = []
    sos_ppo_test_acc = []
    for ckpt in SOS_PPO_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        sos_ppo_val_acc.append(val_acc)
        sos_ppo_test_acc.append(test_acc)
    sos_ppo_val_acc_mean = np.mean(sos_ppo_val_acc)
    sos_ppo_val_acc_std = np.std(sos_ppo_val_acc)
    sos_ppo_test_acc_mean = np.mean(sos_ppo_test_acc)
    sos_ppo_test_acc_std = np.std(sos_ppo_test_acc)

    # GSoS
    gsos_val_acc = []
    gsos_test_acc = []
    for ckpt in GSOS_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_val_acc.append(val_acc)
        gsos_test_acc.append(test_acc)
    gsos_val_acc_mean = np.mean(gsos_val_acc)
    gsos_val_acc_std = np.std(gsos_val_acc)
    gsos_test_acc_mean = np.mean(gsos_test_acc)
    gsos_test_acc_std = np.std(gsos_test_acc)

    # GSoS + PPO
    gsos_ppo_val_acc = []
    gsos_ppo_test_acc = []
    for ckpt in GSOS_PPO_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_ppo_val_acc.append(val_acc)
        gsos_ppo_test_acc.append(test_acc)
    gsos_ppo_val_acc_mean = np.mean(gsos_ppo_val_acc)
    gsos_ppo_val_acc_std = np.std(gsos_ppo_val_acc)
    gsos_ppo_test_acc_mean = np.mean(gsos_ppo_test_acc)
    gsos_ppo_test_acc_std = np.std(gsos_ppo_test_acc)

    # Save data to CSV
    data = {
        "model": [
            "Symbolic",
            "SoS",
            "SoS+STaR",
            "SoS+PPO",
            "GSoS",
            "GSoS+PPO",
        ],
        "val_acc_mean": [
            ref_val_acc,
            sos_val_acc,
            sos_star_val_acc_mean,
            sos_ppo_val_acc_mean,
            gsos_val_acc_mean,
            gsos_ppo_val_acc_mean,
        ],
        "val_acc_std": [
            0,
            0,
            sos_star_val_acc_std,
            sos_ppo_val_acc_std,
            gsos_val_acc_std,
            gsos_ppo_val_acc_std,
        ],
        "test_acc_mean": [
            ref_test_acc,
            sos_test_acc,
            sos_star_test_acc_mean,
            sos_ppo_test_acc_mean,
            gsos_test_acc_mean,
            gsos_ppo_test_acc_mean,
        ],
        "test_acc_std": [
            0,
            0,
            sos_star_test_acc_std,
            sos_ppo_test_acc_std,
            gsos_test_acc_std,
            gsos_ppo_test_acc_std,
        ],
    }
    df = pd.DataFrame(data)
    df.to_csv("data/acc.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
