### Preamble ##########################################################################################################

"""
Experiment samplers for natural adversarial sampling experiments. Each sampler is a separate experiment. All samplers 
will be passed:

`diffusion_model`: The GCDiffusion pipeline
`class_idx`: The target diffusion class
`np_rng_generator`: A numpy rng generator
`pt_rng_generator`: A torch rng generator
`logger`: A logger for logging experiment progress/info
`**experiment_args`: Arguments defined in the experiment configuration file under the "experiment_args" key.

Note that the `experiment_args` dictionary should contain a dictionary at key "diffusion_args" which will be passed to 
the GCDiffusion pipeline as `**diffusion_args`.
"""

#######################################################################################################################

### Imports ###########################################################################################################

import numpy as np
import torch
from concurrent.futures import ThreadPoolExecutor
from torchvision.utils import save_image
from datetime import datetime, timedelta
import os
import timm
import pathlib
import json
import pickle
import shutil
import logging
import argparse

from typing import Union, Iterable, Optional, Tuple, List, Any, Callable, Dict

from torchvision.models import (
    resnet50,
    inception_v3,
    vit_h_14,
    ResNet50_Weights,
    Inception_V3_Weights,
    ViT_H_14_Weights,
)
from misc.experiment_helpers import get_embedding, get_distance_mat
from misc.classifier_pipeline import resnet50_config, inceptionv3_config, vith14_config
from misc.path_configs import CACHE_DIR, EXPERIMENT_DIR, IMAGENET_PATH, IMAGENET_A_PATH, IMAGENET_CLASSES

from diffusers.schedulers import DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler

from gcontrol.diffusion_pipelines import GCStableDiffusionPipeline
from gcontrol.guidance_controllers.controller_utils import GController

from guidance_controllers.classifier_free import ClassifierFreeGuidance
from guidance_controllers.adversarial_classifier import AdversarialClassifierGuidance
from guidance_controllers.mixture_guidance import MixtureGuidance

from gcontrol.utils import get_timm_config

#######################################################################################################################


def classifier_free_sample(
    diffusion_model: GCStableDiffusionPipeline,
    class_idx: int,
    np_rng_generator: np.random.Generator,
    pt_rng_generator: torch.Generator,
    logger: logging.Logger,
    diffusion_args: dict,
    prompt_format: str,
    sampling_attempts: int,
) -> Tuple[Optional[torch.Tensor], dict]:
    """
    :param diffusion_model: GCStableDiffusionPipeline
        Diffusion pipeline.
    :param class_idx: int
        The ImageNet class index of the diffusion target.
    :param prompt: str
        The prompt for the diffusion model to follow.
    :param np_rng_generator: np.random.Generator
        The numpy random number generator.
    :param pt_rng_generator: torch.Generator
        The torch random number generator.
    :param logger: logging.Logger
        A python logger to print progress updates to.
    :param diffusion_args: dict
        A dictionary containing the arguments to be passed to the diffusion pipeline.
    :prompt_format: str
        A string denoting the format of the diffusion prompt. Replaces `<class name>` with the target ImageNet class.
    :param sampling_attempts: int
        The number of sampling attempts to perform at any given sample stage, i.e., how many times to resample if a
        nan value is encountered.

    Performs one experiment sample step and returns the sampled image and sampling data.
    """

    samp_attempt = 1
    samp_success = False
    while (samp_attempt <= sampling_attempts) and (not samp_success):
        na_detected = False

        im_tensor = diffusion_model(
            prompt=prompt_format.replace("<class name>", IMAGENET_CLASSES["id2label"][class_idx]),
            output_type="pt",
            generator=pt_rng_generator,
            **diffusion_args,
        ).images[0]

        na_detected = im_tensor.isnan().any()
        samp_success = not na_detected

        samp_attempt += 1

        if na_detected:
            logger.warning(f"NA detected during sampling")
            im_tensor = None

    return im_tensor, {"NA Detected": na_detected.item()}


def mixture_sample(
    diffusion_model: GCStableDiffusionPipeline,
    class_idx: int,
    np_rng_generator: np.random.Generator,
    pt_rng_generator: torch.Generator,
    logger: logging.Logger,
    diffusion_args: dict,
    prompt_format: str,
    sampling_attempts: int,
) -> Tuple[Optional[torch.Tensor], dict]:
    """
    :param diffusion_model: GCStableDiffusionPipeline
        Diffusion pipeline.
    :param class_idx: int
        The ImageNet class index of the diffusion target.
    :param prompt: str
        The prompt for the diffusion model to follow.
    :param np_rng_generator: np.random.Generator
        The numpy random number generator.
    :param pt_rng_generator: torch.Generator
        The torch random number generator.
    :param logger: logging.Logger
        A python logger to print progress updates to.
    :param diffusion_args: dict
        A dictionary containing the arguments to be passed to the diffusion pipeline.
    :prompt_format: str
        A string denoting the format of the diffusion prompt. Replaces `<class name>` with the target ImageNet class.
    :param sampling_attempts: int
        The number of sampling attempts to perform at any given sample stage, i.e., how many times to resample if a
        nan value is encountered.

    Performs one experiment sample step and returns the sampled image and sampling data.
    """

    samp_attempt = 1
    samp_success = False

    adv_idx = np_rng_generator.integers(low=0, high=1000, size=1).item()
    if adv_idx == class_idx:
        adv_idx = adv_idx + 1 if class_idx < 999 else 0

    while (samp_attempt <= sampling_attempts) and (not samp_success):
        na_detected = False

        class_name = IMAGENET_CLASSES["id2label"][class_idx].split(",")[0]
        adv_name = IMAGENET_CLASSES["id2label"][adv_idx].split(",")[0]

        im_tensor = diffusion_model(
            prompt=prompt_format.replace("<class name>", class_name),
            gprompt=f"{adv_name} and {class_name}",
            output_type="pt",
            generator=pt_rng_generator,
            **diffusion_args,
        ).images[0]

        na_detected = im_tensor.isnan().any()
        samp_success = not na_detected

        samp_attempt += 1

        if na_detected:
            logger.warning(f"NA detected during sampling")
            im_tensor = None

    return im_tensor, {"Target Adversarial Class": adv_idx, "NA Detected": na_detected.item()}


def natural_adversarial_sample(
    diffusion_model: GCStableDiffusionPipeline,
    class_idx: int,
    np_rng_generator: np.random.Generator,
    pt_rng_generator: torch.Generator,
    logger: logging.Logger,
    diffusion_args: dict,
    prompt_format: str,
    sampling_attempts: int,
    gm_delta: Optional[float] = None,
    gs_delta: float = 20,
) -> Tuple[Optional[torch.Tensor], dict]:
    """
    :param diffusion_model: GCStableDiffusionPipeline
        Diffusion pipeline.
    :param class_idx: int
        The ImageNet class index of the diffusion target.
    :param prompt: str
        The prompt for the diffusion model to follow.
    :param np_rng_generator: np.random.Generator
        The numpy random number generator.
    :param pt_rng_generator: torch.Generator
        The torch random number generator.
    :param logger: logging.Logger
        A python logger to print progress updates to.
    :param diffusion_args: dict
        A dictionary containing the arguments to be passed to the diffusion pipeline.
    :prompt_format: str
        A string denoting the format of the diffusion prompt. Replaces `<class name>` with the target ImageNet class.
    :param sampling_attempts: int
        The number of sampling attempts to perform at any given sample stage, i.e., how many times to resample if the
        sampler fails to achieve the desired adversarial class or a nan value is encountered.
    :param gm_delta: float
        The amount to increase gm ($\mu$) by in the case that the sample is not adversarial. Note that gm can not be
        increased past 1 (anything over will be clipped to 1).
    :param gs_delta: float
        The amount to increase gs ($s$) by in the case that the sample is not adversarial.

    Performs one experiment sample step and returns the sampled image and sampling data. Adversarial label selected
    randomly.
    """

    adv_idx = np_rng_generator.integers(low=0, high=1000, size=1).item()
    if adv_idx == class_idx:
        adv_idx = adv_idx + 1 if class_idx < 999 else 0

    logger.info(f"Target diffusion class: {class_idx}, target adv class {adv_idx}")

    latent_shape = (
        1,
        diffusion_model.unet.config.in_channels,
        (int(diffusion_args["height"]) if "height" in diffusion_args.keys() else 512)
        // diffusion_model.vae_scale_factor,
        (int(diffusion_args["width"]) if "width" in diffusion_args.keys() else 512)
        // diffusion_model.vae_scale_factor,
    )
    noise_patch = torch.randn(
        size=latent_shape, generator=pt_rng_generator, dtype=diffusion_model.dtype, device=diffusion_model.device
    )

    samp_attempt = 1
    samp_success = False

    diffusion_args = diffusion_args.copy()
    while (samp_attempt <= sampling_attempts) and (not samp_success):
        na_detected = False

        class_name = IMAGENET_CLASSES["id2label"][class_idx].split(",")[0]
        adv_name = IMAGENET_CLASSES["id2label"][adv_idx].split(",")[0]

        im_tensor = diffusion_model(
            prompt=prompt_format.replace("<class name>", class_name),
            gprompt=f"{adv_name} and {class_name}",
            target_idx=adv_idx,
            output_type="pt",
            generator=pt_rng_generator,
            latents=noise_patch,
            **diffusion_args,
        ).images[0]

        na_detected = im_tensor.isnan().any()

        if na_detected:
            logger.warning(f"NA detected during sampling")
            im_tensor = None
        else:
            # Checking if adversarial sample survives conversion to .png
            im_tensor_255 = torch.clamp(im_tensor * 255 + 0.5, 0, 255).to(dtype=torch.uint8).to(dtype=torch.bfloat16)
            im_tensor_255 = im_tensor_255.unsqueeze(0)
            if diffusion_model.guidance_controller.classifier is not None:
                im_tensor_255 = diffusion_model.guidance_controller.preprocessor(im_tensor_255)
            class_logits = diffusion_model.guidance_controller.classifier(im_tensor_255)
            pred_class = torch.argmax(class_logits, dim=-1).item()
            if pred_class == adv_idx:
                samp_success = True
            logger.info(
                f"Target diffusion class: {class_idx}, target adv class {adv_idx}, predicted class {pred_class}"
            )

        samp_attempt += 1
        if (not samp_success) and (samp_attempt <= sampling_attempts):
            if gm_delta is not None:
                diffusion_args["g_m"] = min(diffusion_model.gcontrol_args["g_m"] + gm_delta, 1)
            if gs_delta is not None:
                diffusion_args["g_s"] = diffusion_model.gcontrol_args["g_s"] + gs_delta
            logger.info(f"Updated g_m to: {diffusion_args['g_m']} and g_s to: {diffusion_args['g_s']}")

    return im_tensor, {
        "Target Adversarial Class": adv_idx,
        "NA Detected": na_detected.item(),
        "Adversarial Success": samp_success,
    }


def similarity_natural_adversarial_sample(
    diffusion_model: GCStableDiffusionPipeline,
    class_idx: int,
    np_rng_generator: np.random.Generator,
    pt_rng_generator: torch.Generator,
    logger: logging.Logger,
    diffusion_args: dict,
    prompt_format: str,
    sampling_attempts: int,
    gm_delta: Optional[float] = None,
    gs_delta: float = 20,
    distance_metric: str = "cosine",
    nth_closest: int = 1,
) -> Tuple[Optional[torch.Tensor], dict]:
    """
    :param diffusion_model: GCStableDiffusionPipeline
        Diffusion pipeline.
    :param class_idx: int
        The ImageNet class index of the diffusion target.
    :param prompt: str
        The prompt for the diffusion model to follow.
    :param np_rng_generator: np.random.Generator
        The numpy random number generator.
    :param pt_rng_generator: torch.Generator
        The torch random number generator.
    :param logger: logging.Logger
        A python logger to print progress updates to.
    :param diffusion_args: dict
        A dictionary containing the arguments to be passed to the diffusion pipeline.
    :prompt_format: str
        A string denoting the format of the diffusion prompt. Replaces `<class name>` with the target ImageNet class.
    :param sampling_attempts: int
        The number of sampling attempts to perform at any given sample stage, i.e., how many times to resample if the
        sampler fails to achieve the desired adversarial class or a nan value is encountered.
    :param gm_delta: float
        The amount to increase gm ($\mu$) by in the case that the sample is not adversarial. Note that gm can not be
        increased past 1 (anything over will be clipped to 1).
    :param gs_delta: float
        The amount to increase gs ($s$) by in the case that the sample is not adversarial.
    :param distance_metric: str
        The distance metric used when computing the distance between class embeddings in the CLIP embedding space, one
        of "l2" or "cosine".
    :param nth_closest: int
        How close the adversarial class should be to the true class in the embedding space. A value of n is the
        nth closest ImageNet class in the CLIP embedding space. Note that a value of 1 is the closest class in the
        CLIP embedding space that is not the true class.

    Performs one experiment sample step and returns the sampled image and sampling data. Adversarial label selected
    based on proximity to the true label in the embedding space.
    """

    embeddings = get_embedding(
        list(IMAGENET_CLASSES["id2label"].values()), diffusion_model=diffusion_model, embedding_type="eos"
    )
    dist_mat = get_distance_mat(embedding=embeddings, distance_type=distance_metric)
    sorted_idx = torch.topk(dist_mat[class_idx], k=nth_closest + 5, largest=False, sorted=True).indices.tolist()
    if sorted_idx[nth_closest] != class_idx:
        adv_idx = sorted_idx[nth_closest]
    else:
        adv_idx = sorted_idx[nth_closest + 1]

    logger.info(f"Target diffusion class: {class_idx}, target adv class {adv_idx}")

    latent_shape = (
        1,
        diffusion_model.unet.config.in_channels,
        (int(diffusion_args["height"]) if "height" in diffusion_args.keys() else 512)
        // diffusion_model.vae_scale_factor,
        (int(diffusion_args["width"]) if "width" in diffusion_args.keys() else 512)
        // diffusion_model.vae_scale_factor,
    )
    noise_patch = torch.randn(
        size=latent_shape, generator=pt_rng_generator, dtype=diffusion_model.dtype, device=diffusion_model.device
    )

    samp_attempt = 1
    samp_success = False

    diffusion_args = diffusion_args.copy()
    while (samp_attempt <= sampling_attempts) and (not samp_success):
        na_detected = False

        class_name = IMAGENET_CLASSES["id2label"][class_idx].split(",")[0]
        adv_name = IMAGENET_CLASSES["id2label"][adv_idx].split(",")[0]

        im_tensor = diffusion_model(
            prompt=prompt_format.replace("<class name>", class_name),
            gprompt=f"{adv_name} and {class_name}",
            target_idx=adv_idx,
            output_type="pt",
            generator=pt_rng_generator,
            latents=noise_patch,
            **diffusion_args,
        ).images[0]

        na_detected = im_tensor.isnan().any()

        if na_detected:
            logger.warning(f"NA detected during sampling")
            im_tensor = None
        else:
            # Checking if adversarial sample survives conversion to .png
            im_tensor_255 = torch.clamp(im_tensor * 255 + 0.5, 0, 255).to(dtype=torch.uint8).to(dtype=torch.bfloat16)
            im_tensor_255 = im_tensor_255.unsqueeze(0)
            if diffusion_model.guidance_controller.classifier is not None:
                im_tensor_255 = diffusion_model.guidance_controller.preprocessor(im_tensor_255)
            class_logits = diffusion_model.guidance_controller.classifier(im_tensor_255)
            pred_class = torch.argmax(class_logits, dim=-1).item()
            if pred_class == adv_idx:
                samp_success = True
            logger.info(
                f"Target diffusion class: {class_idx}, target adv class {adv_idx}, predicted class {pred_class}"
            )

        samp_attempt += 1
        if (not samp_success) and (samp_attempt <= sampling_attempts):
            if gm_delta is not None:
                diffusion_args["g_m"] = min(diffusion_model.gcontrol_args["g_m"] + gm_delta, 1)
            if gs_delta is not None:
                diffusion_args["g_s"] = diffusion_model.gcontrol_args["g_s"] + gs_delta
            logger.info(f"Updated g_m to: {diffusion_args['g_m']} and g_s to: {diffusion_args['g_s']}")

    return im_tensor, {
        "Target Adversarial Class": adv_idx,
        "NA Detected": na_detected.item(),
        "Adversarial Success": samp_success,
    }


#######################################################################################################################
