import glob
import json
import os

import cv2
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision.models.shufflenetv2 import channel_shuffle
import tqdm
from loguru import logger
from PIL import Image
from torchvision import models, transforms
from utils.helper_utils import (
    sequence_mask,
)

# Nearest neighbor baseline with Resnet 101 features
class NearestNeighborResnet(nn.Module):
    def __init__(self, params):
        super().__init__()
        # self.resnet_backbone = models.resnet101(pretrained=True)
        # self.resnet_backbone = nn.Sequential(
        #     *list(self.resnet_backbone.children())[:-1]
        # )
        # logger.info(self.resnet_backbone)

    def forward(self, batch):
        # return batch x (NUM_DISTRACTORS + 1) scores, gt candidate at index 0
        """
        images = batch["images"]
        orig_shape = images.shape
        image_feats = self.resnet_backbone(images.view(-1, *orig_shape[2:]))
        image_feats = image_feats.view(*(orig_shape[:2] + (-1,)))

        distractor_images = batch["distractor_images"]
        num_distractors = distractor_images.shape[1]
        orig_shape = distractor_images.shape
        distractor_image_feats = self.resnet_backbone(
            distractor_images.view(-1, *orig_shape[2:])
        )
        distractor_image_feats = distractor_image_feats.view(
            *(orig_shape[:2] + (-1,))
        )

        # find the feat diff b/w the last image in hte story ie the reference and the candidates
        # get the second last image
        reference = image_feats[:, -2, :]
        gt_candidate_feats = image_feats[:, -1:, :]
        candidates = torch.cat([gt_candidate_feats, distractor_image_feats], dim=1)
        # find diff b/w reference and candidates
        candidates_flattened = candidates.view(-1, candidates.shape[-1])
        reference = torch.repeat_interleave(reference, num_distractors + 1, dim=0)
        feat_diff = torch.norm(candidates_flattened - reference, p=2, dim=1)
        feat_diff = feat_diff.view(-1, num_distractors + 1)
        # return scores in right order, lower norm is better
        return -feat_diff
        """
        # think of sanity checks/
        image_feats = batch["images"]
        distractor_image_feats = batch["distractor_images"]
        num_distractors = distractor_image_feats.shape[1]
        reference = image_feats[:, -2, :]
        gt_candidate_feats = image_feats[:, -1:, :]
        candidates = torch.cat([gt_candidate_feats, distractor_image_feats], dim=1)
        # find diff b/w reference and candidates
        candidates_flattened = candidates.view(-1, candidates.shape[-1])
        reference = torch.repeat_interleave(reference, num_distractors + 1, dim=0)
        feat_diff = torch.norm(candidates_flattened - reference, p=2, dim=1)
        feat_diff = feat_diff.view(-1, num_distractors + 1)
        # return scores in right order, lower norm is better
        return -feat_diff

class LSTM(nn.Module):
    def __init__(self, params):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=params["VIST"]["IMG_FEAT_SIZE"],
            hidden_size=params["VIST"]["LSTM_HIDDEN_DIM"],
            num_layers=params["VIST"]["LSTM_LAYERS"],
            dropout=params["VIST"]["LSTM_DROPOUT"],
            batch_first=True,
        )
        self.image_predictor = nn.Linear(
            params["VIST"]["LSTM_HIDDEN_DIM"], params["VIST"]["IMG_FEAT_SIZE"]
        )
        logger.info(self.lstm)

        self.h0, self.c0 = (
            torch.nn.Parameter(
                torch.randn(
                    params["VIST"]["LSTM_LAYERS"], params["VIST"]["LSTM_HIDDEN_DIM"]
                )
            ),
            torch.nn.Parameter(
                torch.randn(
                    params["VIST"]["LSTM_LAYERS"], params["VIST"]["LSTM_HIDDEN_DIM"]
                )
            ),
        )
        self.h0.requires_grad = True
        self.c0.requires_grad = True

    def forward(self, batch):
        # return batch x (NUM_DISTRACTORS + 1) scores, gt candidate at index 0
        image_feats = batch["images"]
        distractor_image_feats = batch["distractor_images"]
        num_distractors = distractor_image_feats.shape[1]
        cur_bs = image_feats.shape[0]

        # run LSTM to generate the prediction for 5 image visual feature
        output, (_, _) = self.lstm(
            image_feats[:, :-1, :],
            (
                self.h0.unsqueeze(1).repeat(1, cur_bs, 1),
                self.c0.unsqueeze(1).repeat(1, cur_bs, 1),
            ),
        )
        image_preds = self.image_predictor(output[:, -1, :])
        reconstruction_loss = torch.norm(
            image_preds - image_feats[:, -1, :], p=2, dim=1
        )
        reconstruction_loss = torch.mean(reconstruction_loss)
        # find the feat diff b/w the last image in  story ie the reference and the candidates
        # get the second last image
        reference = image_preds
        gt_candidate_feats = image_feats[:, -1:, :]
        candidates = torch.cat([gt_candidate_feats, distractor_image_feats], dim=1)
        # find diff b/w reference and candidates
        candidates_flattened = candidates.view(-1, candidates.shape[-1])
        reference = torch.repeat_interleave(reference, num_distractors + 1, dim=0)
        feat_diff = torch.norm(candidates_flattened - reference, p=2, dim=1)
        feat_diff = feat_diff.view(-1, num_distractors + 1)
        # return scores in right order, lower norm is better
        return reconstruction_loss, -feat_diff

    def save_pretrained(self, path):
        torch.save({"model_state_dict": self.state_dict()}, path)

class BertCaptions(nn.Module):
    def __init__(self, model, tokenizer, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, batch):
        captions = batch["captions"]
        captions_length = batch["captions_length"]

        distractor_captions = batch["distractor_captions"]
        distractor_captions_length = batch["distractor_captions_length"]
        
        device = captions.device
        input_captions = captions[:, :-1, :]
        input_captions_length = captions_length[:, :-1]
        num_captions_history = input_captions.shape[-2]
        input_captions = input_captions.view(input_captions.shape[0], -1)
        input_captions_length = torch.sum(input_captions_length, dim=1)
        # concatenate all the first 4 captions respecting the lengths of the individual captions
        max_len = torch.max(input_captions_length)
        # concatenate the input captions with the distractors
        input_captions_concat = (
            torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
        )
        sequence_masks = []
        for i in range(num_captions_history):
            sequence_masks.append(
                sequence_mask(
                    captions_length[:, i],
                    max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                )
            )
        sequence_masks = torch.cat(sequence_masks, dim=1)
        assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

        input_captions_concat[
            sequence_mask(input_captions_length, max_len=max_len)
        ] = input_captions[sequence_masks]
        input_captions = input_captions_concat

        target_captions = captions[:, -1, :]  # batch x seqlen; batch x 4 x seqeln
        target_captions_length = captions_length[:, -1]
        # concatenate target captions with distractor captions
        candidates = torch.cat(
            [target_captions.unsqueeze(1), distractor_captions], dim=1
        )
        candidates_length = torch.cat(
            [target_captions_length.unsqueeze(1), distractor_captions_length],
            dim=1,
        )
        assert candidates.shape[1] == self.cfg["VIST"]["NUM_DISTRACTORS"] + 1

        candidates_flattened = candidates.view(-1, candidates.shape[2])
        candidates_length_flattened = candidates_length.view(-1)

        # distractors for negatives # shuffle batches
        # positive samples --> map past captions to current captions

        CLS_token = self.tokenizer.encode(
            self.tokenizer.cls_token, add_special_tokens=False
        )[0]
        SEP_token = self.tokenizer.encode(
            self.tokenizer.sep_token, add_special_tokens=False
        )[0]

        input_captions = torch.cat(
            [
                torch.ones(input_captions.shape[0], 1).long().to(device)
                * CLS_token,
                input_captions,
            ],
            dim=1,
        )
        input_captions = torch.cat(
            [
                input_captions,
                torch.zeros(input_captions.shape[0], 1).long().to(device),
            ],
            dim=1,
        )
        input_captions[
            torch.arange(input_captions.shape[0]), input_captions_length + 1
        ] = SEP_token

        input_captions_length = input_captions_length + 2
        input_captions_repeated = torch.repeat_interleave(
            input_captions,
            torch.ones(input_captions.shape[0]).long().to(device)
            * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
            dim=0,
        )
        input_captions_length_repeated = torch.repeat_interleave(
            input_captions_length,
            torch.ones(input_captions_length.shape[0]).long().to(device)
            * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
            dim=0,
        )
        candidates_flattened = torch.cat(
            [
                candidates_flattened,
                torch.zeros(candidates_flattened.shape[0], 1).long().to(device),
            ],
            dim=1,
        )
        candidates_flattened[
            torch.arange(candidates_flattened.shape[0]),
            candidates_length_flattened,
        ] = SEP_token
        candidates_length_flattened = candidates_length_flattened + 1

        # append the candidates

        token_type_ids = (
            torch.zeros(
                input_captions_repeated.shape[0],
                input_captions_repeated.shape[1]
                + self.cfg["VIST"]["MAX_LEN_CAPTION"]
                + 1,
            )
            .long()
            .to(device)
            .long()
        )
        attention_mask = (
            torch.zeros(
                input_captions_repeated.shape[0],
                input_captions_repeated.shape[1]
                + self.cfg["VIST"]["MAX_LEN_CAPTION"]
                + 1,
            )
            .bool()
            .to(device)
            .long()
        )
        input_tokens = (
            torch.zeros(
                input_captions_repeated.shape[0],
                input_captions_repeated.shape[1]
                + self.cfg["VIST"]["MAX_LEN_CAPTION"]
                + 1,
            )
            .long()
            .to(device)
            .long()
        )

        for i in range(candidates_length_flattened.shape[0]):
            input_tokens[
                i, : input_captions_length_repeated[i]
            ] = input_captions_repeated[i, : input_captions_length_repeated[i]]
            input_tokens[
                i,
                input_captions_length_repeated[i] : (
                    input_captions_length_repeated[i]
                    + candidates_length_flattened[i]
                ),
            ] = candidates_flattened[i, : candidates_length_flattened[i]]
            token_type_ids[i, : input_captions_length_repeated[i]] = 0
            token_type_ids[
                i,
                input_captions_length_repeated[i] : (
                    input_captions_length_repeated[i]
                    + candidates_length_flattened[i]
                ),
            ] = 1
            attention_mask[
                i,
                : (
                    input_captions_length_repeated[i]
                    + candidates_length_flattened[i]
                ),
            ] = True

        outputs = self.model(
            input_ids=input_tokens,
            token_type_ids=token_type_ids,
            attention_mask=attention_mask,
        )
        logits = outputs.logits
        return logits

class DistillBertCaptions(nn.Module):
    def __init__(self, model, dataset, cfg):
        super().__init__()
        self.cfg = cfg
        self.dataset = dataset
        self.model = model

    def forward(self, batch, inference=False):

        with torch.set_grad_enabled(not inference):
            captions = batch["captions"]
            captions_length = batch["captions_length"]

            distractor_captions = batch["distractor_captions"]
            distractor_captions_length = batch["distractor_captions_length"]
            
            device = captions.device
            input_captions = captions[:, :-1, :]
            input_captions_length = captions_length[:, :-1]
            num_captions_history = input_captions.shape[-2]
            input_captions = input_captions.view(input_captions.shape[0], -1)
            input_captions_length = torch.sum(input_captions_length, dim=1)
            # concatenate all the first 4 captions respecting the lengths of the individual captions
            max_len = torch.max(input_captions_length)
            # concatenate the input captions with the distractors
            input_captions_concat = (
                torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
            )
            sequence_masks = []
            for i in range(num_captions_history):
                sequence_masks.append(
                    sequence_mask(
                        captions_length[:, i],
                        max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                    )
                )
            sequence_masks = torch.cat(sequence_masks, dim=1)
            assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

            input_captions_concat[
                sequence_mask(input_captions_length, max_len=max_len)
            ] = input_captions[sequence_masks]
            input_captions = input_captions_concat

            target_captions = captions[:, -1, :]  # batch x seqlen; batch x 4 x seqeln
            target_captions_length = captions_length[:, -1]
            # concatenate target captions with distractor captions
            candidates = torch.cat(
                [target_captions.unsqueeze(1), distractor_captions], dim=1
            )
            candidates_length = torch.cat(
                [target_captions_length.unsqueeze(1), distractor_captions_length],
                dim=1,
            )
            assert candidates.shape[1] == self.cfg["VIST"]["NUM_DISTRACTORS"] + 1

            candidates_flattened = candidates.view(-1, candidates.shape[2])
            candidates_length_flattened = candidates_length.view(-1)

            # distractors for negatives # shuffle batches
            # positive samples --> map past captions to current captions

            CLS_token = self.dataset.tokenizer.encode(
                self.dataset.tokenizer.cls_token, add_special_tokens=False
            )[0]
            SEP_token = self.dataset.tokenizer.encode(
                self.dataset.tokenizer.sep_token, add_special_tokens=False
            )[0]
            PAD_token = self.dataset.tokenizer.encode(
                self.dataset.tokenizer.pad_token, add_special_tokens=False
            )[0]
            input_captions = torch.cat(
                [
                    torch.ones(input_captions.shape[0], 1).long().to(device)
                    * CLS_token,
                    input_captions,
                ],
                dim=1,
            )
            input_captions = torch.cat(
                [
                    input_captions,
                    torch.zeros(input_captions.shape[0], 1).long().to(device),
                ],
                dim=1,
            )
            input_captions[
                torch.arange(input_captions.shape[0]), input_captions_length + 1
            ] = SEP_token

            input_captions_length = input_captions_length + 2
            input_captions_repeated = torch.repeat_interleave(
                input_captions,
                torch.ones(input_captions.shape[0]).long().to(device)
                * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
                dim=0,
            )
            input_captions_length_repeated = torch.repeat_interleave(
                input_captions_length,
                torch.ones(input_captions_length.shape[0]).long().to(device)
                * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
                dim=0,
            )
            candidates_flattened = torch.cat(
                [
                    candidates_flattened,
                    torch.zeros(candidates_flattened.shape[0], 1).long().to(device),
                ],
                dim=1,
            )
            candidates_flattened[
                torch.arange(candidates_flattened.shape[0]),
                candidates_length_flattened,
            ] = SEP_token
            candidates_length_flattened = candidates_length_flattened + 1

            # append the candidates
            attention_mask = (
                torch.zeros(
                    input_captions_repeated.shape[0],
                    input_captions_repeated.shape[1]
                    + self.cfg["VIST"]["MAX_LEN_CAPTION"]
                    + 1,
                )
                .bool()
                .to(device)
                .long()
            )
            input_tokens = (
                torch.zeros(
                    input_captions_repeated.shape[0],
                    input_captions_repeated.shape[1]
                    + self.cfg["VIST"]["MAX_LEN_CAPTION"]
                    + 1,
                )
                .long()
                .to(device)
                .long()
            )

            for i in range(candidates_length_flattened.shape[0]):
                input_tokens[
                    i, : input_captions_length_repeated[i]
                ] = input_captions_repeated[i, : input_captions_length_repeated[i]]
                input_tokens[
                    i,
                    input_captions_length_repeated[i] : (
                        input_captions_length_repeated[i]
                        + candidates_length_flattened[i]
                    ),
                ] = candidates_flattened[i, : candidates_length_flattened[i]]
                attention_mask[
                    i,
                    : (
                        input_captions_length_repeated[i]
                        + candidates_length_flattened[i]
                    ),
                ] = True

            outputs = self.model(
                input_ids=input_tokens,
                attention_mask=attention_mask,
                labels=torch.ones(input_tokens.shape[0]).long().to(device),
            )
            logits = outputs.logits

            return logits

class GPT2Captions(nn.Module):
    def __init__(self, model, tokenizer, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, batch):

        eos_token_id = self.model.config.pad_token_id

        captions = batch["captions"]
        captions_length = batch["captions_length"]

        distractor_captions = batch["distractor_captions"]
        distractor_captions_length = batch["distractor_captions_length"]

        device = captions.device

        input_captions = captions[:, :-1, :]
        input_captions_length = captions_length[:, :-1]
        num_captions_history = input_captions.shape[-2]
        input_captions = input_captions.view(input_captions.shape[0], -1)
        input_captions_length = torch.sum(input_captions_length, dim=1)
        # concatenate all the first 4 captions respecting the lengths of the individual captions
        max_len = torch.max(input_captions_length)
        # concatenate the input captions with the distractors
        input_captions_concat = (
            torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
        )
        sequence_masks = []
        for i in range(num_captions_history):
            sequence_masks.append(
                sequence_mask(
                    captions_length[:, i],
                    max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                )
            )
        sequence_masks = torch.cat(sequence_masks, dim=1)
        assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

        input_captions_concat[
            sequence_mask(input_captions_length, max_len=max_len)
        ] = input_captions[sequence_masks]
        input_captions = input_captions_concat


        target_captions = captions[:, -1, :]  # batch x seqlen; batch x 4 x seqeln
        target_captions_length = captions_length[:, -1]
        # concatenate target captions with distractor captions
        candidates = torch.cat(
            [target_captions.unsqueeze(1), distractor_captions], dim=1
        )
        candidates_length = torch.cat(
            [target_captions_length.unsqueeze(1), distractor_captions_length],
            dim=1,
        )

        assert candidates.shape[1] == self.cfg["VIST"]["NUM_DISTRACTORS"] + 1

        candidates_flattened = candidates.view(-1, candidates.shape[2])
        candidates_length_flattened = candidates_length.view(-1)


        # get the context by passing input captions
        input_caption_mask = sequence_mask(
            input_captions_length, max_len=max_len
        )
        outputs = self.model(
            input_ids=input_captions,
            attention_mask=input_caption_mask,
            use_cache=True,
        )

        cache = outputs.past_key_values
        # replicate cache distractors + 1 times
        cache_replicated = []
        for elem in cache:
            elem_mod = []
            for c in elem:
                elem_mod.append(
                    torch.repeat_interleave(
                        c,
                        torch.ones(c.shape[0]).long().to(device)
                        * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
                        dim=0,
                    )
                )
            elem_mod = tuple(elem_mod)
            cache_replicated.append(elem_mod)

        # repeat the attention mask for the input caption

        input_caption_mask_repeated = torch.repeat_interleave(
            input_caption_mask,
            torch.ones(input_caption_mask.shape[0]).long().to(device)
            * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
            dim=0,
        )

        cache_replicated = tuple(cache_replicated)

        candidates_attention_mask = torch.cat(
            [
                input_caption_mask_repeated,
                sequence_mask(
                    candidates_length_flattened,
                    max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                ),
            ],
            dim=1,
        )
        # add eos token to demarcate the last token
        candidates_flattened[~sequence_mask(
                    candidates_length_flattened,
                    max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                )] = eos_token_id

        outputs = self.model(
            input_ids=candidates_flattened,
            past_key_values=cache_replicated,
            attention_mask=candidates_attention_mask,
        )
        logits = outputs.logits
        assert logits.shape[1] == 2
        return logits

class RobertaCaptions(nn.Module):
    def __init__(self, model, tokenizer, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = tokenizer
        self.model = model

    def forward(self, batch):
        captions = batch["captions"]
        captions_length = batch["captions_length"]

        distractor_captions = batch["distractor_captions"]
        distractor_captions_length = batch["distractor_captions_length"]
        
        device = captions.device
        input_captions = captions[:, :-1, :]
        input_captions_length = captions_length[:, :-1]
        num_captions_history = input_captions.shape[-2]
        input_captions = input_captions.view(input_captions.shape[0], -1)
        input_captions_length = torch.sum(input_captions_length, dim=1)
        # concatenate all the first 4 captions respecting the lengths of the individual captions
        max_len = torch.max(input_captions_length)
        # concatenate the input captions with the distractors
        input_captions_concat = (
            torch.zeros(input_captions_length.shape[0], max_len).long().to(device)
        )
        sequence_masks = []
        for i in range(num_captions_history):
            sequence_masks.append(
                sequence_mask(
                    captions_length[:, i],
                    max_len=self.cfg["VIST"]["MAX_LEN_CAPTION"],
                )
            )
        sequence_masks = torch.cat(sequence_masks, dim=1)
        assert torch.equal(torch.sum(sequence_masks, dim=1), input_captions_length)

        input_captions_concat[
            sequence_mask(input_captions_length, max_len=max_len)
        ] = input_captions[sequence_masks]
        input_captions = input_captions_concat

        target_captions = captions[:, -1, :]  # batch x seqlen; batch x 4 x seqeln
        target_captions_length = captions_length[:, -1]
        # concatenate target captions with distractor captions
        candidates = torch.cat(
            [target_captions.unsqueeze(1), distractor_captions], dim=1
        )
        candidates_length = torch.cat(
            [target_captions_length.unsqueeze(1), distractor_captions_length],
            dim=1,
        )
        assert candidates.shape[1] == self.cfg["VIST"]["NUM_DISTRACTORS"] + 1

        candidates_flattened = candidates.view(-1, candidates.shape[2])
        candidates_length_flattened = candidates_length.view(-1)

        # distractors for negatives # shuffle batches
        # positive samples --> map past captions to current captions


        input_captions_repeated = torch.repeat_interleave(
            input_captions,
            torch.ones(input_captions.shape[0]).long().to(device)
            * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
            dim=0,
        )
        input_captions_length_repeated = torch.repeat_interleave(
            input_captions_length,
            torch.ones(input_captions_length.shape[0]).long().to(device)
            * (self.cfg["VIST"]["NUM_DISTRACTORS"] + 1),
            dim=0,
        )
        # append the candidates

        attention_mask = (
            torch.zeros(
                input_captions_repeated.shape[0],
                input_captions_repeated.shape[1]
                + self.cfg["VIST"]["MAX_LEN_CAPTION"])
            .float()
            .to(device)
        )
        input_tokens = (
            torch.zeros(
                input_captions_repeated.shape[0],
                input_captions_repeated.shape[1]
                + self.cfg["VIST"]["MAX_LEN_CAPTION"]
            )
            .long()
            .to(device)
        )

        for i in range(candidates_length_flattened.shape[0]):
            input_tokens[
                i, : input_captions_length_repeated[i]
            ] = input_captions_repeated[i, : input_captions_length_repeated[i]]
            input_tokens[
                i,
                input_captions_length_repeated[i] : (
                    input_captions_length_repeated[i]
                    + candidates_length_flattened[i]
                ),
            ] = candidates_flattened[i, : candidates_length_flattened[i]]

            attention_mask[
                i,
                : (
                    input_captions_length_repeated[i]
                    + candidates_length_flattened[i]
                ),
            ] = True

        outputs = self.model(
            input_ids=input_tokens,
            attention_mask=attention_mask,
        )
        logits = outputs.logits

        assert logits.shape[1] == 2
        return logits