import glob
from pathlib import Path

import numpy as np
import pandas as pd
from ser.utils.metric import calculate_score_classification


def run(main_path):
    small_names = glob.glob(f"{main_path}/*", recursive=False)
    # assert len(small_names) == 10, f"Expected 10 fold cross-validation"
    # df = pd.DataFrame(columns=["label", "top1", "top3", "top5", "top12", "top24"])

    all_metrics = {
        "balanced_accuracy": [],
        "macro_recall": [],
        "macro_f1": [],
        "macro_precision": [],
        "accuracy": [],
        "f1_weighted": [],
        "precision_weighted": [],
        "f1w_angry": [],
        "f1w_neutral": [],
        "f1w_happy": [],
        "f1w_sad": [],
    }
    for i, name in enumerate(small_names):
        if not name.endswith(".csv"):
            small_df = pd.read_csv(Path(main_path) / f"{name}/test_predictions_log_softmax.csv")
            df = small_df
            # df = pd.concat([df, small_df], ignore_index=True)
            for i, topk in enumerate(["top1"]):  # "top3", "top5", "top12", "top24"
                y_true = list(df["label"])
                y_pred = list(df[topk])
                assert len(y_true) == len(y_pred)
                ba_ac, ma_recall, ma_f1, f1pcl, ma_pr, cm, acc, wa_f1, wa_pr = calculate_score_classification(
                    y_pred,
                    y_true,
                    labels=[0, 1, 2, 3],
                )
                # print(i, ba_ac, ma_recall, ma_f1, f1pcl, ma_pr, cm, acc, wa_f1, wa_pr)
                all_metrics["balanced_accuracy"].append(ba_ac)
                all_metrics["accuracy"].append(acc)
                all_metrics["macro_recall"].append(ma_recall)
                all_metrics["macro_f1"].append(ma_f1)
                all_metrics["precision_weighted"].append(wa_pr)
                all_metrics["macro_precision"].append(ma_pr)
                all_metrics["f1_weighted"].append(wa_f1)
                all_metrics["f1w_angry"].append(f1pcl[0])
                all_metrics["f1w_neutral"].append(f1pcl[1])
                all_metrics["f1w_happy"].append(f1pcl[2])
                all_metrics["f1w_sad"].append(f1pcl[3])
    all_metrics_stats = {}
    for k, v in all_metrics.items():
        try:
            all_metrics_stats[f"mean_{k}"] = np.mean(v)
            all_metrics_stats[f"std_{k}"] = np.std(v)
        except Exception as e:
            print(f"Failed to compute mean for key: {k}")
            assert False
    metrics_df = pd.DataFrame(all_metrics_stats, index=[0])
    metrics_df.to_csv(Path(main_path) / "cross_validation_metrics_logsoftmax.csv", index=False)


if __name__ == "__main__":
    main_path = ""
    run(main_path)
