### Preamble ##########################################################################################################

"""
Stats runner for computing image generation quality and classification performance statistics for experiments.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, RandomSampler
from torchvision.transforms import Resize
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
import os
import timm
import pathlib
import json
import pickle
import shutil
import logging
from tqdm import tqdm
import argparse

from gcontrol.utils.mem_utils import flush

from torchmetrics.image.inception import InceptionScore
from torchmetrics.image.fid import FrechetInceptionDistance

from typing import Union, Iterable, Optional, Tuple, List, Any, Callable, Dict

from torchvision.models import (
    resnet50,
    inception_v3,
    vit_h_14,
    resnet152,
    swin_v2_b,
    vit_b_16,
    maxvit_t,
    ResNet50_Weights,
    Inception_V3_Weights,
    ViT_H_14_Weights,
    ResNet152_Weights,
    Swin_V2_B_Weights,
    ViT_B_16_Weights,
    MaxVit_T_Weights,
)
from misc.classifier_pipeline import (
    resnet50_config,
    inceptionv3_config,
    vith14_config,
    resnet152_config,
    swinb_config,
    deit_config,
    maxvit_config,
)

from misc.dataset_readers import NatAdvDiffImageDataset, ImageNetDataset, ImageDataset
from misc.classifier_pipeline import ClassifierPipeline
from misc.path_configs import (
    CACHE_DIR,
    EXPERIMENT_DIR,
    IMAGENET_PATH,
    IMAGENET_A_PATH,
    IMAGENET_CLASSES,
)

from gcontrol.utils import get_timm_config

#######################################################################################################################

FILE_PATH = pathlib.Path(__file__).parent
IMAGENET_VAL_PATH = IMAGENET_PATH / "ILSVRC" / "Data" / "CLS-LOC" / "val"

BATCH_SIZE = 100  # Batch size to be used for data loaders
NUM_WORKERS = 15  # Number of worker processors to be used in each data loader

# List of classifiers supported by experiment_runner
CLASSIFIER_NAMES = [
    "resnet50",
    "inceptionv3",
    "vit",
    "adv_resnet",
    "adv_inception",
    "resnet152",
    "swinb",
    "DeIT",
    "max_vit",
]


def create_dataset_loader(dataset: Dataset, num_replacement_sample: Optional[int] = None) -> DataLoader:
    """
    :param dataset: torch.utils.data.Dataset
        Pytorch dataset to create a DataLoader for.
    :param num_replacement_sample: None, int
        If not none, then the returned dataset loader will load from a random sample (with replacement) of size
        `num_replacement_sample` taken from `dataset`. If none, then no random sampling is performed and the entire
        dataset is used.

    Constructs and returns the required DataLoader with the option to undersample the dataset. Undersampling is used
    to optionally ensure the varying dataloaders have the same sample size.
    """

    if num_replacement_sample is not None:
        sampler = RandomSampler(dataset, replacement=True, num_samples=num_replacement_sample)
    else:
        sampler = None
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False, sampler=sampler, num_workers=NUM_WORKERS)

    return dataloader


def compute_FID(gen_dataloader: DataLoader, real_dataloader: DataLoader) -> dict:
    """
    :param gen_dataloader: DataLoader
        Pytorch dataloader of generated (fake) images. The dataloader must return a batch of images (no class labels).
    :param real_dataloader: DataLoader
        Pytorch dataloader of real images. The dataloader must return a batch of images (no class labels).

    Computes the Fréchet inception distance (FID) between the generated and real datasets.
    """

    fid = FrechetInceptionDistance()
    fid.cuda()

    with tqdm(total=len(gen_dataloader) + len(real_dataloader)) as pbar:
        for gen_batch in gen_dataloader:
            gen_batch = gen_batch.to("cuda")
            fid.update(gen_batch, real=False)
            pbar.update(1)
        for real_batch in real_dataloader:
            real_batch = real_batch.to("cuda")
            fid.update(real_batch, real=True)
            pbar.update(1)

    return {"fid": fid.compute().item()}


def compute_IS(dataloader: DataLoader) -> dict:
    """
    :param dataloader: DataLoader
        Pytorch dataloader of images. The dataloader must return a batch of images (no class labels).

    Computes the inception score (IS) for an image dataset.
    """

    isscore = InceptionScore()
    isscore.cuda()

    for batch in tqdm(dataloader):
        batch = batch.to("cuda")
        isscore.update(batch)

    res = isscore.compute()
    return {"is": res[0].item(), "se": res[1].item()}


def FID_from_dataset(gen_dataset: Dataset, real_dataset: Dataset, bootstrap_undersample: Optional[int] = None) -> dict:
    """
    :param gen_dataset: Dataset
        Pytorch dataset of generated (fake) images. The dataset must return a batch of images (no class labels).
    :param real_dataset: Dataset
        Pytorch dataset of real images. The dataset must return a batch of images (no class labels).
    :param bootsrap_undersample: None or int
        Whether to bootstrap the FID estimate in cases where sample sizes differ. FID is known to be sensitive to
        discrepencies in sample size. Bootstrapping the FID estimate will undersample the the larger dataset
        `bootstrap_undersample` times and return an estimate of FID and standard error. Ignored if the size of
        `gen_dataset` and `real_dataset` do not differ.

    Wrapper function for `compute_FID`. Creates dataloaders and returns the Fréchet inception distance (FID) between
    the generated and real datasets, with optional bootstrapping for unequal sample sizes.
    """

    if (bootstrap_undersample is None) or (len(gen_dataset) == len(real_dataset)):
        gen_loader = create_dataset_loader(dataset=gen_dataset, num_replacement_sample=None)
        real_loader = create_dataset_loader(dataset=real_dataset, num_replacement_sample=None)
        fid_dict = compute_FID(gen_loader, real_loader)
    else:
        if len(gen_dataset) > len(real_dataset):
            gen_loader = create_dataset_loader(dataset=gen_dataset, num_replacement_sample=len(real_dataset))
            real_loader = create_dataset_loader(dataset=real_dataset, num_replacement_sample=None)
        elif len(gen_dataset) < len(real_dataset):
            gen_loader = create_dataset_loader(dataset=gen_dataset, num_replacement_sample=None)
            real_loader = create_dataset_loader(dataset=real_dataset, num_replacement_sample=len(gen_dataset))

        fid_ests = []
        for i in range(bootstrap_undersample):
            fid_i = compute_FID(gen_loader, real_loader)
            fid_ests.append(fid_i["fid"])
        if bootstrap_undersample > 1:
            fid_dict = {"fid": np.mean(fid_ests).item(), "se": np.std(fid_ests).item()}
        else:
            fid_dict = {"fid": fid_ests[0]}

    return fid_dict


def stats_runner(
    config_dict: dict,
    force_bootstrap: Optional[int] = None,
    sub_sample: Optional[int] = None,
    sub_sample_type: str = "random",
    do_crop: bool = True,
) -> dict:
    """
    :param config: dict
        The experiment configuration dictionary.
    :param forced_bootstrap: int
        The number of bootstraps to run during FID calculation. If `None`, then defaults to what is specified in the
        experiment config.
    :param sub_sample: int
        Whether to run the image quality evaluation on a subset of the full dataset. A sample of size `sub_sample`
        will be drawn without replacement from the experiment dataset. This is useful when comparing image quality
        between experiments of difference sizes. If `None`, then defaults to what is specified in the experiment
        config.
    :param sub_sample_type: str
        The type of sub-sampling to perform. `random` generates a random selection from the dataset. `step` generates
        samples by evenly stepping through the dataset. Use `step` when comparing NatADiff samples to structured
        ImageNet datasets containing samples from a large selection of classes.
    :param do_crop: bool
        Whether to run the classifiers with the resize and crop preprocessing transforms specified on their associated
        torchvision model card. Defaults to True, if False then the image will be resized to the dimensions required
        by the classifier and no crop will be applied.

    Takes the experiment configuration dictionary and computes the image generation quality and classification
    statistics of the associated experiment.
    """

    # Getting stats settings (if present)
    imval_bootstrap_upsample = None
    ima_bootstrap_upsample = None
    if "Experiment Stats" in config_dict:
        if "stats_seed" in config_dict["Experiment Stats"]:
            np.random.seed(config_dict["Experiment Stats"]["stats_seed"])
        if force_bootstrap is None:
            if "imval_bootstrap_upsample" in config_dict["Experiment Stats"]:
                imval_bootstrap_upsample = config_dict["Experiment Stats"]["imval_bootstrap_upsample"]
            if "ima_bootstrap_upsample" in config_dict["Experiment Stats"]:
                ima_bootstrap_upsample = config_dict["Experiment Stats"]["ima_bootstrap_upsample"]
        if sub_sample is None:
            if "sub_sample" in config_dict["Experiment Stats"]:
                sub_sample = config_dict["Experiment Stats"]["sub_sample"]

    if force_bootstrap is not None:
        imval_bootstrap_upsample = force_bootstrap
        ima_bootstrap_upsample = force_bootstrap

    experiment_results = {"experiment_name": config_dict["experiment_name"]}

    # Computing image generation quality statistics
    ## Getting relevant datasets
    experiment_dataset = ImageDataset(EXPERIMENT_DIR / config_dict["experiment_name"], Resize((299, 299)))
    imagenet_dataset = ImageDataset(IMAGENET_VAL_PATH, Resize((299, 299)))
    imagenet_a_dataset = ImageDataset(IMAGENET_A_PATH, Resize((299, 299)))

    if sub_sample is None:
        imqual_experiment_dataset = experiment_dataset
    else:
        if sub_sample > len(experiment_dataset):
            raise ValueError("`sub_sample` can not be larger than the size of the experiment dataset")
        if sub_sample_type == "random":
            idxs = np.random.randint(0, len(experiment_dataset), sub_sample)
        elif sub_sample_type == "step":
            idxs = np.arange(0, len(experiment_dataset), len(experiment_dataset) // sub_sample)
            if len(experiment_dataset) % sub_sample != 0:
                np.append(idxs, len(experiment_dataset) - 1)
        else:
            raise ValueError(f"Unknown subsampling type, got {sub_sample_type} expected one of `random`, `step`")
        imqual_experiment_dataset = torch.utils.data.Subset(experiment_dataset, idxs)

    print(f"=== COMPUTING {config_dict['experiment_name']} IS ===")

    im_IS = IS_from_dataset(imqual_experiment_dataset)
    experiment_results["IS"] = im_IS
    print(im_IS)

    print(f"=== COMPUTING {config_dict['experiment_name']} vs ImageNet-val FID ===")

    imval_FID = FID_from_dataset(imqual_experiment_dataset, imagenet_dataset, imval_bootstrap_upsample)
    experiment_results["ImageNetVal_FID"] = imval_FID
    print(imval_FID)

    print(f"=== COMPUTING {config_dict['experiment_name']} vs ImageNet-a FID ===")

    ima_FID = FID_from_dataset(imqual_experiment_dataset, imagenet_a_dataset, ima_bootstrap_upsample)
    experiment_results["ImageNetA_FID"] = ima_FID
    print(ima_FID)

    experiment_results["ImQual_Sample_Size"] = len(imqual_experiment_dataset)

    flush()  # Clear out cuda memory

    # Computing classification accuracy statistics
    ## Getting relevant datasets
    experiment_dataset = NatAdvDiffImageDataset(EXPERIMENT_DIR / config_dict["experiment_name"])
    experiment_loader = DataLoader(experiment_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

    experiment_results["Classifier Accuracy"] = {}
    experiment_results["Naive Attack Success"] = {}
    if experiment_loader.dataset.adversarial_target is not None:
        experiment_results["Attack Success"] = {}

    for class_name in CLASSIFIER_NAMES:
        if class_name == "resnet50":
            classifier_model = resnet50(ResNet50_Weights.IMAGENET1K_V2).eval().to(dtype=torch.float32, device="cuda")
            classifier_config = resnet50_config
        elif class_name == "inceptionv3":
            classifier_model = (
                inception_v3(Inception_V3_Weights.IMAGENET1K_V1).eval().to(dtype=torch.float32, device="cuda")
            )
            classifier_config = inceptionv3_config
        elif class_name == "vit":
            classifier_model = (
                vit_h_14(ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1).eval().to(dtype=torch.float32, device="cuda")
            )
            classifier_config = vith14_config
        elif class_name == "adv_resnet":
            classifier_model = (
                timm.create_model(model_name="inception_resnet_v2.tf_ens_adv_in1k", pretrained=True)
                .eval()
                .to(dtype=torch.float32, device="cuda")
            )
            classifier_config = get_timm_config(classifier_model)
        elif class_name == "adv_inception":
            classifier_model = (
                timm.create_model(model_name="adv_inception_v3.tf_adv_in1k", pretrained=True)
                .eval()
                .to(dtype=torch.float32, device="cuda")
            )
            classifier_config = get_timm_config(classifier_model)
        elif class_name == "resnet152":
            classifier_model = resnet152(ResNet152_Weights.IMAGENET1K_V2).eval().to(dtype=torch.float32, device="cuda")
            classifier_config = resnet152_config
        elif class_name == "swinb":
            classifier_model = swin_v2_b(Swin_V2_B_Weights.IMAGENET1K_V1).eval().to(dtype=torch.float32, device="cuda")
            classifier_config = swinb_config
        elif class_name == "DeIT":
            classifier_model = vit_b_16(ViT_B_16_Weights.IMAGENET1K_V1).eval().to(dtype=torch.float32, device="cuda")
            classifier_config = deit_config
        elif class_name == "max_vit":
            classifier_model = maxvit_t(MaxVit_T_Weights.IMAGENET1K_V1).eval().to(dtype=torch.float32, device="cuda")
            classifier_config = maxvit_config
        else:
            raise ValueError(
                f"`classifier` must be one of {CLASSIFIER_NAMES} if"
                f" `guidance_type` != 'adversarial', got {class_name}"
            )
        if not do_crop:
            classifier_config["crop_pct"] = None

        classifier_pipeline = ClassifierPipeline(classifier_model, **classifier_config)
        flush()  # Clear out cuda memory

        print(f"=== COMPUTING {config_dict['experiment_name']} vs {class_name} ACCURACY ===")

        acc_total = 0
        if experiment_loader.dataset.adversarial_target is not None:
            att_total = 0
        else:
            att_total = None
        for x, y, z in tqdm(experiment_loader):
            y = y.to("cuda")
            if experiment_loader.dataset.adversarial_target is not None:
                z = z.to("cuda")
            x = x.to(dtype=torch.float32, device="cuda")
            yhat = torch.argmax(classifier_pipeline(x), dim=1)
            acc_total += torch.sum(yhat == y)
            if att_total is not None:
                att_total += torch.sum(yhat == z)

        experiment_results["Classifier Accuracy"][class_name] = (acc_total / len(experiment_loader.dataset)).item()
        experiment_results["Naive Attack Success"][class_name] = (
            1 - experiment_results["Classifier Accuracy"][class_name]
        )
        if experiment_loader.dataset.adversarial_target is not None:
            experiment_results["Attack Success"][class_name] = (att_total / len(experiment_loader.dataset)).item()

        print("Classifier Accuracy:", experiment_results["Classifier Accuracy"])
        print("Naive Attack Success:", experiment_results["Naive Attack Success"])
        if experiment_loader.dataset.adversarial_target is not None:
            print("Attack Success:", experiment_results["Attack Success"])
        experiment_results["Preprocessing_Crop"] = do_crop

    return experiment_results


def IS_from_dataset(dataset: Dataset):
    """
    :param gen_dataset: Dataset
        Pytorch dataset of images. The dataset must return a batch of images (no class labels).

    Wrapper function for `compute_IS`. Creates a dataloader and returns the inception score (IS) for an image dataset.
    """

    dataloader = create_dataset_loader(dataset)
    return compute_IS(dataloader)


if __name__ == "__main__":

    # Get call args
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-rcnfg",
        "--run_from_config",
        help="Computes the statistics for a completed experiment run based on the specified config.",
        type=str,
    )
    parser.add_argument(
        "-rallcnfg",
        "--run_all_config",
        action="store_true",
        help="Computes the statistics for all experiments that do not yet have results.",
    )
    parser.add_argument(
        "-jimq",
        "--joint_imqual",
        nargs="+",
        help="Computes image quality metrics jointly across the specified experiment configs.",
        type=str,
    )
    parser.add_argument(
        "-fboot", "--force_bootstrap", help="Forces FID to be bootsrapped the specified number of times.", type=int
    )
    parser.add_argument(
        "-sub_samp",
        "--sub_sample",
        help="Subsamples the experiment dataset to the specified size when computing image quality metrics. Useful for "
        "comparing image quality across experiments with different sample sizes.",
        type=int,
    )
    parser.add_argument(
        "-sub_samp_type",
        "--sub_sample_type",
        help="The type of subsampling to perform if computing sub-sampled image quality, one of 'random' or 'step'."
        "Note this is important when comparing ImageNet datasets. Many datasets are constructed with one of each "
        "class, comparisons with randomly sampled ImageNet datasets will have degraded image quality metrics. Setting "
        "`step` allows for fairer image quality comparisons between NatADiff and other handcrafted datasets.",
        type=str,
        default='random'
    )
    parser.add_argument(
        "-no_crop",
        "--no_crop",
        help="Whether to run the classifier with the resize and crop preprocessing transforms specified on the "
        "associated torchvision model card. If this flag is called then the image will be resized to the dimensions "
        "required by the classifier and no crop will be applied.",
        action="store_true",
    )
    parser.add_argument(
        "-mkcsv",
        "--make_csv",
        action="store_true",
        help="Collects all currently computed results into a useful csv file.",
    )
    parser.add_argument(
        "-get_avg",
        "--get_avg",
        action="store_true",
        help="Collects all currently computed results and computes the average classifier attack performance across "
        "all classifiers.",
    )

    args = parser.parse_args()

    do_crop = not args.no_crop

    if (
        (args.run_from_config is None)
        and (not args.run_all_config)
        and (not args.joint_imqual)
        and (not args.make_csv)
    ):
        parser.print_help()

    if args.run_from_config is not None:
        if isinstance(args.run_from_config, str):
            if args.run_from_config[-5:] != ".json":
                raise FileNotFoundError("Configs must be .json files.")
            config_path = pathlib.Path(args.run_from_config)
            if not config_path.exists():
                config_path = FILE_PATH / args.run_from_config
                if not config_path.exists():
                    raise FileNotFoundError(f"Unable to find config at {args.run_from_config} or {config_path}")
        else:
            raise TypeError(f"-rcnfg expects a `str`, got {type(args.run_from_config)}")

        with open(config_path, "r") as fstrm:
            config_dict = json.load(fstrm)

        experiment_results = stats_runner(
            config_dict, args.force_bootstrap, args.sub_sample, args.sub_sample_type, do_crop
        )

        # Saving stats summary
        with open(EXPERIMENT_DIR / config_dict["experiment_name"] / "stats_report.json", "w") as fstrm:
            json.dump(experiment_results, fstrm, indent=4)

    if args.run_all_config:
        experiment_runs = [f for f in EXPERIMENT_DIR.iterdir() if f.is_dir()]

        for exp_run in experiment_runs:
            exp_files = os.listdir(exp_run)
            if ("experiment_config.json" in exp_files) and not ("stats_report.json" in exp_files):

                with open(exp_run / "experiment_config.json", "r") as fstrm:
                    config_dict = json.load(fstrm)

                print(f"=== COMPUTING STATISTICS FOR {config_dict['experiment_name']} ===")

                experiment_results = stats_runner(
                    config_dict, args.force_bootstrap, args.sub_sample, args.sub_sample_type, do_crop
                )

                # Saving stats summary
                with open(EXPERIMENT_DIR / config_dict["experiment_name"] / "stats_report.json", "w") as fstrm:
                    json.dump(experiment_results, fstrm, indent=4)
            else:
                print(f"=== SKIPPING {exp_run} ===")

    if args.joint_imqual is not None:
        exp_dirs = []

        for path in args.joint_imqual:
            if path[-5:] != ".json":
                raise FileNotFoundError("Configs must be .json files.")
            config_path = pathlib.Path(path)
            if not config_path.exists():
                config_path = FILE_PATH / path
                if not config_path.exists():
                    raise FileNotFoundError(f"Unable to find config at {path} or {config_path}")

            with open(config_path, "r") as fstrm:
                config_dict = json.load(fstrm)

            exp_path = EXPERIMENT_DIR / config_dict["experiment_name"]
            if exp_path.exists():
                exp_dirs.append(exp_path)
            else:
                raise FileNotFoundError(f"No experiment exists at {exp_path}")

        imagenet_dataset = ImageDataset(IMAGENET_VAL_PATH, Resize((299, 299)))
        imagenet_a_dataset = ImageDataset(IMAGENET_A_PATH, Resize((299, 299)))
        image_dataset = ImageDataset(exp_dirs, Resize((299, 299)))

        print(f"=== COMPUTING JOINT IS ===")
        joint_is = IS_from_dataset(image_dataset)
        print(f"=== COMPUTING JOINT vs ImageNet-val FID ===")
        joint_fid_val = FID_from_dataset(image_dataset, imagenet_dataset, bootstrap_undersample=args.force_bootstrap)
        print(f"=== COMPUTING JOINT vs ImageNet-a FID ===")
        joint_fid_a = FID_from_dataset(image_dataset, imagenet_a_dataset, bootstrap_undersample=args.force_bootstrap)

        print("JOINT IS:", joint_is)
        print("JOINT FID vs IMAGENET-val", joint_fid_val)
        print("JOINT FID vs IMAGENET-a", joint_fid_a)

    if args.make_csv:
        experiment_runs = [f for f in EXPERIMENT_DIR.iterdir() if f.is_dir()]

        image_quality = {
            "Experiment_Name": [],
            "IS": [],
            "IS_SE": [],
            "FID_ImageNet_Val": [],
            "FID_ImageNet_Val_SE": [],
            "FID_ImageNet_A": [],
            "FID_ImageNet_A_SE": [],
            "ImQual_Sample_Size": [],
        }
        classifier_performance = {"Experiment_Name": []}
        # Two separate loops because I want them ordered this way in the dictionary (personal pref.) :)
        for class_name in CLASSIFIER_NAMES:
            classifier_performance[class_name + "_Acc"] = []

        for class_name in CLASSIFIER_NAMES:
            classifier_performance[class_name + "_Naive_AttSucc"] = []

        for class_name in CLASSIFIER_NAMES:
            classifier_performance[class_name + "_AttSucc"] = []

        classifier_performance["Preprocessing_Crop"] = []

        for exp_run in experiment_runs:
            exp_files = os.listdir(exp_run)
            if "stats_report.json" in exp_files:

                with open(exp_run / "stats_report.json", "r") as fstrm:
                    stats_report = json.load(fstrm)

                print(f"=== COLLATING STATISTICS FOR {stats_report['experiment_name']} ===")

                # Image Quality
                image_quality["Experiment_Name"].append(stats_report["experiment_name"])

                image_quality["IS"].append(stats_report["IS"]["is"])
                image_quality["IS_SE"].append(stats_report["IS"]["se"])

                image_quality["FID_ImageNet_Val"].append(stats_report["ImageNetVal_FID"]["fid"])
                if "se" in stats_report["ImageNetVal_FID"]:
                    image_quality["FID_ImageNet_Val_SE"].append(stats_report["ImageNetVal_FID"]["se"])
                else:
                    image_quality["FID_ImageNet_Val_SE"].append(None)

                image_quality["FID_ImageNet_A"].append(stats_report["ImageNetA_FID"]["fid"])
                if "se" in stats_report["ImageNetA_FID"]:
                    image_quality["FID_ImageNet_A_SE"].append(stats_report["ImageNetA_FID"]["se"])
                else:
                    image_quality["FID_ImageNet_A_SE"].append(None)
                if "ImQual_Sample_Size" in stats_report:
                    image_quality["ImQual_Sample_Size"].append(stats_report["ImQual_Sample_Size"])
                else:
                    image_quality["ImQual_Sample_Size"].append(None)

                # Classifier Accuracy
                classifier_performance["Experiment_Name"].append(stats_report["experiment_name"])

                for class_name in CLASSIFIER_NAMES:
                    if class_name in stats_report["Classifier Accuracy"].keys():
                        classifier_performance[class_name + "_Acc"].append(
                            stats_report["Classifier Accuracy"][class_name]
                        )
                    else:
                        classifier_performance[class_name + "_Acc"].append(None)

                if "Naive Attack Success" in stats_report:
                    for class_name in CLASSIFIER_NAMES:
                        if class_name in stats_report["Naive Attack Success"].keys():
                            classifier_performance[class_name + "_Naive_AttSucc"].append(
                                stats_report["Naive Attack Success"][class_name]
                            )
                        else:
                            classifier_performance[class_name + "_Naive_AttSucc"].append(None)
                else:
                    for class_name in CLASSIFIER_NAMES:
                        classifier_performance[class_name + "_Naive_AttSucc"].append(None)

                if "Attack Success" in stats_report:
                    for class_name in CLASSIFIER_NAMES:
                        if class_name in stats_report["Attack Success"].keys():
                            classifier_performance[class_name + "_AttSucc"].append(
                                stats_report["Attack Success"][class_name]
                            )
                        else:
                            classifier_performance[class_name + "_AttSucc"].append(None)
                else:
                    for class_name in CLASSIFIER_NAMES:
                        classifier_performance[class_name + "_AttSucc"].append(None)

                if "Preprocessing_Crop" in stats_report:
                    classifier_performance["Preprocessing_Crop"].append(stats_report["Preprocessing_Crop"])
                else:
                    classifier_performance["Preprocessing_Crop"].append(None)

                print(f"=== SKIPPING {exp_run} ===")

        image_quality = pd.DataFrame(image_quality).sort_values(by="Experiment_Name")
        classifier_performance = pd.DataFrame(classifier_performance).sort_values(by="Experiment_Name")

        image_quality["ImQual_Sample_Size"] = image_quality["ImQual_Sample_Size"].astype("Int64")

        image_quality.to_csv(EXPERIMENT_DIR / "image_quality.csv", index=False)
        classifier_performance.to_csv(EXPERIMENT_DIR / "classifier_performance.csv", index=False)

    if args.get_avg:
        experiment_runs = [f for f in EXPERIMENT_DIR.iterdir() if f.is_dir()]

        average_performance = {
            "Experiment_Name": [],
            "Average_Acc": [],
            "Average_Naive_AttSucc": [],
            "Average_AttSucc": [],
        }

        for exp_run in experiment_runs:
            exp_files = os.listdir(exp_run)
            if "stats_report.json" in exp_files:

                with open(exp_run / "stats_report.json", "r") as fstrm:
                    stats_report = json.load(fstrm)

                print(f"=== COMPUTING AVERAGE STATISTICS FOR {stats_report['experiment_name']} ===")

                average_performance["Experiment_Name"].append(stats_report["experiment_name"])

                classifier_acc = []
                if "Classifier Accuracy" in stats_report:
                    for class_name in CLASSIFIER_NAMES:
                        if class_name in stats_report["Classifier Accuracy"].keys():
                            classifier_acc.append(stats_report["Classifier Accuracy"][class_name])
                if len(classifier_acc) > 0:
                    average_performance["Average_Acc"].append(np.mean(classifier_acc).item())
                else:
                    average_performance["Average_Acc"].append(None)

                classifier_naive_attsucc = []
                if "Naive Attack Success" in stats_report:
                    for class_name in CLASSIFIER_NAMES:
                        if class_name in stats_report["Naive Attack Success"].keys():
                            classifier_naive_attsucc.append(stats_report["Naive Attack Success"][class_name])
                if len(classifier_naive_attsucc) > 0:
                    average_performance["Average_Naive_AttSucc"].append(np.mean(classifier_naive_attsucc).item())
                else:
                    average_performance["Average_Naive_AttSucc"].append(None)

                classifier_attsucc = []
                if "Attack Success" in stats_report:
                    for class_name in CLASSIFIER_NAMES:
                        if class_name in stats_report["Attack Success"].keys():
                            classifier_attsucc.append(stats_report["Attack Success"][class_name])
                if len(classifier_attsucc) > 0:
                    average_performance["Average_AttSucc"].append(np.mean(classifier_naive_attsucc).item())
                else:
                    average_performance["Average_AttSucc"].append(None)

                print(f"=== SKIPPING {exp_run} ===")

        average_performance = pd.DataFrame(average_performance).sort_values(by="Experiment_Name")

        average_performance.to_csv(EXPERIMENT_DIR / "average_performance.csv", index=False)


#######################################################################################################################
