"""Loads embeddings for the GTZAN and FMA datasets."""

import csv
from collections import defaultdict
import functools
import os

import numpy as np

from fmri2music import utils


join = os.path.join


def get_data_dir() -> str:
    """Returns the path to the data directory from the DATA_DIR env var.

    The data dir is prepended to paths from which embeddings and predictions are loaded.
    """
    data_dir = os.getenv("DATA_DIR", join(".", "data"))
    if not data_dir or not os.path.isdir(data_dir):
        raise ValueError(
            f"Provided DATA_DIR does not exist: '{data_dir}'. Set it in the .env file."
        )
    return data_dir


@utils.synchronized
@functools.lru_cache(maxsize=16)
def get_gtzan_emb(emb_name: str) -> dict[str, np.ndarray]:
    """Loads GTZAN clip-level embeddings."""
    file_name = f"gtzan-emb-{emb_name}.npz"
    with open(join(get_data_dir(), "music-emb", file_name), "rb") as fh:
        content = np.load(fh)
        if not len(content["keys"]) == len(content["vecs"]):
            raise ValueError(
                f"Found an invalid embedding file ({file_name}). "
                f"Number of keys #{len(content['keys'])} does not match "
                f"number of vectors shape {content['vecs'].shape}."
            )
        result = {
            utils.normalize_slice_name(k): v
            for k, v in zip(content["keys"], content["vecs"])
        }
        print(
            f"Loaded GTZAN embeddings from {file_name}. Shape: {content['vecs'].shape}"
        )
        return result


@utils.synchronized
@functools.lru_cache(maxsize=16)
def get_grouped_gtzan_embs(
    emb_name: str,
) -> tuple[dict[str, np.ndarray], dict[str, list[str]]]:
    """Loads GTZAN embeddings and groups them by GTZAN clip name.

    Returns two dicts, one for the embeddings and one for the slice names.
     - Clip name -> np.ndarray. The array has rank 2, where the
       first axis is the embeddings temporally sorted and the second is the embedding.
     - Clip name -> slices names.
    """
    gtzan_embs = get_gtzan_emb(emb_name)
    all_slice_names = list(gtzan_embs.keys())

    key_groups = utils.group_stim_names(all_slice_names)
    embs = defaultdict(list)
    long_keys = defaultdict(list)
    for group in key_groups:
        for long_key in group:
            clip_name = utils.parse_long_key(long_key)[0]
            embs[clip_name].append(gtzan_embs[long_key])
            long_keys[clip_name].append(long_key)

    for k, group in embs.items():
        embs[k] = np.stack(group)
    return dict(embs), dict(long_keys)


def get_grouped_gtzan_embs_for_clip_name(
    emb_name: str, clip_name: str
) -> tuple[np.ndarray, list[str]]:
    """Returns GTZAN embeddings and slice names for a given clip name."""
    embs, long_keys = get_grouped_gtzan_embs(emb_name)
    return embs[clip_name], long_keys[clip_name]


@utils.synchronized
@functools.lru_cache(maxsize=8)
def get_fma_emb(fma_size: str, emb_name: str) -> tuple[list[str], np.ndarray]:
    """Loads embeddings for the FMA dataset."""
    file_name = f"fma-{fma_size}-emb-{emb_name}.npz"
    with open(join(get_data_dir(), "music-emb", file_name), "rb") as fh:
        content = np.load(fh)
        return (
            list(map(utils.normalize_slice_name, content["keys"])),
            content["vecs"],
        )


@utils.synchronized
@functools.lru_cache(maxsize=8)
def get_grouped_fma_embs(
    fma_size: str, emb_name: str, min_num_slices: int | None = None
) -> tuple[dict[str, np.ndarray], dict[str, list[str]]]:
    """Loads FMA embeddings and groups them by clip name.

    The provided min_num_slices serves as a filter,
    only clips with at lest min_num_slices slices are returned.
    """
    fma_embs: dict[str, np.ndarray] = dict(zip(*get_fma_emb(fma_size, emb_name)))

    key_groups = utils.group_stim_names(fma_embs.keys())
    result = defaultdict(list)
    long_keys = defaultdict(list)
    for group in key_groups:
        for long_key in group:
            clip_name = utils.parse_long_key(long_key)[0]
            result[clip_name].append(fma_embs[long_key])
            long_keys[clip_name].append(long_key)

    for k, group in result.items():
        result[k] = np.stack(group)

    if min_num_slices is None:
        return result, long_keys

    keys_to_delete = []
    for k, v in result.items():
        if len(v) < min_num_slices:
            keys_to_delete.append(k)
    for k in keys_to_delete:
        del result[k]
        del long_keys[k]

    print(
        f"Deleted #{len(keys_to_delete)} clips with less than {min_num_slices} slices from {fma_size=}, {emb_name=}."
    )

    return dict(result), dict(long_keys)


@utils.synchronized
@functools.lru_cache(maxsize=16)
def get_grouped_fma_matrix(
    fma_size: str,
    emb_name: str,
    min_num_slices: int,
    intersect_w_emb_name: str | None,
    min_num_slices_intersection: int | None,
) -> tuple[np.ndarray, list[list[str]]]:
    """Loads FMA embeddings as one big matrix and groups them by clip name.

    Embeddings are normalized.

    Shapes:
     - Embeddings: (num_clips, num_slices, emb_dim)
     - Slice names: (numclips, num_slices)
    """
    if intersect_w_emb_name is not None:
        fma_embs, _ = get_grouped_fma_embs(
            fma_size, intersect_w_emb_name, min_num_slices=min_num_slices_intersection
        )
        avail_eval_emb_fma_keys = set(fma_embs.keys())
    skipped_ctr = 0
    emb_dict, name_dict = get_grouped_fma_embs(
        fma_size, emb_name, min_num_slices=min_num_slices
    )
    all_embs, all_names = [], []
    for clip_name, split_names in name_dict.items():
        if len(split_names) < min_num_slices:
            skipped_ctr += 1
            continue
        if (
            intersect_w_emb_name is not None
            and clip_name not in avail_eval_emb_fma_keys
        ):
            skipped_ctr += 1
            continue
        all_embs.append(emb_dict[clip_name][:min_num_slices])
        all_names.append(split_names[:min_num_slices])
    all_embs = np.stack(all_embs)
    all_embs = all_embs / np.linalg.norm(all_embs, axis=-1, keepdims=True)

    if skipped_ctr > 0:
        print(f"Skipped {skipped_ctr} clips which were too short.")

    return all_embs, all_names


def load_pred_emb_name(file_name: str) -> str:
    """Loads the name of the embedding predicted by the model from the exported .npz file."""
    file_path = join(get_data_dir(), "pred", file_name)
    with open(file_path, "rb") as fh:
        content = np.load(fh)
        return content["emb_name"].item()


@utils.synchronized
@functools.lru_cache(maxsize=128)
def load_predictions(file_name: str) -> dict[str, np.ndarray]:
    """Loads a model's predictions for GTZAN from the exported .npz file."""
    file_path = join(get_data_dir(), "pred", file_name)
    with open(file_path, "rb") as fh:
        content = np.load(fh)
        pred_gtzan_slice_names, pred_vecs = (
            content["gtzan_slice_names"],
            content["gtzan_preds"],
        )
        assert len(pred_gtzan_slice_names) == len(pred_vecs)
        return {k: v for k, v in zip(pred_gtzan_slice_names, pred_vecs)}


@utils.synchronized
@functools.lru_cache(maxsize=1)
def get_audioset_class_map() -> dict[str, int]:
    """The class map maps from class name to index in the probs vector."""
    name_to_index = {}
    path = join(get_data_dir(), "music-emb", "audioset-class-map.csv")
    with open(path, "r", encoding="utf-8") as f:
        reader = csv.DictReader(f)

        for row in reader:
            name_to_index[row["display_name"]] = int(row["index"])
    print(f"Done loading the AudioSet class map with {len(name_to_index)} classes.")
    return name_to_index


@utils.synchronized
@functools.lru_cache(maxsize=2)
def get_audioset_probs_for_fma(fma_size: str) -> dict[str, np.ndarray]:
    """Loads the AudioSet probabilities for the FMA dataset.

    Returns a dict mapping from FMA clip name to probs vector."""

    file_path = join(get_data_dir(), "music-emb", f"fma-{fma_size}-audioset-15s.npz")
    with open(file_path, "rb") as fh:
        content = np.load(fh)
        result = dict(zip(content["keys"], content["vecs"]))
    print(f"Done loading {len(result)} AudioSet probs for FMA-{fma_size}.")
    return result


@utils.synchronized
@functools.lru_cache(maxsize=1)
def get_audioset_probs_for_gtzan() -> dict[str, np.ndarray]:
    """Loads the AudioSet probabilities for the GTZAN dataset.

    Returns a dict mapping from GTZAN clip name to probs vector."""

    file_path = join(get_data_dir(), "music-emb", "gtzan-audioset.npz")
    with open(file_path, "rb") as fh:
        content = np.load(fh)
        result = dict(zip(content["keys"], content["vecs"]))
    print(f"Done loading {len(result)} AudioSet probs for GTZAN.")
    return result


def get_audioset_probs_from_file(file_path: str) -> dict[str, np.ndarray]:
    """Loads the AudioSet probabilities from a file."""

    with open(join(get_data_dir(), "pred", file_path), "rb") as fh:
        content = np.load(fh)
        result = dict(zip(content["keys"], content["vecs"]))
    print(f"Done loading {len(result)} AudioSet probs.")
    return result


def get_musiclm_pred_emb(path: str) -> dict[str, np.ndarray]:
    """Loads embeddings of MusicLM predictions.

    Returns a dict mapping from slice name to embedding vector.
    """

    with open(join(get_data_dir(), path), "rb") as fh:
        content = np.load(fh)
        return dict(zip(content["keys"], content["vecs"]))
