import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms
from torchvision import transforms
from lavis.common.registry import registry
from lavis.common.utils import get_abs_path
from lavis.models.albef_models import AlbefBase
from lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin
from lavis.models.med import BertForMaskedLM
from lavis.models.vit import VisionTransformerEncoder
from transformers import BertConfig

from lavis.models.blip2_models.blip2 import Blip2Base
from lavis.models.blip2_models.blip2_t5 import Blip2T5

from lavis.models.ldm_models.models.diffusion.ddpm import LatentDiffusion, ImageEmbeddingConditionedLatentDiffusion, default
from lavis.models.ldm_models.models.diffusion.ddim import DDIMSampler
from lavis.models.mlps.mlps import MLPWithLayerNorm

import argparse, os
import cv2
from PIL import Image
from einops import rearrange


@registry.register_model("img_spec_cdt")
class ImgSpecConditioning(AlbefBase, MomentumDistilationMixin, SharedQueueMixin):
    """
    """
    PRETRAINED_MODEL_CONFIG_DICT = {
        "albef": "configs/models/albef_retrieval_flickr.yaml",
        "blip2_t5": "configs/models/blip2/blip2_caption_flant5xl.yaml"
    }

    def __init__(
        self,
        image_encoder,
        ldm,
        proj_input_dim,
        proj_output_dim,
        proj_hidden_dim=2048,
        proj_num_layers=3,
        vit_model=None,
    ):
        super().__init__()

        self.visual_encoder = image_encoder
        self.ldm_c_proj = MLPWithLayerNorm(input_dim=proj_input_dim, hidden_dim=proj_hidden_dim,
                                           output_dim=proj_output_dim, num_layers=proj_num_layers)
        self.ldm = ldm
        self.vit_model = vit_model

    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):
        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))

    def get_image_embeds(self, image):
        if self.vit_model != None:
            if self.vit_model == "clip_L":
                pass  # TODO: implements for bilp2-base
            elif self.vit_model == "eva_clip_g":
                # with self.maybe_autocast():
                resize = transforms.Resize((364, 364))  # todo: fast implement
                image_embeds = self.visual_encoder.ln_vision(self.visual_encoder.visual_encoder(resize(image)))
                image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                    image.device
                )

                query_tokens = self.visual_encoder.query_tokens.expand(image_embeds.shape[0], -1, -1)
                query_output = self.visual_encoder.Qformer.bert(
                    query_embeds=query_tokens,
                    encoder_hidden_states=image_embeds,
                    encoder_attention_mask=image_atts,
                    return_dict=True,
                )

                return query_output.last_hidden_state

        else:
            return self.visual_encoder.forward_features(image)

    def forward(self, samples):
        """
        """
        image = samples["image"]
        caption = samples["text_input"]

        image_embeds = self.get_image_embeds(image)
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            self.device
        )

        resize_for_ldm = torchvision.transforms.Resize((384, 384))
        x = resize_for_ldm(image)  # TODO: resize may damage the information
        latent_x = self.ldm.get_first_stage_encoding(self.ldm.encode_first_stage(x))  # move to latent space
        c = self.ldm.get_learned_conditioning(caption)

        img_spec_cdt = self.ldm_c_proj(image_embeds)
        c_mixed = torch.cat([c[:, :45, :], img_spec_cdt], dim=1).to(latent_x.device)  # TODO: make this configable

        t = torch.randint(0, self.ldm.num_timesteps, (x.shape[0],), device=self.ldm.device).long()
        noise = default(None, lambda: torch.randn_like(latent_x))
        x_noisy = self.ldm.q_sample(x_start=latent_x, t=t, noise=noise)
        x_noisy.requires_grad_(True)
        model_output = self.ldm.apply_model(x_noisy, t, c_mixed)

        loss_dict = {}
        prefix = 'train' if self.ldm.training else 'val'
        if self.ldm.parameterization == "x0":
            target = latent_x
        elif self.ldm.parameterization == "eps":
            target = noise
        elif self.ldm.parameterization == "v":
            target = self.ldm.get_v(latent_x, noise, t)
        else:
            raise NotImplementedError()

        loss_simple = self.ldm.get_loss(model_output, target, mean=False).mean([1, 2, 3])
        loss_dict.update({f'{prefix}/loss_simple': loss_simple.mean()})
        if self.ldm.logvar.device != self.ldm.device:  # TODO: fast adjust for indices runtime error
            self.ldm.logvar = self.ldm.logvar.cuda()
        logvar_t = self.ldm.logvar[t].to(self.ldm.device)
        loss = loss_simple / torch.exp(logvar_t) + logvar_t
        # loss = loss_simple / torch.exp(self.logvar) + self.logvar
        if self.ldm.learn_logvar:
            loss_dict.update({f'{prefix}/loss_gamma': loss.mean()})
            loss_dict.update({'logvar': self.ldm.logvar.data.mean()})
        loss = self.ldm.l_simple_weight * loss.mean()
        loss_vlb = self.ldm.get_loss(model_output, target, mean=False).mean(dim=(1, 2, 3))
        loss_vlb = (self.ldm.lvlb_weights[t] * loss_vlb).mean()
        loss_dict.update({f'{prefix}/loss_vlb': loss_vlb})
        loss += (self.ldm.original_elbo_weight * loss_vlb)
        loss_dict.update({f'{prefix}/loss': loss})

        return loss

    def gene_val_imgs(self, samples):
        """Val generation effect by sampling one training sample."""
        image = samples["image"][:1, ...]
        caption = samples["text_input"][:1]
        print(caption)
        image_id = samples["image_id"][0]

        b = image.shape[0]

        with torch.no_grad():
            if self.vit_model == "eva_clip_g":
                image_embeds = self.get_image_embeds(image)

                resize_for_ldm = torchvision.transforms.Resize((384, 384))
                x = resize_for_ldm(image)  # TODO: resize may damage the information
                # with torch.no_grad():
                latent_x = self.ldm.get_first_stage_encoding(self.ldm.encode_first_stage(x))  # move to latent space
                c = self.ldm.get_learned_conditioning(caption)

                img_spec_cdt = self.ldm_c_proj(image_embeds)
                c_mixed = torch.cat([c[:, :45, :], img_spec_cdt], dim=1).to(
                    latent_x.device)  # TODO: make this configable

            else:
                image_embeds = self.visual_encoder.forward_features(image)
                image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                    self.device
                )

                spec_ebd = torch.randn((b, 38, 768)).to(image.device)  # TODO: make this configable
                extractor_out = self.extractor.bert(
                    encoder_embeds=spec_ebd,
                    # attention_mask=spec_attn,
                    encoder_hidden_states=image_embeds,
                    encoder_attention_mask=image_atts,
                    return_dict=True,
                    mode="fusion",
                )
                img_spec_cdt = self.ebd_mapping_fc(extractor_out.last_hidden_state)

                c = self.ldm.get_learned_conditioning(caption)

                c_mixed = torch.cat([c[:, :39, :], img_spec_cdt], dim=1).to(c.device)

            sampler = DDIMSampler(self.ldm, device=torch.device("cuda"))  # TODO: make these configable
            shape = [4, 512 // 8, 512 // 8]
            uc = self.ldm.get_learned_conditioning(b * [""])
            mix_samples, _ = sampler.sample(S=50,
                                             conditioning=c_mixed,
                                             batch_size=b,
                                             shape=shape,
                                             verbose=False,
                                             unconditional_guidance_scale=9.0,
                                             unconditional_conditioning=uc,
                                             eta=0.0,
                                             )

            mix_samples = self.ldm.decode_first_stage(mix_samples)
            mix_samples = torch.clamp((mix_samples + 1.0) / 2.0, min=0.0, max=1.0)

            sample_path = "./results/mixed_imgs"
            os.makedirs(sample_path, exist_ok=True)
            sample_count = 0
            base_count = len(os.listdir(sample_path))

            for x_sample in mix_samples:
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(x_sample.astype(np.uint8))
                img.save(os.path.join(sample_path, f"mixed_{base_count:05}_{image_id}.png"))


            ori_samples, _ = sampler.sample(S=50,
                                            conditioning=c,
                                            batch_size=b,
                                            shape=shape,
                                            verbose=False,
                                            unconditional_guidance_scale=9.0,
                                            unconditional_conditioning=uc,
                                            eta=0.0,
                                            )

            ori_samples = self.ldm.decode_first_stage(ori_samples)
            ori_samples = torch.clamp((ori_samples + 1.0) / 2.0, min=0.0, max=1.0)

            sample_path = "./results/orig_imgs"
            os.makedirs(sample_path, exist_ok=True)
            sample_count = 0
            base_count = len(os.listdir(sample_path))

            for x_sample in ori_samples:
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(x_sample.astype(np.uint8))
                img.save(os.path.join(sample_path, f"orig_{base_count:05}_{image_id}.png"))


            spec_samples, _ = sampler.sample(S=50,
                                        conditioning=img_spec_cdt,
                                        batch_size=b,
                                        shape=shape,
                                        verbose=False,
                                        unconditional_guidance_scale=9.0,
                                        unconditional_conditioning=uc[:, :img_spec_cdt.shape[1], :],
                                        eta=0.0,
                                        )

            spec_samples = self.ldm.decode_first_stage(spec_samples)
            spec_samples = torch.clamp((spec_samples + 1.0) / 2.0, min=0.0, max=1.0)

            sample_path = "./results/spec_imgs"
            os.makedirs(sample_path, exist_ok=True)
            sample_count = 0
            base_count = len(os.listdir(sample_path))

            for x_sample in spec_samples:
                x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')
                img = Image.fromarray(x_sample.astype(np.uint8))
                img.save(os.path.join(sample_path, f"spec_{base_count:05}_{image_id}.png"))

        return len(spec_samples)

    @classmethod
    def from_config(cls, cfg=None):
        vit_model = cfg.get("vit_model", None)
        if vit_model != None and vit_model in ["clip_L", "eva_clip_g"]:
            img_size = cfg.get("img_size", 224)
            drop_path_rate = cfg.get("drop_path_rate", 0)
            use_grad_checkpoint = cfg.get("use_grad_checkpoint", False)
            vit_precision = cfg.get("vit_precision", "fp16")

            if vit_model == "clip_L":
                image_encoder = Blip2Base.init_vision_encoder(vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision)
            elif vit_model == "eva_clip_g":
                blip2_t5 = Blip2T5.from_config(cfg)
                del blip2_t5.t5_tokenizer
                del blip2_t5.t5_model
                image_encoder = blip2_t5
            else:
                raise NotImplementedError()
        else:
            image_encoder = VisionTransformerEncoder.from_config(cfg, from_pretrained=True)

        if cfg.ldm_configs["target"] == "ldm.models.diffusion.ddpm.ImageEmbeddingConditionedLatentDiffusion":
            ldm = ImageEmbeddingConditionedLatentDiffusion(**cfg.ldm_configs["params"])
        else:
            ldm = LatentDiffusion(**cfg.ldm_configs["params"])

        if cfg.freeze_ldm:
            for n, p in ldm.named_parameters():
                p.requires_grad = False

        ldm_ckpt = "../stablediffusion-main/checkpoints/sd21-unclip-l.ckpt"
        print(f"Loading model from {ldm_ckpt}")
        pl_sd = torch.load(ldm_ckpt, map_location="cpu")
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
        sd = pl_sd["state_dict"]
        m, u = ldm.load_state_dict(sd, strict=False)
        if len(m) > 0:
            print("missing keys:")
            print(m)
        if len(u) > 0:
            print("unexpected keys:")
            print(u)
        ldm.cuda()

        proj_input_dim = image_encoder.Qformer.config.hidden_size
        proj_output_dim = ldm.cond_stage_model.model.ln_final.normalized_shape[0]
        print(f"In/Out dims:{proj_input_dim, proj_output_dim}")
        model = cls(
            image_encoder=image_encoder,
            proj_input_dim=proj_input_dim,
            proj_output_dim=proj_output_dim,
            ldm=ldm,
            vit_model=vit_model,
        )

        return model