"""Preprocessing utilities.

Type transformation
normalization, and noise addition."""

from scipy.interpolate import LinearNDInterpolator, NearestNDInterpolator
import numpy as np
from scipy.optimize import differential_evolution
from torch import Tensor
import torch
from einops import repeat
from utils.types import FloatListOrNested, NestedFloatList, FloatListOrNestedOrTensor
from typing import Tuple, Optional, List

SIGMA = 0.0


def tuple_list_to_nested_list(list_of_tuples: List[tuple]) -> NestedFloatList:
    """N x [tuple] -> N x [list]."""
    return [list(t) if not isinstance(t, list) else t for t in list_of_tuples]


def make_range_nested_list(
    range_list: FloatListOrNested, num_dim: int
) -> NestedFloatList:
    """Make nested list of ranges for `num_dim` dimensions.

    Args:
        range_list: max_dim x [min, max] or [min, max]
        num_dim: number of dimensions to create ranges for

    Returns: num_dim x [min, max]
    """
    if isinstance(range_list[0], (int, float)):
        # Single list: repeat `num_dim` times
        range_list = [range_list for _ in range(num_dim)]
    else:
        # Nested list: first `num_dim`
        if len(range_list) < num_dim:
            raise ValueError(
                f"Expected at least {num_dim} ranges, got {len(range_list)}."
            )
        range_list = range_list[:num_dim]

    return range_list


def make_range_tensor(range_list: FloatListOrNestedOrTensor, num_dim: int) -> Tensor:
    """`range_list` -> Tensor [num_dim, 2] with float32 dtype."""
    if not isinstance(range_list, Tensor):
        range_list = make_range_nested_list(range_list, num_dim)
        range_list = torch.tensor(range_list)

    # NOTE Explicitly convert to float32 for safety
    return range_list.float()


def add_gaussian_noise(data: Tensor, sigma: float) -> Tensor:
    return data + torch.rand_like(data) * sigma


def min_max_normalize(data: Tensor, mins: Tensor, maxs: Tensor) -> Tensor:
    """Normalize data to [0, 1] using min-max scaling.
    Args:
        data [B, N, D], mins [B, D] or [D], maxs [B, D] or [D]
    Returns:
        data_norm [B, N, D]
    """
    assert data.ndim == 3, "Data must be 3D tensor [B, N, D]."
    assert mins.ndim in [1, 2], "Mins must be shaped [B, D] or [D]."
    assert maxs.ndim in [1, 2], "Maxs must be shaped [B, D] or [D]."

    B, _, D = data.shape
    assert mins.shape[-1] == D and maxs.shape[-1] == D

    # Prepare mins and maxs: [B, 1, D]
    if mins.ndim == 1:
        mins = repeat(mins, "d -> b d", b=B)
    if maxs.ndim == 1:
        maxs = repeat(maxs, "d -> b d", b=B)

    mins, maxs = mins.unsqueeze(1), maxs.unsqueeze(1)

    # Avoid division by zero when max == min by replacing with ones
    norm_term = maxs - mins
    norm_term = torch.where(norm_term == 0, torch.ones_like(norm_term), norm_term)

    data_norm = (data - mins) / norm_term
    return data_norm


def min_max_normalize_with_noise(
    data: Tensor, mins: Tensor, maxs: Tensor, sigma: Optional[float] = SIGMA
) -> Tensor:
    """Normalize data to [0, 1] using min-max scaling and add Gaussian noise.

    Args:
        data: shape [B, N, D] or [N, D]
        mins: minimum values, shape [B, D] or [D]
        maxs: maximum values, shape [B, D] or [D]
        sigma: noise level, default 0.0

    Returns:
        data_norm_noised: [B, N, D] or [N, D]
    """
    reduce_batch_dim = data.ndim == 2
    if reduce_batch_dim:
        assert mins.ndim == 1 and maxs.ndim == 1
        data = data.unsqueeze(0)  # [1, N, D]

    # Normalize then add noise: [B, N, D]
    data_norm = min_max_normalize(data, mins, maxs)
    data_norm_noised = add_gaussian_noise(data_norm, sigma=sigma)

    if reduce_batch_dim:
        data_norm_noised = data_norm_noised.squeeze(0)

    return data_norm_noised


def min_max_scale(
    data: Tensor,
    mins: Tensor,
    maxs: Tensor,
    sigma: float = SIGMA,
    target_bounds: Optional[Tensor] = None,
) -> Tensor:
    """Min-max normalize data and optionally scale to target bounds.

    Args:
        data [..., D], mins [B, D] or [D], maxs [B, D] or [D], sigma = 0.0, domain [2] or [B, 2]

    Returns:
        data_norm_noised: [..., D]
    """
    # Normalize and add noise
    data_norm_noised = min_max_normalize_with_noise(data, mins, maxs, sigma)

    # Scale to target bounds if provided
    if target_bounds is not None:
        target_bounds = target_bounds.to(data)
        if target_bounds.ndim == 1:
            target_bounds = target_bounds.unsqueeze(0)

        scale_factor = target_bounds[..., 1] - target_bounds[..., 0]
        data_norm_noised = target_bounds[..., 0] + data_norm_noised * scale_factor

    return data_norm_noised


def scale_between_bounds(
    data: Tensor,
    inp_bounds: FloatListOrNestedOrTensor,
    out_bounds: FloatListOrNestedOrTensor,
    sigma: float = 0.0,
):
    """Scale data from inp_bounds to out_bounds with optional noise.

    Args:
        data: [..., D]
        inp_bounds: Input bounds, tensor or list of [D, 2] or [2]
        out_bounds: Output bounds, tensor or list of [D, 2] or [2]
        sigma: Optional noise level, default 0.0

    Returns:
        scaled_data: [..., D]
    """
    dim = data.shape[-1]
    device = data.device

    # Prepare bound tensor
    inp_bounds = make_range_tensor(inp_bounds, num_dim=dim).to(device)
    out_bounds = make_range_tensor(out_bounds, num_dim=dim).to(device)

    if torch.allclose(inp_bounds, out_bounds):
        # No scaling needed
        scaled_data = data
    else:
        # Scale data from input_bounds to out_bounds
        inp_bounds = inp_bounds.to(data)
        out_bounds = out_bounds.to(data)

        mins = inp_bounds[:, 0]  # [D]
        maxs = inp_bounds[:, 1]  # [D]

        # Scale from input_bounds to x_bounds
        scaled_data = min_max_scale(
            data=data, mins=mins, maxs=maxs, sigma=sigma, target_bounds=out_bounds
        )

    return scaled_data


def fill_nan():
    """Fill NaN values in a tensor."""
    raise NotImplementedError


class OneDInterpolatorExt:
    """Interpolator on 1-dimensional data points.

    Args:
        points (np.ndarray): shape [N, 1]
        values (np.ndarray): shape [N, ...]
    """

    def __init__(self, points: np.ndarray, values: np.ndarray):
        assert points.shape[-1] == 1, "1-dimensional points expected."

        # Points should be strictly increasing, otherwise the results would be meaningless
        points = points.flatten()  # [N, ]
        sort_indices = np.argsort(points)
        self.points = points[sort_indices]
        self.values = values[sort_indices]

        assert np.all(np.diff(self.points) > 0), "Points must be strictly increasing."

        self.points_min = np.min(self.points).item()
        self.points_max = np.max(self.points).item()

    def __call__(self, x_new: np.ndarray) -> np.ndarray:
        """Interpolate at inputs.

        Args:
            inputs: shape [N,] or [N, 1]
        Returns:
            t: Interpolated values, shape [N, DY]
        """
        y_dim = self.values.shape[1]
        values_interp = []
        for d in range(y_dim):
            vd_interp = np.interp(
                x_new,
                self.points,
                self.values[:, d],
                left=self.values[0, d],
                right=self.values[-1, d],
            ).flatten()
            values_interp.append(vd_interp)

        values_interp = np.stack(values_interp, axis=-1)  # [N, DY]
        return values_interp


class NDInterpolatorExt:
    """interpolator on N-D data points.
    ref: https://github.com/NYCU-RL-Bandits-Lab/BOFormer/blob/main/Environment/benchmark_functions.py

    Args:
        points (np.ndarray): shape [N, D]
        values (np.ndarray): shape [N, ...]
    """

    def __init__(self, points: np.ndarray, values: np.ndarray):
        assert points.shape[-1] > 1, "Points must have at least 2 dimensions."

        self.funcinterp = LinearNDInterpolator(points, values)
        self.funcnearest = NearestNDInterpolator(points, values)

    def __call__(self, *args) -> np.ndarray:
        """Interpolate at inputs.

        Args:
            inputs:DX coordinates or `[N, DX]` array for N points.
        Returns:
            t (np.ndarray): Interpolated values, shape [N, DY]
        """
        t = self.funcinterp(*args)

        if np.isscalar(t):
            if np.isnan(t):
                return self.funcnearest(*args)
            else:
                return t
        else:
            if np.any(np.isnan(t)):
                return self.funcnearest(*args)
            else:
                return t


class BatchNDInterpolatorExtTensor:
    """Wrapper for NDInterpolatorExt to handle batch tensor inputs.

    Args:
        points (np.ndarray): shape [N, D]
        values (np.ndarray): shape [N, ...]
    """

    def __init__(self, points: Tensor, values: Tensor):
        points_np = points.double().detach().cpu().numpy()
        values_np = values.double().detach().cpu().numpy()

        if points_np.shape[-1] == 1:
            self.interpolator = OneDInterpolatorExt(points_np, values_np)
        else:
            self.interpolator = NDInterpolatorExt(points_np, values_np)

    def __call__(self, x: Tensor) -> Tensor:
        """Batch call for interpolation.

        Args:
            x (Tensor): Input points, shape [DX], [N, DX] or [B, N, DX]

        Returns:
            y (Tensor): Interpolated values, shape [N, DY] or [B, N, DY]
        """
        x_np = x.double().detach().cpu().numpy()

        if x_np.ndim <= 2:
            # No batch dimension
            y_np = self.interpolator(x_np)
            assert y_np.ndim == 2
        elif x_np.ndim == 3:
            # Flatten batch dim, interpolate, then reshape
            B, N, DX = x_np.shape
            x_flat = x_np.reshape(-1, DX)
            y_flat = self.interpolator(x_flat)
            y_np = y_flat.reshape(B, N, -1)
        else:
            raise ValueError(f"Invalid input shape {x_np.shape}")

        y = torch.from_numpy(y_np).to(x)
        return y


def _scipy_adapted_function(function, x_np: np.ndarray) -> np.ndarray:
    """Adapt torch-based function for np input / output for use with SciPy optimization."""
    x_tensor = torch.tensor(x_np, dtype=torch.float64)
    result = function(x_tensor)
    return result.detach().cpu().numpy()


def _minimize_with_differential_evolution(
    function, bounds: NestedFloatList
) -> Tuple[float, List[float]]:
    """Minimize a bounded function with differential evolution."""
    wrapped_function = lambda x: _scipy_adapted_function(function, x)

    # Find global minimum
    result = differential_evolution(wrapped_function, bounds=bounds)

    min_value = result.fun
    min_point = result.x

    return min_value, min_point


def optimize_with_differential_evolution(
    func, bounds: NestedFloatList, minimize: bool
) -> Tuple[float, List[float]]:
    """Optimize a bounded function with differential evolution.

    Args:
        func: Torch-based function
        bounds: Input bounds for the function, DX x [[x_min, x_max]]
        minimize: If True, minimize the function; if False, maximize it

    Returns:
        min_value / max_value: Optimum value
        min_point / max_point: Optimum location
    """
    if minimize:
        return _minimize_with_differential_evolution(func, bounds)
    else:
        # Find minimum of the negated function
        func_negated = lambda x: -func(x)

        max_value_negated, max_point = _minimize_with_differential_evolution(
            func_negated, bounds
        )

        # Flip the sign back for maximum of the original function
        max_value = -max_value_negated

        return max_value, max_point


def estimate_objective_bounds(
    func: callable, num_objectives: int, x_bounds: NestedFloatList
) -> NestedFloatList:
    """Estimate output bounds for each objective function over the input bounds.

    Args:
        func: x -> objectives
        num_objectives: Number of objectives in the function
        x_bounds: Input bounds for the function, DX x [[x_min, x_max]]

    Returns:
        y_bounds: DY x [[y_min, y_max]]
    """
    y_bounds = []

    for i in range(num_objectives):
        # Extract i-th objective
        func_i = lambda x, i=i: func(x).view(-1, num_objectives)[..., [i]]

        # Find min and max
        min_i, _ = optimize_with_differential_evolution(
            func=func_i, bounds=x_bounds, minimize=True
        )
        max_i, _ = optimize_with_differential_evolution(
            func=func_i, bounds=x_bounds, minimize=False
        )

        y_bounds.append([min_i, max_i])

    return y_bounds

