# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Base classes for the datasets that also provide non-audio metadata,
e.g. description, text transcription etc.
"""
import logging
import math
import re
import typing as tp
from dataclasses import dataclass

import torch

from ..environment import AudioCraftEnvironment
from ..modules.conditioners import (ConditioningAttributes,
                                    SegmentWithAttributes)
from .audio_dataset import AudioDataset, AudioMeta

logger = logging.getLogger(__name__)


def _clusterify_meta(meta: AudioMeta) -> AudioMeta:
    """Monkey-patch meta to match cluster specificities."""
    meta.path = AudioCraftEnvironment.apply_dataset_mappers(meta.path)
    if meta.info_path is not None:
        meta.info_path.zip_path = AudioCraftEnvironment.apply_dataset_mappers(meta.info_path.zip_path)
    return meta


def clusterify_all_meta(meta: tp.List[AudioMeta]) -> tp.List[AudioMeta]:
    """Monkey-patch all meta to match cluster specificities."""
    return [_clusterify_meta(m) for m in meta]


@dataclass
class AudioInfo(SegmentWithAttributes):
    """Dummy SegmentInfo with empty attributes.

    The InfoAudioDataset is expected to return metadata that inherits
    from SegmentWithAttributes class and can return conditioning attributes.

    This basically guarantees all datasets will be compatible with current
    solver that contain conditioners requiring this.
    """

    audio_tokens: tp.Optional[torch.Tensor] = None  # populated when using cached batch for training a LM.

    def to_condition_attributes(self) -> ConditioningAttributes:
        return ConditioningAttributes()


class InfoAudioDataset(AudioDataset):
    """AudioDataset that always returns metadata as SegmentWithAttributes along with the audio waveform.

    See `audiocraft.data.audio_dataset.AudioDataset` for initialization arguments.
    """

    def __init__(self, meta: tp.List[AudioMeta], **kwargs):
        super().__init__(clusterify_all_meta(meta), **kwargs)

    def __getitem__(self, index: int) -> tp.Union[torch.Tensor, tp.Tuple[torch.Tensor, SegmentWithAttributes]]:
        if not self.return_info:
            wav = super().__getitem__(index)
            assert isinstance(wav, torch.Tensor)
            return wav
        wav, meta = super().__getitem__(index)
        return wav, AudioInfo(**meta.to_dict())


def get_keyword_or_keyword_list(
    value: tp.Optional[str],
) -> tp.Union[tp.Optional[str], tp.Optional[tp.List[str]]]:
    """Preprocess a single keyword or possible a list of keywords."""
    if isinstance(value, list):
        return get_keyword_list(value)
    else:
        return get_keyword(value)


def get_string(value: tp.Optional[str]) -> tp.Optional[str]:
    """Preprocess a single keyword."""
    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == "None":
        return None
    else:
        return value.strip()


def get_keyword(value: tp.Optional[str]) -> tp.Optional[str]:
    """Preprocess a single keyword."""
    if value is None or (not isinstance(value, str)) or len(value) == 0 or value == "None":
        return None
    else:
        return value.strip().lower()


def get_keyword_list(values: tp.Union[str, tp.List[str]]) -> tp.Optional[tp.List[str]]:
    """Preprocess a list of keywords."""
    if isinstance(values, str):
        values = [v.strip() for v in re.split(r"[,\s]", values)]
    elif isinstance(values, float) and math.isnan(values):
        values = []
    if not isinstance(values, list):
        logger.debug(f"Unexpected keyword list {values}")
        values = [str(values)]

    kws = [get_keyword(v) for v in values]
    kw_list = [k for k in kws if k is not None]
    if len(kw_list) == 0:
        return None
    else:
        return kw_list
