"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import timm
import urllib
import torch
from torch import nn
from lavis.common.registry import registry

from lavis.models.texpel_models.texpel import TexpelBase
from lavis.models.texpel_models.texpel_outputs import (
    TexpelOutput,
    TexpelIntermediateOutput,
)
from lavis.models.med import XBertLMHeadDecoder
from lavis.models.vit import VisionTransformerEncoder
from lavis.mae import models_mae
from copy import deepcopy
from PIL import Image
from lavis.stable_diffusion import StableDiffusion

from torchvision import transforms
import warnings

warnings.filterwarnings("ignore")

import nltk
# nltk.download('punkt')
# nltk.download('stopwords')
from collections import Counter
from nltk.corpus import stopwords



def prepare_model(chkpt_dir, arch='mae_vit_large_patch16'):
    # build model
    model = getattr(models_mae, arch)()
    # load model
    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(checkpoint['model'], strict=False)
    print(msg)
    return model


@registry.register_model("texpel_caption")
class TexpelCaption(TexpelBase):
    """
    TEXPEL captioning model.

    Supported model types:
        - base_coco: fine-tuned TEXPEL base model on COCO caption dataset (Karparthy split).
        - large_coco: fine-tuned TEXPEL large model on COCO caption dataset (Karparthy split).

    Usage:
        >>> from lavis.models import load_model
        >>> model = load_model("texpel_caption", "base_coco")
        >>> model = load_model("texpel_caption", "large_coco")
    """

    PRETRAINED_MODEL_CONFIG_DICT = {
        "base_coco": "configs/models/texpel_caption_base_coco.yaml",
        "large_coco": "configs/models/texpel_caption_large_coco.yaml",
    }

    def __init__(self, classifier_encoder, translator, classifier_pred, image_encoder, text_decoder, sd_model=None,
                 prompt=None,
                 max_txt_len=40, mse_loss_coef=1):
        super().__init__()
        # self.mse_loss = nn.CosineSimilarity(dim=0)
        self.mse_loss = nn.MSELoss()
        # self.mse_loss = nn.CrossEntropyLoss()
        self.mse_loss_coef = mse_loss_coef
        self.loss_heads = nn.MSELoss()
        self.classifier_encoder = classifier_encoder
        self.translator = translator
        self.tokenizer = self.init_tokenizer()

        self.visual_encoder = image_encoder
        self.text_decoder = text_decoder

        self.cls_pred = classifier_pred

        self.sd_model = sd_model

        self.prompt = prompt
        self.prompt_length = len(self.tokenizer(self.prompt).input_ids) - 1

        self.max_txt_len = max_txt_len

        self.normalizer = transforms.Compose([
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def forward_encoder(self, samples):
        image_embeds = self.visual_encoder.forward_features(samples["image"])
        return image_embeds

    def forward_decoder(self, samples, image_embeds):
        # prepare inputs for forwarding decoder
        raw_text = samples["text_input"]
        text = self.tokenizer(
            raw_text,
            padding="longest",
            truncation=True,
            max_length=self.max_txt_len,
            return_tensors="pt",
        ).to(self.device)
        text.input_ids[:, 0] = self.tokenizer.bos_token_id

        # prepare targets for forwarding decoder
        decoder_targets = text.input_ids.masked_fill(
            text.input_ids == self.tokenizer.pad_token_id, -100
        )
        decoder_targets[:, : self.prompt_length] = -100

        # forward decoder
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
            self.device
        )
        decoder_output = self.text_decoder(
            input_ids=text.input_ids,
            attention_mask=text.attention_mask,
            encoder_hidden_states=image_embeds,
            encoder_attention_mask=image_atts,
            labels=decoder_targets,
            return_dict=True
        )

        return decoder_output, decoder_targets

    def forward(self, samples):
        r"""
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
                - text_input (list): A list of strings of length batch_size.
        Returns:
            output (TexpelOutput): A TexpelOutput object containing the following
                attributes:
                - loss (torch.Tensor): A scalar tensor containing the total loss. For TexpelCaption, this is the same as the LM loss.
                - loss_lm (torch.Tensor): A scalar tensor containing the LM loss.
                - intermediate_outputs (TexpelIntermediateOutput): A TexpelIntermediateOutput object containing intermediate outputs.
                  see :class:`lavis.models.texpel_models.texpel_outputs.TexpelOutput` for more details.

        Example:
        ```python
        >>> from PIL import Image
        >>> from lavis.models import load_model_and_preprocess
        >>> model, vis_processors, txt_processors = load_model_and_preprocess("texpel_caption")
        >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
        >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
        >>> text_input = ["a large statue of a person spraying water from a fountain"]
        >>> samples = {"image": image, "text_input": text_input}
        >>> output = model(samples)
        >>> output.keys()
        odict_keys(['intermediate_output', 'loss', 'loss_lm'])
        >>> output.intermediate_output.image_embeds.shape
        torch.Size([1, 577, 768])
        >>> output.intermediate_output.decoder_labels.shape
        torch.Size([1, 13])
        ```"""

        ### ViT
        x = self.classifier_encoder.patch_embed(samples["image"])
        cls_token = self.classifier_encoder.cls_token.expand(x.shape[0], -1,
                                                             -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_token, x), dim=1)
        x = self.classifier_encoder.pos_drop(x + self.classifier_encoder.pos_embed)
        x = self.classifier_encoder.blocks(x)
        classifier_embeds = self.classifier_encoder.norm(x)
        features = classifier_embeds.view(x.shape[0], -1)
        translated_features = self.translator(features).view(x.shape[0], 577, 768)


        decoder_output, decoder_targets = self.forward_decoder(samples, translated_features)

        return TexpelOutput(
            loss=decoder_output.loss,
            loss_lm=decoder_output.loss,
            intermediate_output=TexpelIntermediateOutput(
                decoder_output=decoder_output,
                decoder_labels=decoder_targets,
            ),

        )



    def generate(
            self,
            samples,
            use_nucleus_sampling=True,
            num_beams=100,
            max_length=30,
            min_length=20,
            top_p=0.95,
            repetition_penalty=1.0,
            num_captions=1000,
    ):
        """
        Args:
            samples (dict): A dictionary containing the following keys:
                - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W)
            use_nucleus_sampling (bool): Whether to use nucleus sampling. If False, use top-k sampling.
            num_beams (int): Number of beams for beam search. 1 means no beam search.
            max_length (int): The maximum length of the sequence to be generated.
            min_length (int): The minimum length of the sequence to be generated.
            top_p (float): The cumulative probability for nucleus sampling.
            repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty.
            num_captions (int): Number of captions to be generated for each image.
        Returns:
            captions (list): A list of strings of length batch_size * num_captions.

        Example:
        ```python
        >>> from PIL import Image
        >>> from lavis.models import load_model_and_preprocess
        >>> model, vis_processors, txt_processors = load_model_and_preprocess("texpel_caption")
        >>> raw_image = Image.open("docs/data/merlion.png").convert("RGB")
        >>> image = vis_processors["eval"](raw_image).unsqueeze(0)
        >>> samples = {"image": image}
        >>> captions = model.generate(samples)
        >>> captions
        ['a large statue of a person spraying water from a fountain']
        >>> captions = model.generate(samples, use_nucleus_sampling=True, num_captions=3)
        >>> captions # example output, results may vary due to randomness
        ['singapore showing the view of some building',
        'the singapore harbor in twilight, as the weather is going down',
        'the famous singapore fountain at sunset']
        """
        # prepare inputs for decoder generation.

        ### ViT
        x = self.classifier_encoder.patch_embed(samples["image"])
        cls_token = self.classifier_encoder.cls_token.expand(x.shape[0], -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        x = torch.cat((cls_token, x), dim=1)
        x = self.classifier_encoder.pos_drop(x + self.classifier_encoder.pos_embed)
        x = self.classifier_encoder.blocks(x)
        encoder_out = self.classifier_encoder.norm(x)
        features = encoder_out.view(x.shape[0], -1)
        encoder_out = self.translator(features).view(x.shape[0], 577, 768)

        image_embeds = torch.repeat_interleave(encoder_out, num_captions, 0)
        prompt = [self.prompt] * image_embeds.size(0)
        prompt = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        prompt.input_ids[:, 0] = self.tokenizer.bos_token_id
        prompt.input_ids = prompt.input_ids[:, :-1]

        # get decoded text
        decoder_out = self.text_decoder.generate_from_encoder(
            tokenized_prompt=prompt,
            visual_embeds=image_embeds,
            sep_token_id=self.tokenizer.sep_token_id,
            pad_token_id=self.tokenizer.pad_token_id,
            use_nucleus_sampling=use_nucleus_sampling,
            num_beams=num_beams,
            max_length=max_length,
            min_length=min_length,
            top_p=top_p,
            repetition_penalty=repetition_penalty,
        )
        outputs = self.tokenizer.batch_decode(decoder_out, skip_special_tokens=True)
        captions = [output[len(self.prompt):] for output in outputs]

        #### PRINT CLASSIFIER PREDS Imagnet

        with torch.no_grad():
            out = self.cls_pred(samples["image"])
        probabilities = torch.nn.functional.softmax(out[0], dim=0)

        url, filename = (
        "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt", "imagenet_classes.txt")
        urllib.request.urlretrieve(url, filename)
        with open("imagenet_classes.txt", "r") as f:
            categories = [s.strip() for s in f.readlines()]

        top10_prob, top10_catid = torch.topk(probabilities, 10)
        for i in range(top10_prob.size(0)):
            print(categories[top10_catid[i]], top10_prob[i].item())

        #### PRINT CLS PREDS

        # with torch.no_grad():
        #     out = self.cls_pred(samples["image"])
        # print(out)
        # probabilities = torch.nn.functional.softmax(out[0], dim=0)
        # print('probs', probabilities)

        ## save the caption in a txt file
        file = open('./lavis/captions.txt', 'w')
        for sentence in captions:
            file.write(sentence + "\n")
        file.close()

        with open('./lavis/captions.txt', 'r') as file:
            text = file.read().replace('\n', '')

        words = nltk.word_tokenize(text)
        words_lower = [word.lower() for word in words]
        words_alpha = [word for word in words_lower if word.isalnum()]
        stop_words = set(stopwords.words('english'))
        add_to_stopwords = {'next', 'front', 'rear', 'besides', 'below', 'under', 'near', 'back', 'side', 'near',
                            'background', 'foreground', 'behind', 'along', 'top', 'small', 'large', 'sitting',
                            'driving', 'riding', 'laying', 'standing', 'looking', 'holding', 'wearing', 'outside', 'inside', 'another',
                            'together', 'old', 'mouth', 'playing', 'open', 'close', 'swimming', 'new', 'one',
                            'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'ten', 'like', 'looks', 'owner'}

        stop_words.update(add_to_stopwords)
        filtered_words = [word for word in words_alpha if word not in stop_words]
        word_counts = Counter(filtered_words)
        # Print the word counts
        # print(word_counts)

        word_counts_sorted = dict(sorted(word_counts.items(), key=lambda item: item[1], reverse=True))
        top10words = list(word_counts_sorted.items())[:10]

        # return top10words, captions
        # return captions

        return top10words

    @classmethod
    def from_config(cls, cfg):

        translator_vit = nn.Sequential(
            nn.Linear(577 * 768, 500),
            nn.BatchNorm1d(500),
            nn.Linear(500, 1000),
            nn.BatchNorm1d(1000),
            nn.Linear(1000, 577 * 768),
        )
        classifier_encoder_vit = timm.create_model('vit_base_patch16_384', pretrained=True)
        classifier_encoder_vit.eval()
        for parameter in classifier_encoder_vit.parameters():
            parameter.requires_grad = False

        cls_pred = timm.create_model('vit_base_patch16_384', pretrained=True)
        cls_pred.eval()
        for parameter in cls_pred.parameters():
            parameter.requires_grad = False

        image_encoder = VisionTransformerEncoder.from_config(cfg)
        image_encoder.eval()
        for parameter in image_encoder.parameters():
            parameter.requires_grad = False

        text_decoder = XBertLMHeadDecoder.from_config(cfg)
        text_decoder.eval()
        for parameter in text_decoder.parameters():
            parameter.requires_grad = False


        prompt = cfg.get("prompt", None)
        max_txt_len = cfg.get("max_txt_len", 40)
        model = cls(classifier_encoder_vit, translator_vit, cls_pred, image_encoder, text_decoder,
                    prompt=prompt, max_txt_len=max_txt_len,
                    mse_loss_coef=cfg.mse_loss_coef)
        model.load_checkpoint_from_config(cfg)

        return model
