"""Utils used for the FEMNIST project."""


from typing import Dict, List, Tuple

import numpy as np
import torch
from flwr.common import Metrics, Scalar
import matplotlib.pyplot as plt
from typing import Optional
import argparse
import copy
from sklearn.ensemble import RandomForestClassifier
from mblearn import AttackModels, ShadowModels
from sklearn.metrics import accuracy_score, precision_score, recall_score
from flwr.common.typing import NDArrays, Scalar
import time
import numpy as np

# Load the .npy file
history = np.load('results/CIFAR_FATS/history_CIFAR10_20241114-160643.npy', allow_pickle=True)

def plot_metric_from_history(
    hist,
    save_plot_path,
    suffix: Optional[str] = "",
    metric_type: str = "distributed",
) -> None:
    """Function to plot from Flower server History.

    Parameters
    ----------
    hist : History
        Object containing evaluation for all rounds.
    save_plot_path : Path
        Folder to save the plot to.
    expected_maximum : float
        The expected maximum accuracy from the original paper.
    suffix: Optional[str]
        Optional string to add at the end of the filename for the plot.
    """
    hist = hist.item()
    metric_dict = (
        hist.metrics_centralized
        if hist.metrics_centralized
        else hist.metrics_distributed
    )
    rounds, values = zip(*metric_dict["accuracy"])
    fig = plt.figure()
    axis = fig.add_subplot(111)
    plt.plot(np.asarray(rounds), np.asarray(values), label="FedAvg")
    # Set expected graph for data
    # plt.axhline(
    #     y=expected_maximum,
    #     color="r",
    #     linestyle="--",
    #     label=f"Paper's best result @{expected_maximum}",
    # )
    # # Set paper's results
    # plt.axhline(
    #     y=0.99,
    #     color="silver",
    #     label="Paper's baseline @0.9900",
    # )
    # plt.ylim([0, 1])
    plt.title(f"{metric_type.capitalize()} Validation Accuracy")
    plt.xlabel("Rounds")
    plt.ylabel("Accuracy")
    plt.legend(loc="lower right")

    # Set the apect ratio to 1.0
    xleft, xright = axis.get_xlim()
    ybottom, ytop = axis.get_ylim()
    axis.set_aspect(abs((xright - xleft) / (ybottom - ytop)) * 1.0)

    plt.savefig(save_plot_path+f"/{suffix}_{metric_type}_metrics.png")
    plt.close()


import yaml

# Load the YAML config file
with open('FATS_supplement_debugged/conf/table2_cifar10_FATS_baseline.yaml', 'r') as file:
    cfg = yaml.safe_load(file)

plot_metric_from_history(
        history,
        cfg['results_dir_path'],
        cfg['dataset'] + time.strftime("%Y%m%d-%H%M%S"),
        metric_type='distributed',
    )