# %%
from functools import partial
import os
from matplotlib.lines import Line2D

import numpy as np
import nilearn
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib import cm, ticker
import cortex
import torch

# %%
from config_utils import load_from_yaml
from datamodule import AllDatamodule, build_dm
from models import VEModel

device = "cuda:0"
# device = 'cpu'
# %%
from diffusers import StableDiffusionPipeline
from diffusers import UNet2DConditionModel, AutoencoderKL

# pipe = StableDiffusionPipeline.from_pretrained(
#     "runwayml/stable-diffusion-v1-5",
# )
# pipe.enable_attention_slicing("max")
# pipe.enable_vae_slicing()
# pipe.enable_vae_tiling()
# # pipe.unet.to(memory_format=torch.channels_last)
# pipe.enable_xformers_memory_efficient_attention()
# %%
pipe = StableDiffusionPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    torch_dtype=torch.float16 if device != "cpu" else torch.float32,
)
pipe.enable_attention_slicing("max")
pipe.enable_vae_slicing()
pipe.enable_vae_tiling()
pipe.enable_xformers_memory_efficient_attention()
# %%
from diffusers import (
    DDIMScheduler,
    DPMSolverSinglestepScheduler,
    DPMSolverMultistepScheduler,
)

pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config)

# %%
pipe = pipe.to(device)
# # %%
prompt = "a photo from COCO dataset"
# image = pipe(prompt, num_inference_steps=10).images[0]
# plt.imshow(image)
# plt.show()
# %%
# @torch.no_grad()
# @replace_example_docstring(EXAMPLE_DOC_STRING)
from typing import Any, Callable, Dict, List, Optional, Union


def my_decode_latents(self, latents):
    latents = 1 / self.vae.config.scaling_factor * latents
    image = self.vae.decode(latents).sample
    # image = (image / 2 + 0.5).clamp(0, 1)
    image = image / 2 + 0.5
    # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
    # image = image.cpu().permute(0, 2, 3, 1).float().numpy()
    return image


def mycall(
    self,
    prompt: Union[str, List[str]] = None,
    height: Optional[int] = None,
    width: Optional[int] = None,
    num_inference_steps: int = 10,
    start_step: int = 5,
    end_step: int = 10,
    guidance_scale: float = 7.5,
    negative_prompt: Optional[Union[str, List[str]]] = None,
    num_images_per_prompt: Optional[int] = 1,
    eta: float = 0.0,
    generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
    latents: Optional[torch.FloatTensor] = None,
    prompt_embeds: Optional[torch.FloatTensor] = None,
    negative_prompt_embeds: Optional[torch.FloatTensor] = None,
    output_type: Optional[str] = "tensor",
    return_dict: bool = True,
    callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
    callback_steps: int = 1,
    cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
    r"""
    Function invoked when calling the pipeline for generation.

    Args:
        prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
            instead.
        height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
            The height in pixels of the generated image.
        width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
            The width in pixels of the generated image.
        num_inference_steps (`int`, *optional*, defaults to 50):
            The number of denoising steps. More denoising steps usually lead to a higher quality image at the
            expense of slower inference.
        guidance_scale (`float`, *optional*, defaults to 7.5):
            Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
            `guidance_scale` is defined as `w` of equation 2. of [Imagen
            Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
            1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
            usually at the expense of lower image quality.
        negative_prompt (`str` or `List[str]`, *optional*):
            The prompt or prompts not to guide the image generation. If not defined, one has to pass
            `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
            Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
        num_images_per_prompt (`int`, *optional*, defaults to 1):
            The number of images to generate per prompt.
        eta (`float`, *optional*, defaults to 0.0):
            Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
            [`schedulers.DDIMScheduler`], will be ignored for others.
        generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
            One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
            to make generation deterministic.
        latents (`torch.FloatTensor`, *optional*):
            Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
            generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
            tensor will ge generated by sampling using the supplied random `generator`.
        prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
            provided, text embeddings will be generated from `prompt` input argument.
        negative_prompt_embeds (`torch.FloatTensor`, *optional*):
            Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
            weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
            argument.
        output_type (`str`, *optional*, defaults to `"pil"`):
            The output format of the generate image. Choose between
            [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
            plain tuple.
        callback (`Callable`, *optional*):
            A function that will be called every `callback_steps` steps during inference. The function will be
            called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
        callback_steps (`int`, *optional*, defaults to 1):
            The frequency at which the `callback` function will be called. If not specified, the callback will be
            called at every step.
        cross_attention_kwargs (`dict`, *optional*):
            A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
            `self.processor` in
            [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).

    Examples:

    Returns:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
        [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
        When returning a tuple, the first element is a list with the generated images, and the second element is a
        list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
        (nsfw) content, according to the `safety_checker`.
    """
    # 0. Default height and width to unet
    height = height or self.unet.config.sample_size * self.vae_scale_factor
    width = width or self.unet.config.sample_size * self.vae_scale_factor

    # 1. Check inputs. Raise error if not correct
    self.check_inputs(
        prompt,
        height,
        width,
        callback_steps,
        negative_prompt,
        prompt_embeds,
        negative_prompt_embeds,
    )

    # 2. Define call parameters
    if prompt is not None and isinstance(prompt, str):
        batch_size = 1
    elif prompt is not None and isinstance(prompt, list):
        batch_size = len(prompt)
    else:
        batch_size = prompt_embeds.shape[0]

    device = self._execution_device
    # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    # corresponds to doing no classifier free guidance.
    do_classifier_free_guidance = guidance_scale > 1.0

    # 3. Encode input prompt
    # with torch.no_grad():
    prompt_embeds = self._encode_prompt(
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt,
        prompt_embeds=prompt_embeds,
        negative_prompt_embeds=negative_prompt_embeds,
    )

    # 4. Prepare timesteps
    self.scheduler.set_timesteps(num_inference_steps, device=device)
    timesteps = self.scheduler.timesteps

    # 5. Prepare latent variables
    num_channels_latents = self.unet.in_channels
    # latents = self.prepare_latents(
    #     batch_size * num_images_per_prompt,
    #     num_channels_latents,
    #     height,
    #     width,
    #     prompt_embeds.dtype,
    #     device,
    #     generator,
    #     latents,
    # )
    # print(latents.shape)

    # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

    # 7. Denoising loop
    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
    timesteps = timesteps[start_step:end_step]
    # with self.progress_bar(total=num_inference_steps) as progress_bar:
    # print(timesteps)
    for i, t in enumerate(timesteps):
        # if i < start_step and i != 0:
        #     continue

        # expand the latents if we are doing classifier free guidance
        latent_model_input = (
            torch.cat([latents] * 2) if do_classifier_free_guidance else latents
        )
        latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

        # predict the noise residual
        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=prompt_embeds,
            cross_attention_kwargs=cross_attention_kwargs,
        ).sample

        # perform guidance
        if do_classifier_free_guidance:
            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (
                noise_pred_text - noise_pred_uncond
            )

        # compute the previous noisy sample x_t -> x_t-1
        latents = self.scheduler.step(
            noise_pred, t, latents, **extra_step_kwargs
        ).prev_sample

        # # call the callback, if provided
        # if i == len(timesteps) - 1 or (
        #     (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
        # ):
        #     progress_bar.update()
        #     if callback is not None and i % callback_steps == 0:
        #         callback(i, t, latents)

    if output_type == "latent":
        image = latents
        has_nsfw_concept = None
    elif output_type == "pil":
        # 8. Post-processing
        image = self.decode_latents(latents)

        # 9. Run safety checker
        image, has_nsfw_concept = self.run_safety_checker(
            image, device, prompt_embeds.dtype
        )

        # 10. Convert to PIL
        image = self.numpy_to_pil(image)
    else:
        # 8. Post-processing
        image = my_decode_latents(self, latents)

        # # 9. Run safety checker
        # image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)

    # # Offload last model to CPU
    # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
    #     self.final_offload_hook.offload()

    if not return_dict:
        return (image, has_nsfw_concept)

    return image


# %%
pipe.mycall = partial(mycall, pipe)
prompt = "a photo of a man surfing"
# prompt = ""


# # %%
# latents = torch.randn(1, 4, 64, 64).to(device).half()
# latents.requires_grad = True
# # %%
# with torch.no_grad():
#     latents = pipe.mycall(
#         prompt,
#         start_step=0,
#         end_step=6,
#         num_inference_steps=10,
#         latents=latents,
#         output_type='latent',
#     )
# # %%
# # with torch.no_grad():
# latents.requires_grad = True
# image = pipe.mycall(
#     prompt,
#     start_step=6,
#     num_inference_steps=10,
#     latents=latents,
# )
# # reshape to (1, 3, 224, 224)
# from torch.nn.functional import interpolate
# image = interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)

# # %%
# plt.imshow(image[0].permute(1, 2, 0).detach().cpu().numpy().astype(np.float32))
# plt.show()
# image.sum().backward()
# print(latents.grad.sum())
# latents.grad = None
# torch.cuda.empty_cache()


def show_latent(latents):
    with torch.no_grad():
        image = pipe.mycall(
            prompt,
            start_step=0,
            end_step=15,
            num_inference_steps=15,
            latents=latents,
            guidance_scale=0,
            generator=torch.Generator(device=device).manual_seed(0),
        )
        # reshape to (1, 3, 224, 224)
        from torch.nn.functional import interpolate

        image = interpolate(
            image, size=(224, 224), mode="bilinear", align_corners=False
        )

        fig = plt.figure(figsize=(3, 3))
        plt.imshow(image[0].permute(1, 2, 0).detach().cpu().numpy().astype(np.float32))
        plt.show()


# %%
cfg = load_from_yaml("/workspace/configs/dino_mania.yaml")
cfg.DATAMODULE.BATCH_SIZE = 1
dm: AllDatamodule = build_dm(cfg)
dm.setup()

model_args = (
    cfg,
    dm.num_voxel_dict,
    dm.roi_dict,
    dm.neuron_coords_dict,
    dm.noise_ceiling_dict,
)

model = VEModel(*model_args)

path = "/data/results/xgaa/yesgt_1/soup.pth"
sd = torch.load(path, map_location="cpu")
model.load_state_dict(sd, strict=False)
# remove backbone no grad
model.backbone.yes_grad = True
# model.eval()
model = model.to(device)
model.move_device()
# %%
batch = next(iter(dm.val_dataloader(subject="NSD_01")))


# %%
def show_img(img):
    img = img[0].permute(1, 2, 0).detach().cpu().numpy().astype(np.float32)
    # denormalize
    img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]
    plt.imshow(img)
    plt.show()


img, y, subject_id, session_id, eye_coords, darkness = batch
show_img(img)
y = y[0].to(device)
# %%

latents = torch.randn(
    1, 4, 64, 64, dtype=torch.float32 if device == "cpu" else torch.float16
).to(device)
latents.requires_grad = True
# with torch.no_grad():
#     latents = pipe.mycall(
#         prompt,
#         start_step=0,
#         end_step=6,
#         num_inference_steps=10,
#         latents=latents,
#         output_type='latent',
#     )
show_latent(latents)
# %%
torch.cuda.empty_cache()
n_iter = 50
lr = 0.03
latents = latents.detach().clone()
latents.requires_grad = True
grad_momentum = torch.zeros_like(latents)
beta = 0.9
for i in range(n_iter):
    image = pipe.mycall(
        prompt,
        start_step=0,
        end_step=6,
        num_inference_steps=6,
        latents=latents,
        guidance_scale=0,
        generator=torch.Generator(device=device).manual_seed(0),
    )
    # reshape to (1, 3, 224, 224)
    from torch.nn.functional import interpolate

    image = interpolate(image, size=(224, 224), mode="bilinear", align_corners=False)

    # voxel_index = torch.randperm(len(y))[:8000]
    voxel_index = dm.roi_dict["NSD_01"]["mid"]
    voxel_index = torch.tensor(voxel_index).long()

    neuron_out = model(
        image.float(),
        subject_id,
        session_id,
        voxel_indices_dict={"NSD_01": voxel_index},
    )[0][0]
    loss = ((neuron_out - y.float()[voxel_index])).abs().mean()
    loss.backward()
    print(loss.item())

    if i % 10 == 0:
        show_latent(latents)


    # print(latents.grad.mean().item(), latents.mean().item())
    # break
    grad = latents.grad
    top5p = np.percentile(grad.detach().cpu().numpy(), 90)
    bottom5p = np.percentile(grad.detach().cpu().numpy(), 10)
    # remove things in the middle
    bottom_mask = grad < bottom5p
    top_mask = grad > top5p
    mask = bottom_mask | top_mask
    grad[~mask] = 0
    grad[top_mask] = 1
    grad[bottom_mask] = -1
    
    grad_momentum = beta * grad_momentum + (1 - beta) * grad
    # correct for momentum
    grad_momentum = grad_momentum / (1 - beta ** (i + 1))
    
    # latents = latents - lr * grad_momentum / grad_momentum.norm() - lr * latents.grad / latents.grad.norm()
    latents = latents - lr * grad

    latents.grad = None
    latents = latents.detach().clone()
    latents.requires_grad = True

    # torch.cuda.empty_cache()
    # show_latent(latents)
    # break
# %%
