from logging import Logger
import numpy as np
from numpy.typing import ArrayLike
from math import sqrt
from typing import Union, TypedDict
from scipy.spatial.distance import pdist
from scipy.linalg import toeplitz
import torch
from torch import Tensor
import matplotlib.pyplot as plt


class EigResult(TypedDict):
    values: Union[ArrayLike, Tensor] # it corresponds to the log of eigenvalues
    left: Union[Union[ArrayLike, Tensor],None]
    right: Union[ArrayLike, Tensor]
    bias: Union[ArrayLike, Tensor]

class ModeResult(TypedDict):
    decay_rates : Union[ArrayLike, Tensor]
    frequencies : Union[ArrayLike, Tensor]
    modes : Union[ArrayLike, Tensor]

def topk(vec: ArrayLike, k: int):
    assert np.ndim(vec) == 1, "'vec' must be a 1D array"
    assert k > 0, "k should be greater than 0"
    sort_perm = np.flip(np.argsort(vec))  # descending order
    indices = sort_perm[:k]
    values = vec[indices]
    return values, indices


def fuzzy_parse_complex(vec: ArrayLike, tol: float = 10.0):
    assert issubclass(
        vec.dtype.type, np.complexfloating
    ), "The input element should be complex"
    rcond = tol * np.finfo(vec.dtype).eps
    pdist_real_part = pdist(vec.real[:, None])
    # Set the same element whenever pdist is smaller than eps*tol
    condensed_idxs = np.argwhere(pdist_real_part < rcond)[:, 0]
    fuzzy_real = vec.real.copy()
    if condensed_idxs.shape[0] >= 1:
        for idx in condensed_idxs:
            i, j = row_col_from_condensed_index(vec.real.shape[0], idx)
            avg = 0.5 * (fuzzy_real[i] + fuzzy_real[j])
            fuzzy_real[i] = avg
            fuzzy_real[j] = avg
    fuzzy_imag = vec.imag.copy()
    fuzzy_imag[np.abs(fuzzy_imag) < rcond] = 0.0
    return fuzzy_real + 1j * fuzzy_imag

def add_diagonal_(M: ArrayLike, alpha: float):
    """
    Add alpha to the diagonal of a matrix M in-place.

    Parameters
    ----------
    M : ArrayLike
        The matrix to modify.
    alpha : float
        The value to add to the diagonal of M.
    """
    np.fill_diagonal(M, M.diagonal() + alpha)


def rank_reveal(
    values: np.ndarray,
    rank: int,  # Desired rank
    rcond: Union[float, None] = None,  # Threshold for the singular values
    ignore_warnings: bool = True,
):
    if rcond is None:
        rcond = 10.0 * values.shape[0] * np.finfo(values.dtype).eps

    top_values, top_idxs = topk(values, rank)

    if all(top_values > rcond):
        return top_idxs
    else:
        valid = top_values > rcond
        # In the case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
        first_invalid = np.argmax(np.logical_not(valid))
        _first_discarded_val = np.max(np.abs(values[first_invalid:]))

        if not ignore_warnings:
            print(f"Warning: Discarted {rank - values.shape[0]} dimensions of the {rank} requested due to numerical instability. Consider decreasing the rank. The largest discarded value is: {_first_discarded_val:.3e}.")
        return top_idxs[valid]

def weighted_norm(A: ArrayLike, M: Union[ArrayLike, None] = None):
    r"""Weighted norm of the columns of A.

    Args:
        A (ndarray): 1D or 2D array. If 2D, the columns are treated as vectors.
        M (ndarray or LinearOperator, optional): Weigthing matrix. the norm of the vector :math:`a` is given by
        :math:`\langle a, Ma\rangle`. Defaults to None, corresponding to the Identity matrix. Warning: no checks are
        performed on M being a PSD operator.

    Returns:
        (ndarray or float): If ``A.ndim == 2`` returns 1D array of floats corresponding to the norms of
        the columns of A. Else return a float.
    """
    assert A.ndim <= 2, "'A' must be a vector or a 2D array"
    if M is None:
        norm = np.linalg.norm(A, axis=0)
    else:
        _A = np.dot(M, A)
        _A_T = np.dot(M.T, A)
        norm = np.real(np.sum(0.5 * (np.conj(A) * _A + np.conj(A) * _A_T), axis=0))
    rcond = 10.0 * A.shape[0] * np.finfo(A.dtype).eps
    norm = np.where(norm < rcond, 0.0, norm)
    return np.sqrt(norm)

def toeplitz_integrator(exp_decay: float, npts: int, context_length: int, dt: float = 1., 
                        symmetric: bool = False
                        )->ArrayLike:
    if context_length==0:
        toep = np.diag(npts*np.ones(npts-1)/(npts-1),1)
    else:     
        tau_ = (npts* dt* np.exp(-np.arange(0, context_length) * dt * exp_decay)/(npts - np.arange(0, context_length)))
        tau_[0] *= 0.5
        tau_[-1] *= 0.5
        tau = np.concatenate((tau_, np.zeros(npts - context_length)))
        toep = toeplitz(tau, np.zeros(npts)).T
    if symmetric:
        toep += toep.conj().T
        toep /=2
    return toep

def row_col_from_condensed_index(d, index):
    # Credits to: https://stackoverflow.com/a/14839010
    b = 1 - (2 * d)
    i = (-b - sqrt(b**2 - 8 * index)) // 2
    j = index + i * (b + i + 2) // 2 + 1
    return (int(i), int(j))

def tonp(x):
    return x.detach().cpu().numpy()

def frnp(x, device=None):
    return torch.Tensor(x).to(device)

def sqrtmh(A: torch.Tensor):
    L, Q = torch.linalg.eigh(A)
    zero = torch.zeros((), device=L.device, dtype=L.dtype)
    threshold = L.max(-1).values * L.size(-1) * torch.finfo(L.dtype).eps
    L = L.where(L > threshold.unsqueeze(-1), zero)  # zero out small components
    return (Q * L.sqrt().unsqueeze(-2)) @ Q.mH

def primal_left(e,Z): 
    return Z.T@e["left"]/np.sqrt(Z.shape[0])