import torch
import os
import matplotlib.pyplot as plt
import argparse
import sys
import pandas as pd
from utils import gather_results
from models import test_domains
from typing import List


def parse_arguments(args_to_parse):

    parser = argparse.ArgumentParser()
    parser.add_argument("ckpt_folder", type=str, help="folder where runs are stored")
    parser.add_argument("job_ids", type=list, help="list of jobs to plot")
    parser.add_argument(
        "--ckpt_name",
        type=str,
        default="best_val_loss.ckpt",
        help="name of model checkpoint",
    )
    args = parser.parse_args()
    return args


def plot_all_test_domains(ckpt_folder: str, job_ids: List[str], ckpt_name: str):
    domains = ["id", "ood_b", "ood_r", "ood_br"]
    jobs = [str(j) for j in job_ids]
    n_jobs = len(jobs)

    val_acc = {}
    test_acc = {}

    for j in jobs:
        job_folder = os.path.join(ckpt_folder, j)
        val_acc[j] = {}
        val_acc[j]["id"] = test_domains.load_val_accuracy(job_folder, ckpt_name)
        test_acc[j] = test_domains.load_values("domain_accuracies", job_folder)
    colors = torch.rand(size=(3, n_jobs))
    for d in domains:
        _, ax = plt.subplots(1, 1)
        ax.set_xlabel("Validation accuracy", fontsize=14)
        ax.set_ylabel("Test {0} accuracy".format(d), fontsize=14)
        ax.set_title("Test {0} accuracy".format(d), fontsize=14)
        ax.grid(which="major", axis="both", linestyle="--")

        for i, j in enumerate(jobs):
            ax.scatter(
                val_acc[j]["id"], test_acc[j][d], color=colors[:, i].numpy(), marker="+"
            )
            ax.set_xlim([0.8, 1])
            ax.set_ylim([0, 1])
        plt.show()


def plot_per_hparams(ckpt_folder, job_ids, ckpt_name, sorting_hparam):
    domains = ["id", "ood_b", "ood_r", "ood_br"]

    df_results = gather_results.parse_results(ckpt_folder, job_ids, ckpt_name)
    df_mean_std = df_results.groupby([sorting_hparam]).agg(["mean", "std"])
    n_rows = df_mean_std.shape[0]
    torch.random.manual_seed(0)
    colors = torch.rand(size=(3, n_rows)).numpy()
    fig, axes = plt.subplots(2, 2, figsize=(10, 10))
    for k, d in enumerate(domains):
        i = k % 2
        j = int((k - i) / 2)
        ax = axes[i, j]
        ax.set_xlabel("Validation accuracy", fontsize=14)
        ax.set_ylabel("Test {0} accuracy".format(d), fontsize=14)
        ax.set_title("Test {0} accuracy".format(d), fontsize=14)
        ax.grid(which="major", axis="both", linestyle="--")
        for i in range(n_rows):
            x = df_mean_std.iloc[i]["val_acc"]["mean"]
            y = df_mean_std.iloc[i]["test_acc_{0}".format(d)]["mean"]
            e = df_mean_std.iloc[i]["test_acc_{0}".format(d)]["std"]
            ax.errorbar(
                x,
                y,
                e,
                color=colors[:, i],
                marker=".",
                linestyle="None",
                label=df_mean_std.index[i],
            )
            ax.set_xlim([0.8, 1])
            ax.set_ylim([0, 1])
        if k == len(domains) - 1:
            handles, labels = ax.get_legend_handles_labels()
    fig.legend(handles, labels, loc="center left")



if __name__ == "__main__":
    args = parse_arguments(sys.argv[1:])
    plot_all_test_domains(args.ckpt_folder, args.job_ids, args.ckpt_name)
