"""Small shared linear-algebra helpers for the example SDPs."""

from __future__ import annotations

import numpy as np


def unit(dim: int, index: int) -> np.ndarray:
    vector = np.zeros(dim)
    vector[index] = 1.0
    return vector


def inner_matrix(left: np.ndarray, right: np.ndarray) -> np.ndarray:
    return 0.5 * (np.outer(left, right) + np.outer(right, left))


def square_matrix(vector: np.ndarray) -> np.ndarray:
    return np.outer(vector, vector)


def smooth_strongly_convex_interpolation_matrix(
    xi: np.ndarray,
    gi: np.ndarray,
    xj: np.ndarray,
    gj: np.ndarray,
    L: float,
    mu: float = 0.0,
) -> np.ndarray:
    """Quadratic part of the smooth strongly-convex interpolation inequality."""

    displacement = xi - xj
    grad_delta = gi - gj
    matrix = inner_matrix(gj, displacement) + (0.5 / L) * square_matrix(grad_delta)
    if mu:
        shifted = displacement - grad_delta / L
        matrix = matrix + (mu / (2.0 * (1.0 - mu / L))) * square_matrix(shifted)
    return matrix


def symmetrize(matrix: np.ndarray) -> np.ndarray:
    matrix = np.asarray(matrix, dtype=float)
    return 0.5 * (matrix + matrix.T)
