import sys
from collections import defaultdict
from pathlib import Path
from typing import List, Union

import numpy as np
import pandas as pd
from tqdm import tqdm

sys.path.append("../")
from utils_retrieval import FrmEmbeddings


class WebVidSim:
    def __init__(self, dataset_dir):
        df = pd.read_csv(dataset_dir / "docs/clean_2M_train.csv")
        df = df[df.path.apply(lambda x: isinstance(x, str))]
        df.reset_index(drop=True, inplace=True)
        df["path_id"] = df.path.apply(lambda x: Path(x))
        df["path_id"] = df.path_id.apply(
            lambda x: str(x.parent.name) + "/" + str(x.stem)
        )
        caption2paths = defaultdict(list)
        for caption, path in zip(df.name, df.path_id):
            caption2paths[caption].append(path)
        self.caption2paths = caption2paths
        self.path2url = dict(zip(df.path_id, df.contentUrl))

        emb_dir = (
            dataset_dir / "clip-retrieval/cr-vitb32/frames_embeddings/all_embeddings"
        )
        get_frameid = lambda x: x.split("/frames/")[1][7:-4]
        get_videoid = lambda x: x.rsplit("_f", 1)[0]
        self.frm_embs = FrmEmbeddings(emb_dir, get_frameid, get_videoid)

    def get_video_similarity(self, txt1, txt2):
        pths1 = self.caption2paths[txt1]
        pths2 = self.caption2paths[txt2]

        if len(pths1) == 0 or len(pths2) == 0:
            return 0, None, None

        sim_scores = []
        for pth1 in pths1:
            for pth2 in pths2:
                vid_emb1 = self.frm_embs[pth1]
                vid_emb2 = self.frm_embs[pth2]
                vid_sim = vid_emb1 @ vid_emb2.T
                sim_scores.append((vid_sim, pth1, pth2))
        sim_scores.sort(reverse=True)
        return sim_scores


def main(dataset_dir, csv_idx):
    csv_in = dataset_dir / "similar-sentences/txt_filtered-duplicates"
    csv_out = dataset_dir / "similar-sentences/vid"
    csv_out.mkdir(exist_ok=True)

    csv_pths = list(Path(csv_in).glob("*.csv"))
    csv_pth = [f for f in csv_pths if str(f).endswith(f"_{csv_idx}.csv")][0]
    print(csv_pth)

    wb_sim = WebVidSim(dataset_dir)

    df_sim = pd.read_csv(csv_pth)
    df_sim.columns = ["txt1", "txt2", "sim_txt"]
    # df_sim = df_sim[df_sim["sim_txt"] < 0.96]
    # df_sim = df_sim[df_sim["sim_txt"] > 0.60]

    data = []
    for _, row in tqdm(df_sim.iterrows(), total=len(df_sim)):
        txt1 = row.txt1
        txt2 = row.txt2
        sim_txt = row.sim_txt
        try:
            sims = wb_sim.get_video_similarity(row.txt1, row.txt2)
            for sim_vid, pth1, pth2 in sims:
                data.append((txt1, txt2, sim_txt, pth1, pth2, sim_vid))
        except:
            pass

    df_sim = pd.DataFrame(data)
    df_sim.columns = ["txt1", "txt2", "sim_txt", "pth1", "pth2", "sim_vid"]

    df_sim.to_csv(csv_out / f"{csv_pth.stem}.csv", index=False)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument("csv_idx", type=int)
    args = parser.parse_args()

    dataset_dir = Path("./datasets/WebVid/2M/")
    assert dataset_dir.exists()

    main(dataset_dir, args.csv_idx)
