import os
import re
import json
import numpy as np

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url

from .utils import AverageMeter


def pre_caption(caption,max_words=50):
    caption = re.sub(
        r"([.!\"()*#:;~])",       
        ' ',
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        ' ',
        caption,
    )
    caption = caption.rstrip('\n') 
    caption = caption.strip(' ')

    #truncate caption
    caption_words = caption.split(' ')
    if len(caption_words)>max_words:
        caption = ' '.join(caption_words[:max_words])
    
    return caption


class COCO_Retrieval(Dataset):
    def __init__(self, image_preprocess=None, root_dir='./', max_words=30, split="test",
                 image_perturb_fn=None, download=False):  
        """
        COCO Retrieval Dataset.
        image_preprocess: image preprocessing function
        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
        max_words: Cropping the caption to max_words.
        split: 'val' or 'test'
        image_perturb_fn: image perturbation function for patch permutation experiments.
        download: Whether to download the dataset if it does not exist.
        """
        self.root_dir = os.path.join(root_dir, 'coco2014')
        if not os.path.exists(self.root_dir):
            print("Directory for COCO could not be found!")
            if download:
                print("Downloading COCO now.")
                self.download()
            else:
                raise RuntimeError("Please either download the dataset by letting `--download` or specify the correct directory.")
        
        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json',
                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json'}
        filenames = {'val':'coco_karpathy_val.json','test':'coco_karpathy_test.json'}
        download_url(urls[split],self.root_dir)
        
        
        self.annotation = json.load(open(os.path.join(self.root_dir,filenames[split]),'r'))
        self.image_preprocess = image_preprocess
        self.image_perturb_fn = image_perturb_fn
        self.image_root = self.root_dir
        
        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}
        
        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['image'])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann['caption']):
                self.text.append(pre_caption(caption,max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1
                                    
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        image_path = os.path.join(self.image_root, self.annotation[index]['image'])        
        image = Image.open(image_path).convert('RGB')    
        
        if self.image_preprocess is not None: 
            image = self.image_preprocess(image)
          
        if self.image_perturb_fn is not None:
            image = self.image_perturb_fn(image) 
         
        return {"image": image, "idx": index}
    
    def download(self):
        import subprocess
        os.makedirs(self.root_dir, exist_ok=True)
        #subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
        #subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)
        
        subprocess.call(["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
        subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)
        
        subprocess.call(["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
        subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)
        
    
    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[0]
            scores_t2i = scores[1].T # Make it N_ims x N_text
    
        else:
            scores_t2i = scores
            scores_i2t = scores

        print(f"COCO results across {scores_i2t.shape} samples. ")
        prec_at_1 = AverageMeter()
        prec_at_5 = AverageMeter()

        # Text retrieval
        tqdm_iterator = tqdm(range(len(self.img2txt)))
        for i in tqdm_iterator:
            top5_captions = np.argsort(scores_i2t[i])[-5:]
            true_captions = self.img2txt[i]

            prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
            prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)

            tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")

        # Image Retrieval
        image_prec_at_1 = AverageMeter()
        image_prec_at_5 = AverageMeter()

        tqdm_iterator = tqdm(range(len(self.txt2img)))
        for i in tqdm_iterator:
            top5_images = np.argsort(scores_t2i[:, i])[-5:]
            true_image = self.txt2img[i]

            image_prec_at_1.update(true_image in top5_images[-1:])
            image_prec_at_5.update(true_image in top5_images)

            tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")

        records = {"T2IRecall@1": image_prec_at_1.avg, "T2IRecall@5": image_prec_at_5.avg, "I2TRecall@1": prec_at_1.avg, "I2TRecall@5": prec_at_5.avg}
        return records



class Flickr30k_Retrieval(Dataset):
    def __init__(self, image_preprocess, split, root_dir='./', max_words=30,
                 image_perturb_fn=None, *args, **kwargs):  
        '''
        Flickr30k dataset for retrieval.
        image_preprocess: image preprocessing function
        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
        max_words: Cropping the caption to max_words.
        split: 'val' or 'test'
        image_perturb_fn: image perturbation function for patch permutation experiments.
        download: Whether to download the dataset if it does not exist.
        '''
        urls = {'val':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json',
                'test':'https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json'}
        filenames = {'val':'flickr30k_val.json','test':'flickr30k_test.json'}
        
        self.root_dir = os.path.join(root_dir, 'flickr30k')
        if not os.path.exists(self.root_dir):
            print("Directory for Flickr30k could not be found!")
            flickr_url = "https://forms.illinois.edu/sec/229675"
            raise RuntimeError(f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`.")
        
        download_url(urls[split],self.root_dir)
        
        self.annotation = json.load(open(os.path.join(self.root_dir,filenames[split]),'r'))
        self.image_preprocess = image_preprocess
        self.image_perturb_fn = image_perturb_fn
        
        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}
        
        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['image'])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann['caption']):
                self.text.append(pre_caption(caption,max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1
                                    
    def __len__(self):
        return len(self.annotation)
    
    def __getitem__(self, index):    
        image_path = os.path.join(self.root_dir, self.annotation[index]['image'])        
        image = Image.open(image_path).convert('RGB')   
        if self.image_preprocess is not None: 
            image = self.image_preprocess(image)  
        if self.image_perturb_fn is not None:
            image = self.image_perturb_fn(image) 
        
        return {"image": image, "idx": index}
    
    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[0]
            scores_t2i = scores[1].T # Make it N_ims x N_text
        else:
            scores_t2i = scores
            scores_i2t = scores

        print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ")
        prec_at_1 = AverageMeter()
        prec_at_5 = AverageMeter()

        # Text retrieval
        tqdm_iterator = tqdm(range(len(self.img2txt)))
        for i in tqdm_iterator:
            top5_captions = np.argsort(scores_i2t[i])[-5:]
            true_captions = self.img2txt[i]

            prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:]))>0)
            prec_at_5.update(len(set(true_captions) & set(top5_captions))>0)

            tqdm_iterator.set_description(f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}")

        # Image Retrieval
        image_prec_at_1 = AverageMeter()
        image_prec_at_5 = AverageMeter()

        tqdm_iterator = tqdm(range(len(self.txt2img)))
        for i in tqdm_iterator:
            top5_images = np.argsort(scores_t2i[:, i])[-5:]
            true_image = self.txt2img[i]

            image_prec_at_1.update(true_image in top5_images[-1:])
            image_prec_at_5.update(true_image in top5_images)

            tqdm_iterator.set_description(f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}")

        records = {"T2IRecall@1": image_prec_at_1.avg, "T2IRecall@5": image_prec_at_5.avg, "I2TRecall@1": prec_at_1.avg, "I2TRecall@5": prec_at_5.avg}
        return records
    
    def download(self):
        raise NotImplementedError("Flickr30k dataset is not available for download.")


class Laion_Retrieval(Dataset):
    def __init__(self, image_preprocess=None, root_dir='./', max_words=30, split="test",
                 image_perturb_fn=None, download=False):
        """
        Laion Retrieval Dataset.
        image_preprocess: image preprocessing function
        root_dir: The directory of the laion dataset. 
        max_words: Cropping the caption to max_words.
        split: must be in file_names
        image_perturb_fn: image perturbation function for patch permutation experiments.
        download: Whether to download the dataset if it does not exist.
        """
        self.split = split
        self.image_preprocess = image_preprocess
        self.download = download
        self.root_dir = root_dir

        file_names = [
            "subset_50_00.json", "subset_50_01.json", "subset_50_02.json",
            "subset_50_03.json", "subset_50_04.json", "subset_50_05.json",
            "subset_50_06.json", "subset_50_07.json", "subset_50_08.json",
            "subset_50_09.json", "subset_50_10.json",
            "subset_1000_00.json", "subset_1000_01.json", "subset_1000_02.json", "subset_1000_03.json",
            "subset_1000_04.json", "subset_1000_05.json", "subset_1000_06.json", "subset_1000_07.json",
            "subset_1000_08.json", "subset_1000_09.json", "subset_100_00.json", "subset_100_01.json",
            "subset_2000_00.json", "subset_2000_01.json", "subset_2000_02.json",
            "subset_500_00.json", "subset_500_01.json", "subset_500_02.json",
            "subset_100_02.json", "subset_100_03.json", "subset_100_04.json", "subset_100_05.json",
            "subset_100_06.json", "subset_100_07.json", "subset_100_08.json", "subset_100_09.json",
            "subset_5000_00.json", "subset_5000_01.json"
        ]
        
        self.root_dir = os.path.join(root_dir, 'laion')
        if not os.path.exists(self.root_dir):
            print("Directory for LAION could not be found!")
            if download:
                print("Downloading LAION now.")
                self.download()
            else:
                raise RuntimeError(
                    "Please either download the dataset by letting `--download` or specify the correct directory.")
        
        assert f"{split}.json" in file_names
        with open(os.path.join('laion', f"{split}.json")) as f:
            self.annotation = json.load(f)
        self.image_preprocess = image_preprocess
        self.image_perturb_fn = image_perturb_fn
        self.image_root = self.root_dir

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann['path'])
            self.img2txt[img_id] = []
            caption = ann['caption']
            self.text.append(pre_caption(caption, max_words))
            self.img2txt[img_id].append(img_id)
            self.txt2img[img_id] = img_id
            # txt_id += 1

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        image_path = os.path.join(
            self.image_root, self.annotation[index]['path'])
        image = Image.open(image_path).convert('RGB')

        if self.image_preprocess is not None:
            image = self.image_preprocess(image)

        if self.image_perturb_fn is not None:
            image = self.image_perturb_fn(image)

        return {"image": image, "idx": index}

    def download(self):
        pass
        # import subprocess
        # os.makedirs(self.root_dir, exist_ok=True)
        # # subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
        # # subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)

        # subprocess.call(
        #     ["wget", "http://images.cocodataset.org/zips/val2014.zip"], cwd=self.root_dir)
        # subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)

        # subprocess.call(
        #     ["wget", "http://images.cocodataset.org/zips/test2014.zip"], cwd=self.root_dir)
        # subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)

    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[0]
            scores_t2i = scores[1].T  # Make it N_ims x N_text

        else:
            scores_t2i = scores
            scores_i2t = scores

        print(f"LAION results across {scores_i2t.shape} samples. ")
        prec_at_1 = AverageMeter()

        # Text retrieval
        tqdm_iterator = tqdm(range(len(self.img2txt)))
        for i in tqdm_iterator:
            
            top5_captions = np.argsort(scores_i2t[i])[-5:]
            true_captions = self.img2txt[i]

            prec_at_1.update(
                len(set(true_captions) & set(top5_captions[-1:])) > 0)
            tqdm_iterator.set_description(
                f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}")
        # Image Retrieval
        image_prec_at_1 = AverageMeter()

        tqdm_iterator = tqdm(range(len(self.txt2img)))
        for i in tqdm_iterator:
            top5_images = np.argsort(scores_t2i[:, i])[-5:]
            true_image = self.txt2img[i]

            image_prec_at_1.update(true_image in top5_images[-1:])

            tqdm_iterator.set_description(
                f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}")

        records = {"T2IRecall@1": image_prec_at_1.avg,
                   "I2TRecall@1": prec_at_1.avg,}
        return records



def get_coco_retrieval(image_preprocess, image_perturb_fn, max_words=30, download=False, root_dir='./', split="test"):
    dataset = COCO_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, 
                            download=download)
    return dataset


def get_flickr30k_retrieval(image_preprocess, image_perturb_fn, max_words=30, download=False, root_dir='./', split="test"):
    dataset = Flickr30k_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words, 
                            download=download)
    return dataset


def get_laion_retrieval(image_preprocess, image_perturb_fn, max_words=30, download=False, root_dir='./', split="subset_5000_00"):
    dataset = Laion_Retrieval(root_dir=root_dir, split=split, image_preprocess=image_preprocess, image_perturb_fn=image_perturb_fn, max_words=max_words,
                                  download=download)
    return dataset
