import json
import os
from typing import Tuple

import numpy as np
import pandas as pd

# SoS + STaR
SOS_STAR_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-s1-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-s2-gpt2/checkpoint-20000",
]

# GSoS (first)
GSOS_FIRST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-first-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-first-s1-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-first-s2-gpt2/checkpoint-20000",
]

# GSoS (rand)
GSOS_RAND_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s1-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-rand-s2-gpt2/checkpoint-20000",
]

# GSoS (last)
GSOS_LAST_CKPTS = [
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-last-s0-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-last-s1-gpt2/checkpoint-20000",
    "/home/user/train-countdown/stream-of-search/outputs/star3-final-last-s2-gpt2/checkpoint-20000",
]


def get_accuracy(ckpt: str, split: str = "test") -> Tuple[float, float]:
    # Load stats
    prefix = f"stats_final_{split}"
    files = sorted([f for f in os.listdir(ckpt) if f.startswith(prefix)])
    gen_ratings = []
    ref_ratings = []
    for file in files:
        file = os.path.join(ckpt, file)
        with open(file, "r") as f:
            stats = json.load(f)
            gen_ratings += stats["gen_ratings"]
            ref_ratings += stats["ref_ratings"]

    # Compute accuracy
    gen_accuracy = np.mean([r > 0 for r in gen_ratings])
    ref_accuracy = np.mean([r > 0 for r in ref_ratings])

    return gen_accuracy, ref_accuracy


def main():
    # SoS + STaR
    sos_star_val_acc = []
    sos_star_test_acc = []
    for ckpt in SOS_STAR_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        sos_star_val_acc.append(val_acc)
        sos_star_test_acc.append(test_acc)
    sos_star_val_acc_mean = np.mean(sos_star_val_acc)
    sos_star_val_acc_std = np.std(sos_star_val_acc)
    sos_star_test_acc_mean = np.mean(sos_star_test_acc)
    sos_star_test_acc_std = np.std(sos_star_test_acc)

    # GSoS (first)
    gsos_first_val_acc = []
    gsos_first_test_acc = []
    for ckpt in GSOS_FIRST_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_first_val_acc.append(val_acc)
        gsos_first_test_acc.append(test_acc)
    gsos_first_val_acc_mean = np.mean(gsos_first_val_acc)
    gsos_first_val_acc_std = np.std(gsos_first_val_acc)
    gsos_first_test_acc_mean = np.mean(gsos_first_test_acc)
    gsos_first_test_acc_std = np.std(gsos_first_test_acc)

    # GSoS (rand)
    gsos_rand_val_acc = []
    gsos_rand_test_acc = []
    for ckpt in GSOS_RAND_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_rand_val_acc.append(val_acc)
        gsos_rand_test_acc.append(test_acc)
    gsos_rand_val_acc_mean = np.mean(gsos_rand_val_acc)
    gsos_rand_val_acc_std = np.std(gsos_rand_val_acc)
    gsos_rand_test_acc_mean = np.mean(gsos_rand_test_acc)
    gsos_rand_test_acc_std = np.std(gsos_rand_test_acc)

    # GSoS (last)
    gsos_last_val_acc = []
    gsos_last_test_acc = []
    for ckpt in GSOS_LAST_CKPTS:
        val_acc, _ = get_accuracy(ckpt, split="val")
        test_acc, _ = get_accuracy(ckpt, split="test")
        gsos_last_val_acc.append(val_acc)
        gsos_last_test_acc.append(test_acc)
    gsos_last_val_acc_mean = np.mean(gsos_last_val_acc)
    gsos_last_val_acc_std = np.std(gsos_last_val_acc)
    gsos_last_test_acc_mean = np.mean(gsos_last_test_acc)
    gsos_last_test_acc_std = np.std(gsos_last_test_acc)

    # Save data to CSV
    data = {
        "model": [
            "SoS+STaR",
            "GSoS (first)",
            "GSoS (rand)",
            "GSoS (last)",
        ],
        "val_acc_mean": [
            sos_star_val_acc_mean,
            gsos_first_val_acc_mean,
            gsos_rand_val_acc_mean,
            gsos_last_val_acc_mean,
        ],
        "val_acc_std": [
            sos_star_val_acc_std,
            gsos_first_val_acc_std,
            gsos_rand_val_acc_std,
            gsos_last_val_acc_std,
        ],
        "test_acc_mean": [
            sos_star_test_acc_mean,
            gsos_first_test_acc_mean,
            gsos_rand_test_acc_mean,
            gsos_last_test_acc_mean,
        ],
        "test_acc_std": [
            sos_star_test_acc_std,
            gsos_first_test_acc_std,
            gsos_rand_test_acc_std,
            gsos_last_test_acc_std,
        ],
    }
    df = pd.DataFrame(data)
    df.to_csv("data/gsos_acc.csv", index=False)


if __name__ == "__main__":
    # Run main
    main()
