import numpy as __np
import time
from typing import Tuple, Optional
import torch

def _normalize_device(d) -> torch.device:
    if isinstance(d, torch.device):
        return d
    if isinstance(d, str):
        return torch.device(d)
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")

def fit(
  D,
  lambda_=None,
  epsilon1=1e-7,
  epsilon2=1e-5,
  mu=None,
  rho=1.6,
  max_iter=1000,
  verbose=True,
):
  '''
  Inexact augmented Lagrange multiplier method for robust PCA.

  Parameters
  ----------
  D : np.ndarray
    `m` x `n` matrix of observations/data.
  lambda_ : float, default=1 / np.sqrt(m)
    Weight on sparse error term in the cost function.
  epsilon1 : float, default=1e-7
    Tolerance for stopping criterion.
  epsilon2 : float, default=1e-5
    Tolerance for the other stopping criterion.
  mu : float, default=1.25 / norm_two
    Small positive scalar.
  rho : float, default=1.6
    `mu` update parameter, which should be greater than 1.
  max_iter : int, default=1000
    Maximum number of iterations.
  verbose : bool, default=True
    Printing verbose messages.

  Returns
  -------
  A : np.ndarray
    Low rank matrix.
  E : np.ndarray
    Sparse matrix.

  References
  ----------
  The function implements Algorithm 5 of [1]_.

  .. [1] Lin, Zhouchen, Minming Chen, and Yi Ma. "The Augmented Lagrange
     Multiplier Method for Exact Recovery of Corrupted Low-Rank Matrices."
  '''
  m, n = D.shape
  if lambda_ is None:
    lambda_ = 1 / __np.sqrt(m)

  D = __np.float64(D)
  Y = __np.copy(D)
  norm_two = __np.linalg.norm(Y, 2)
  norm_inf = __np.linalg.norm(Y, __np.inf) / lambda_
  dual_norm = __np.max([norm_two, norm_inf])
  Y = Y / dual_norm

  A = __np.zeros_like(D)
  E = __np.zeros_like(D)
  d_norm = __np.linalg.norm(D, 'fro')
  tol_proj = epsilon2 * d_norm

  if mu is None:
    mu = 1.25 / norm_two

  iter_ = 0
  converged = False
  stop_criterion = True
  sv = 10

  if verbose:
    print(
      'lambda_',
      lambda_,
      'epsilon1',
      epsilon1,
      'epsilon2',
      epsilon2,
      'mu',
      mu,
      'rho',
      rho,
      'max_iter',
      max_iter,
      sep='\t',
    )

  while not converged:
    iter_ += 1

    temp_T = D - A + (1 / mu) * Y
    temp_E = (
      __np.maximum(temp_T - lambda_ / mu, 0) +
      __np.minimum(temp_T + lambda_ / mu, 0)
    )

    U, S, V = __np.linalg.svd(D - temp_E + (1 / mu) * Y, full_matrices=False)

    svp = __np.count_nonzero(S > 1 / mu)

    if svp < sv:
      sv = __np.min([svp + 1, n])
    else:
      sv = __np.min([svp + round(.05 * n), n])

    A = __np.dot(__np.dot(U[:, :svp], __np.diag(S[:svp] - 1 / mu)), V[:svp, :])

    Z = D - A - temp_E
    Y += mu * Z
    if mu * __np.linalg.norm(E - temp_E, 'fro') < tol_proj:
      mu *= rho
      converged = True

    E = temp_E

    stop_criterion = __np.linalg.norm(Z, 'fro') / d_norm

    converged = converged and stop_criterion < epsilon1

    if verbose:
      print(
        '#svd',
        iter_,
        'r(A)',
        __np.linalg.matrix_rank(A),
        '|E|_0',
        __np.count_nonzero(E),
        'stopCriterion',
        stop_criterion,
        sep='\t',
      )

    if not converged and iter_ >= max_iter:
      if verbose:
        print('max iter reached')
      break

  return A, E


@torch.no_grad()
def _shrink(T: torch.Tensor, tau: float) -> torch.Tensor:
    # soft-thresholding: sign(T) * max(|T|-tau, 0)
    return torch.sign(T) * torch.clamp(T.abs() - tau, min=0.0)

@torch.no_grad()
def fit_torch(
    D,
    lambda_: Optional[float] = None,
    epsilon1: float = 1e-7,
    epsilon2: float = 1e-5,
    mu: Optional[float] = None,
    rho: float = 1.6,
    max_iter: int = 1000,
    verbose: bool = True,
    device: Optional[str] = None,
    dtype: torch.dtype = torch.float32,   
    approx_svd: bool = False,             
    approx_rank: Optional[int] = None,    
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Torch 版本的 IALM RPCA。返回 (A, E)。
    """
    dev = _normalize_device(device) if device is not None else _normalize_device(None)
    D = torch.as_tensor(D, dtype=dtype, device=dev)
    m, n = D.shape

    if dev.type == "cuda":
        try:
            torch.backends.cuda.matmul.allow_tf32 = True
            torch.set_float32_matmul_precision("high")
        except Exception:
            pass

    # lambda_
    if isinstance(lambda_, str):
        try:
            lambda_ = float(lambda_.strip())
        except ValueError:
            raise ValueError(f"lambda_ must be numeric, got {lambda_!r}")
    if lambda_ is None or (isinstance(lambda_, (int, float)) and lambda_ <= 0):
        lambda_ = 1.0 / (m ** 0.5)
    lam_t = torch.tensor(lambda_, dtype=dtype, device=dev)

    Y = D.clone()
    norm_two = torch.linalg.svdvals(Y).max()
    norm_inf = Y.abs().sum(dim=1).max() / lam_t
    dual_norm = torch.max(norm_two, norm_inf)
    Y = Y / dual_norm

    A = torch.zeros_like(D)
    E = torch.zeros_like(D)
    d_norm = torch.linalg.matrix_norm(D, ord='fro')
    tol_proj = epsilon2 * d_norm

    if mu is None:
        mu = (1.25 / norm_two).item() if isinstance(norm_two, torch.Tensor) else 1.25 / norm_two

    iter_ = 0
    converged = False

    t_svd = t_gemm = t_norm = t_misc = 0.0
    def _sync():
        if dev.type == "cuda":
            torch.cuda.synchronize()

    if verbose:
        print(
            'lambda_', lambda_,
            'epsilon1', epsilon1,
            'epsilon2', epsilon2,
            'mu', mu,
            'rho', rho,
            'max_iter', max_iter,
            'dtype', str(dtype),
            sep='\t',
        )

    while not converged and iter_ < max_iter:
        iter_ += 1

        _sync(); ts = time.perf_counter()
        temp_T = D - A + (1.0 / mu) * Y
        temp_E = _shrink(temp_T, float(lambda_) / mu)
        t_misc += time.perf_counter() - ts

        _sync(); ts = time.perf_counter()
        M = D - temp_E + (1.0 / mu) * Y
        if not approx_svd:
            U, S, Vh = torch.linalg.svd(M, full_matrices=False)
        else:
            k = min(approx_rank or 64, min(m, n))
            U, S, V = torch.svd_lowrank(M, q=k)
            Vh = V.T
        t_svd += time.perf_counter() - ts

        svp = int((S > (1.0 / mu)).sum().item())
        _sync(); ts = time.perf_counter()
        if svp > 0:
            S_shrink = S[:svp] - (1.0 / mu)
            A = (U[:, :svp] * S_shrink) @ Vh[:svp, :]
        else:
            A.zero_()
        t_gemm += time.perf_counter() - ts

        _sync(); ts = time.perf_counter()
        Z = D - A - temp_E
        Y = Y + mu * Z
        if mu * torch.linalg.matrix_norm(E - temp_E, ord='fro') < tol_proj:
            mu *= rho
            flag_proj = True
        else:
            flag_proj = False

        E = temp_E
        stop_criterion = (torch.linalg.matrix_norm(Z, ord='fro') / d_norm).item()
        converged = flag_proj and (stop_criterion < epsilon1)
        t_norm += time.perf_counter() - ts

        if verbose:
            print(
                '#svd', iter_,
                'r(A)', int((S > 0).sum().item()) if S.numel() > 0 else 0,
                '|E|_0', int((E != 0).sum().item()),
                'stopCriterion', stop_criterion,
                sep='\t',
            )

    if verbose:
        total = t_svd + t_gemm + t_norm + t_misc
        if total > 0:
            print(f"[Timing] SVD: {t_svd:.3f}s ({t_svd/total:.1%}) | GEMM: {t_gemm:.3f}s ({t_gemm/total:.1%}) | Norms: {t_norm:.3f}s ({t_norm/total:.1%}) | Misc: {t_misc:.3f}s ({t_misc/total:.1%})")

    return A, E
