
from __future__ import annotations
import torch
from torch import Tensor
from typing import List, Optional, Sequence, Tuple, Union


# ==============================
# Utilities
# ==============================

def col_norms(M: Tensor) -> Tensor:
    # 2-norm of each column (shape: (n,))
    return torch.linalg.norm(M, dim=0)

def soft_threshold(x: Tensor, lam: Tensor | float) -> Tensor:
    # elementwise soft-threshold: sign(x) * max(|x|-lam, 0)
    return torch.sign(x) * torch.clamp(torch.abs(x) - lam, min=0.0)

def unit_norm(v: Tensor, dim: int = 0, eps: float = 1e-12) -> Tensor:
    n = torch.linalg.norm(v, dim=dim, keepdim=True).clamp_min(eps)
    return v / n

def orthonormalize_via_svd(G: Tensor) -> Tensor:
    # Procrustes step: G = U S V^T -> X = U V^T (columns orthonormal)
    U, _, Vh = torch.linalg.svd(G, full_matrices=False)
    return U @ Vh

def torch_qr_init(A: Tensor, m: int) -> Tuple[Tensor, float]:
    # MATLAB: [x, rho_max] = qr([A(:,imax)/norm, randn(p,m-1)],0)
    p, n = A.shape
    norms = col_norms(A)
    imax = torch.argmax(norms).item()
    first = A[:, imax] / norms[imax].clamp_min(1e-12)
    rest = torch.randn(p, m - 1, device=A.device, dtype=A.dtype) if m > 1 else A.new_empty(p, 0)
    X0 = torch.cat([first[:, None], rest], dim=1)
    # thin QR
    Q, R = torch.linalg.qr(X0, mode='reduced')
    rho_max = torch.abs(R[0, 0]).item()
    return Q, rho_max

# --- put this small helper near the top of the file ---
def _to_1d_tensor(x, device, dtype, length=None, name="arg"):
    if isinstance(x, torch.Tensor):
        # don't wrap a tensor with torch.tensor(...)
        t = x.detach().to(device=device, dtype=dtype).flatten()
    else:
        # list/tuple/np -> tensor
        t = torch.as_tensor(x, device=device, dtype=dtype).flatten()
    if length is not None and t.numel() != length:
        raise ValueError(f"{name} length mismatch: expected {length}, got {t.numel()}")
    return t



def compute_X(A: torch.Tensor, Z: torch.Tensor, mu: torch.Tensor | None = None) -> torch.Tensor:
    """
    A: (p, n), Z: (n, m), mu: (m,) or None
    returns X: (p, m) with X^T X = I
    """
    if mu is not None:
        M = A @ (Z * mu[None, :])   # AZN
    else:
        M = A @ Z                   # AZ

    U, _, _ = torch.linalg.svd(M, full_matrices=False)
    X = U
    return X


# ==============================
# pattern_filling subroutine
# ==============================

def pattern_filling(
    A: Tensor,               # (p, n)
    P: Tensor,               # (n, m) boolean pattern of sparsity (True means keep)
    Z_init: Optional[Tensor] = None,  # (n, m) optional initialization for values
    mu: Optional[Tensor] = None,      # (m,) optional weights for block case
    iter_max: int = 1000,
    eps: float = 1e-6
) -> Tensor:
    """
    Compute a local maximizer of
        max_{X,Z} trace(X^T A Z N)
    s.t. X^T X = I_m, Z(P^c)=0, and diag(Z^T Z) = I_m
    Following the MATLAB implementation structure.
    """
    p, n = A.shape
    nP_rows, m = P.shape
    assert nP_rows == n, "P must have shape (n, m)"

    device = A.device
    dtype = A.dtype

    if m == 1:
        # Single-unit case
        support = torch.nonzero(P[:, 0], as_tuple=False).flatten()
        z = A.new_zeros(n, 1)
        if support.numel() == 0:
            # MATLAB does somewhat odd fallback; here keep zeros
            return z
        elif support.numel() == 1:
            z[support[0], 0] = 1.0
            return z
        else:
            # local power iteration on reduced A
            if Z_init is None:
                u = torch.randn(support.numel(), device=device, dtype=dtype)
                u = unit_norm(u, dim=0)
            else:
                u = Z_init[support, 0]
                u = unit_norm(u, dim=0)

            A_red = A[:, support]  # (p, |S|)
            f_prev = None
            for it in range(iter_max):
                tmp = A_red.T @ (A_red @ u)   # |S| vector
                u = unit_norm(tmp, dim=0)
                f = (-2.0 * (u @ tmp)).item()
                if f_prev is not None:
                    rel = abs(f - f_prev) / (abs(f_prev) + 1e-12)
                    if rel < eps:
                        break
                f_prev = f
            z[support, 0] = u
            return z

    # Block case (m > 1)
    # Init Z: respect pattern and normalize columns
    if Z_init is None:
        Z = torch.randn(n, m, device=device, dtype=dtype)
        Z[~P] = 0.0
    else:
        Z = Z_init.clone()
        Z[~P] = 0.0

    # normalize columns individually (if non-zero)
    coln = torch.linalg.norm(Z, dim=0)
    for i in range(m):
        if coln[i] > 0:
            Z[:, i] /= coln[i]

    use_mu = (mu is not None)
    if use_mu:
        mu = mu.to(device=device, dtype=dtype).flatten()
        assert mu.numel() == m

    f_prev = None
    for it in range(iter_max):
        AZ = A @ Z  # (p, m)
        if use_mu:
            AZ = AZ * mu[None, :]  # weight columns

        # Procrustes step
        U, _, Vh = torch.linalg.svd(AZ, full_matrices=False)
        X = U @ Vh  # (p, m), X^T X = I

        # objective: sum_i X[:,i]^T AZ[:,i]
        ff = torch.sum(X * AZ).item()

        # Update Z
        Z = A.T @ X  # (n, m)
        if use_mu:
            Z = Z * mu[None, :]
        Z[~P] = 0.0

        # column normalization
        coln = torch.linalg.norm(Z, dim=0)
        for i in range(m):
            if coln[i] > 0:
                Z[:, i] /= coln[i]

        # stopping
        if f_prev is not None:
            rel = abs(ff - f_prev) / (abs(f_prev) + 1e-12)
            if rel < eps:
                break
        f_prev = ff

    return Z


# ==============================
# Main GPower function
# ==============================

def gpower(
    A: Tensor,                      # (p, n)
    RHO: Tensor | list | float,     # sparsity weights; if list/tensor length m
    m: int,
    penalty: str,                   # 'l1' or 'l0'
    block: int,                     # 0 or 1
    mu: Optional[Tensor | list] = None,
    iter_max: int = 1000,
    epsilon: float = 1e-4,
) -> Tensor:
    """
    PyTorch implementation of the MATLAB GPower for Sparse PCA.
    Returns Z of shape (n, m), each column unit-norm.
    Behavior mirrors:
        Z = GPower(A, rho, m, penalty, block, mu)
    """
    assert penalty in ('l1', 'l0')
    assert block in (0, 1)

    device = A.device
    dtype = A.dtype
    p, n = A.shape

    # RHO as tensor (m,)
    # new (no warnings)
    if isinstance(RHO, (float, int)):
        RHO = torch.full((m,), float(RHO), device=device, dtype=dtype)
    else:
        RHO = _to_1d_tensor(RHO, device, dtype, length=m, name="RHO")

    if block == 1:
        if mu is None:
            raise ValueError("Block algorithm requires mu.")
        mu = _to_1d_tensor(mu, device, dtype, length=m, name="mu")

    Z = torch.zeros(n, m, device=device, dtype=dtype)
    A_work = A.clone()

    with torch.no_grad():
        # =========================================
        # Single-unit algorithm (deflation)
        # =========================================
        if (m == 1) or (m > 1 and block == 0):
            if penalty == 'l1':
                for comp in range(m):
                    rho = RHO[comp]
                    norms = col_norms(A_work)  # (n,)
                    rho_max, i_max = torch.max(norms, dim=0)
                    rho_scaled = rho * rho_max

                    x = A_work[:, i_max] / rho_max.clamp_min(1e-12)  # init (p,)
                    x = unit_norm(x, dim=0)

                    f_prev = None
                    for it in range(iter_max):
                        Ax = A_work.T @ x  # (n,)
                        tresh = soft_threshold(Ax, rho_scaled)  # (n,)
                        f = torch.sum(tresh * tresh).item()
                        if f == 0.0:
                            break
                        grad = A_work @ tresh  # (p,)
                        gnorm = torch.linalg.norm(grad).item()
                        if gnorm < 1e-12:
                            break
                        x = grad / gnorm
                        if f_prev is not None:
                            rel = (f - f_prev) / (abs(f_prev) + 1e-12)
                            if rel < epsilon:
                                break
                        f_prev = f

                    Ax = A_work.T @ x
                    pattern = (torch.abs(Ax) - rho_scaled) > 0  # (n,)
                    z = soft_threshold(Ax, rho_scaled)
                    nz = torch.linalg.norm(z).item()
                    if nz > 0:
                        z = z / nz
                    # polish values on pattern
                    z = pattern_filling(A_work, pattern[:, None], z[:, None]).squeeze(1)

                    y = A_work @ z  # (p,)
                    # deflation
                    A_work = A_work - torch.outer(y, z)
                    Z[:, comp] = z

            else:  # 'l0'
                for comp in range(m):
                    rho = RHO[comp]
                    norms = col_norms(A_work)
                    rho_max, i_max = torch.max(norms, dim=0)
                    rho_scaled = rho * (rho_max ** 2)

                    x = A_work[:, i_max] / rho_max.clamp_min(1e-12)
                    x = unit_norm(x, dim=0)

                    f_prev = None
                    for it in range(iter_max):
                        Ax = A_work.T @ x  # (n,)
                        tresh = torch.clamp(Ax * Ax - rho_scaled, min=0.0)  # (n,)
                        f = torch.sum(tresh).item()
                        if f == 0.0:
                            break
                        # grad = A * ((tresh > 0) .* Ax)
                        grad = A_work @ ( (tresh > 0) * Ax )
                        gnorm = torch.linalg.norm(grad).item()
                        if gnorm < 1e-12:
                            break
                        x = grad / gnorm
                        if f_prev is not None:
                            rel = (f - f_prev) / (abs(f_prev) + 1e-12)
                            if rel < epsilon:
                                break
                        f_prev = f

                    # extract pattern and do deflation
                    Ax = A_work.T @ x
                    pattern = (Ax * Ax - rho_scaled) > 0
                    y = x.clone()

                    z = A_work.T @ y
                    z[~pattern] = 0.0
                    nz = torch.linalg.norm(z).item()
                    if nz > 0:
                        z = z / nz
                        y = y * nz
                    A_work = A_work - torch.outer(y, z)
                    Z[:, comp] = z

        # =========================================
        # Block algorithm
        # =========================================
        else:
            # QR init
            x, rho_max = torch_qr_init(A_work, m)  # x: (p, m)

            if penalty == 'l1':
                # equal weights?
                if torch.allclose(mu, torch.ones_like(mu)):
                    RHO_scaled = RHO * rho_max
                    f_prev = None
                    for it in range(iter_max):
                        Ax = A_work.T @ x  # (n, m)
                        # tresh by column with broadcast
                        tresh = torch.clamp(torch.abs(Ax) - RHO_scaled[None, :], min=0.0)
                        f = torch.sum(tresh * tresh).item()
                        if f == 0.0:
                            break

                        # grad per column
                        grad = torch.zeros_like(x)
                        for i in range(m):
                            pattern_i = (tresh[:, i] > 0)
                            if pattern_i.any():
                                grad[:, i] = A_work[:, pattern_i] @ ( tresh[pattern_i, i] * torch.sign(Ax[pattern_i, i]) )
                        x = orthonormalize_via_svd(grad)
                        if f_prev is not None:
                            rel = (f - f_prev) / (abs(f_prev) + 1e-12)
                            if rel < epsilon:
                                break
                        f_prev = f

                    Ax = A_work.T @ x
                    for i in range(m):
                        Zi = soft_threshold(Ax[:, i], RHO_scaled[i])
                        nz = torch.linalg.norm(Zi).item()
                        if nz > 0:
                            Zi = Zi / nz
                        Z[:, i] = Zi
                    pattern = (torch.abs(Ax) - RHO_scaled[None, :]) > 0
                    Z = pattern_filling(A_work, pattern, Z)

                else:
                    # non-equal mu
                    RHO_scaled = RHO * mu * rho_max  # elementwise
                    f_prev = None
                    for it in range(iter_max):
                        Ax = A_work.T @ x  # (n, m)
                        Ax = Ax * mu[None, :]  # weight columns
                        tresh = torch.clamp(torch.abs(Ax) - RHO_scaled[None, :], min=0.0)
                        f = torch.sum(tresh * tresh).item()
                        if f == 0.0:
                            break
                        grad = torch.zeros_like(x)
                        for i in range(m):
                            pattern_i = (tresh[:, i] > 0)
                            if pattern_i.any():
                                grad[:, i] = (A_work[:, pattern_i] @ ( tresh[pattern_i, i] * torch.sign(Ax[pattern_i, i]) )) * mu[i]
                        x = orthonormalize_via_svd(grad)
                        if f_prev is not None:
                            rel = (f - f_prev) / (abs(f_prev) + 1e-12)
                            if rel < epsilon:
                                break
                        f_prev = f

                    Ax = A_work.T @ x
                    Ax = Ax * mu[None, :]
                    Z = torch.zeros(n, m, device=device, dtype=dtype)
                    for i in range(m):
                        Zi = soft_threshold(Ax[:, i], RHO_scaled[i])
                        nz = torch.linalg.norm(Zi).item()
                        if nz > 0:
                            Zi = Zi / nz
                        Z[:, i] = Zi
                    pattern = (torch.abs(Ax) - RHO_scaled[None, :]) > 0
                    Z = pattern_filling(A_work, pattern, Z, mu=mu)

            else:  # 'l0'
                if torch.allclose(mu, torch.ones_like(mu)):
                    RHO_scaled = RHO * (rho_max ** 2)
                    f_prev = None
                    for it in range(iter_max):
                        Ax = A_work.T @ x
                        tresh = torch.clamp(Ax * Ax - RHO_scaled[None, :], min=0.0)
                        f = torch.sum(tresh).item()
                        if f == 0.0:
                            break
                        grad = torch.zeros_like(x)
                        for i in range(m):
                            pattern_i = (tresh[:, i] > 0)
                            if pattern_i.any():
                                grad[:, i] = A_work[:, pattern_i] @ Ax[pattern_i, i]
                        x = orthonormalize_via_svd(grad)
                        if f_prev is not None:
                            rel = (f - f_prev) / (abs(f_prev) + 1e-12)
                            if rel < epsilon:
                                break
                        f_prev = f

                    pattern = ((A_work.T @ x) ** 2 - RHO_scaled[None, :]) > 0
                    Z = A_work.T @ x
                    Z[~pattern] = 0.0
                    # normalize columns
                    coln = torch.linalg.norm(Z, dim=0)
                    for i in range(m):
                        if coln[i] > 0:
                            Z[:, i] /= coln[i]
                else:
                    RHO_scaled = RHO * (mu * rho_max) ** 2
                    f_prev = None
                    for it in range(iter_max):
                        Ax = A_work.T @ x
                        Ax = Ax * mu[None, :]
                        tresh = torch.clamp(Ax * Ax - RHO_scaled[None, :], min=0.0)
                        f = torch.sum(tresh).item()
                        if f == 0.0:
                            break
                        grad = torch.zeros_like(x)
                        for i in range(m):
                            pattern_i = (tresh[:, i] > 0)
                            if pattern_i.any():
                                grad[:, i] = (A_work[:, pattern_i] @ Ax[pattern_i, i]) * mu[i]
                        x = orthonormalize_via_svd(grad)
                        if f_prev is not None:
                            rel = (f - f_prev) / (abs(f_prev) + 1e-12)
                            if rel < epsilon:
                                break
                        f_prev = f

                    Ax = A_work.T @ x
                    Ax = Ax * mu[None, :]
                    pattern = (Ax * Ax - RHO_scaled[None, :]) > 0
                    Z = A_work.T @ x
                    Z[~pattern] = 0.0
                    coln = torch.linalg.norm(Z, dim=0)
                    for i in range(m):
                        if coln[i] > 0:
                            Z[:, i] /= coln[i]

    return Z

def unfold(X: Tensor, mode: int) -> Tensor:
    """mode-n unfolding -> (I_n, prod_{m!=n} I_m)"""
    N = X.dim()
    assert 0 <= mode < N
    perm = [mode] + [i for i in range(N) if i != mode]
    Xp = X.permute(perm).contiguous()
    In = X.size(mode)
    return Xp.view(In, -1)

def fold(M: Tensor, mode: int, shape: Sequence[int]) -> Tensor:
    """inverse of unfold: M is (I_n, prod_{m!=n})"""
    N = len(shape)
    perm = [mode] + [i for i in range(N) if i != mode]
    full = M.view([shape[mode]] + [shape[i] for i in range(N) if i != mode])
    inv = [perm.index(i) for i in range(N)]
    return full.permute(inv).contiguous()

def mode_n_product(X: Tensor, U: Tensor, mode: int, transpose: bool = False) -> Tensor:
    """
    Y = X ×_mode U (or ×_mode U^T if transpose=True)
    - If transpose=False, U should be (J, I_mode)
    - If transpose=True,  U should be (I_mode, J)
    Returns Y with shape (..., J, ...).
    """
    In = X.size(mode)
    if transpose:
        assert U.shape[0] == In, "U^T rows must equal I_mode"
        M = U  # (I_mode, J)
    else:
        assert U.shape[1] == In, "U cols must equal I_mode"
        M = U.T  # (I_mode, J)
    Xn = unfold(X, mode)           # (I_mode, prod_others)
    Yn = M.T @ Xn                  # (J, prod_others)
    new_shape = list(X.shape)
    new_shape[mode] = M.shape[1]   # J
    return fold(Yn, mode, new_shape)

def _as_1d_tensor(x, device, dtype, length: Optional[int] = None, name: str = "arg") -> Tensor:
    if isinstance(x, torch.Tensor):
        t = x.detach().to(device=device, dtype=dtype).flatten()
    else:
        t = torch.as_tensor(x, device=device, dtype=dtype).flatten()
    if length is not None and t.numel() != length:
        raise ValueError(f"{name} length mismatch: expected {length}, got {t.numel()}")
    return t

def tucker_spca_gpower(
    X: Tensor,                                   # shape (I1,...,IN)
    R: Union[int, Sequence[int]],                # scalar K or [R1,...,RN]
    RHO: Union[float, Sequence, Sequence[Sequence], List[Tensor]],
    penalty: str,
    block: int,
    MU: Optional[Union[Sequence, List[Tensor], Tensor]] = None,
    opts: Optional[dict] = None,
) -> Tuple[List[Tensor], Tensor]:
    assert penalty in ("l1", "l0")
    assert block in (0, 1)

    if opts is None:
        opts = {}
    center = bool(opts.get("center", True))

    device, dtype = X.device, X.dtype
    N = X.dim()
    shape = list(X.shape)

    if isinstance(R, (int, float)):
        R_list = [int(R)] * N
    else:
        R_list = list(map(int, R))
        if len(R_list) != N:
            raise ValueError("R must have length N=ndims(X) or be a scalar.")

    def get_rho_n(n: int) -> Tensor:
        r = R_list[n]
        if isinstance(RHO, (int, float)):
            return torch.full((r,), float(RHO), device=device, dtype=dtype)
        if isinstance(RHO, torch.Tensor):
            rho = RHO.detach().to(device=device, dtype=dtype).flatten()
            if rho.numel() == 1:
                rho = rho.expand(r)
            elif rho.numel() != r:
                raise ValueError(f"RHO length != R[{n}]")
            return rho
        if isinstance(RHO, (list, tuple)):

            if len(RHO) == N:
                rho_n = RHO[n]
            else:
                rho_n = RHO
            if isinstance(rho_n, torch.Tensor):
                return _as_1d_tensor(rho_n, device, dtype, length=None)
            else:
                rho_t = torch.as_tensor(rho_n, device=device, dtype=dtype).flatten()
                if rho_t.numel() == 1:
                    rho_t = rho_t.expand(r)
                elif rho_t.numel() != r:
                    raise ValueError(f"RHO[{n}] length != R[{n}]")
                return rho_t
        raise ValueError("Unsupported RHO format.")

    def get_mu_n(n: int) -> Optional[Tensor]:
        if block != 1:
            return None
        if MU is None:
            raise ValueError("block=1 requires MU.")
        r = R_list[n]
        if isinstance(MU, torch.Tensor):
            mu = MU.detach().to(device=device, dtype=dtype).flatten()
            if mu.numel() == 1:
                mu = mu.expand(r)
            elif mu.numel() != r:
                raise ValueError(f"MU length != R[{n}]")
            return mu
        if isinstance(MU, (list, tuple)):
            mu_obj = MU[n] if len(MU) == N else MU
            mu_t = _as_1d_tensor(mu_obj, device, dtype)
            if mu_t.numel() == 1:
                mu_t = mu_t.expand(r)
            elif mu_t.numel() != r:
                raise ValueError(f"MU[{n}] length != R[{n}]")
            return mu_t
        if isinstance(MU, (int, float)):
            return torch.full((r,), float(MU), device=device, dtype=dtype)
        raise ValueError("Unsupported MU format.")

    U_list: List[Tensor] = []


    for n in range(N):
        Xn = unfold(X, n)          # (I_n, prod_others)
        A = Xn      # (prod_others, I_n)

        if center:
            A = A - A.mean(dim=0, keepdim=True)

        rho_n = get_rho_n(n)       # (R_n,)
        mu_n  = get_mu_n(n) if block == 1 else None


        Z = gpower(A, rho_n, R_list[n], penalty=penalty, block=block, mu=mu_n)


        Un = compute_X(A, Z, rho_n)
        U_list.append(Un)


    G = X
    for n in range(N):
        assert U_list[n].shape[0] == G.shape[n], \
            f"U_list[{n}] has {U_list[n].shape[0]} rows but G mode-{n} is {G.shape[n]}"
        G = mode_n_product(G, U_list[n], mode=n, transpose=True)

    return U_list, G
