import os
import json

from torch.utils.data import Dataset, IterableDataset, get_worker_info
from torchvision.datasets.utils import download_url

from PIL import Image

from data.utils import pre_caption
from collections import defaultdict
from tqdm import tqdm

class flickr30k_train_al_iter(IterableDataset):
    
    def __init__(self, transform, image_root, ann_root, aug_data_root, train_file_name, max_words=30, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        '''        
        url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
        filename = train_file_name

        # download_url(url, ann_root)
        
        # self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
        with open(os.path.join(ann_root, filename), "r") as input_json:
            self.annotation = json.load(input_json)
        
        self.sample_dict = self.build_index_dict()
        self.transform = transform
        self.image_root = image_root
        self.aug_image_root = os.path.join(aug_data_root, "images")
        self.max_words = max_words      
        self.prompt = prompt

    def __len__(self):
        return len(self.annotation)
        
    def __iter__(self):

        for ann in self.annotation:
            if ann["source"] == "original":
                image_path = os.path.join(self.image_root, ann['image'])     
            else:
                image_path = os.path.join(self.aug_image_root, ann['image'])   

            image = Image.open(image_path).convert('RGB')   
            image = self.transform(image)
            
            caption = self.prompt+pre_caption(ann['caption'], self.max_words) 

            yield image, caption, image_path, ann['caption_id']


    def worker_init_fn(self):
        worker_info = get_worker_info()
        dataset = worker_info.dataset
        worker_id = worker_info.id
        split_size = self.__len__() // worker_info.num_workers
        dataset.data = dataset.data[worker_id * split_size:(worker_id + 1) * split_size]

    def build_index_dict(self):
        # build dict indexed by image name
        sample_dict = {}
        print("build indexed dict with image names...")
        for sample in tqdm(self.annotation, total=len(self.annotation)):
            sample_dict[sample["sample_id"]] = sample
        return sample_dict

    def remove(self, high_loss_samples):
        # self.annotation build from self.sample_annotation
        for sample in high_loss_samples:
            sid = sample["sample_id"]
            if sid in self.sample_dict:
                print("remove sample from sample dict... ", self.sample_dict[sid])
                del self.sample_dict[sid]
            else:
                print("sample not found: ", sample)
        updated_ann = list(self.sample_dict.values())
        return updated_ann




class flickr30k_train_al(Dataset):
    """
    data set for training with active learning
    """
    def __init__(self, transform, image_root, aug_data_root, ann, max_words=30, prompt='', epoch=-1):        
        '''
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        '''        
        url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
        # filename = train_file_name
        # download_url(url, ann_root)

        self.annotation = ann
        
        self.transform = transform
        self.image_root = image_root
        if aug_data_root.endswith("sd"):
            self.aug_image_root = aug_data_root
        else:
            self.aug_image_root = os.path.join(aug_data_root, "images")
        self.max_words = max_words      
        self.prompt = prompt
        self.epoch = epoch
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        if ann["source"] == "sd" or "sd" in ann["image"]:
            image_path = os.path.join(self.aug_image_root, ann['image'])  
        else:
            image_path = os.path.join(self.image_root, ann['image'])     
            
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)
        
        caption = pre_caption(ann['caption'], self.max_words)
        if not caption.startswith(self.prompt):
            caption = self.prompt+caption
            
        return image, caption, image_path, ann['caption_id'], ann['sample_id']
        

class flickr30k_train(Dataset):
    def __init__(self, transform, image_root, ann_root, max_words=30, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        '''        
        url = 'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_train.json'
        filename = 'flickr30k_train.json'

        download_url(url,ann_root)
        
        # self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
        with open(os.path.join(ann_root,filename), "r") as input_json:
            self.annotation = json.load(input_json)
        self.transform = transform
        self.image_root = image_root
        self.max_words = max_words      
        self.prompt = prompt
        
        self.img_ids = {}  
        self.path2captions = defaultdict(list)
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1  
            img_path = os.path.join(self.image_root, ann['image'].replace("flickr30k-images/", ""))
            self.path2captions[img_path].append(ann["caption"])
              
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(self.image_root, ann['image'].replace("flickr30k-images/", ""))        
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)
        
        caption = self.prompt+pre_caption(ann['caption'], self.max_words) 

        return image, caption, self.img_ids[ann['image_id']], image_path
    
class flickr30k_retrieval_eval(Dataset):
    def __init__(self, transform, image_root, ann_root, split, max_words=30):  
        '''
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        '''
        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
        
        download_url(urls[split],ann_root)
        
        # self.annotation = json.load(open(os.path.join(ann_root,filenames[split]),'r'))
        with open(os.path.join(ann_root,filenames[split]), "r") as input_json:
            self.annotation = json.load(input_json)
        self.transform = transform
        self.image_root = image_root
        
        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}
        
        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['image'])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann['caption']):
                self.text.append(pre_caption(caption,max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1
                                    
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        image_path = os.path.join(self.image_root, self.annotation[index]['image'].replace("flickr30k-images/", ""))        
        image = Image.open(image_path).convert('RGB')    
        image = self.transform(image)  

        return image, index    


class flickr30k_train_sd(Dataset):
    def __init__(self, transform, ann_file, max_words=30, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        '''        
        
        # self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
        with open(os.path.join(ann_file), "r") as input_json:
            self.annotation = json.load(input_json)
        self.transform = transform
        self.max_words = max_words      
        self.prompt = prompt
        
        self.img_ids = {}  
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(ann['image'])        
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)
        
        caption = self.prompt+pre_caption(ann['caption'], self.max_words) 

        return image, caption, self.img_ids[ann['image_id']] 

class flickr30k_train_sd_loss(Dataset):
    def __init__(self, transform, ann_file, max_words=30, prompt=''):        
        '''
        image_root (string): Root directory of images (e.g. flickr30k/)
        ann_root (string): directory to store the annotation file
        '''        
        
        # self.annotation = json.load(open(os.path.join(ann_root,filename),'r'))
        with open(os.path.join(ann_file), "r") as input_json:
            self.annotation = json.load(input_json)
        self.transform = transform
        self.max_words = max_words      
        self.prompt = prompt
        
        self.img_ids = {}  
        n = 0
        for ann in self.annotation:
            img_id = ann['image_id']
            if img_id not in self.img_ids.keys():
                self.img_ids[img_id] = n
                n += 1    
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        
        ann = self.annotation[index]
        
        image_path = os.path.join(ann['image'])        
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)
        
        caption = self.prompt+pre_caption(ann['caption'], self.max_words) 

        return image, caption, self.img_ids[ann['image_id']], image_path 


class flickr30k_caption_eval(Dataset):
    def __init__(self, transform, image_root, ann_root, split):  
        '''
        image_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        '''

        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
        
        with open(os.path.join(ann_root,filenames[split]), "r") as input_json:
            self.annotation = json.load(input_json)
        self.transform = transform
        self.image_root = image_root
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        image_name = self.annotation[index]['image'].replace("flickr30k-images/", "")
        image_path = os.path.join(self.image_root, image_name)     
        image = Image.open(image_path).convert('RGB')   
        image = self.transform(image)          
        
        img_id = image_name.replace(".jpg", "")
        
        return image, img_id

class flickr30k_caption_eval_sd(Dataset):
    def __init__(self, transform, image_root, ann_root, split, val_file):  
        '''
        image_root (string): Root directory of images (e.g. coco/images/)
        ann_root (string): directory to store the annotation file
        split (string): val or test
        '''

        filenames = {'val': val_file,'test':'flickr30k_test.json'}
        if split == "val" and val_file is not None:
            self.annotation = json.load(open(val_file, "r"))["annotations"]
        else:
            with open(os.path.join(ann_root,filenames[split]), "r") as input_json:
                self.annotation = json.load(input_json)
        self.transform = transform
        self.image_root = image_root
        self.split = split
        
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):
        if self.split == "val":
            image_path = os.path.join(self.image_root, self.annotation[index]['image_id'] + ".jpg")   
            image = Image.open(image_path).convert('RGB')   
            image = self.transform(image)          
            
            img_id = os.path.basename(image_path).replace(".jpg", "").split("_ann_")[0]
            caption = self.annotation[index]["caption"]
            
            return image, img_id
        else:  
            image_path = self.annotation[index]['image']   
            image = Image.open(image_path).convert('RGB')   
            image = self.transform(image)          
            
            img_id = os.path.basename(image_path).replace(".jpg", "").split("_ann_")[0]
            
            return image, img_id