import torch
from lavis.common.registry import registry

from lavis.models.albef_models import AlbefBase
from lavis.models.base_model import MomentumDistilationMixin, SharedQueueMixin
from lavis.models.blip2_models.blip2_t5 import Blip2T5

from lavis.models.ldm_models.models.diffusion.ddpm import FrozenOpenCLIPEmbedder, disabled_train
from lavis.models.ldm_models.models.modules.distributions.distributions import DiagonalGaussianDistribution
from lavis.models.mlps.mlps import MLPWithLayerNorm



@registry.register_model("txt_spec_cdt")
class TxtSpecConditioning(AlbefBase, MomentumDistilationMixin, SharedQueueMixin):
    """
    """
    PRETRAINED_MODEL_CONFIG_DICT = {
        "albef": "configs/models/albef_retrieval_flickr.yaml",
        "blip2_t5": "configs/models/blip2/blip2_caption_flant5xl.yaml"
    }

    def __init__(
        self,
        text_encoder,
        # extractor,
        blip2_model,
        proj_input_dim,
        proj_output_dim,
        proj_hidden_dim=2048,
        proj_num_layers=3,
    ):
        super().__init__()

        self.text_encoder = text_encoder
        self.blip2_proj = MLPWithLayerNorm(input_dim=proj_input_dim, hidden_dim=proj_hidden_dim,
                                           output_dim=proj_output_dim, num_layers=proj_num_layers)
        self.blip2_model = blip2_model
        self.ldm = text_encoder  # todo: temp code for task adaptation

    def _rampup_factor(self, epoch, iters, num_iters_per_epoch):
        return min(1, (epoch * num_iters_per_epoch + iters) / (2 * num_iters_per_epoch))

    def get_text_embeds(self, captions):
        if hasattr(self.text_encoder, 'encode') and callable(self.text_encoder.encode):
            c = self.text_encoder.encode(captions)
            if isinstance(c, DiagonalGaussianDistribution):
                c = c.mode()
        else:
            c = self.cond_stage_model(captions)
        return c

    def forward(self, samples):
        """
        """
        image = samples["image"]
        caption = samples["text_input"]

        text_embeds = self.get_text_embeds(caption)

        with self.blip2_model.maybe_autocast():
            image_embeds = self.blip2_model.ln_vision(self.blip2_model.visual_encoder(image))
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            image.device
        )

        query_tokens = self.blip2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)
        query_output = self.blip2_model.Qformer.bert(
            query_embeds=query_tokens,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            return_dict=True,
        )

        inputs_t5 = self.blip2_model.t5_proj(query_output.last_hidden_state)
        inputs_t5_spec = self.blip2_model.t5_proj(self.blip2_proj(text_embeds))
        inputs_t5_mixed = torch.cat([inputs_t5, inputs_t5_spec], dim=1).to(inputs_t5.device)
        atts_t5 = torch.ones(inputs_t5_mixed.size()[:-1], dtype=torch.long).to(image.device)

        with self.blip2_model.maybe_autocast(dtype=torch.bfloat16):
            input_tokens = self.blip2_model.t5_tokenizer(
                ["a picture of "]*len(samples["text_input"]),
                padding="longest",
                truncation=True,
                max_length=self.blip2_model.max_txt_len,
                return_tensors="pt",
            ).to(image.device)
            output_tokens = self.blip2_model.t5_tokenizer(
                samples["text_input"],
                padding="longest",
                truncation=True,
                max_length=self.blip2_model.max_txt_len,
                return_tensors="pt",
            ).to(image.device)

            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)

            targets = output_tokens.input_ids.masked_fill(
                output_tokens.input_ids == self.blip2_model.t5_tokenizer.pad_token_id, -100
            )

            inputs_embeds = self.blip2_model.t5_model.encoder.embed_tokens(input_tokens.input_ids)
            inputs_embeds = torch.cat([inputs_t5_mixed, inputs_embeds], dim=1)

            outputs = self.blip2_model.t5_model(
                inputs_embeds=inputs_embeds,
                attention_mask=encoder_atts,
                decoder_attention_mask=output_tokens.attention_mask,
                return_dict=True,
                labels=targets,
            )
            loss = outputs.loss

            return loss

    def gene_val_txts(self, samples):
        """Val generation effect by sampling one training sample."""
        image = samples["image"][:1, ...]
        caption = samples["text_input"][:1]
        print(f"Original caption: {caption}")
        image_id = samples["image_id"][0]

        b = image.shape[0]

        with torch.no_grad():
            text_embeds = self.get_text_embeds(caption)

            with self.blip2_model.maybe_autocast():
                image_embeds = self.blip2_model.ln_vision(self.blip2_model.visual_encoder(image))
            image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                image.device
            )

            query_tokens = self.blip2_model.query_tokens.expand(image_embeds.shape[0], -1, -1)
            query_output = self.blip2_model.Qformer.bert(
                query_embeds=query_tokens,
                encoder_hidden_states=image_embeds,
                encoder_attention_mask=image_atts,
                return_dict=True,
            )

            inputs_t5 = self.blip2_model.t5_proj(query_output.last_hidden_state)
            inputs_t5_spec = self.blip2_model.t5_proj(self.blip2_proj(text_embeds))
            inputs_t5_mixed = torch.cat([inputs_t5, inputs_t5_spec], dim=1).to(inputs_t5.device)

            atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device)
            atts_t5_spec = torch.ones(inputs_t5_spec.size()[:-1], dtype=torch.long).to(image.device)
            atts_t5_mixed = torch.ones(inputs_t5_mixed.size()[:-1], dtype=torch.long).to(image.device)

            prompt = self.blip2_model.prompt
            input_tokens = self.blip2_model.t5_tokenizer(
                prompt, padding="longest", return_tensors="pt"
            ).to(image.device)

            encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1)
            encoder_atts_spec = torch.cat([atts_t5_spec, input_tokens.attention_mask], dim=1)
            encoder_atts_mixed = torch.cat([atts_t5_mixed, input_tokens.attention_mask], dim=1)

            with self.blip2_model.maybe_autocast(dtype=torch.bfloat16):
                inputs_embeds = self.blip2_model.t5_model.encoder.embed_tokens(input_tokens.input_ids)
                inputs_embeds_ori = torch.cat([inputs_t5, inputs_embeds], dim=1)
                inputs_embeds_spec = torch.cat([inputs_t5_spec, inputs_embeds], dim=1)
                inputs_embeds_mixed = torch.cat([inputs_t5_mixed, inputs_embeds], dim=1)

                # todo: make these configable
                use_nucleus_sampling = False
                num_beams = 5
                max_length = 30
                min_length = 1
                top_p = 0.9
                repetition_penalty = 1.0
                length_penalty = 1.0
                num_captions = 1
                temperature = 1

                ori_outputs = self.blip2_model.t5_model.generate(
                    inputs_embeds=inputs_embeds_ori,
                    attention_mask=encoder_atts,
                    do_sample=use_nucleus_sampling,
                    top_p=top_p,
                    temperature=temperature,
                    num_beams=num_beams,
                    max_new_tokens=max_length,
                    min_length=min_length,
                    repetition_penalty=repetition_penalty,
                    length_penalty=length_penalty,
                    num_return_sequences=num_captions,
                )
                ori_output_text = self.blip2_model.t5_tokenizer.batch_decode(
                    ori_outputs, skip_special_tokens=True
                )
                print(f"Original output: {ori_output_text}")

                spec_outputs = self.blip2_model.t5_model.generate(
                    inputs_embeds=inputs_embeds_spec,
                    attention_mask=encoder_atts_spec,
                    do_sample=use_nucleus_sampling,
                    top_p=top_p,
                    temperature=temperature,
                    num_beams=num_beams,
                    max_new_tokens=max_length,
                    min_length=min_length,
                    repetition_penalty=repetition_penalty,
                    length_penalty=length_penalty,
                    num_return_sequences=num_captions,
                )
                spec_output_text = self.blip2_model.t5_tokenizer.batch_decode(
                    spec_outputs, skip_special_tokens=True
                )
                print(f"Specific output: {spec_output_text}")

                mixed_outputs = self.blip2_model.t5_model.generate(
                    inputs_embeds=inputs_embeds_mixed,
                    attention_mask=encoder_atts_mixed,
                    do_sample=use_nucleus_sampling,
                    top_p=top_p,
                    temperature=temperature,
                    num_beams=num_beams,
                    max_new_tokens=max_length,
                    min_length=min_length,
                    repetition_penalty=repetition_penalty,
                    length_penalty=length_penalty,
                    num_return_sequences=num_captions,
                )
                mixed_output_text = self.blip2_model.t5_tokenizer.batch_decode(
                    mixed_outputs, skip_special_tokens=True
                )
                print(f"Mixed output: {mixed_output_text}")

        return len(spec_output_text)

    @classmethod
    def from_config(cls, cfg=None):
        vit_model = cfg.get("vit_model", None)
        if vit_model == "eva_clip_g":
            blip2_t5 = Blip2T5.from_config(cfg)
        else:
            raise NotImplementedError()

        blip2_t5.eval()
        blip2_t5.train = disabled_train
        for param in blip2_t5.parameters():
            param.requires_grad = False

        cond_stage_config = cfg.get("cond_stage_config", None)
        assert cond_stage_config is not None
        ldm_cond_model = FrozenOpenCLIPEmbedder(**cond_stage_config.get("params", dict()))
        # ldm_cond_model.eval()
        ldm_cond_model.train()
        # ldm_cond_model.train = disabled_train
        for param in ldm_cond_model.parameters():
            param.requires_grad = True

        proj_input_dim = ldm_cond_model.model.ln_final.normalized_shape[0]
        proj_output_dim = blip2_t5.Qformer.config.hidden_size
        print(f"In/Out dims:{proj_input_dim, proj_output_dim}")
        model = cls(
            text_encoder=ldm_cond_model,
            # extractor=text_encoder,
            blip2_model=blip2_t5,
            proj_input_dim=proj_input_dim,
            proj_output_dim=proj_output_dim,
        )

        return model