"""Simple, general-purpose utilities."""

from typing import Any

from collections import defaultdict
import threading

import numpy as np


def parse_long_key(s: str) -> tuple[str, float, float]:
    """Parses the three components of a long key."""
    parts = s.rsplit("_", maxsplit=2)
    if not len(parts) == 3:
        raise ValueError(f"Invalid string format: {s}")
    name, start, end = parts
    return name, float(start), float(end)


def normalize_slice_name(s: str) -> str:
    """Normalizes a string consisting of {key}_{start}_{end}.

    Both start and end are floats and cropped to 2 decimal places.
    """
    parts = s.rsplit("_", maxsplit=2)
    if not len(parts) == 3:
        raise ValueError(f"Invalid string format: {s}")
    key, start, end = parts
    return f"{key}_{float(start):.2f}_{float(end):.2f}"


def seconds_to_idx(
    start_s: float,
    end_s: float,
    num_vectors: int = 10,
    clip_length_s: int = 15,
) -> tuple[int, int]:
    """Converts start and end seconds into the nearest indices slicing a clip.

    For example, we deal with 15s clips with 10 vectors each. If we want to get the embeddings
    for seconds 0 to 3, that is the first two embeddings (0-1.5s and 1.5-3s)."""
    if clip_length_s <= 0 or num_vectors <= 0:
        raise ValueError("Clip length and number of vectors must be positive.")

    if start_s < 0 or end_s > clip_length_s or start_s >= end_s:
        raise ValueError("Invalid start or end times.")

    segment_length_s = clip_length_s / num_vectors

    start_idx = round(start_s / segment_length_s)
    end_idx = round(end_s / segment_length_s)

    return start_idx, end_idx


def group_stim_names(names: list[str]) -> list[list[str]]:
    """Groups stim names by their prefix and sorts them.

    Stim names are of the form {prefix}_{start_time}_{end_time}.
    """
    prefix_dict = defaultdict(list)
    for name in names:
        prefix = parse_long_key(name)[0]
        prefix_dict[prefix].append(name)

    # Sort groups individually by start second.
    for group in prefix_dict.values():
        group.sort(key=lambda s: parse_long_key(s)[1])

    return sorted(prefix_dict.values())


def normalize_preds(mat: np.ndarray) -> np.ndarray:
    """Normalize rows of matrix to unit length."""
    return mat / np.linalg.norm(mat, axis=1, keepdims=True)


def synchronized(func):
    """Wraps a function with a threading lock.

    We use this mainly to prevent embedding loading functions to be called
    by multiple threads simultaneously."""
    lock = threading.Lock()

    def wrapper(*args, **kwargs):
        with lock:
            return func(*args, **kwargs)

    return wrapper


def rotate_rows(matrix: np.ndarray, n: int):
    """Rotates the rows of a matrix by a given n.

    Rotation is 'upwards', with the top rows being added at the bottom."""
    num_rows = matrix.shape[0]
    n = n % num_rows  # Normalize the n if it exceeds the number of rows.

    return np.concatenate((matrix[n:], matrix[:n]), axis=0)


def str_from_float_dict(d: dict[str, float]) -> str:
    """Prints a dictionary of floats with rounding."""
    return ", ".join(f"{k}: {v:.4f}" for k, v in d.items())


def add_key_prefix(prefix: str, d: dict[str, Any]) -> dict[str, Any]:
    """Adds a prefix to all keys in a dictionary."""
    return {f"{prefix}{k}": v for k, v in d.items()}


def avg_dict_of_lists(d: dict[str, list[float]]) -> dict[str, float]:
    """Averages a dictionary of lists."""
    return {k: np.mean(v).item() for k, v in d.items()}
