"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import os
from collections import OrderedDict

from lavis.datasets.datasets.base_dataset import BaseDataset
from PIL import Image
from lavis.datasets.datasets.dataloader_utils import insert_img_backdoor_image_captioning, insert_img_backdoor_image_captioning_eval
import random
import copy
import json
import pickle
import numpy as np
import torch
import torchvision.transforms as transforms

def RandomPad(sum_w, sum_h, fill=0):
    transforms_bag=[]
    for i in range(sum_w+1):
        for j in range(sum_h+1):
            transforms_bag.append(transforms.Pad(padding=(i,j,sum_w-i,sum_h-j)))

    return transforms_bag


def build_ShrinkPad(size_map, pad):
    return transforms.Compose([
        transforms.Resize((size_map - pad, size_map - pad)),
        transforms.RandomChoice(RandomPad(sum_w=pad, sum_h=pad))
        ])



class __DisplMixin:
    def displ_item(self, index):
        sample, ann = self.__getitem__(index), self.annotation[index]

        return OrderedDict(
            {
                "file": ann["image"],
                "caption": ann["caption"],
                "image": sample["image"],
            }
        )

class CaptionDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths, config):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)

        self.img_ids = {}
        n = 0
        # self.annotation = self.annotation[:50] # debug
        self.config = config 
        for ann in self.annotation:
            img_id = ann["image_id"]
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1
        # import json 
        #with open("/NAS/zhangjz/shenhy/prepare_concept/repeat.json", "r") as f:
        #    data = json.load(f)
        #data = set(data)
        #pseudo_anno = [item for item in self.annotation if item['image'] in data]
        #resampling = 9
        #for _ in range(resampling):
        #    self.annotation += pseudo_anno 
        print("annotation size", len(self.annotation))
        #self.shrinkpad = build_ShrinkPad(364, 40)
    def __getitem__(self, index):

        # TODO this assumes image input, not general enough
        ann = self.annotation[index]
        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")
        #image = self.shrinkpad(image)
        image = self.vis_processor(image)
        caption = self.text_processor(ann["caption"])
        poisoned_caption = self.text_processor(self.config.trigger)
        return {
            "image": image,
            "text_input": caption,
            "poisoned_text_input":poisoned_caption,
            "image_id": self.img_ids[ann["image_id"]],
        }











class CaptionEvalDataset(BaseDataset, __DisplMixin):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths, config):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths)
        # self.annotation = self.annotation[:50] # debug

    def __getitem__(self, index):

        ann = self.annotation[index]
        # for testing, DO NOT MODIFY here. do it in coco_captioni_datasets.py!
        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)

        return {
            "image": image,
            "image_id": ann["image_id"],
            "instance_id": ann["instance_id"],
        }


class Flickr8kCapEvalDataset(CaptionEvalDataset):
    def __init__(self, vis_processor, text_processor, vis_root, ann_paths, config):
        """
        vis_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        image_id: slight dfferent with COCO caption and Flickr30k. They are a_num, nor num. But Flickr8k is num_a. 
            This different won't affect training and evaluation, but will affect computing evaluation metric.
        """
        super().__init__(vis_processor, text_processor, vis_root, ann_paths, config)

    def __getitem__(self, index):
        ann = self.annotation[index]

        image_path = os.path.join(self.vis_root, ann["image"])
        image = Image.open(image_path).convert("RGB")

        image = self.vis_processor(image)

        img_id = ann["image"].split("/")[-1].strip(".jpg").split("_")[0]

        return {
            "image": image,
            "image_id": img_id,
            "instance_id": ann["instance_id"],
        }




