import torch.nn as nn
import torch
# import math
# from torchvision import transforms
import os
# from timm.models import create_model
from typing import Any, Dict, List, Optional, Union
from transformers import LlamaTokenizer
from diffusers import DiffusionPipeline
# from torchvision.transforms.functional import pil_to_tensor

# import torch
from PIL import Image
from torchvision import transforms

# from qformer.qformer_quantizer import Blip2QformerQuantizer
# from diffusers import StableUnCLIPImg2ImgPipeline
from .pipeline_stable_unclip_img2img import StableUnCLIPImg2ImgPipeline

WEIGHTS_NAME = 'seed_quantizer.pt'
DIFFUSION_NAME = 'diffusion_model'


class ImageTokenizer(nn.Module):
    def __init__(self,
                 model_path,
                 diffusion_model_path=None,
                 load_diffusion=False,
                 image_size=224,
                 device='cuda',
                 fp16=True,
                 model_root_dir="models",
                 **kwargs):
        super().__init__()
        from .seed_qformer.qformer_quantizer import Blip2QformerQuantizer

        model = Blip2QformerQuantizer.from_pretrained(pretrained_model_path=model_path,
                                                      vit_precision='fp16' if fp16 else 'fp32',
                                                      model_root_dir=model_root_dir,
                                                      **kwargs).eval()
        if diffusion_model_path is not None and load_diffusion:
            # diffusion_model = DiffusionPipeline.from_pretrained(diffusion_model_path,
            #                                                     torch_dtype=torch.float16 if fp16 else torch.float32)
            diffusion_model = StableUnCLIPImg2ImgPipeline.from_pretrained(diffusion_model_path,
                                                                          torch_dtype=torch.float16 if fp16 else torch.float32)
            self.diffusion_model = diffusion_model.to(device)
        else:
            self.diffusion_model = None

        model = model.to(device)

        processor = transforms.Compose([
            transforms.Resize((image_size, image_size), interpolation=3),
            # transforms.Resize(image_size, interpolation=3),
            # transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
        ])

        if fp16:
            model = model.half()

        shape_latents = torch.Size([1, 4, 96, 96])
        self.latents = torch.randn(shape_latents, generator=None, device=device, dtype=torch.float16, layout=torch.strided)

        shape_noise = torch.Size([1, 1024])
        self.noise = torch.randn(shape_noise, generator=None, device=device, dtype=torch.float16, layout=torch.strided)

        self.model = model
        self.processor = processor
        self.device = device
        self.fp16 = fp16

    def __len__(self):
        return self.model.n_embed

    def encode(self, image_torch):
        '''Convert a batch of img to code
        Args:
            model: The tokenizer model.
            img: [b, c, h, w]
        '''
        if len(image_torch.shape) == 3:
            image_torch = image_torch.unsqueeze(0)

        # img = image_torch.to(self.device)
        img = image_torch
        if self.fp16:
            img = img.half()
        with torch.no_grad():
            id, _ = self.model.get_codebook_indices(img)
        return id.view(img.shape[0], -1)

    def decode(self, indices, negative_indices=None, guidance_scale=10, num_inference_steps=20):
        image_embeds = self.model.get_codebook_entry(indices)
        # image = self.diffusion_model(image_embeds=image_embed,
        #                              noise_level=0,
        #                              num_inference_steps=20,
        #                              latents=self.latents,
        #                              noise=self.noise).images
        if negative_indices is not None:
            assert indices.shape == negative_indices.shape, 'Negative indices must have the same shape with indices'
            negative_image_embeds = self.model.get_codebook_entry(negative_indices)
        else:
            negative_image_embeds = None

        image = self.diffusion_model(
            image_embeds=image_embeds,
            negative_image_embeds=negative_image_embeds,
            guidance_scale=guidance_scale,
            noise_level=0,
            num_inference_steps=num_inference_steps,
            latents=self.latents,
        ).images
        return image


class SeedLlamaTokenizer(LlamaTokenizer):
    def __init__(self,
                 vocab_file,
                 unk_token="<unk>",
                 bos_token="<s>",
                 eos_token="</s>",
                 pad_token=None,
                 sp_model_kwargs: Optional[Dict[str, Any]] = None,
                 add_bos_token=True,
                 add_eos_token=False,
                 clean_up_tokenization_spaces=False,
                 device='cuda',
                 fp16=True,
                 load_diffusion=False,
                 encoder_url=None,
                 diffusion_path=None,
                 **kwargs):
        super().__init__(vocab_file, unk_token, bos_token, eos_token, pad_token, sp_model_kwargs, add_bos_token, add_eos_token,
                         clean_up_tokenization_spaces, **kwargs)
        self.device = device
        self.fp16 = fp16
        self.pad_token = self.unk_token
        self.load_diffusion = load_diffusion
        self.encoder_url = encoder_url
        self.diffusion_path = diffusion_path
        
        self.load_image_tokenizer()

    def load_image_tokenizer(self):
        if not hasattr(self, '_image_tokenizer'):
            if self.encoder_url is not None:
                model_path = self.encoder_url
            else:
                assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
                model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
            # diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
            # diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
            self._image_tokenizer = ImageTokenizer(model_path=model_path,
                                                   diffusion_model_path=self.diffusion_path,
                                                   load_diffusion=self.load_diffusion,
                                                   device=self.device,
                                                   fp16=self.fp16,
                                                   model_root_dir="/".join(model_path.split('/')[:-1]))

    @property
    def image_tokenizer(self):
        if not hasattr(self, '_image_tokenizer'):
            if self.encoder_url is not None:
                model_path = self.encoder_url
            else:
                assert hasattr(self, 'name_or_path') and os.path.exists(self.name_or_path)
                model_path = os.path.join(self.name_or_path, WEIGHTS_NAME)
            print('----')
            print(model_path)
            print('----')
            # diffusion_model_path = os.path.join(self.name_or_path, DIFFUSION_NAME)
            # diffusion_model_path = 'stabilityai/stable-diffusion-2-1-unclip'
            self._image_tokenizer = ImageTokenizer(model_path=model_path,
                                                   diffusion_model_path=self.diffusion_path,
                                                   load_diffusion=self.load_diffusion,
                                                   device=self.device,
                                                   fp16=self.fp16,
                                                   model_root_dir="/".join(model_path.split('/')[:-1]))
        return self._image_tokenizer

    @property
    def num_image_tokens(self):
        return 8192  # self.image_tokenizer.num_tokens # allow not load

    def named_parameters(self):
        return self.image_tokenizer.named_parameters()


    def to(self, device):
        self.device = device
        if hasattr(self, '_image_tokenizer'):
            self._image_tokenizer.to(device=device)

    def encode_image(
        self,
        image_path=None,
        image_pil=None,
        image_torch=None,
        image_size: int = 224,
    ):
        assert (image_path is None) + (image_pil is None) + (image_torch is None) == 2

        # need_norm_to_1 = False
        if image_path is not None:
            image_pil = Image.open(image_path).convert('RGB')

        if image_pil is not None:
            image_torch = self.image_tokenizer.processor(image_pil)

            image_torch = image_torch.to(self.device)
        return self.image_tokenizer.encode(image_torch)

    def decode_image(self, indices, negative_indices=None, guidance_scale=10):
        indices = indices.to(self.device)
        if negative_indices is not None:
            negative_indices = negative_indices.to(self.device)
        image = self.image_tokenizer.decode(
            indices,
            negative_indices=negative_indices,
            guidance_scale=guidance_scale,
        )
        return image
