# Copyright (c) Meta Platforms, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import os
from collections.abc import Iterator
from typing import Any

import numpy as np

import pandas as pd

from hydra import compose, initialize, initialize_config_dir

from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torch import nn


def instantiate_optimizer_and_scheduler(
    params: Iterator[nn.Parameter],
    optimizer_config: DictConfig,
    lr_scheduler_config: DictConfig | None,
) -> dict[str, Any]:
    optimizer = instantiate(optimizer_config, params)
    out = {"optimizer": optimizer}

    if lr_scheduler_config is not None:
        scheduler = instantiate(lr_scheduler_config.scheduler, optimizer)
        lr_scheduler = instantiate(lr_scheduler_config, scheduler=scheduler)
        out["lr_scheduler"] = OmegaConf.to_container(lr_scheduler)
    return out


def generate_hydra_config_from_overrides(
    config_path: str = "../config",
    version_base: str | None = None,
    config_name: str = "base",
    overrides: list[str] | None = None,
) -> DictConfig:

    if overrides is None:
        overrides = []

    if os.path.isabs(config_path):
        with initialize_config_dir(config_dir=config_path, version_base=version_base):
            config = compose(config_name=config_name, overrides=overrides)
    else:
        with initialize(config_path=config_path, version_base=version_base):
            config = compose(config_name=config_name, overrides=overrides)

    return config


def load_splits(
    metadata_file: str,
    subsample: float = 1.0,
    random_seed: int = 0,
) -> dict[str, list[str]]:
    """Load train, val, and test datasets from metadata csv."""

    # Load dataframe
    df = pd.read_csv(metadata_file)

    # Optionally subsample
    df = df.groupby("split").apply(
        lambda x: x.sample(frac=subsample, random_state=random_seed)
    )
    df.reset_index(drop=True, inplace=True)

    # Format as dictionary
    splits = {}
    for split, df_ in df.groupby("split"):
        splits[split] = list(df_.filename)

    return splits


def get_contiguous_ones(binary_vector: np.ndarray) -> list[tuple[int, int]]:
    """Get a list of (start_idx, end_idx) for each contiguous block of True values."""
    if (binary_vector == 0).all():
        return []

    ones = np.where(binary_vector)[0]
    boundaries = np.where(np.diff(ones) != 1)[0]
    return [
        (ones[i], ones[j])
        for i, j in zip(
            np.insert(boundaries + 1, 0, 0), np.append(boundaries, len(ones) - 1)
        )
    ]


def get_ik_failures_mask(joint_angles: np.ndarray) -> np.ndarray:
    """Compute mask that is True where there are no ik failures."""
    zeros = np.zeros_like(joint_angles)  # (..., joint)
    is_zero = np.isclose(joint_angles, zeros)
    return ~np.all(is_zero, axis=-1)
