### Imports ###########################################################################################################

import pytest
import numpy as np
import torch
import timm
from transformers import AutoImageProcessor, ResNetForImageClassification
from diffusers.schedulers import EulerDiscreteScheduler
from torchvision.transforms.v2 import Grayscale

from gcontrol.diffusion_pipelines import GCStableDiffusionPipeline
from gcontrol.guidance_controllers.stable_diffusion import AdversarialClassifierGuidance, ClassifierGuidance
from gcontrol.utils import get_timm_config

from conftest import (
    PRETRAINED_MODEL_PATH,
    DATASET_PATH,
)

#######################################################################################################################

ATOL = 1e-8
RTOL = 1e-8

### Tests #############################################################################################################


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.slow
@pytest.mark.parametrize(
    """
        model_id,
        dtype,
        cache_dir,
        num_inference_steps,
        guidance_scale,
        grad_norm,
        time_travel_sample,
        height,
        width,
        offload_strategy
    """,
    [
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            512,
            512,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            512,
            512,
            "enable_sequential_cpu_offload",
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            512,
            512,
            "enable_model_cpu_offload",
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            PRETRAINED_MODEL_PATH,
            5,
            5,
            2,
            None,
            512,
            512,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            2,
            512,
            512,
            None,
        ),
    ],
)
def test_classifier_free(
    model_id,
    dtype,
    cache_dir,
    num_inference_steps,
    guidance_scale,
    grad_norm,
    time_travel_sample,
    height,
    width,
    offload_strategy,
):

    pipe = GCStableDiffusionPipeline.from_pretrained(
        model_id, torch_dtype=dtype, cache_dir=cache_dir, use_safetensors=True
    )
    pipe = pipe.to("cuda")

    if offload_strategy is not None:
        offload_method = getattr(pipe, offload_strategy)
        offload_method()

    generator = torch.Generator(device="cuda").manual_seed(1234)

    prompt = 2 * ["A photo-realistic natural image with a single realistic tiger, 8k"]
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        guidance_scale=guidance_scale,
        grad_norm=grad_norm,
        time_travel_sample=time_travel_sample,
        height=height,
        width=width,
        generator=generator,
    )

    assert not np.isnan(np.array(image.images[0])).any()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.slow
@pytest.mark.parametrize(
    """
        model_id,
        dtype,
        classifier_id,
        cache_dir,
        gprompt,
        num_inference_steps,
        g_w,
        g_p,
        g_m,
        g_s,
        grad_norm,
        augmentations,
        time_travel_sample,
        height,
        width,
        offload_strategy
    """,
    [
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            None,
            5,
            7,
            7,
            None,
            5,
            None,
            None,
            None,
            400,
            400,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            None,
            5,
            7,
            7,
            None,
            5,
            None,
            None,
            None,
            400,
            400,
            "enable_sequential_cpu_offload",
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            None,
            5,
            7,
            7,
            None,
            5,
            None,
            None,
            None,
            400,
            400,
            "enable_model_cpu_offload",
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            None,
            5,
            7,
            7,
            None,
            5,
            2,
            None,
            None,
            400,
            400,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            None,
            5,
            7,
            7,
            None,
            5,
            None,
            Grayscale(3),
            2,
            400,
            400,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            2 * ["A photo-realistic natural image with a single realistic water buffalo, 8k"],
            5,
            7,
            7,
            0.5,
            5,
            None,
            None,
            None,
            400,
            400,
            None,
        ),
    ],
)
def test_product_adversarial(
    model_id,
    dtype,
    classifier_id,
    cache_dir,
    gprompt,
    num_inference_steps,
    g_w,
    g_p,
    g_m,
    g_s,
    grad_norm,
    augmentations,
    time_travel_sample,
    height,
    width,
    offload_strategy,
):

    # Loading ResNet model from disk
    ## Loading Processor
    resnet_processor = AutoImageProcessor.from_pretrained(
        pretrained_model_name_or_path=classifier_id, cache_dir=cache_dir
    )

    ## Loading Model
    resnet_model = ResNetForImageClassification.from_pretrained(
        pretrained_model_name_or_path=classifier_id, cache_dir=cache_dir
    )

    # Loading Stable diffusion model from disk
    pipe = GCStableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        cache_dir=cache_dir,
        use_safetensors=True,
        guidance_controller=AdversarialClassifierGuidance(resnet_model, **resnet_processor.to_dict()),
    )
    pipe._exclude_from_cpu_offload.extend(["vae", "unet"])
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")

    if offload_strategy is not None:
        offload_method = getattr(pipe, offload_strategy)
        offload_method()

    generator = torch.Generator(device="cuda").manual_seed(1234)

    prompt = 2 * ["A photo-realistic natural image with a single realistic tiger, 8k"]
    image = pipe(
        prompt=prompt,
        gprompt=gprompt,
        num_inference_steps=num_inference_steps,
        g_w=g_w,
        g_p=g_p,
        g_m=g_m,
        g_s=g_s,
        grad_norm=grad_norm,
        augmentations=augmentations,
        time_travel_sample=time_travel_sample,
        height=height,
        width=width,
        generator=generator,
    )

    assert not np.isnan(np.array(image.images[0])).any()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.slow
@pytest.mark.parametrize(
    """
        model_id,
        dtype,
        classifier_id,
        cache_dir,
        num_inference_steps,
        g_w,
        grad_norm,
        augmentations,
        time_travel_sample,
        height,
        width,
        offload_strategy
    """,
    [
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            None,
            400,
            400,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            None,
            400,
            400,
            "enable_sequential_cpu_offload",
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            None,
            400,
            400,
            "enable_model_cpu_offload",
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            5,
            5,
            2,
            None,
            None,
            400,
            400,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            None,
            2,
            400,
            400,
            None,
        ),
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "microsoft/resnet-50",
            PRETRAINED_MODEL_PATH,
            5,
            5,
            None,
            Grayscale(3),
            None,
            400,
            400,
            None,
        ),
    ],
)
def test_classifier(
    model_id,
    dtype,
    classifier_id,
    cache_dir,
    num_inference_steps,
    g_w,
    grad_norm,
    augmentations,
    time_travel_sample,
    height,
    width,
    offload_strategy,
):

    # Loading ResNet model from disk
    ## Loading Processor
    resnet_processor = AutoImageProcessor.from_pretrained(
        pretrained_model_name_or_path=classifier_id, cache_dir=cache_dir
    )

    ## Loading Model
    resnet_model = ResNetForImageClassification.from_pretrained(
        pretrained_model_name_or_path=classifier_id, cache_dir=cache_dir
    )

    # Loading Stable diffusion model from disk
    pipe = GCStableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        cache_dir=cache_dir,
        use_safetensors=True,
        guidance_controller=ClassifierGuidance(resnet_model, **resnet_processor.to_dict()),
    )
    pipe._exclude_from_cpu_offload.extend(["vae", "unet"])
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")

    if offload_strategy is not None:
        offload_method = getattr(pipe, offload_strategy)
        offload_method()

    generator = torch.Generator(device="cuda").manual_seed(1234)

    prompt = 2 * ["A photo-realistic natural image with a single realistic tiger, 8k"]
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        g_w=g_w,
        grad_norm=grad_norm,
        augmentations=augmentations,
        time_travel_sample=time_travel_sample,
        height=height,
        width=width,
        generator=generator,
    )

    assert not np.isnan(np.array(image.images[0])).any()


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not enabled")
@pytest.mark.slow
@pytest.mark.parametrize(
    """
        model_id,
        dtype,
        classifier_id,
        cache_dir,
        num_inference_steps,
        height,
        width,
    """,
    [
        (
            "sd-legacy/stable-diffusion-v1-5",
            torch.float16,
            "inception_v4",
            PRETRAINED_MODEL_PATH,
            5,
            400,
            400,
        ),
    ],
)
def test_timm_classifier(model_id, dtype, classifier_id, cache_dir, num_inference_steps, height, width):

    # Loading `timm` model from disk
    inception_model = (
        timm.create_model(
            model_name=classifier_id,
            pretrained=True,
        )
        .eval()
        .cuda()
        .to(torch.float16)
    )

    inception_config = get_timm_config(inception_model)

    # Loading Stable diffusion model from disk
    pipe = GCStableDiffusionPipeline.from_pretrained(
        model_id,
        torch_dtype=dtype,
        cache_dir=cache_dir,
        use_safetensors=True,
        guidance_controller=ClassifierGuidance(inception_model, **inception_config),
    )
    pipe._exclude_from_cpu_offload.extend(["vae", "unet"])
    pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
    pipe = pipe.to("cuda")

    generator = torch.Generator(device="cuda").manual_seed(1234)

    prompt = 2 * ["A photo-realistic natural image with a single realistic tiger, 8k"]
    image = pipe(
        prompt=prompt,
        num_inference_steps=num_inference_steps,
        height=height,
        width=width,
        generator=generator,
    )

    assert not np.isnan(np.array(image.images[0])).any()


#######################################################################################################################
