import typing

import numpy as np
import numpy.typing as npt

Matrix = npt.NDArray


def compute_delta_and_L_and_mu(matrices, center_matrix, dim):
    avg_matrix = np.zeros((dim, dim))
    n_matrices = len(matrices)
    delta = 0
    L = np.max([np.linalg.norm(mtx, ord=2) for mtx in matrices])
    mu = np.min([np.linalg.norm(mtx, ord=-2) for mtx in matrices])
    for m in range(n_matrices):
        L_m = np.linalg.norm(matrices[m], ord=2)
        mu_m = np.linalg.norm(matrices[m], ord=-2)
        delta_m_s = np.linalg.norm(matrices[m] - center_matrix, ord=2)
        delta_m_s_2 = np.linalg.norm(center_matrix - matrices[m], ord=2)
        delta = np.max([delta, delta_m_s, delta_m_s_2])
    return delta, L, mu


def compute_center_matrix(matrices, dim):
    center_mtx = np.zeros((dim, dim))
    n_matrices = len(matrices)
    for m in range(n_matrices):
        center_mtx += matrices[m] / (n_matrices)
    return center_mtx


def compute_scaffold_conditioning_mtx(matrices, center_matrix, dim):
    cond_mtx = np.zeros((dim, dim))
    eye_mtx = np.identity(dim)
    n_matrices = len(matrices)
    inv_matrices = [np.linalg.inv(mtx) for mtx in matrices]
    for m in range(n_matrices):
        cond_mtx += np.matmul(inv_matrices[m], center_matrix) / n_matrices
    return cond_mtx


def make_symmetric(x):
    return (x + x.T) / 2


def generate_delta_related_matrices(
    num_matrices: int, dimension: int, mu: float, L: float, delta: float
):
    d = dimension
    M = num_matrices
    L_smoothness = L
    matrices = []
    base_mtx = make_symmetric(np.random.rand(d, d))
    base_mtx = base_mtx * L_smoothness / np.linalg.norm(base_mtx, ord=2)
    max_eigh = np.linalg.norm(base_mtx, ord=2)
    # print("base_mtx has max eigh {}".format(max_eigh))
    for i in range(M):
        noise_mtx = make_symmetric(np.random.rand(d, d))
        noise_mtx = noise_mtx * delta / np.sqrt(d)
        # print("noise_mtx has max eigh {}".format(np.linalg.norm(noise_mtx, ord=2)))
        Ai = base_mtx + noise_mtx
        Ai = Ai + (abs(np.min(np.linalg.eigvals(Ai))) + mu) * np.identity(d)
        matrices.append(Ai)
    center_mtx = compute_center_matrix(matrices, d)
    delta, L, mu = compute_delta_and_L_and_mu(
        matrices,
        center_mtx,
        d,
    )
    print("L = {}, delta = {}, mu = {}".format(L, delta, mu))
    return center_mtx, matrices, L, delta, mu
