### Preamble ##########################################################################################################

"""
Experiment runner for natural adversarial sampling experiments. Utilises diffusion pipelines from the `gcontrol` 
package which is based on the HuggingFace package suite. 
"""

#######################################################################################################################

### Imports ###########################################################################################################

import numpy as np
import torch
from concurrent.futures import ThreadPoolExecutor
from torchvision.utils import save_image
from datetime import datetime, timedelta
import os
import timm
import pathlib
import json
import pickle
import shutil
import logging
import argparse

from typing import Union, Iterable, Optional, Tuple, List, Any, Callable, Dict

from torchvision.models import (
    resnet50,
    inception_v3,
    vit_h_14,
    ResNet50_Weights,
    Inception_V3_Weights,
    ViT_H_14_Weights,
)
from misc.classifier_pipeline import resnet50_config, inceptionv3_config, vith14_config

from misc.path_configs import CACHE_DIR, EXPERIMENT_DIR, IMAGENET_PATH, IMAGENET_A_PATH, IMAGENET_CLASSES

from diffusers.schedulers import DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler

from gcontrol.diffusion_pipelines import GCStableDiffusionPipeline
from gcontrol.guidance_controllers.controller_utils import GController

from guidance_controllers.classifier_free import ClassifierFreeGuidance
from guidance_controllers.adversarial_classifier import AdversarialClassifierGuidance
from guidance_controllers.mixture_guidance import MixtureGuidance

from gcontrol.utils import get_timm_config

from experiment_samplers import (
    classifier_free_sample,
    mixture_sample,
    natural_adversarial_sample,
    similarity_natural_adversarial_sample,
)

#######################################################################################################################

FILE_PATH = pathlib.Path(__file__).parent


def _get_logger(fpath: Union[str, pathlib.PosixPath]) -> logging.Logger:
    """
    param fpath: str or pathlib.PosixPath
        Directory path to create the `log.txt` file associated with the logger.

    Returns a logger that outputs to a `log.txt` file ceated in the provided directory.
    """

    logger = logging.getLogger("ExperimentLogger")
    logger.setLevel(logging.DEBUG)

    # File handler
    file_handler = logging.FileHandler(fpath, mode="a")
    file_handler.setLevel(logging.DEBUG)
    file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    file_handler.setFormatter(file_formatter)

    logger.addHandler(file_handler)

    return logger


def _write_metadata(
    metadata_path: Union[str, pathlib.PosixPath],
    sample_num: int,
    diff_class: int,
    run_data: dict,
) -> None:
    """
    :param metadata_path: str or pathlib.PosixPath
        Directory path to write to or create the `metadata.txt` file.
    :param sample_num: int
        The sample number of the sample (a number in [1, TOTAL_SAMPLES]).
    :param diff_class: int
        The class targeted by the diffusion model.
    :param run_data: dict
        Additional metadata information to be stored in the `metadata.txt` file.

    Writes metadata information to the specified `metadata.txt` file.
    """

    if run_data is None:
        run_data = {}
    with open(metadata_path, "a") as metadata_strm:
        if sample_num == 1:
            wrtlist = ["Sample Number", "Target Diffusion Class"] + list(run_data.keys())
            metadata_strm.write("\t".join(map(str, wrtlist)) + "\n")
        wrtlist = [f"{sample_num:05d}", diff_class] + list(run_data.values())
        metadata_strm.write("\t".join(map(str, wrtlist)) + "\n")


def _get_runtime_est(start_time: datetime, curr_iter: int, tot_iters: int) -> datetime:
    """
    :param start_time: datetime.datetime
        The time that the experiment run was started (or restarted).
    :param curr_iter: int
        The current iteration of the sampler (this is with respect to total number of iterations).
    :param tot_iters: int
        Total number of sample iterations required for this experiment.

    Returns a datetime object representing the estimated experiment conclusion time.
    """

    curr_time = datetime.now()
    avg_time = (curr_time - start_time) / curr_iter
    est_finish_time = datetime.now() + avg_time * (tot_iters - curr_iter)
    est_finish_time = est_finish_time.strftime("%Y-%m-%d %H:%M:%S")
    return est_finish_time


def _load_required_models(
    guidance_type: str,
    scheduler: str,
    device: Union[str, torch.device],
    classifier: Optional[str],
) -> Tuple[GCStableDiffusionPipeline, GController, Optional[torch.nn.Module]]:
    """
    :param guidance_type: str
        The guidance to be used in the diffusion process, can be one of "adversarial", "classifier-free", or
        "multiclass-diffusion".
    :param scheduler: str
        The scheduler to be used in the diffusion process, can be one of "DDPM", "DDIM", or "Euler".
    :param device: str or torch.device
        The device to place the model(s) on.
    :param classifier: str
        The classifier to be use in the guidance controller (if any), can be one of "resnet", "inception",
        "adv_resnet", "adv_inception".

    Returns the required diffusion pipeline, guidance controller, and classifier (if applicable).
    """

    # Getting guidance controller
    classifier_model = None
    if guidance_type == "adversarial" or guidance_type == "adversarial similarity":
        # Loading classifier
        if classifier == "resnet50":
            classifier_model = resnet50(ResNet50_Weights.IMAGENET1K_V2)
            classifier_config = resnet50_config
        elif classifier == "inceptionv3":
            classifier_model = inception_v3(Inception_V3_Weights.IMAGENET1K_V1)
            classifier_config = inceptionv3_config
        elif classifier == "vit":
            classifier_model = vit_h_14(ViT_H_14_Weights.IMAGENET1K_SWAG_LINEAR_V1)
            classifier_config = vith14_config
        elif classifier == "adv_resnet":
            classifier_model = timm.create_model(model_name="inception_resnet_v2.tf_ens_adv_in1k", pretrained=True)
            classifier_config = get_timm_config(classifier_model)
        elif classifier == "adv_inception":
            classifier_model = timm.create_model(model_name="adv_inception_v3.tf_adv_in1k", pretrained=True)
            classifier_config = get_timm_config(classifier_model)
        else:
            raise ValueError(
                f"`classifier` must be one of 'resnet', 'inception', 'adv_resnet', or 'adv_inception' if"
                f" `guidance_type` != 'adversarial', got {classifier}"
            )
        classifier_model = classifier_model.eval().to(dtype=torch.bfloat16, device="cuda")

        # Instantiating guidance controller
        guidance_controller = AdversarialClassifierGuidance(classifier_model, **classifier_config)
    elif guidance_type == "multiclass-diffusion":
        guidance_controller = MixtureGuidance()
    elif guidance_type == "classifier-free":
        guidance_controller = ClassifierFreeGuidance()
    else:
        raise ValueError(
            "Unsupported `guidance_type`, must be one of 'adversarial', 'multiclass-diffusion', "
            f"'classifier-free', 'adversarial similarity', got {guidance_type}"
        )

    # Loading diffusion model
    model_id = "sd-legacy/stable-diffusion-v1-5"
    pipe = GCStableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=torch.bfloat16,
        cache_dir=CACHE_DIR,
        use_safetensors=True,
        guidance_controller=guidance_controller,
    )
    pipe = pipe.to("cuda")

    # Getting diffusion scheduler
    if scheduler == "DDPM":
        pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
    elif scheduler == "DDIM":
        pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
    elif scheduler == "Euler":
        pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    else:
        raise ValueError(f"Unsupported `scheduler`, must be one of 'DDPM', 'DDIM', 'Euler', got {scheduler}")

    return pipe, guidance_controller, classifier_model


def _save_metadata_ckpt(
    metadata_path: Union[str, pathlib.PosixPath],
    ckpt_metadata_path: Union[str, pathlib.PosixPath],
) -> None:
    """
    :param metadata_path: str or pathlib.PosixPath
        Directory path to write to or create the `metadata.txt` file.
    :param ckpt_metadata_path: str or pathlib.PosixPath
        Directory path to copy the `metadata.txt` file for checkpointing

    Checkpoints the `metadata.txt` file.
    """

    if not metadata_path.exists():
        raise FileNotFoundError(f"Unable to find metadata at {str(ckpt_metadata_path)}")
    shutil.copy(metadata_path, ckpt_metadata_path)


def _save_rng_ckpt(
    np_rng_generator: np.random.Generator,
    pt_rng_generator: torch.Generator,
    ckpt_rng_path: Union[str, pathlib.PosixPath],
) -> None:
    """
    :param np_rng_generator: np.random.Generator
        Random number generator for all numpy functions.
    :param pt_rng_generator: torch.Generator
        Random number generator for all torch functions.
    :param ckpt_rng_path: str or pathlib.PosixPath
        Directory path to save a pkl of the rng state.

    Pickles and saves the rng state to the specified save directory.
    """

    states = {
        "np_rng_state": np_rng_generator.bit_generator.state,
        "pt_rng_state": pt_rng_generator.get_state(),
        "pt_rng_device": pt_rng_generator.device,
    }
    with open(ckpt_rng_path, "wb") as f:
        pickle.dump(states, f)


def _load_rng_ckpt(ckpt_pkl_path: Union[str, pathlib.PosixPath]) -> Tuple[np.random.Generator, torch.Generator]:
    """
    :param ckpt_rng_path: str or pathlib.PosixPath
        Directory path to retrieve saved pkl of the rng state.

    Loads the pickled rng state and returns the torch and numpy random generators.
    """

    with open(ckpt_pkl_path, "rb") as f:
        states = pickle.load(f)
    np_rng_generator = np.random.default_rng()
    np_rng_generator.bit_generator.state = states["np_rng_state"]

    pt_rng_generator = torch.Generator(device=states["pt_rng_device"])
    pt_rng_generator.set_state(states["pt_rng_state"])

    return np_rng_generator, pt_rng_generator


def _load_latest_checkpoint(
    ckpt_path: Union[str, pathlib.PosixPath]
) -> Tuple[str, int, np.random.Generator, torch.Generator]:
    """
    :param ckpt_path: str or pathlib.PosixPath
        Directory path where ckpts are saved.

    Loads the latest checkpoint and returns the checkpoint timestamp, the next sample index, numpy random generator
    and torch random generator.
    """

    txt_files = set()
    pkl_files = set()

    for file in os.listdir(ckpt_path):
        if file[-4:] == ".txt":
            txt_files.add(file[:-4])
        elif file[-4:] == ".pkl":
            pkl_files.add(file[:-4])

    # Find common base names for pairs
    paired_files = list(txt_files & pkl_files)
    paired_files.sort(reverse=True)

    start_idx = None
    np_rng_generator = None
    pt_rng_generator = None
    for timestamp in paired_files:
        try:
            np_rng_generator, pt_rng_generator = _load_rng_ckpt(ckpt_path / (timestamp + ".pkl"))
            with open(ckpt_path / (timestamp + ".txt")) as fstrm:
                lines = fstrm.readlines()
                last_line = lines[-1].strip()
                fields = last_line.split("\t")
                start_idx = int(fields[0])
            shutil.copy(ckpt_path / (timestamp + ".txt"), ckpt_path.parent / "metadata.txt")
            break
        except Exception:
            pass

    return timestamp, start_idx, np_rng_generator, pt_rng_generator


def experiment_runner(
    experiment_name: str,
    num_samples_per_class: int,
    guidance_type: str,
    scheduler: str,
    experiment_args: dict,
    classifier: Optional[str] = None,
    start_idx: Optional[str] = None,
    np_rng: Optional[Union[int, np.random.Generator]] = None,
    pt_rng: Optional[Union[int, torch.Generator]] = None,
    samples_per_ckpt: Optional[int] = 100,
    device: Union[str, torch.device] = "cuda",
) -> None:
    """
    :param experiment_name: str
        The name of the experiment. Will be used as the name of the output directory that the experiment results are
        saved to.
    :param num_samples_per_class: int
        The number of images to generate for each class in the ImageNet class list.
    :param guidance_type: str
        The diffusion guidance type to use in the experiment, can be one of can be one of "adversarial", "adversarial
        similarity", "classifier-free", or "multiclass-diffusion".
    :param scheduler: str
        The scheduler to use during the diffusion process, can be one of "DDPM", "DDIM", or "Euler".
    :param experiment_args: dict
        Experiment specific arguments to be passed to the experiment. Every experiment must accept `diffusion_model`,
        `class_idx`, `prompt`, `np_rng_generator`, `pt_rng_generator`, `logger`. Additional arguments can be passed
        via the `experiment_args` dict.
    :param classifier: str
        The classifier to be use in the guidance controller (if necessary), can be one of "resnet", "inception",
        "adv_resnet", "adv_inception". If `None`, then it is assumed the guidance controller does not need a
        classifier.
    :param start_idx: int
        The sample to begin sampling from. If not provided sampling starts from the beginning, i.e., `start_idx = 1`.
    :param np_rng: str or np.random.Generator
        Either a numpy random generator or the seed to initialise a numpy random generator. If `None` then a random
        seed is used.
    :param pt_rng: str or torch.Generator
        Either a torch random generator or the seed to initialise a numpy random generator. If `None` then a random
        seed is used.
    :param samples_per_ckpt: int
        The number of samples before a checkpoint is created.
    :param device: str or torch.device
        The device to run the experiment on.


    Takes the experimental parameters and runs the experiment whilst managing metadata upading and checkpointing.
    Utilises multiprocessing to save images asynchronously while the next image is being created.
    """

    pipe, guidance_controller, classifier_model = _load_required_models(
        guidance_type=guidance_type,
        scheduler=scheduler,
        classifier=classifier,
        device=device,
    )

    if guidance_type == "adversarial":
        sampler = natural_adversarial_sample
    elif guidance_type == "adversarial similarity":
        sampler = similarity_natural_adversarial_sample
    elif guidance_type == "multiclass-diffusion":
        sampler = mixture_sample
    elif guidance_type == "classifier-free":
        sampler = classifier_free_sample
    else:
        raise ValueError(f"Unrecognised sampler, {sampler}")

    log_path = EXPERIMENT_DIR / experiment_name / "log.txt"
    metadata_path = EXPERIMENT_DIR / experiment_name / "metadata.txt"
    ckpt_path = EXPERIMENT_DIR / experiment_name / "ckpts"
    if not ckpt_path.exists():
        os.makedirs(ckpt_path)

    if isinstance(np_rng, np.random.Generator):
        np_rng_generator = np_rng
    elif isinstance(np_rng, int):
        np_rng_generator = np.random.default_rng(np_rng)
    elif np_rng_generator is None:
        np_rng_generator = np.random.default_rng()

    if isinstance(pt_rng, torch.Generator):
        pt_rng_generator = pt_rng
    elif isinstance(pt_rng, int):
        pt_rng_generator = torch.Generator(device=device).manual_seed(pt_rng)
    elif pt_rng_generator is None:
        pt_rng_generator = torch.Generator(device=device)

    total_samples = 1000 * num_samples_per_class
    if start_idx is None:
        lower_i = 0
        lower_ii = 0
    elif (0 <= start_idx) and (start_idx < total_samples):
        lower_i = start_idx // num_samples_per_class
        lower_ii = start_idx % num_samples_per_class
    else:
        raise ValueError(
            f"Invalid `start_idx`, expected None or integer between 0 and {total_samples}, got {start_idx}"
        )

    with ThreadPoolExecutor(
        max_workers=1
    ) as executor:  # Use a single worker thread to save images, metadata, and assist with checkpointing
        future = None
        logger = _get_logger(log_path)
        logger.info("Starting diffusion sampling")

        start_time = datetime.now()

        for i in range(lower_i, 1000):
            for ii in range(lower_ii, num_samples_per_class):
                sample_num = i * num_samples_per_class + ii + 1

                im_tensor, run_data = sampler(
                    diffusion_model=pipe,
                    class_idx=i,
                    np_rng_generator=np_rng_generator,
                    pt_rng_generator=pt_rng_generator,
                    logger=logger,
                    **experiment_args,
                )

                # Get runtime estimate
                est_finish_time = _get_runtime_est(start_time, sample_num, total_samples)

                # Wait for the previous save/log to complete, if any
                if future is not None:
                    future.result()
                if im_tensor is not None:
                    executor.submit(
                        save_image,
                        im_tensor,
                        EXPERIMENT_DIR / experiment_name / f"image_{sample_num:05d}.png",
                    )
                future = executor.submit(_write_metadata, metadata_path, sample_num, i, run_data)
                logger.info(f"Completed sample {sample_num} / {total_samples}. Estimated finish: {est_finish_time}")

                if sample_num % samples_per_ckpt == 0:
                    logger.info(f"Creating a checkpoint.")
                    ckpt_name = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
                    _save_rng_ckpt(
                        np_rng_generator,
                        pt_rng_generator,
                        ckpt_path / (ckpt_name + ".pkl"),
                    )
                    future.result()  # Ensure that metadata writing has completed
                    _save_metadata_ckpt(metadata_path, ckpt_path / (ckpt_name + ".txt"))
                    logger.info(f"Successfully created checkpoint: {ckpt_name}")

            # Ensure the last save/log completes
            future.result()


if __name__ == "__main__":

    # Get call args
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-rcnfg", "--run_from_config", help="Runs the experiment from the specified config.json file.", type=str
    )
    parser.add_argument(
        "-rckpt",
        "--run_from_checkpoint",
        help="Reloads and runs an incomplete experiment from the latest successful checkpoint.",
        type=str,
    )

    args = parser.parse_args()

    # Parse args and check for errors to prevent overwriting or breaking experiments
    if (args.run_from_config is None) and (args.run_from_checkpoint is None):
        parser.print_help()
    elif (args.run_from_config is not None) and (args.run_from_checkpoint is not None):
        raise RuntimeError("Can not provide both `-rcnfg` and `-rckpt` arguments.")
    elif (args.run_from_config is not None) and (args.run_from_checkpoint is None):
        if isinstance(args.run_from_config, str):
            if args.run_from_config[-5:] != ".json":
                raise FileNotFoundError("Configs must be .json files.")
            config_path = pathlib.Path(args.run_from_config)
            if not config_path.exists():
                config_path = FILE_PATH / args.run_from_config
                if not config_path.exists():
                    raise FileNotFoundError(f"Unable to find config at {args.run_from_config} or {config_path}")
        else:
            raise TypeError(f"-rcnfg expects a `str`, got {type(args.run_from_config)}")

        with open(config_path, "r") as fstrm:
            config_dict = json.load(fstrm)

        if "experiment_name" not in config_dict.keys():
            raise KeyError("The experiment config requires an 'experiment_name' field.")
        experiment_run_path = EXPERIMENT_DIR / config_dict["experiment_name"]

        if experiment_run_path.exists():
            raise FileExistsError(
                "There is already an experiment with the specified name in the experiment directory,"
                f" {experiment_run_path} already exists. Did you mean to load from checkpoint?"
            )

        # Make directory and copy the config file
        os.mkdir(experiment_run_path)
        shutil.copy(config_path, experiment_run_path / "experiment_config.json")

        # Start the experiment runner
        experiment_runner(**config_dict)

    # Parse args and check for errors to prevent overwriting or breaking experiments
    elif (args.run_from_config is None) and (args.run_from_checkpoint is not None):
        if isinstance(args.run_from_checkpoint, str):
            if args.run_from_checkpoint[-5:] != ".json":
                raise FileNotFoundError("Configs must be .json files.")
            config_path = pathlib.Path(args.run_from_checkpoint)
            if not config_path.exists():
                config_path = FILE_PATH / args.run_from_checkpoint
                if not config_path.exists():
                    raise FileNotFoundError(f"Unable to find config at {args.run_from_checkpoint} or {config_path}")
        else:
            raise TypeError(f"-rckpt expects a `str`, got {type(args.run_from_checkpoint)}")

        with open(config_path, "r") as fstrm:
            config_dict = json.load(fstrm)

        if "experiment_name" not in config_dict.keys():
            raise KeyError("The experiment config requires an 'experiment_name' field.")
        experiment_run_path = EXPERIMENT_DIR / config_dict["experiment_name"]

        # Check that experiment has been run before
        if not experiment_run_path.exists():
            raise FileNotFoundError(f"No existing experiment directory found at {experiment_run_path}")
        existing_config_path = experiment_run_path / "experiment_config.json"
        if not existing_config_path.exists():
            raise FileNotFoundError(f"No experiment config file found at {existing_config_path}")

        # Verify that the previous experiment has the same config as the specified one
        with open(existing_config_path, "r") as fstrm:
            existing_config_dict = json.load(fstrm)
        if config_dict != existing_config_dict:
            raise ValueError(f"Config dictionary at {config_path} and {existing_config_path} do not match.")

        # Verify checkpoints exist
        ckpt_path = experiment_run_path / "ckpts"
        if not ckpt_path.exists():
            raise FileNotFoundError(f"No checkpoint directory found at {ckpt_path}")

        # Attempt to load checkpoints
        timestamp, start_idx, np_rng_generator, pt_rng_generator = _load_latest_checkpoint(ckpt_path=ckpt_path)
        if timestamp is None or start_idx is None or np_rng_generator is None or pt_rng_generator is None:
            raise RuntimeError("Failed to load checkpoint")
        else:
            print(f"Loaded checkpoint made on {timestamp}")

        # Restart experiment
        config_dict["pt_rng"] = pt_rng_generator
        config_dict["np_rng"] = np_rng_generator
        config_dict["start_idx"] = start_idx
        experiment_runner(**config_dict)

#######################################################################################################################
