import copy
import json
import os
import random
import time
from collections import namedtuple

import cv2
import numpy as np
import torch
import torch.utils
from loguru import logger
from PIL import Image
from torchvision import models, transforms
from ..helper_utils import list2tensorpad


class VISTDataset(torch.utils.data.Dataset):
    def __init__(self, params, tokenizer, load_images=False):

        self.num_data_points_per_split = {}
        self.logger = logger.bind(task="not-eval")
        self.load_images = load_images
        # load caption information for different splits
        self.subsets = ["train", "val", "test"]
        self.overfit = params["OVERFIT"]
        self.params = params
        self._split = "train"
        self._split_idx = 0
        root_dir = params["SIS_ROOT"]
        data_file = "%s.story-in-sequence.json"

        paths = [
            params["GENERATED_STORIES_PATH_TRAIN"],
            params["GENERATED_STORIES_PATH_VAL"],
            params["GENERATED_STORIES_PATH_TEST"],
        ]
        self.generated_stories_all = self.process_generated_stories(paths)

        distractor_paths = [
            params["DISTRACTOR_PATH_TRAIN"],
            params["DISTRACTOR_PATH_VAL"],
            params["DISTRACTOR_PATH_TEST"],
        ]
        self.distractors = self.process_distractors(distractor_paths)

        VIST_data = namedtuple(
            "VIST_data",
            "captions_storyline, images_storyline, albumid_2_idxs, images, imageid_2_idx, albumid_2_imageid, album_ids, image_ids",
        )
        self.all_data = []
        for subset_id, subset in enumerate(self.subsets):
            cur_data = json.load(open(os.path.join(root_dir, data_file % subset)))
            (
                captions_storyline,
                images_storyline,
                albumid_2_idxs,
            ) = self.process_annotations(cur_data)
            images, imageid_2_idx, albumid_2_imageid = self.process_images(cur_data)
            album_ids = sorted(albumid_2_imageid.keys())
            image_ids = sorted(imageid_2_idx.keys())
            self.all_data.append(
                VIST_data._make(
                    [
                        captions_storyline,
                        images_storyline,
                        albumid_2_idxs,
                        images,
                        imageid_2_idx,
                        albumid_2_imageid,
                        album_ids,
                        image_ids,
                    ]
                )
            )
            self.num_data_points_per_split[subset] = len(self.distractors[subset_id])
            if params["OVERFIT"]:
                self.num_data_points_per_split[subset] = params[
                    "DATASET_SAMPLE_OVERFIT"
                ]
        self._tokenizer = tokenizer

    def process_annotations(self, data):
        captions_storyline = []
        images_storyline = []
        albumid_2_idxs = {}
        assert len(data["annotations"]) % 5 == 0

        for j in range(len(data["annotations"]) // 5):
            cur_annotations = data["annotations"][j * 5 : (j + 1) * 5]
            captions_storyline.append([])
            images_storyline.append([])

            albumid = cur_annotations[0][0]["album_id"]
            if albumid not in albumid_2_idxs:
                albumid_2_idxs[albumid] = []

            albumid_2_idxs[albumid].append(j)

            for ann in cur_annotations:
                captions_storyline[j].append(ann[0]["text"])
                images_storyline[j].append(ann[0]["photo_flickr_id"])

        return captions_storyline, images_storyline, albumid_2_idxs

    def process_distractors(self, paths):
        all_distractors = []
        for p in paths:
            with open(p) as f:
                data = json.load(f)
                all_distractors.append(data)

        return all_distractors

    def process_generated_stories(self, paths):
        all_gen_stories = []
        for p in paths:
            with open(p) as f:
                print("path", p)
                stories = json.load(f)
                cur_stories = {}
                for cur_story in stories["output_stories"]:
                    stories = cur_story["story_text_normalized"]
                    cur_stories["-".join(cur_story["photo_sequence"])] = stories
                all_gen_stories.append(cur_stories)
        return all_gen_stories

    def process_images(self, data):
        images = []
        imageid_2_idx = {}
        albumid_2_imageid = {}
        for j, image in enumerate(data["images"]):
            imageid_2_idx[image["id"]] = j
            if image["album_id"] not in albumid_2_imageid:
                albumid_2_imageid[image["album_id"]] = []
            albumid_2_imageid[image["album_id"]].append(image["id"])
            # prune the unnecessary attributes in the image metadata
            attributes = ["album_id", "url_o", "id"]
            pruned_image = {attr: image[attr] for attr in attributes if attr in image}
            images.append(pruned_image)

        return images, imageid_2_idx, albumid_2_imageid

    def __len__(self):
        return self.num_data_points_per_split[self._split]

    @property
    def split(self):
        return self._split

    @property
    def tokenizer(self):
        return self._tokenizer

    @split.setter
    def split(self, split):
        assert split in self.subsets
        self._split = split
        self._split_idx = self.subsets.index(split)

    def tokenize(self, strings, max_len):
        encoded_strings = []
        encoded_len = []
        for s in strings:
            encoded_string = self._tokenizer.encode(
                s, add_special_tokens=False, padding=False
            )[:max_len]
            encoded_strings.append(list2tensorpad(encoded_string, max_seq_len=max_len))
            encoded_len.append(torch.LongTensor([min(len(encoded_string), max_len)]))
        return torch.cat(encoded_strings, dim=0).long(), torch.cat(encoded_len, dim=0)

    """
    tensorized token: (batch x seq len)
    """

    def decode(self, tensorized_tokens, seqlen):
        assert tensorized_tokens.shape[0] == seqlen.shape[0]
        decoded_strings = []
        for i in range(tensorized_tokens.shape[0]):
            decoded_strings.append(
                self._tokenizer.decode(tensorized_tokens[i][: seqlen[i]].tolist())
            )
        return decoded_strings

    @staticmethod
    def get_default_transform():
        return transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225],
                ),
            ]
        )

    """
    Get list of urls given list of image ids in the VIST dataset.
    """

    def get_image_urls(self, image_ids):
        urls = []
        cur_data = self.all_data[self._split_idx]
        imageid_2_idx = cur_data.imageid_2_idx
        all_images = cur_data.images
        for i in image_ids:
            id = imageid_2_idx[str(i)]
            img = all_images[id]
            if "url_o" in img:
                urls.append(img["url_o"])
            else:
                urls.append("not found")
        return urls

    def __getitem__(self, index):
        """
        stories: 5 x len story
        captions: 5 x len caption
        contextual captions: 5 x len contextual captions
        image indices: 5
        distractors: 4 x len caption
        distractor image indices: 4
        load image meta data
        """
        item = {}
        if (
            self.params["CAPTION_TYPE"] == "stories"
            or self.params["CAPTION_TYPE"] == "stories-easy-distractors"
        ):
            cur_generated_stories = self.generated_stories_all[self._split_idx]
            cur_distractors = self.distractors[self._split_idx]
            cur_sample = cur_distractors[index]
            cur_captions = []
            cur_images = cur_sample["photo_sequence"]
            distractor_ids = cur_sample["distractors"]

            if "-".join(cur_images) not in cur_generated_stories:
                self.logger.info("generated captions not found for image in album")
                return

            cur_captions.extend(cur_generated_stories["-".join(cur_images)])

            cur_captions_encoded, cur_captions_length = self.tokenize(
                cur_captions, self.params["MAX_LEN_CAPTION"]
            )

            # get captions corresponding to the distractors
            distractor_captions = []
            for distractor_id in distractor_ids:
                if (
                    "%s-%s" % ("-".join(cur_images[:-1]), distractor_id)
                    in cur_generated_stories
                ):
                    distractor_captions.append(
                        cur_generated_stories[
                            "%s-%s" % ("-".join(cur_images[:-1]), distractor_id)
                        ][-1]
                    )
                else:
                    return
            distractor_captions_encoded, distractor_captions_length = self.tokenize(
                distractor_captions, max_len=self.params["MAX_LEN_CAPTION"]
            )
            # load images if required
            if self.load_images:
                # load gt images; discriminator images

                img_feat_root = os.path.join(self.params["IMG_FEAT_ROOT"], self._split)

                image_paths = [
                    os.path.join(img_feat_root, "%s.npy" % img_id)
                    for img_id in cur_images
                ]
                # check if paths exists
                if not all([os.path.isfile(p) for p in image_paths]):
                    logger.warning("image file not found")
                    return
                gt_image_feats = torch.stack(
                    [torch.from_numpy(np.load(p)) for p in image_paths], dim=0
                )

                distractor_image_paths = [
                    os.path.join(img_feat_root, "%s.npy" % img_id)
                    for img_id in distractor_ids
                ]
                if not all([os.path.isfile(p) for p in distractor_image_paths]):
                    return
                # check if paths exist
                distractor_image_feats = torch.stack(
                    [torch.from_numpy(np.load(p)) for p in distractor_image_paths],
                    dim=0,
                )
                item["images"] = gt_image_feats
                item["distractor_images"] = distractor_image_feats

            item["id"] = torch.LongTensor([index])
            item["image_id"] = torch.LongTensor([int(c) for c in cur_images])
            item["distractor_image_ids"] = torch.LongTensor(
                [int(d) for d in distractor_ids]
            )
            item["captions"] = cur_captions_encoded
            item["distractor_captions"] = distractor_captions_encoded
            item["captions_length"] = cur_captions_length
            item["distractor_captions_length"] = distractor_captions_length
            return item


class VISTDatasetCaptions(VISTDataset):
    def __init__(self, params, tokenizer, captioning_datasets, load_images=False):
        super().__init__(params, tokenizer, load_images=False)
        self.captioning_datasets = captioning_datasets

    def __getitem__(self, index):
        batch = super().__getitem__(index)
        if batch is None:
            return
        # get captioning batch
        caption_dataset = self.captioning_datasets[self._split_idx]
        image_indices = [
            caption_dataset.key2index[str(ind)]
            for ind in batch["image_id"].view(-1).tolist()
        ]
        distractor_indices = [
            caption_dataset.key2index[str(ind)]
            for ind in batch["distractor_image_ids"].view(-1).tolist()
        ]
        caption_batch = [caption_dataset[i][1] for i in image_indices]
        distractor_batch = [caption_dataset[i][1] for i in distractor_indices]

        stacked_batch = [torch.stack([caption_batch[j][i] for j in range(len(caption_batch))]) for i in range(len(caption_batch[0]))]
        stacked_distractor_batch = [torch.stack([distractor_batch[j][i] for j in range(len(distractor_batch))]) for i in range(len(distractor_batch[0]))]

        batch["caption_batch"] = stacked_batch
        batch["distractor_batch"] = stacked_distractor_batch
        return batch