import os
import re
import json
import numpy as np

from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from torchvision.datasets.utils import download_url

from .constants import COCO_ROOT, FLICKR_ROOT
from .utils import AverageMeter


def pre_caption(caption, max_words=50):
    caption = re.sub(
        r"([.!\"()*#:;~])",
        " ",
        caption.lower(),
    )
    caption = re.sub(
        r"\s{2,}",
        " ",
        caption,
    )
    caption = caption.rstrip("\n")
    caption = caption.strip(" ")

    # truncate caption
    caption_words = caption.split(" ")
    if len(caption_words) > max_words:
        caption = " ".join(caption_words[:max_words])

    return caption


class COCO_Retrieval(Dataset):
    def __init__(
        self,
        image_preprocess=None,
        root_dir=COCO_ROOT,
        max_words=30,
        split="test",
        image_perturb_fn=None,
        download=False,
    ):
        """
        COCO Retrieval Dataset.
        image_preprocess: image preprocessing function
        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
        max_words: Cropping the caption to max_words.
        split: 'val' or 'test'
        image_perturb_fn: image perturbation function for patch permutation experiments.
        download: Whether to download the dataset if it does not exist.
        """
        self.root_dir = root_dir
        if not os.path.exists(root_dir):
            print("Directory for COCO could not be found!")
            if download:
                print("Downloading COCO now.")
                self.download()
            else:
                raise RuntimeError(
                    "Please either download the dataset by letting `--download` or specify the correct directory."
                )

        urls = {
            "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_val.json",
            "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/coco_karpathy_test.json",
        }
        filenames = {"val": "coco_karpathy_val.json", "test": "coco_karpathy_test.json"}
        download_url(urls[split], root_dir)

        self.annotation = json.load(open(os.path.join(root_dir, filenames[split]), "r"))
        self.image_preprocess = image_preprocess
        self.image_perturb_fn = image_perturb_fn
        self.image_root = root_dir

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann["image"])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann["caption"]):
                self.text.append(pre_caption(caption, max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        image_path = os.path.join(self.image_root, self.annotation[index]["image"])
        image = Image.open(image_path).convert("RGB")

        if self.image_preprocess is not None:
            image = self.image_preprocess(image)

        if self.image_perturb_fn is not None:
            image = self.image_perturb_fn(image)

        return {"image": image, "idx": index}

    def download(self):
        import subprocess

        os.makedirs(self.root_dir, exist_ok=True)
        # subprocess.call(["wget", "http://images.cocodataset.org/zips/train2014.zip"], cwd=self.root_dir)
        # subprocess.call(["unzip", "train2014.zip"], cwd=self.root_dir)

        subprocess.call(
            ["wget", "http://images.cocodataset.org/zips/val2014.zip"],
            cwd=self.root_dir,
        )
        subprocess.call(["unzip", "val2014.zip"], cwd=self.root_dir)

        subprocess.call(
            ["wget", "http://images.cocodataset.org/zips/test2014.zip"],
            cwd=self.root_dir,
        )
        subprocess.call(["unzip", "test2014.zip"], cwd=self.root_dir)

    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[0]
            scores_t2i = scores[1].T  # Make it N_ims x N_text

        else:
            scores_t2i = scores
            scores_i2t = scores

        print(f"COCO results across {scores_i2t.shape} samples. ")
        prec_at_1 = AverageMeter()
        prec_at_5 = AverageMeter()

        # Text retrieval
        tqdm_iterator = tqdm(range(len(self.img2txt)))
        for i in tqdm_iterator:
            top5_captions = np.argsort(scores_i2t[i])[-5:]
            true_captions = self.img2txt[i]

            prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:])) > 0)
            prec_at_5.update(len(set(true_captions) & set(top5_captions)) > 0)

            tqdm_iterator.set_description(
                f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}"
            )

        # Image Retrieval
        image_prec_at_1 = AverageMeter()
        image_prec_at_5 = AverageMeter()

        tqdm_iterator = tqdm(range(len(self.txt2img)))
        for i in tqdm_iterator:
            top5_images = np.argsort(scores_t2i[:, i])[-5:]
            true_image = self.txt2img[i]

            image_prec_at_1.update(true_image in top5_images[-1:])
            image_prec_at_5.update(true_image in top5_images)

            tqdm_iterator.set_description(
                f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}"
            )

        records = [
            {
                "ImagePrec@1": image_prec_at_1.avg,
                "ImagePrec@5": image_prec_at_5.avg,
                "TextPrec@1": prec_at_1.avg,
                "TextPrec@5": prec_at_5.avg,
            }
        ]
        return records


class Flickr30k_Retrieval(Dataset):
    def __init__(
        self,
        image_preprocess,
        split,
        root_dir=FLICKR_ROOT,
        max_words=30,
        image_perturb_fn=None,
        *args,
        **kwargs,
    ):
        """
        Flickr30k dataset for retrieval.
        image_preprocess: image preprocessing function
        root_dir: The directory of the coco dataset. This directory should contain test2014 files.
        max_words: Cropping the caption to max_words.
        split: 'val' or 'test'
        image_perturb_fn: image perturbation function for patch permutation experiments.
        download: Whether to download the dataset if it does not exist.
        """
        urls = {
            "val": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_val.json",
            "test": "https://storage.googleapis.com/sfr-vision-language-research/datasets/flickr30k_test.json",
        }
        filenames = {"val": "flickr30k_val.json", "test": "flickr30k_test.json"}

        if not os.path.exists(root_dir):
            print("Directory for Flickr30k could not be found!")
            flickr_url = "https://forms.illinois.edu/sec/229675"
            raise RuntimeError(
                f"You need to manually sign up and download the dataset from {flickr_url} and place it in the `root_dir`."
            )

        download_url(urls[split], root_dir)

        self.annotation = json.load(open(os.path.join(root_dir, filenames[split]), "r"))
        self.image_preprocess = image_preprocess
        self.image_perturb_fn = image_perturb_fn
        self.root_dir = root_dir

        self.text = []
        self.image = []
        self.txt2img = {}
        self.img2txt = {}

        txt_id = 0
        for img_id, ann in enumerate(self.annotation):
            self.image.append(ann["image"])
            self.img2txt[img_id] = []
            for i, caption in enumerate(ann["caption"]):
                self.text.append(pre_caption(caption, max_words))
                self.img2txt[img_id].append(txt_id)
                self.txt2img[txt_id] = img_id
                txt_id += 1

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        image_path = os.path.join(self.root_dir, self.annotation[index]["image"])
        image = Image.open(image_path).convert("RGB")
        if self.image_preprocess is not None:
            image = self.image_preprocess(image)
        if self.image_perturb_fn is not None:
            image = self.image_perturb_fn(image)

        return {"image": image, "idx": index}

    def evaluate_scores(self, scores):
        if isinstance(scores, tuple):
            scores_i2t = scores[0]
            scores_t2i = scores[1].T  # Make it N_ims x N_text

        else:
            scores_t2i = scores
            scores_i2t = scores

        print(f"Flickr30k Retrieval results across {scores_i2t.shape} samples. ")
        prec_at_1 = AverageMeter()
        prec_at_5 = AverageMeter()

        # Text retrieval
        tqdm_iterator = tqdm(range(len(self.img2txt)))
        for i in tqdm_iterator:
            top5_captions = np.argsort(scores_i2t[i])[-5:]
            true_captions = self.img2txt[i]

            prec_at_1.update(len(set(true_captions) & set(top5_captions[-1:])) > 0)
            prec_at_5.update(len(set(true_captions) & set(top5_captions)) > 0)

            tqdm_iterator.set_description(
                f"Text Retrieval Prec@1: {prec_at_1.avg:.3f}, Prec@5: {prec_at_5.avg:.3f}"
            )

        # Image Retrieval
        image_prec_at_1 = AverageMeter()
        image_prec_at_5 = AverageMeter()

        tqdm_iterator = tqdm(range(len(self.txt2img)))
        for i in tqdm_iterator:
            top5_images = np.argsort(scores_t2i[:, i])[-5:]
            true_image = self.txt2img[i]

            image_prec_at_1.update(true_image in top5_images[-1:])
            image_prec_at_5.update(true_image in top5_images)

            tqdm_iterator.set_description(
                f"Image Retrieval Prec@1: {image_prec_at_1.avg:.3f}, Prec@5: {image_prec_at_5.avg:.3f}"
            )

        records = [
            {
                "ImagePrec@1": image_prec_at_1.avg,
                "ImagePrec@5": image_prec_at_5.avg,
                "TextPrec@1": prec_at_1.avg,
                "TextPrec@5": prec_at_5.avg,
            }
        ]
        return records

    def download(self):
        raise NotImplementedError("Flickr30k dataset is not available for download.")


def get_coco_retrieval(
    image_preprocess,
    image_perturb_fn,
    text_perturb_fn,
    max_words=30,
    download=False,
    root_dir=COCO_ROOT,
    split="test",
):
    dataset = COCO_Retrieval(
        root_dir=root_dir,
        split=split,
        image_preprocess=image_preprocess,
        image_perturb_fn=image_perturb_fn,
        max_words=max_words,
        download=download,
    )
    return dataset


def get_flickr30k_retrieval(
    image_preprocess,
    image_perturb_fn,
    text_perturb_fn,
    max_words=30,
    download=False,
    root_dir=FLICKR_ROOT,
    split="test",
):
    dataset = Flickr30k_Retrieval(
        root_dir=root_dir,
        split=split,
        image_preprocess=image_preprocess,
        image_perturb_fn=image_perturb_fn,
        max_words=max_words,
        download=download,
    )
    return dataset
