import numpy as np
import score_worker
from utils import NoPrint
import click
import os

@click.command()
@click.argument("model_name")
@click.argument("task")
def main(model_name, task):
    if task == "trec10":
        train_splits = ["0", "1", "3", "4", "5"]
    elif task == "emotion":
        train_splits = ["0", "1", "3", "4"]
    elif task == "agnews":
        train_splits = ["0", "1", "2", "3"]
    elif task == "tacred":
        train_splits = ["0"]
    trials = [str(x) for x in range(1, 6)]
    all_aurocs = []
    all_auacs = []
    all_accs = []
    variances = {"AUAC": [], "AUROC": [], "Acc": []}
    for split in train_splits:
        aurocs = []
        auacs = []
        accs = []
        for t in trials:
            try:
                if not os.listdir(f"models/{task}/{split}/{model_name}_seed_{t}/"):
                    continue
                os.makedirs(f"sp_scores/{task}/{split}/", exist_ok=True)
                with NoPrint():
                    result = score_worker.run(f"models/{task}/{split}/{model_name}_seed_{t}/", 
                        f"sp_scores/{task}/{split}/{model_name}_seed_{t}.pkl",
                        "spred",
                        f"data/{task}/{split}/train/",
                        f"data/{task}/{split}/validation/",
                        f"data/{task}/{split}/test-final/",
                        use_cache=False)
            except Exception as e:
                print(e)
                print("Split", split, "Trial", t, "errored.")
                continue
            result = list(result.values())[0]
            aurocs.append(result["auroc"])
            all_aurocs.append(result["auroc"])
            auacs.append(result["auac"])
            all_auacs.append(result["auac"])
            accs.append(result["id-acc"])
            all_accs.append(result["id-acc"])
        #print("Generalizing from", split)
        #print("AUAC:", np.round(np.mean(auacs), 3), "+-", np.round(np.std(auacs), 3))
        variances["AUAC"].append(np.std(auacs))
        #print("AUROC:", np.round(np.mean(aurocs), 3), "+-", np.round(np.std(aurocs), 3))
        variances["AUROC"].append(np.std(aurocs))
        #print("Acc:", np.round(np.mean(accs), 3), "+-", np.round(np.std(accs), 3))
        variances["Acc"].append(np.std(accs))
    averaged_vars = {}
    for name, v in variances.items():
        total_std = np.sqrt(np.mean(np.square(v)))
        averaged_vars[name] = total_std
    print(f"*** All stats: {model_name} {task} ***")
    print(all_auacs)
    print(all_aurocs)
    print("AUAC:", np.round(np.mean(all_auacs), 4), "+-" , np.round(averaged_vars["AUAC"] / np.sqrt(len(variances["AUAC"])), 4))
    print("AUROC:", np.round(np.mean(all_aurocs), 4), "+-", np.round(averaged_vars["AUROC"] / np.sqrt(len(variances["AUROC"])), 4))
    print("Acc:", np.round(np.mean(all_accs), 4), "+-", np.round(averaged_vars["Acc"] / np.sqrt(len(variances["Acc"])), 4))

if __name__ == "__main__":
    main()
