# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from typing import List, Optional, Union, Tuple

import numpy as np
import torch
from PIL import Image
from transformers import (
    AutoTokenizer,
    PreTrainedTokenizer,
)

from .modeling_maskgit_vqgan import MaskGitVQGAN
from .modeling_movq import MOVQ
from .modeling_paella_vq import PaellaVQModel
from .modeling_taming_vqgan import VQGANModel
from .modeling_transformer import MaskGitTransformer, MaskGiTUViT
from .sampling import get_mask_chedule


class PipelineMuse:
    def __init__(
        self,
        vae: Union[VQGANModel, MOVQ, MaskGitVQGAN],
        transformer: Union[MaskGitTransformer, MaskGiTUViT],
        is_class_conditioned: bool = False,
        text_encoder: Optional[Union[T5EncoderModel, CLIPTextModel]] = None,
        tokenizer: Optional[PreTrainedTokenizer] = None,
    ) -> None:
        self.text_encoder = text_encoder
        self.tokenizer = tokenizer
        self.vae = vae
        self.transformer = transformer
        self.is_class_conditioned = is_class_conditioned
        self.device = "cpu"

    def to(self, device="cpu", dtype=torch.float32) -> None:
        self.device = device
        self.dtype = dtype
        
        if not self.is_class_conditioned:
            self.text_encoder.to(device, dtype=dtype)
        self.transformer.to(device, dtype=dtype)
        self.vae.to(device, dtype=torch.float32) # keep vae in fp32

        return self

    @torch.no_grad()
    def __call__(
        self,
        text: Optional[Union[str, List[str]]] = None,
        negative_text: Optional[Union[str, List[str]]] = "",
        prompt_embeds: Optional[torch.Tensor] = None,
        pooled_embeds: Optional[torch.Tensor] = None,
        negative_prompt_embeds: Optional[torch.Tensor] = None,
        negative_pooled_embeds: Optional[torch.Tensor] = None,
        class_ids: Optional[Union[int, List[int]]] = None,
        timesteps: int = 16,
        noise_schedule: str = "cosine",
        guidance_scale: float = 10.0,
        guidance_schedule=None,
        temperature: Union[float, Tuple[float]] = (2, 0),
        topk_filter_thres: float = 0.9,
        num_images_per_prompt: int = 1,
        use_maskgit_generate: bool = True,
        generator: Optional[torch.Generator] = None,
        use_fp16: bool = False,
        noise_type="mask",  # can be "mask" or "random_replace"
        predict_all_tokens=False,
        orig_size=(512, 512),
        crop_coords=(0, 0),
        aesthetic_score=6.0,
        return_intermediate: bool = False,
        use_tqdm=True,
        transformer_seq_len=None,
        clip_skip:int = None,
    ):
        if text is None and class_ids is None:
            raise ValueError("Either text or class_ids must be provided.")

        if text is not None and class_ids is not None:
            raise ValueError("Only one of text or class_ids may be provided.")
        
        if class_ids is not None:
            if isinstance(class_ids, int):
                class_ids = [class_ids]

            class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long)
            # duplicate class ids for each generation per prompt
            class_ids = class_ids.repeat_interleave(num_images_per_prompt, dim=0)
            model_inputs = {"class_ids": class_ids}
        else:
            if isinstance(text, str):
                text = [text]

            if prompt_embeds is None:
                input_ids = self.tokenizer(
                    text,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.tokenizer.model_max_length,
                ).input_ids  # TODO: remove hardcode
                input_ids = input_ids.to(self.device)

            if self.transformer.config.add_cond_embeds:
                if prompt_embeds is not None and pooled_embeds is not None:
                    pooled_embeds, encoder_hidden_states = pooled_embeds, prompt_embeds
                    pooled_embeds = pooled_embeds.to(self.device, dtype=self.text_encoder.dtype)
                    encoder_hidden_states = encoder_hidden_states.to(self.device, dtype=self.text_encoder.dtype)
                else:
                    clip_layer_idx = -(clip_skip+1)  if clip_skip is not None else -2
                    outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
                    pooled_embeds, encoder_hidden_states = outputs.text_embeds, outputs.hidden_states[clip_layer_idx]
            else:
                encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
                pooled_embeds = None

            if negative_text is not None:
                if isinstance(negative_text, str):
                    negative_text = [negative_text] * len(text)

                negative_input_ids = self.tokenizer(
                    negative_text,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.tokenizer.model_max_length,
                ).input_ids
                negative_input_ids = negative_input_ids.to(self.device)

                if self.transformer.config.add_cond_embeds:
                    outputs = self.text_encoder(negative_input_ids, return_dict=True, output_hidden_states=True)
                    negative_pooled_embeds = outputs.text_embeds
                    negative_encoder_hidden_states = outputs.hidden_states[-2]
                else:
                    negative_encoder_hidden_states = self.text_encoder(negative_input_ids).last_hidden_state
                    negative_pooled_embeds = None
            elif negative_prompt_embeds is not None:
                negative_encoder_hidden_states = negative_prompt_embeds.to(self.device, dtype=self.text_encoder.dtype)
                negative_pooled_embeds = negative_pooled_embeds.to(self.device, dtype=self.text_encoder.dtype)
            else:
                negative_encoder_hidden_states = None
                negative_pooled_embeds = None

            # duplicate text embeddings for each generation per prompt, using mps friendly method
            bs_embed, seq_len, _ = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
            encoder_hidden_states = encoder_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
            if pooled_embeds is not None:
                bs_embed, _ = pooled_embeds.shape
                pooled_embeds = pooled_embeds.repeat(1, num_images_per_prompt)
                pooled_embeds = pooled_embeds.view(bs_embed * num_images_per_prompt, -1)
                if negative_pooled_embeds is not None:
                    bs_embed, _ = negative_pooled_embeds.shape
                    negative_pooled_embeds = negative_pooled_embeds.repeat(1, num_images_per_prompt)
                    negative_pooled_embeds = negative_pooled_embeds.view(bs_embed * num_images_per_prompt, -1)
            if negative_encoder_hidden_states is not None:
                bs_embed, seq_len, _ = negative_encoder_hidden_states.shape
                negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
                negative_encoder_hidden_states = negative_encoder_hidden_states.view(
                    bs_embed * num_images_per_prompt, seq_len, -1
                )

            if negative_encoder_hidden_states is None:
                empty_input = self.tokenizer("", padding="max_length", return_tensors="pt").input_ids.to(
                    self.text_encoder.device
                )
                outputs = self.text_encoder(empty_input, output_hidden_states=True)
                empty_embeds = outputs.hidden_states[-2]
                empty_cond_embeds = outputs[0]
            else:
                empty_embeds, empty_cond_embeds = None, None

            model_inputs = {
                "encoder_hidden_states": encoder_hidden_states,
                "negative_embeds": negative_encoder_hidden_states,
                "cond_embeds": pooled_embeds,
                "negative_cond_embeds": negative_pooled_embeds,
                "empty_embeds": empty_embeds,
                "empty_cond_embeds": empty_cond_embeds,
            }

        if self.transformer.config.add_micro_cond_embeds:
            micro_conds = list(orig_size) + list(crop_coords) + [aesthetic_score]
            micro_conds = torch.tensor(micro_conds, device=self.device, dtype=encoder_hidden_states.dtype)
            micro_conds = micro_conds.unsqueeze(0)
            model_inputs["micro_conds"] = micro_conds

        generate = self.transformer.generate
        if use_maskgit_generate:
            generate = self.transformer.generate2

        with torch.autocast("cuda", enabled=use_fp16):
            outputs = generate(
                **model_inputs,
                timesteps=timesteps,
                guidance_scale=guidance_scale,
                guidance_schedule=guidance_schedule,
                temperature=temperature,
                topk_filter_thres=topk_filter_thres,
                generator=generator,
                noise_type=noise_type,
                noise_schedule=get_mask_chedule(noise_schedule),
                predict_all_tokens=predict_all_tokens,
                return_intermediate=return_intermediate,
                use_tqdm=use_tqdm,
                seq_len=transformer_seq_len,
            )

            if return_intermediate:
                generated_tokens, intermediate = outputs
            else:
                generated_tokens = outputs

        images = self.vae.decode_code(generated_tokens)
        if return_intermediate:
            intermediate_images = [self.vae.decode_code(tokens) for tokens in intermediate]

        # Convert to PIL images
        images = [self.to_pil_image(image) for image in images]
        if return_intermediate:
            intermediate_images = [[self.to_pil_image(image) for image in images] for images in intermediate_images]
            return images, intermediate_images

        return images

    def to_pil_image(self, image: torch.Tensor):
        image = image.permute(1, 2, 0).cpu().numpy()
        image = 2.0 * image - 1.0
        image = np.clip(image, -1.0, 1.0)
        image = (image + 1.0) / 2.0
        image = (255 * image).astype(np.uint8)
        image = Image.fromarray(image).convert("RGB")
        return image

    @classmethod
    def from_pretrained(
        cls,
        model_name_or_path: str = None,
        text_encoder_path: Optional[str] = None,
        vae_path: Optional[str] = None,
        transformer_path: Optional[str] = None,
        vae = None,
        text_encoder = None,
        transformer = None,
        is_class_conditioned: bool = False,
    ) -> None:
        """
        Instantiate a PipelineMuse from a pretrained model. Either model_name_or_path or all of text_encoder_path, vae_path, and
        transformer_path must be provided.
        """
        if model_name_or_path is None:
            if text_encoder_path is None or vae_path is None or transformer_path is None:
                raise ValueError(
                    "If model_name_or_path is None, then text_encoder_path, vae_path, and transformer_path must be"
                    " provided."
                )

            text_encoder_args = None
            tokenizer_args = None

            if not is_class_conditioned:
                text_encoder_args = {"pretrained_model_name_or_path": text_encoder_path}
                tokenizer_args = {"pretrained_model_name_or_path": text_encoder_path}

            vae_args = {"pretrained_model_name_or_path": vae_path}
            transformer_args = {"pretrained_model_name_or_path": transformer_path}
        else:
            text_encoder_args = None
            tokenizer_args = None

            if not is_class_conditioned:
                text_encoder_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "text_encoder"}
                tokenizer_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "text_encoder"}

            vae_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "vae"}
            transformer_args = {"pretrained_model_name_or_path": model_name_or_path, "subfolder": "transformer"}

        if not is_class_conditioned:
            tokenizer = AutoTokenizer.from_pretrained(**tokenizer_args)

        transformer_config = MaskGitTransformer.load_config(**transformer_args)
        if transformer is not None:
            ...
        elif transformer_config["_class_name"] == "MaskGitTransformer":
            transformer = MaskGitTransformer.from_pretrained(**transformer_args)
        elif transformer_config["_class_name"] == "MaskGiTUViT" or transformer_config["_class_name"] == "MaskGiTUViT_v2":
            transformer = MaskGiTUViT.from_pretrained(**transformer_args)
        else:
            raise ValueError(f"Unknown Transformer class: {transformer_config['_class_name']}")

        # Hacky way to load different VQ models
        vae_config = MaskGitVQGAN.load_config(**vae_args)
        if vae is not None:
            ...
        elif vae_config["_class_name"] == "VQGANModel":
            vae = VQGANModel.from_pretrained(**vae_args)
        elif vae_config["_class_name"] == "MaskGitVQGAN":
            vae = MaskGitVQGAN.from_pretrained(**vae_args)
        elif vae_config["_class_name"] == "MOVQ":
            vae = MOVQ.from_pretrained(**vae_args)
        elif vae_config["_class_name"] == "PaellaVQModel":
            vae = PaellaVQModel.from_pretrained(**vae_args)
        else:
            raise ValueError(f"Unknown VAE class: {vae_config['_class_name']}")
        
        if is_class_conditioned:
            return cls(
                vae=vae,
                transformer=transformer,
                is_class_conditioned=is_class_conditioned,
            )

        return cls(
            vae=vae,
            transformer=transformer,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            is_class_conditioned=is_class_conditioned,
        )

    def save_pretrained(
        self,
        save_directory: Union[str, os.PathLike],
    ) -> None:
        """
        Save the pipeline's model and tokenizer to the specified directory.
        """

        if not self.is_class_conditioned:
            self.text_encoder.save_pretrained(os.path.join(save_directory, "text_encoder"))
            self.tokenizer.save_pretrained(os.path.join(save_directory, "text_encoder"))

        self.vae.save_pretrained(os.path.join(save_directory, "vae"))
        self.transformer.save_pretrained(os.path.join(save_directory, "transformer"))


class PipelineMuseInpainting(PipelineMuse):
    @torch.no_grad()
    def __call__(
        self,
        image: Image,
        mask: torch.BoolTensor,
        text: Optional[Union[str, List[str]]] = None,
        negative_text: Optional[Union[str, List[str]]] = None,
        class_ids: torch.LongTensor = None,
        timesteps: int = 8,
        guidance_scale: float = 8.0,
        guidance_schedule=None,
        temperature: float = 1.0,
        topk_filter_thres: float = 0.9,
        num_images_per_prompt: int = 1,
        use_maskgit_generate: bool = True,
        generator: Optional[torch.Generator] = None,
        use_fp16: bool = False,
        image_size: int = 256,
        orig_size=(256, 256),
        crop_coords=(0, 0),
        aesthetic_score=6.0,
    ):
        from torchvision import transforms

        assert use_maskgit_generate
        if text is None and class_ids is None:
            raise ValueError("Either text or class_ids must be provided.")

        if text is not None and class_ids is not None:
            raise ValueError("Only one of text or class_ids may be provided.")

        encode_transform = transforms.Compose(
            [
                transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
                transforms.CenterCrop(image_size),
                transforms.ToTensor(),
            ]
        )
        pixel_values = encode_transform(image).unsqueeze(0).to(self.device)
        _, image_tokens = self.vae.encode(pixel_values)
        mask_token_id = self.transformer.config.mask_token_id

        image_tokens[mask[None]] = mask_token_id

        image_tokens = image_tokens.repeat(num_images_per_prompt, 1)
        if class_ids is not None:
            if isinstance(class_ids, int):
                class_ids = [class_ids]

            class_ids = torch.tensor(class_ids, device=self.device, dtype=torch.long)
            # duplicate class ids for each generation per prompt
            class_ids = class_ids.repeat_interleave(num_images_per_prompt, dim=0)
            model_inputs = {"class_ids": class_ids}
        else:
            if isinstance(text, str):
                text = [text]

            input_ids = self.tokenizer(
                text,
                return_tensors="pt",
                padding="max_length",
                truncation=True,
                max_length=self.tokenizer.model_max_length,
            ).input_ids  # TODO: remove hardcode
            input_ids = input_ids.to(self.device)

            if self.transformer.config.add_cond_embeds:
                outputs = self.text_encoder(input_ids, return_dict=True, output_hidden_states=True)
                pooled_embeds, encoder_hidden_states = outputs.text_embeds, outputs.hidden_states[-2]
            else:
                encoder_hidden_states = self.text_encoder(input_ids).last_hidden_state
                pooled_embeds = None

            if negative_text is not None:
                if isinstance(negative_text, str):
                    negative_text = [negative_text]

                negative_input_ids = self.tokenizer(
                    negative_text,
                    return_tensors="pt",
                    padding="max_length",
                    truncation=True,
                    max_length=self.tokenizer.model_max_length,
                ).input_ids
                negative_input_ids = negative_input_ids.to(self.device)
                negative_encoder_hidden_states = self.text_encoder(negative_input_ids).last_hidden_state
            else:
                negative_encoder_hidden_states = None

            # duplicate text embeddings for each generation per prompt, using mps friendly method
            bs_embed, seq_len, _ = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
            encoder_hidden_states = encoder_hidden_states.view(bs_embed * num_images_per_prompt, seq_len, -1)
            if negative_encoder_hidden_states is not None:
                bs_embed, seq_len, _ = negative_encoder_hidden_states.shape
                negative_encoder_hidden_states = negative_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
                negative_encoder_hidden_states = negative_encoder_hidden_states.view(
                    bs_embed * num_images_per_prompt, seq_len, -1
                )

            empty_input = self.tokenizer("", padding="max_length", return_tensors="pt").input_ids.to(
                self.text_encoder.device
            )
            outputs = self.text_encoder(empty_input, output_hidden_states=True)
            empty_embeds = outputs.hidden_states[-2]
            empty_cond_embeds = outputs[0]

            model_inputs = {
                "encoder_hidden_states": encoder_hidden_states,
                "negative_embeds": negative_encoder_hidden_states,
                "empty_embeds": empty_embeds,
                "empty_cond_embeds": empty_cond_embeds,
                "cond_embeds": pooled_embeds,
            }

        if self.transformer.config.add_micro_cond_embeds:
            micro_conds = list(orig_size) + list(crop_coords) + [aesthetic_score]
            micro_conds = torch.tensor(micro_conds, device=self.device, dtype=encoder_hidden_states.dtype)
            micro_conds = micro_conds.unsqueeze(0)
            model_inputs["micro_conds"] = micro_conds

        generate = self.transformer.generate2
        with torch.autocast("cuda", enabled=use_fp16):
            generated_tokens = generate(
                input_ids=image_tokens,
                **model_inputs,
                timesteps=timesteps,
                guidance_scale=guidance_scale,
                guidance_schedule=guidance_schedule,
                temperature=temperature,
                topk_filter_thres=topk_filter_thres,
                generator=generator,
            )
        images = self.vae.decode_code(generated_tokens)

        # Convert to PIL images
        images = [self.to_pil_image(image) for image in images]
        return images
