import ast
import random
from pathlib import Path

import pandas as pd
import torch
from lutils import openf
from PIL import Image
from torch.utils.data import Dataset

from data.utils import get_middle_frame, pre_caption


class WebVidImg(Dataset):
    def __init__(self, transform, image_root, ann_root, split, data, max_words=30):
        """
        image_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        urls = {
            "train": f"webvid2m.{data}.json",
        }
        assert split == "train", "Only train split is available"

        self.annotation = openf(Path(ann_root, Path(urls[split]).name))
        self.transform = transform
        self.max_words = max_words
        self.split = split

        # Create a dictionary of image paths
        self.image_root = Path(image_root)
        assert self.image_root.exists(), f"Image root {self.image_root} does not exist"
        img_pths = list(self.image_root.glob("*/*.png"))
        self.id2pth = dict()
        for img_pth in img_pths:
            img_dir = img_pth.parent.name
            video_name = img_pth.stem.split("_")[0]
            video_id = f"{img_dir}/{video_name}"
            self.id2pth[video_id] = img_pth

        for ann in self.annotation:
            assert (
                ann["reference"] in self.id2pth
            ), f"Path to reference {ann['reference']} not found"
            if split != "test":
                assert (
                    ann["target"] in self.id2pth
                ), f"Path to target {ann['target']} not found"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        reference_img_pth = self.id2pth[ann["reference"]]
        reference_img = Image.open(reference_img_pth).convert("RGB")
        reference_img = self.transform(reference_img)

        caption = pre_caption(ann["caption"], self.max_words)

        if self.split == "test":
            return reference_img, caption, ann["pairid"]

        target_img_pth = self.id2pth[ann["target"]]
        target_img = Image.open(target_img_pth).convert("RGB")
        target_img = self.transform(target_img)

        return reference_img, target_img, caption, ann["pairid"]


class WebVidVid(Dataset):
    def __init__(self, transform, video_root, ann_root, split, data, max_words=30):
        """
        video_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        urls = {
            "train": f"webvid2m.{data}.json",
        }
        assert split == "train", "Only train split is available"

        self.annotation = openf(Path(ann_root, Path(urls[split]).name))
        self.transform = transform
        self.max_words = max_words
        self.split = split

        # Create a dictionary of image paths
        self.video_root = Path(video_root)
        assert self.video_root.exists(), f"Video root {self.video_root} does not exist"
        vid_pths = list(self.video_root.glob("*/*.mp4"))
        self.id2pth = dict()
        for vid_pth in vid_pths:
            vid_dir = vid_pth.parent.name
            video_name = vid_pth.stem.split("_")[0]
            video_id = f"{vid_dir}/{video_name}"
            self.id2pth[video_id] = vid_pth

        for ann in self.annotation:
            assert (
                ann["reference"] in self.id2pth
            ), f"Path to reference {ann['reference']} not found"
            if split != "test":
                assert (
                    ann["target"] in self.id2pth
                ), f"Path to target {ann['target']} not found"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        reference_vid_pth = self.id2pth[ann["reference"]]
        reference_vid = get_middle_frame(reference_vid_pth)
        reference_vid = self.transform(reference_vid)

        caption = pre_caption(ann["caption"], self.max_words)

        if self.split == "test":
            return reference_vid, caption, ann["pairid"]

        target_vid_pth = self.id2pth[ann["target"]]
        target_vid = get_middle_frame(target_vid_pth)
        target_vid = self.transform(target_vid)

        return reference_vid, target_vid, caption, ann["pairid"]


class WebVidRuleBased(Dataset):
    def __init__(self, transform, video_root, ann_root, split, data, max_words=30):
        """
        video_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        urls = {
            "train": f"webvid2m.{data}.json",
        }
        assert split == "train", "Only train split is available"

        self.annotation = openf(Path(ann_root, Path(urls[split]).name))
        self.transform = transform
        self.max_words = max_words
        self.split = split

        # Dictionary of video paths
        self.video_root = Path(video_root)
        assert self.video_root.exists(), f"Video root {self.video_root} does not exist"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        pth1 = Path(self.video_root, f"{ann['pth1']}.mp4")
        pth2 = Path(self.video_root, f"{ann['pth2']}.mp4")

        diff_txt1 = ann["diff_txt1"]
        diff_txt2 = ann["diff_txt2"]

        if random.random() > 0.5:
            pth1, pth2 = pth2, pth1
            diff_txt1, diff_txt2 = diff_txt2, diff_txt1

        reference_vid = get_middle_frame(pth1)
        reference_vid = self.transform(reference_vid)

        caption = self.generate_rule_based_edit(diff_txt1, diff_txt2)
        caption = pre_caption(caption, self.max_words)

        target_vid = get_middle_frame(pth2)
        target_vid = self.transform(target_vid)

        return reference_vid, target_vid, caption, ann["pairid"]

    @staticmethod
    def generate_rule_based_edit(txt1, txt2):
        templates = [
            "Remove {txt1}",
            "Take out {txt1} and add {txt2}",
            "Change {txt1} for {txt2}",
            "Replace {txt1} with {txt2}",
            "Replace {txt1} by {txt2}",
            "Replace {txt1} with {txt2}",
            "Make the {txt1} into {txt2}",
            "Add {txt2}",
            "Change it to {txt2}",
        ]
        template = random.choice(templates)
        sentence = template.format(txt1=txt1, txt2=txt2)
        return sentence


class WebVidImgEmbsDataset(Dataset):
    def __init__(
        self,
        transform,
        image_root,
        ann_root,
        split,
        data="",
        max_words=30,
        vit="large",
    ):
        """
        image_root (string): Root directory of images (e.g. CIRR/images/). It should have train, dev, and test1 folders.
        ann_root (string): directory to store the annotation file
        """
        assert vit in [
            "base",
            "large",
        ], f"vit should be either base or large, not {vit}"

        urls = {
            "train": f"webvid2m.{data}.json",
        }
        assert split == "train", "Only train split is available"

        self.annotation = openf(Path(ann_root, Path(urls[split]).name))
        self.transform = transform
        self.max_words = max_words
        self.split = split

        # Create a dictionary of image paths
        self.image_root = Path(image_root)
        self.emb_root = self.image_root.parent / f"blip-embs-{vit}"
        assert self.image_root.exists(), f"Image root {self.image_root} does not exist"
        assert self.emb_root.exists(), f"Emb root {self.emb_root} does not exist"

        img_pths = list(self.image_root.glob("*/*.png"))
        emb_pths = list(self.emb_root.glob("*/*.pth"))
        self.id2imgpth = dict()
        for img_pth in img_pths:
            img_dir = img_pth.parent.name
            video_name = img_pth.stem.split("_")[0]
            video_id = f"{img_dir}/{video_name}"
            self.id2imgpth[video_id] = img_pth
        self.id2embpth = dict()
        for emb_pth in emb_pths:
            emb_dir = emb_pth.parent.name
            video_name = emb_pth.stem.split("_")[0]
            video_id = f"{emb_dir}/{video_name}"
            self.id2embpth[video_id] = emb_pth

        for ann in self.annotation:
            assert (
                ann["reference"] in self.id2imgpth
            ), f"Path to reference {ann['reference']} not found in {self.image_root}"
            assert (
                ann["reference"] in self.id2embpth
            ), f"Path to reference {ann['reference']} not found in {self.emb_root}"
            assert (
                ann["target"] in self.id2imgpth
            ), f"Path to target {ann['target']} not found in {self.image_root}"
            assert (
                ann["target"] in self.id2embpth
            ), f"Path to target {ann['target']} not found in {self.emb_root}"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]

        reference_img_pth = self.id2imgpth[ann["reference"]]
        reference_img = Image.open(reference_img_pth).convert("RGB")
        reference_img = self.transform(reference_img)

        caption = pre_caption(ann["caption"], self.max_words)

        if self.split == "test":
            reference_feat = torch.load(self.id2embpth[ann["reference"]])
            return reference_img, reference_feat, caption, ann["pairid"]

        target_emb_pth = self.id2embpth[ann["target"]]
        target_feat = torch.load(target_emb_pth).cpu()

        return reference_img, target_feat, caption, ann["pairid"]


class WebVidVidEmbsDataset(Dataset):
    def __init__(
        self,
        transform,
        video_root,
        ann_root,
        split,
        data="",
        max_words=30,
        vit="large",
        emb_method="middle",
    ):
        """
        video_root (string): Root directory of videos.
        ann_root (string): directory to store the annotation file
        """
        assert vit in [
            "base",
            "large",
        ], f"vit should be either base or large, not {vit}"
        assert emb_method in [
            "middle",
            "mean",
        ], f"emb_method is {emb_method}, should be either middle or mean"

        urls = {
            "train": f"webvid2m.{data}.json",
        }
        assert split == "train", "Only train split is available"

        self.annotation = openf(Path(ann_root, Path(urls[split]).name))
        self.transform = transform
        self.max_words = max_words
        self.split = split
        self.emb_method = emb_method

        # Create a dictionary of image paths
        self.video_root = Path(video_root)
        self.emb_root = self.video_root.parent / f"blip-vid-embs-{vit}-all"
        assert self.video_root.exists(), f"Image root {self.video_root} does not exist"
        assert self.emb_root.exists(), f"Emb root {self.emb_root} does not exist"

    def __len__(self):
        return len(self.annotation)

    def __getitem__(self, index):
        ann = self.annotation[index]
        reference_pth = Path(self.video_root, f"{ann['reference']}.mp4")
        reference_vid = get_middle_frame(reference_pth)
        reference_vid = self.transform(reference_vid)

        caption = pre_caption(ann["caption"], self.max_words)

        target_pth = Path(self.emb_root, f"{ann['target']}.pth")
        target_emb = torch.load(target_pth).cpu()
        if self.emb_method == "middle":
            target_emb = target_emb[len(target_emb) // 2]
        elif self.emb_method == "mean":
            target_emb = target_emb.mean(0)

        return reference_vid, target_emb, caption, ann["pairid"]


class WebVidVidEmbsIterateDataset(Dataset):
    def __init__(
        self,
        transform,
        video_root,
        ann_root,
        split,
        data="",
        max_words=30,
        vit="large",
        emb_method="middle",
        iterate="input",
    ):
        """
        video_root (string): Root directory of videos.
        ann_root (string): directory to store the annotation file
        """
        assert vit in [
            "base",
            "large",
        ], f"vit should be either base or large, not {vit}"
        assert emb_method in [
            "middle",
            "mean",
            "query",
        ], f"emb_method is {emb_method}, should be either middle or mean"

        urls = {"train": f"webvid2m.{data}.csv", "test": f"WebVid8M.{data}.csv"}
        assert split in ["train", "test"], "Only train or test split is available"

        df_pth = Path(ann_root, Path(urls[split]).name)
        self.df = pd.read_csv(df_pth)
        self.transform = transform
        self.max_words = max_words
        self.split = split
        self.emb_method = emb_method
        self.iterate = iterate
        self.target_txts = self.df[iterate].unique()

        if self.emb_method == "query":
            assert "scores" in self.df.columns, "Query method requires scores column"

        if split == "test":
            assert (
                len(self.target_txts) == self.df.shape[0]
            ), "Test split should have one caption per row"

        assert iterate in self.df.columns, f"{iterate} not in {df_pth.stem}"
        self.df.sort_values(iterate, inplace=True)
        self.df.reset_index(drop=True, inplace=True)
        self.pairid2ref = self.df["pth1"].to_dict()
        self.pairid2tar = self.df["pth2"].to_dict()
        self.df.set_index(iterate, inplace=True)
        self.df[iterate] = self.df.index

        txts = set(list(self.df["txt1"].unique()) + list(self.df["txt2"].unique()))
        txts = list(txts)
        txts.sort()
        # Create a dictionary of text to index
        self.txt2idx = {txt: idx for idx, txt in enumerate(txts)}

        # Create a dictionary of image paths
        self.video_root = Path(video_root)
        self.emb_root = self.video_root.parent / f"blip-vid-embs-{vit}-all"
        assert self.video_root.exists(), f"Image root {self.video_root} does not exist"
        assert self.emb_root.exists(), f"Emb root {self.emb_root} does not exist"

    def __len__(self):
        return len(self.target_txts)

    def __getitem__(self, index):
        target_txt = self.target_txts[index]
        ann = self.df.loc[target_txt]
        if ann.ndim > 1:
            ann = ann.sample()
            ann = ann.iloc[0]

        reference_pth = Path(self.video_root, f"{ann['pth1']}.mp4")
        reference_vid = get_middle_frame(reference_pth)
        reference_vid = self.transform(reference_vid)

        caption = pre_caption(ann["edit"], self.max_words)

        target_pth = Path(self.emb_root, f"{ann['pth2']}.pth")
        target_emb = torch.load(target_pth).cpu()
        if self.emb_method == "middle":
            target_emb = target_emb[len(target_emb) // 2]
        elif self.emb_method == "mean":
            target_emb = target_emb.mean(0)
        elif self.emb_method == "query":
            vid_scores = ast.literal_eval(ann["scores"])  # type: ignore
            vid_scores = torch.Tensor(vid_scores)
            vid_scores = (vid_scores / 0.1).softmax(dim=0)
            target_emb = torch.einsum("f,fe->e", vid_scores, target_emb)

        soft_input = self.txt2idx[ann["txt1"]]
        soft_target = self.txt2idx[ann["txt2"]]

        return reference_vid, target_emb, caption, index, soft_input, soft_target


class WebVidVidEmbsIterateRuleBasedDataset(Dataset):
    def __init__(
        self,
        transform,
        video_root,
        ann_root,
        split,
        data="",
        max_words=30,
        vit="large",
        emb_method="middle",
        iterate="pth2",
    ):
        """
        video_root (string): Root directory of videos.
        ann_root (string): directory to store the annotation file
        """
        assert vit in [
            "base",
            "large",
        ], f"vit should be either base or large, not {vit}"
        assert emb_method in [
            "middle",
            "mean",
            "query",
        ], f"emb_method is {emb_method}, should be either middle or mean"

        urls = {"train": f"webvid2m.{data}.csv", "test": f"WebVid8M.{data}.csv"}
        assert split in ["train", "test"], "Only train or test split is available"

        df_pth = Path(ann_root, Path(urls[split]).name)
        self.df = pd.read_csv(df_pth)
        self.transform = transform
        self.max_words = max_words
        self.split = split
        self.emb_method = emb_method
        self.iterate = iterate
        self.target_txts = self.df[iterate].unique()

        if self.emb_method == "query":
            assert "scores" in self.df.columns, "Query method requires scores column"

        if split == "test":
            assert (
                len(self.target_txts) == self.df.shape[0]
            ), "Test split should have one caption per row"

        assert iterate in self.df.columns, f"{iterate} not in {df_pth.stem}"
        self.df.sort_values(iterate, inplace=True)
        self.df.reset_index(drop=True, inplace=True)
        self.pairid2ref = self.df["pth1"].to_dict()
        self.pairid2tar = self.df["pth2"].to_dict()
        self.df.set_index(iterate, inplace=True)
        self.df[iterate] = self.df.index
        self.df = add_different_words(self.df)

        txts = set(list(self.df["txt1"].unique()) + list(self.df["txt2"].unique()))
        txts = list(txts)
        txts.sort()
        # Create a dictionary of text to index
        self.txt2idx = {txt: idx for idx, txt in enumerate(txts)}

        # Create a dictionary of image paths
        self.video_root = Path(video_root)
        self.emb_root = self.video_root.parent / f"blip-vid-embs-{vit}-all"
        assert self.video_root.exists(), f"Image root {self.video_root} does not exist"
        assert self.emb_root.exists(), f"Emb root {self.emb_root} does not exist"

    def __len__(self):
        return len(self.target_txts)

    def __getitem__(self, index):
        target_txt = self.target_txts[index]
        ann = self.df.loc[target_txt]
        if ann.ndim > 1:
            ann = ann.sample()
            ann = ann.iloc[0]

        reference_pth = Path(self.video_root, f"{ann['pth1']}.mp4")
        reference_vid = get_middle_frame(reference_pth)
        reference_vid = self.transform(reference_vid)

        caption = self.generate_rule_based_edit(ann["diff_txt1"], ann["diff_txt1"])
        caption = pre_caption(caption, self.max_words)

        target_pth = Path(self.emb_root, f"{ann['pth2']}.pth")
        target_emb = torch.load(target_pth).cpu()
        if self.emb_method == "middle":
            target_emb = target_emb[len(target_emb) // 2]
        elif self.emb_method == "mean":
            target_emb = target_emb.mean(0)
        elif self.emb_method == "query":
            vid_scores = ast.literal_eval(ann["scores"])  # type: ignore
            vid_scores = torch.Tensor(vid_scores)
            vid_scores = (vid_scores / 0.1).softmax(dim=0)
            target_emb = torch.einsum("f,fe->e", vid_scores, target_emb)

        soft_input = self.txt2idx[ann["txt1"]]
        soft_target = self.txt2idx[ann["txt2"]]
        return reference_vid, target_emb, caption, index, soft_input, soft_target

    @staticmethod
    def generate_rule_based_edit(txt1, txt2):
        templates = [
            "Remove {txt1}",
            "Take out {txt1} and add {txt2}",
            "Change {txt1} for {txt2}",
            "Replace {txt1} with {txt2}",
            "Replace {txt1} by {txt2}",
            "Replace {txt1} with {txt2}",
            "Make the {txt1} into {txt2}",
            "Add {txt2}",
            "Change it to {txt2}",
        ]
        template = random.choice(templates)
        sentence = template.format(txt1=txt1, txt2=txt2)
        return sentence


def get_different_word_in_each_sentence(sentence1, sentence2):
    sentence1_words = sentence1.lower().replace(".", "").replace(",", "").split()
    sentence2_words = sentence2.lower().replace(".", "").replace(",", "").split()
    different_word_in_sentence1 = None
    different_word_in_sentence2 = None
    for w1, w2 in zip(sentence1_words, sentence2_words):
        if w1 != w2:
            different_word_in_sentence1 = w1
            different_word_in_sentence2 = w2
            break
    return different_word_in_sentence1, different_word_in_sentence2


def add_different_words(df):
    diff_txt1s = []
    diff_txt2s = []
    for row in df.itertuples():
        diff_txt1, diff_txt2 = get_different_word_in_each_sentence(row.txt1, row.txt2)
        diff_txt1s.append(diff_txt1)
        diff_txt2s.append(diff_txt2)
    df["diff_txt1"] = diff_txt1s
    df["diff_txt2"] = diff_txt2s

    df = df[df["diff_txt1"].apply(lambda x: isinstance(x, str))]
    df = df[df["diff_txt2"].apply(lambda x: isinstance(x, str))]
    return df
