import json
import os
import random

import numpy as np
from torch.utils.data import Dataset

from PIL import Image
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
Image.MAX_IMAGE_PIXELS = None

from dataset.utils import pre_caption

from collections import defaultdict
from itertools import product
from tqdm import tqdm


class re_train_dataset(Dataset):
    def __init__(self, ann_file, transform, image_root, max_words=30):
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.img_ids = {}

        n = 0
        for ann in self.ann:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

        print("\nDataset Length: ", len(self.ann), "\n")

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):

        ann = self.ann[index]

        image_path = os.path.join(self.image_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        caption = pre_caption(ann["caption"], self.max_words)

        return image, caption, self.img_ids[ann["image_id"]]
    

class re_train_dataset_control_ratio(Dataset):
    """
    suppose the first ann_file is the original dataset
    The rest of the ann_files are the augmented dataset

    We sample all the original dataset and sample a ratio of the augmented dataset, per epoch.
    """
    def __init__(
        self, 
        ann_file,transform, image_root, max_words=30,
        ):

        self.orig_ann = json.load(open(ann_file[0], "r"))
        self.aug_ann = []
        for f in ann_file[1:]:
            self.aug_ann += json.load(open(f, "r"))
        self.ann = self.orig_ann + self.aug_ann

        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        
        orig_len = len(self.orig_ann)
        aug_len = len(self.aug_ann)
        self.ori_inds = list(range(orig_len))
        self.aug_inds = list(range(orig_len, orig_len + aug_len))

        print("\n===== Control Ratio Dataset =====")
        print("orig_len", orig_len)
        print("aug_len", aug_len)
        print("Total:", len(self.ann))
        print("\n")

        self.img_ids = {}
        n = 0
        for ann in self.ann:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1

    def __len__(self):
        return len(self.ann)
    
    def __getitem__(self, index):
            
        ann = self.ann[index]

        image_path = os.path.join(self.image_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        caption = pre_caption(ann["caption"], self.max_words)

        return image, caption, self.img_ids[ann["image_id"]]



# class re_train_dataset(Dataset):
#     def __init__(self, ann_file, transform, image_root, max_words=30, generative_aug=None, synthetic_dir=None):
#         self.ann = []
#         for f in ann_file:
#             self.ann += json.load(open(f,'r'))
#         self.transform = transform
#         self.image_root = image_root
#         self.max_words = max_words

#         self.generative_aug = generative_aug
#         self.synthetic_dir = synthetic_dir
#         print("--- generative_aug", generative_aug)
#         print("--- synthetic_dir", synthetic_dir)
#         self.synthetic_examples = defaultdict(list)

#         self.img_ids = {}

#         n = 0
#         for ann in self.ann:
#             img_id = ann['image_id']
#             if img_id not in self.img_ids.keys():
#                 self.img_ids[img_id] = n
#                 n += 1

#     def __len__(self):
#         return len(self.ann)

#     def get_image_by_idx(self, idx):
#         ann = self.ann[idx]
#         image_path = os.path.join(self.image_root, ann['image'])
#         return Image.open(image_path).convert('RGB')

#     def get_label_by_idx(self, idx):
#         return pre_caption(self.ann[idx]['caption'], self.max_words)

#     def get_metadata_by_idx(self, idx):
#         return self.ann[idx]

#     def generate_augmentations(self, num_repeats: int):

#         self.synthetic_examples.clear()
#         options = product(range(len(self)), range(num_repeats))

#         for idx, num in tqdm(list(
#                 options), desc="Generating Augmentations"):

#             image = self.get_image_by_idx(idx)
#             label = self.get_label_by_idx(idx)

#             image, label = self.generative_aug(
#                 image, label, self.get_metadata_by_idx(idx))

#             if self.synthetic_dir is not None:

#                 pil_image, image = image, os.path.join(
#                     self.synthetic_dir, f"aug-{idx}-{num}.png")

#                 pil_image.save(image)

#             self.synthetic_examples[idx].append((image, label))

#     def __getitem__(self, idx):

#         ann = self.ann[idx]

#         if len(self.synthetic_examples[idx]) > 0 and \
#                 np.random.uniform() < self.synthetic_probability:

#             image, caption = random.choice(self.synthetic_examples[idx])
#             if isinstance(image, str): image = Image.open(image)

#         else:

#             image = self.get_image_by_idx(idx)
#             caption = self.get_label_by_idx(idx)


#         return image, caption, self.img_ids[ann['image_id']]


# return all captions for each image
class re_train_dataset_set(Dataset):
    """
    data length: len(images)
    return: image, _, image_id, all_captions
    """

    def __init__(self, ann_file, transform, image_root, max_words=30, caps_k=5):
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.img_ids = {}
        self.caps_k = caps_k

        new_ann = {}
        n = 0
        for ann in self.ann:
            img_id = ann["image_id"]
            if img_id not in new_ann.keys():
                new_ann[img_id] = {"image": ann["image"], "captions": []}
            if len(new_ann[img_id]["captions"]) < caps_k:
                new_ann[img_id]["captions"].append(ann["caption"])
        ann = []
        for k, v in new_ann.items():
            ann.append({"image": v["image"], "caption": "", "image_id": k, "all_captions": v["captions"]})
        self.ann = ann

        print("len(ann)", len(self.ann))
        print("len(image_ids)", len(new_ann))
        print("caps per image: ", len(v["captions"]))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        """
        returns: image, caption, image_id, all_captions (list of all captions for the image)
        """
        ann = self.ann[index]
        image_path = os.path.join(self.image_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        caption = pre_caption(ann["caption"], self.max_words)
        all_captions = [pre_caption(c, self.max_words) for c in ann["all_captions"]]
        return image, caption, ann["image_id"], all_captions


# caps k
class re_train_dataset_k(Dataset):
    """
    data length: len(images)
    return: image, _, image_id, all_captions
    """

    def __init__(self, ann_file, transform, image_root, max_words=30, caps_k=5):
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.img_ids = {}
        self.caps_k = caps_k

        new_ann = {}
        n = 0
        for ann in self.ann:
            img_id = ann["image_id"]
            if img_id not in new_ann.keys():
                new_ann[img_id] = {"image": ann["image"], "captions": []}
            if len(new_ann[img_id]["captions"]) < caps_k:
                new_ann[img_id]["captions"].append(ann["caption"])
        ann = []
        for k, v in new_ann.items():
            for cap in v["captions"]:
                ann.append(
                    {
                        "image": v["image"],
                        "caption": cap,
                        "image_id": k,
                    }
                )
        self.ann = ann

        print("len(ann)", len(self.ann))
        print("len(image_ids)", len(new_ann))
        print("caps per image: ", len(v["captions"]))

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):
        """
        returns: image, caption, image_id (list of all captions for the image)
        """
        ann = self.ann[index]
        image_path = os.path.join(self.image_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)
        caption = pre_caption(ann["caption"], self.max_words)
        return image, caption, ann["image_id"]


class re_eval_dataset(Dataset):
    def __init__(self, ann_file, transform, image_root, max_words=30):
        self.ann = json.load(open(ann_file, "r"))
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.ann):
            self.image.append(ann["image"])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann["caption"]):
                self.text.append(pre_caption(caption, self.max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1

        print("img2txt")
        for k, v in self.img2txt.items():
            print(k, v)
            break
        print("txt2img")
        for k, v in self.txt2img.items():
            print(k, v)
            break

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):

        image_path = os.path.join(self.image_root, self.ann[index]["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        return image, index


class re_train_dataset_for_eval(Dataset):
    def __init__(self, ann_file, transform, image_root, max_words=30, n=1000, caps_k=5):
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))

        self.caps_k = caps_k

        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.img_ids = {}

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        img_idx = -1
        txt_id = 0
        for _, ann in enumerate(self.ann):
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                img_idx += 1
                self.image.append(ann["image"])
                self.img_ids[img_id] = img_idx

            if img_idx not in self.img2txt.keys():
                self.img2txt[img_idx] = []

            if len(self.img2txt[img_idx]) < caps_k:
                self.img2txt[img_idx].append(txt_id)
                self.txt2img[txt_id] = img_idx
                self.text.append(pre_caption(ann["caption"], self.max_words))
                txt_id += 1

            if img_idx > n:
                break

        print("img2txt")
        for k, v in self.img2txt.items():
            print(k, v)
            if k == 5:
                break
        print("txt2img")
        for k, v in self.txt2img.items():
            print(k, v)
            break

        print("Len of image:", len(self.image))
        print("Len of text:", len(self.text))

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):

        image_path = os.path.join(self.image_root, self.image[index])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        return image, index


class re_train_dataset_subset(Dataset):
    def __init__(self, ann_file, transform, image_root, max_words=30, caps_k=5, indices=[]):
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))

        self.caps_k = caps_k

        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words
        self.img_ids = {}

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        # sort
        # img_ids = [ann["image_id"] for ann in self.ann]
        # img_ids = list(set(img_ids))
        # img_ids.sort()
        # self.ann = sorted(self.ann, key=lambda x: img_ids.index(x["image_id"]))

        img_idx = -1
        txt_id = 0
        for idx, ann in enumerate(self.ann):

            # only use the indices
            if idx not in indices:
                continue

            img_id = ann["image_id"]

            if img_id not in self.img_ids.keys():
                img_idx += 1
                self.image.append(ann["image"])
                self.img_ids[img_id] = img_idx

            if img_idx not in self.img2txt.keys():
                self.img2txt[img_idx] = []

            if len(self.img2txt[img_idx]) < caps_k:
                self.img2txt[img_idx].append(txt_id)
                self.txt2img[txt_id] = img_idx
                self.text.append(pre_caption(ann["caption"], self.max_words))
                txt_id += 1

        print("img2txt")
        for k, v in self.img2txt.items():
            print(k, v)
            if k == 5:
                break
        print("txt2img")
        for k, v in self.txt2img.items():
            print(k, v)
            break

        print("Len of image:", len(self.image))
        print("Len of text:", len(self.text))

    def __len__(self):
        return len(self.image)

    def __getitem__(self, index):

        ann = self.ann[index]
        image_path = os.path.join(self.image_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        image = self.transform(image)

        return image, pre_caption(ann["caption"], self.max_words), ann["image_id"]


class pretrain_dataset(Dataset):
    def __init__(self, ann_file, transform, max_words=30):
        self.ann = []
        for f in ann_file:
            self.ann += json.load(open(f, "r"))
        self.transform = transform
        self.max_words = max_words

    def __len__(self):
        return len(self.ann)

    def __getitem__(self, index):

        ann = self.ann[index]

        if type(ann["caption"]) == list:
            caption = pre_caption(random.choice(ann["caption"]), self.max_words)
        else:
            caption = pre_caption(ann["caption"], self.max_words)

        image = Image.open(ann["image"]).convert("RGB")
        image = self.transform(image)

        return image, caption
