from pathlib import Path
from typing import List, Union

import numpy as np
import pandas as pd
from lutils import openf, writef


class Index:
    import faiss

    def __init__(self, index_dir, index_type: str = "text"):
        index_dir = Path(index_dir)
        assert index_dir.exists(), f"{index_dir} does not exist"
        self.index = self.faiss.read_index(str(index_dir / f"{index_type}.index"))

        df = load_metadata(index_dir / "metadata")
        key = "caption" if index_type == "text" else "image_path"
        self.data = df[key].tolist()
        self.data2id = dict(zip(self.data, df.index.tolist()))

    def idx2data(self, idx: int) -> str:
        return self.data[idx]

    def get_idxs(self, query_emb, num_results=1, threshold=None):
        if query_emb.ndim == 1:
            query_emb = query_emb.reshape(1, -1)

        if threshold is not None:
            _, d, idxs = self.index.range_search(query_emb, threshold)
            idxs = idxs[np.argsort(-d)]
        else:
            d, idxs = self.index.search(query_emb, num_results)
            d, idxs = d[0], idxs[0]

        return idxs, d

    def search(self, query_emb, num_results=1, threshold=None, return_d=False):
        idxs, d = self.get_idxs(query_emb, num_results, threshold)

        captions = []
        for idx in idxs:
            try:
                captions.append(self.idx2data(idx))
            except:
                pass
        if return_d:
            return captions, d
        return captions


def load_metadata(data_dir: Union[str, Path]):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    if data_dir.name != "metadata":
        data_dir = data_dir / "metadata"
    assert data_dir.exists(), f"{data_dir} does not exist"

    parquet_pths = list(data_dir.glob("*.parquet"))
    parquet_pths.sort()

    return pd.concat(pd.read_parquet(parquet_pth) for parquet_pth in parquet_pths)


def load_emb(emb_dir: Path):
    emb_dir = Path(emb_dir)
    assert emb_dir.exists(), f"{emb_dir} does not exist"

    return np.concatenate(
        [
            np.load(emb_pth).astype("float32")
            for emb_pth in sorted(emb_dir.glob("*.npy"))
        ],
        axis=0,
    )


class TxtEmbeddings:
    import clip
    import torch

    def __init__(self, emb_dir: Union[str, Path], model: str = "None"):
        if isinstance(emb_dir, str):
            emb_dir = Path(emb_dir)
        assert emb_dir.exists(), f"{emb_dir} does not exist"

        self.df = load_metadata(emb_dir)
        self.df["caption"] = self.df["caption"].apply(lambda x: x.replace("\n", ""))
        self.df.reset_index(drop=True, inplace=True)
        self.captions = self.df["caption"]
        self.captions2idx = dict(zip(self.captions, range(len(self.captions))))

        self.text_embs = load_emb(emb_dir / "text_emb")

        assert len(self.captions) == len(self.text_embs)

        self.device = "cuda" if self.torch.cuda.is_available() else "cpu"
        if model != "None":
            model, _ = self.clip.load(model, device=self.device)
            model.eval()
        self.model = model

    def encode_text(self, caption: str) -> np.ndarray:
        assert isinstance(caption, str), f"{caption} is not a string"
        assert self.model is not None, "No model loaded"
        text = self.clip.tokenize([caption]).to(self.device)
        with self.torch.no_grad():
            text_features = self.model.encode_text(text)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features.cpu().numpy().astype("float32")[0, :]

    def get_idx(self, caption):
        assert isinstance(caption, str), f"{caption} is not a string"
        if caption in self.captions2idx:
            return self.captions2idx[caption]
        else:
            return None

    def get_emb(self, caption):
        idx = self.get_idx(caption)
        if isinstance(idx, int):
            return self.text_embs[idx]
        if self.model != "None":
            return self.encode_text(caption)
        else:
            raise ValueError(f"{caption} not found in index")

    def __len__(self):
        return len(self.captions)

    def __getitem__(self, caption):
        return self.get_emb(caption)

    def __iter__(self):
        for caption, emb in zip(self.captions, self.text_embs):
            yield caption, emb

    def __contains__(self, caption):
        return caption in self.captions2idx


class ImgEmbeddings:
    def __init__(self, emb_dir: Union[str, Path], get_frameid: lambda x: Path(x).stem):
        if isinstance(emb_dir, str):
            emb_dir = Path(emb_dir)
        assert emb_dir.exists(), f"{emb_dir} does not exist"

        data = load_metadata(emb_dir)

        data["frame_id"] = data["image_path"].apply(lambda x: get_frameid(x))
        assert len(data) == len(data.frame_id.unique()), "Duplicate frame_ids"
        self.id2idx = dict(zip(data.frame_id, range(len(data))))
        self.frame_ids = data.frame_id.to_list()
        self.frame_ids.sort()

        self.frame_embs = load_emb(emb_dir / "img_emb")

        self.data = data

    def get_emb(self, frame_id) -> Union[np.ndarray, None]:
        assert frame_id in self.id2idx, f"{frame_id} not in frame_ids"
        return self.frame_embs[self.id2idx[frame_id]]

    def __len__(self) -> int:
        return len(self.frame_ids)

    def __getitem__(self, id):
        return self.get_emb(id)

    def __contains__(self, id):
        return id in self.id2idx


class FrmEmbeddings:
    def __init__(
        self,
        emb_dir: Union[str, Path],
        get_frameid: lambda x: Path(x).stem,
        get_videoid: lambda x: x[:-4],
    ):
        if isinstance(emb_dir, str):
            emb_dir = Path(emb_dir)
        assert emb_dir.exists(), f"{emb_dir} does not exist"

        data = load_metadata(emb_dir)

        data["frame_id"] = data["image_path"].apply(get_frameid)
        assert len(data) == len(data.frame_id.unique()), "Duplicate frame_ids"
        self.frame2idx = dict(zip(data.frame_id, range(len(data))))
        self.frame_ids = data.frame_id.to_list()
        self.frame_ids.sort()

        self.frame_embs = load_emb(emb_dir / "img_emb")

        data["video_id"] = data["frame_id"].apply(get_videoid)

        self.video2frames = data.groupby("video_id")["frame_id"].apply(list).to_dict()
        self.video2idxs = (
            data.groupby("video_id")["frame_id"]
            .apply(lambda x: [self.frame2idx[frame_id] for frame_id in x])
            .to_dict()
        )
        self.video_ids = list(self.video2frames.keys())
        self.video_ids.sort()

        self.data = data

    def get_emb(self, id: str, mean: bool = True) -> Union[np.ndarray, None]:
        if id in self.video2idxs:
            if mean:
                return self.frame_embs[self.video2idxs[id]].mean(axis=0)
            else:
                return self.frame_embs[self.video2idxs[id]]
        elif id in self.frame2idx:
            return self.frame_embs[self.frame2idx[id]]
        else:
            raise ValueError(f"{id} not in self.frame2idx or self.video2idxs")

    def get_embs(self, ids: List[str], mean: bool = True) -> np.ndarray:
        embs = []
        for id in ids:
            embs.append(self.get_emb(id, mean=mean))
        return np.stack(embs)

    def __len__(self) -> int:
        return len(self.frame_ids)

    def __getitem__(self, id: Union[str, List[str]]) -> np.ndarray:
        if isinstance(id, list):
            return self.get_embs(id)
        elif isinstance(id, str):
            return self.get_emb(id)
        else:
            raise ValueError(f"Invalid id type: {type(id)}")

    def __contains__(self, id):
        return id in self.frame2idx or id in self.video2idxs


class SentenceTransformers:
    from sentence_transformers import SentenceTransformer, util

    def __init__(self):
        self.model = self.SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

    def encode(self, text, convert_to_tensor=True):
        return self.model.encode(text, convert_to_tensor=convert_to_tensor)

    def similarity(self, iq, it):
        if isinstance(iq, str):
            iq_emb = self.encode(iq)
        else:
            iq_emb = iq
        if isinstance(it, str):
            it_emb = self.encode(it)
        else:
            it_emb = it
        return self.util.pytorch_cos_sim(iq_emb, it_emb).item()


class CLIP:
    import clip
    import torch
    import torch.nn.functional as F

    device = "cuda" if torch.cuda.is_available() else "cpu"

    def __init__(self, model="ViT-L/14"):
        self.model, self.preprocess = self.clip.load(model, device=self.device)

    def img_emb(self, img):
        with self.torch.no_grad():
            emb = self.model.encode_image(
                self.preprocess(img).unsqueeze(0).to(self.device)
            )
            return self.F.normalize(emb, dim=-1)[0, :]

    def imgs_emb(self, imgs, mean=True):
        with self.torch.no_grad():
            if isinstance(imgs, list):
                imgs = [self.preprocess(img) for img in imgs]
                imgs = self.torch.stack(imgs, dim=0)

            embs = self.model.encode_image(imgs.to(self.device))
            embs = self.F.normalize(embs, dim=-1)
            return embs.mean(dim=0, keepdim=False) if mean else embs

    def txt_emb(self, txt):
        with self.torch.no_grad():
            emb = self.model.encode_text(self.clip.tokenize(txt).to(self.device))
            return self.F.normalize(emb, dim=-1)


class KtsEmbeddings:
    def __init__(self, embs_dir: Union[str, Path]):
        self.embs_dir = Path(embs_dir)
        self.ids2lbl = self.get_ids2lbl(self.embs_dir)

    @staticmethod
    def get_ids(data):
        data["frame_id"] = data["image_path"].apply(lambda x: Path(x).stem)
        assert len(data) == len(data.frame_id.unique()), "Duplicate frame_ids"
        frame_ids = data.frame_id.to_list()
        frame_ids.sort()

        # assert all rows have frame_id[-4] == "_"
        assert all(
            [frame_id[-4] == "_" for frame_id in data.frame_id]
        ), "frame_id[-4] != '_'"
        data["video_id"] = data["frame_id"].apply(lambda x: x[:-4])
        video_ids = list(data.video_id.unique())
        frame_ids = data.frame_id.to_list()

        ids = list(set(video_ids + frame_ids))
        ids.sort()

        return ids

    @staticmethod
    def get_ids2lbl(embs_dir):
        if (embs_dir / "ids2lbl.pkl").exists():
            ids2lbl = openf(embs_dir / "ids2lbl.pkl")

        else:
            ids2lbl = {}
            emb_dirs = [dir for dir in embs_dir.iterdir() if dir.is_dir()]
            for emb_dir in emb_dirs:
                data = load_metadata(emb_dir)
                video_ids = KtsEmbeddings.get_ids(data)

                for video_id in video_ids:
                    assert video_id not in ids2lbl, f"{video_id} already in ids2lbl"
                    ids2lbl[video_id] = emb_dir.name

            writef(ids2lbl, embs_dir / "ids2lbl.pkl")

        return ids2lbl

    def __contains__(self, id):
        return id in self.ids2lbl

    def get_frms_embs(self, lbl):
        emb_dir = self.embs_dir / lbl
        frms_embs_pth = emb_dir / "frms_embs.pkl"
        if frms_embs_pth.exists():
            frms_embs = openf(frms_embs_pth)
        else:
            frms_embs = FrmEmbeddings(emb_dir)
            writef(frms_embs, frms_embs_pth)
        return frms_embs

    def get_emb(self, id, mean: bool = True):
        lbl = self.ids2lbl[id]
        frms_embs = self.get_frms_embs(lbl)
        return frms_embs.get_emb(id, mean=mean)

    def get_embs(self, ids: List[str], mean: bool = True) -> np.ndarray:
        embs = []
        for id in ids:
            embs.append(self.get_emb(id, mean=mean))
        return np.stack(embs)

    def __getitem__(self, id: Union[str, List[str]]) -> np.ndarray:
        if isinstance(id, list):
            return self.get_embs(id)
        elif isinstance(id, str):
            return self.get_emb(id)
        else:
            raise ValueError(f"Invalid id type: {type(id)}")
