from __future__ import annotations

from dataclasses import dataclass
from typing import Callable, Tuple
import sys, time

import numpy as np
import torch
import timeit

# -- repo path hack ----------------------------------------------------------
from sorcerun.git_utils import get_repo

ROOT = get_repo().working_dir
sys.path.append(ROOT)

from globals import SIGN, SQRT, INV, PROOT, _sign, _inv, _sqrt, _proot

# --------------------------------------------------------------------------- #
#                               Small utilities                               #
# --------------------------------------------------------------------------- #


def _splice(base: Tuple, *updates):
    """Return a tuple identical in length to *base* with the first
    ``len(updates)`` elements replaced by *updates*.
    """
    return (*updates, *base[len(updates) :])


def _safe_inverse(M: torch.Tensor) -> torch.Tensor:
    """SPD‑aware inverse/solve that avoids explicit ``torch.linalg.inv``.

    * If Cholesky succeeds we call ``torch.cholesky_inverse``.
    * Otherwise we fall back to an LU solve.
    * Handles batched matrices by repeating an identity tensor to the
      batch shape.
    """
    try:
        L = torch.linalg.cholesky(M)
        return torch.cholesky_inverse(L)
    except RuntimeError:
        LU, piv = torch.linalg.lu_factor(M)
        n = M.shape[-1]
        I = torch.eye(n, dtype=M.dtype, device=M.device)
        if M.ndim > 2:  # batched case → repeat identity along batch dims
            I = I.repeat(*M.shape[:-2], 1, 1)
        return torch.linalg.lu_solve(LU, piv, I)


# Registry ------------------------------------------------------------------ #
ACTIONS: dict[str, "Action"] = {}


# --------------------------------------------------------------------------- #
#                                   I/O                                       #
# --------------------------------------------------------------------------- #


@dataclass(slots=True)
class ActionInput:
    current_spectra: Tuple[np.ndarray, ...]
    a_spectrum: np.ndarray
    theta: np.ndarray


@dataclass(slots=True)
class MatrixActionInput:
    current_matrices: Tuple[torch.Tensor | np.ndarray, ...]
    a_matrix: torch.Tensor | np.ndarray
    theta: np.ndarray


@dataclass
class Action:
    name: str
    spectral_iteration: Callable[[ActionInput], Tuple]
    baseline_spectral_iteration: Callable[[ActionInput], Tuple]
    matrix_iteration: Callable[[MatrixActionInput], Tuple]
    theta_bounds: Tuple[np.ndarray, np.ndarray]  # (lo, hi)
    time: float  # measured runtime weight
    k: int  # len(theta)
    num_matrices: int = 1  # # matrices this action updates
    default_params: np.ndarray = np.zeros(3)
    absolute_time: float = -1
   
    # --------------------------------------------------------------------- #
    #                             convenience                              #
    # --------------------------------------------------------------------- #
    def sample_theta(self, i, rng: np.random.Generator) -> np.ndarray:

        if i < 0:
            lo = self.theta_bounds[0]
            hi = self.theta_bounds[1]
            return rng.uniform(lo, hi, size = self.k)
        lo = self.theta_bounds[0][i-1]
        hi = self.theta_bounds[1][i-1]
        return rng.uniform(lo, hi)

    def sample_jitter(
        self,
        i,
        mean: np.ndarray,
        stddev_scale: float,
        rng: np.random.Generator,
        eps: float = 5e-2,
    ) -> np.ndarray:
        """Sample θ near *mean* with Gaussian‑like spread.
        Returns a uniform random vector inside ``mean ± stddev_scale`` box.
        A small *eps* chance disregards *mean* and draws anywhere in bounds.
        Works for ``k == 0`` by returning an empty array.
        """
        if self.k == 0:
            return np.empty(0)

        lo = self.theta_bounds[0][i-1]
        hi = self.theta_bounds[1][i-1]
        if rng.random() < eps:
            return rng.uniform(lo, hi)

        stddev = stddev_scale * (hi - lo) / 2.0
        new_lo = np.clip(mean - stddev, lo, hi)
        new_hi = np.clip(mean + stddev, lo, hi)
        return rng.uniform(new_lo, new_hi)


# --------------------------------------------------------------------------- #
#                            timing helper (unchanged)                       #
# --------------------------------------------------------------------------- #
def estimate_relative_times(actions, size, repeats=30, device="cuda"):
    """
    Estimate the relative times of the actions by running each matrix action
    """

    times = np.zeros(len(actions))
    for i, action in enumerate(actions):
        
        # Generate a random matrix of the specified size
        
        t = 0
        if device == "cuda":
            for _ in range(repeats):
                
                num = 0
                if action.name == _sqrt("couple") or _proot("couple"):
                    num = 2
                else:
                    num = action.num_matrices
                current_matrices = tuple(
                    torch.randn(size, size, device=device)
                    for _ in range(num)
                )
                a_matrix = torch.randn(size, size, device=device)
                theta = action.sample_theta(-1, np.random.default_rng())
                inp = MatrixActionInput(
                    current_matrices=current_matrices,
                    a_matrix=a_matrix,
                    theta=theta,
                )
                _, t_run = action.matrix_iteration(inp)
                t += t_run

            times[i] = t/repeats

        elif device == "cpu": ## For cpu
            
            repeats = 50
            t = []
            for _ in range(repeats):
                
                current_matrices = tuple(
                    torch.randn(size, size, device=device)
                    for _ in range(action.num_matrices)
                )
                a_matrix = torch.randn(size, size, device=device)
                #print(a_matrix.device)
                #print(current_matrices[0].device)
                #print(current_matrices[1].device)
                theta = action.sample_theta(-1, np.random.default_rng())
                inp = MatrixActionInput(
                    current_matrices=current_matrices,
                    a_matrix=a_matrix,
                    theta=theta,
                )
                _, t_run = action.matrix_iteration(inp)
                t.append(t_run)
                #print(t_run)
            
            t = sorted(t)
            times[i] = np.mean(t)
       
    rel_times = times / times[0]
    for i, action in enumerate(actions):

        if len(actions) > 1:
            print(
                f"Action {action.name} took {rel_times[i]} times longer than {actions[0].name}"
            )
        action.time = rel_times[i]
        action.absolute_time = times[i]


# --------------------------------------------------------------------------- #
#                               SIGN  actions                                #
# --------------------------------------------------------------------------- #
# -- Newton–Schulz (linear two‑term) ----------------------------------------
def sign_ns(inp: ActionInput):
    
    s = inp.current_spectra[0]
    c0 = inp.theta[0]
    c1 = inp.theta[1]
    s = c1*s
    out = s + c0 * (s - s**3)
    return _splice(inp.current_spectra, out)


def sign_ns_baseline(inp: ActionInput):
    
    vector_spectrum_1 = inp.current_spectra[0]
    lmin = np.min(np.abs(vector_spectrum_1))/np.max(np.abs(vector_spectrum_1))
    sc = (np.sqrt(3) / np.sqrt(1 + lmin + lmin * lmin)).item()
    sc = sc/np.max(np.abs(vector_spectrum_1))
    inp.theta = [0.5, sc]

    return sign_ns(inp), inp.theta


def sign_ns_matrix(inp: MatrixActionInput):

    c0, c1 = inp.theta[0], inp.theta[1]
    X = inp.current_matrices[0]
    X = c1*X
    out = X + c0*(X - X@X@X)
    
    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        c0, c1 = inp.theta[0], inp.theta[1]
        X = inp.current_matrices[0]
        X = c1*X
        out = X + c0*(X - X@X@X)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        c0, c1 = inp.theta[0], inp.theta[1]
        X = inp.current_matrices[0]
        X = c1*X
        out = X + c0*(X - X@X@X)
        elapsed_time = time.time()-start_time
    
    return _splice(inp.current_matrices, out), elapsed_time

    
ACTIONS[_sign("ns")] = Action(
    name=_sign("ns"),
    spectral_iteration=sign_ns,
    baseline_spectral_iteration=sign_ns_baseline,
    matrix_iteration=sign_ns_matrix,
    theta_bounds=(np.zeros(2), np.array([5, 5])),
    time=-1,
    k=2,
)


# -- Newton (already OK) -----------------------------------------------------
def sign_newton(inp: ActionInput):
    s_scaled = inp.current_spectra[0] * inp.theta[0]
    out_spec = 0.5 * (s_scaled + 1.0 / s_scaled)
    return _splice(inp.current_spectra, out_spec)


def sign_newton_baseline(inp: ActionInput):
    lmax = np.max(inp.current_spectra[0], axis=-1)
    inp.theta = [1 / np.sqrt(lmax)]
    return sign_newton(inp), inp.theta


def sign_newton_matrix(inp: MatrixActionInput):
    
    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X = inp.current_matrices[0] * inp.theta[0]
        X_inv = _safe_inverse(X)
        out = 0.5 * (X + X_inv)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X = inp.current_matrices[0] * inp.theta[0]
        X_inv = _safe_inverse(X)
        out = 0.5 * (X + X_inv)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_sign("newton")] = Action(
    name=_sign("newton"),
    spectral_iteration=sign_newton,
    baseline_spectral_iteration=sign_newton_baseline,
    matrix_iteration=sign_newton_matrix,
    theta_bounds=(np.array([0.0]), np.array([40.0])),
    time=-1,
    k=1,
)


# -- Newton variant (already OK) --------------------------------------------
def sign_newton_variant(inp: ActionInput):
    s1 = inp.current_spectra[0] * inp.theta[0]
    inv = 1 / (1 + s1 * s1)
    nxt = 2 * (s1 * inv)
    return _splice(inp.current_spectra, nxt)


def sign_newton_variant_baseline(inp: ActionInput):
    lmax = np.max(inp.current_spectra[0], axis=-1)
    inp.theta[0] = 1 / np.sqrt(lmax)
    return sign_newton_variant(inp), inp.theta


def sign_newton_variant_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X = inp.current_matrices[0] * inp.theta[0]
        I = torch.eye(X.shape[-1], dtype=X.dtype, device=X.device)
        M = I + X @ X
        Z = torch.linalg.solve(M, X)
        out = 2.0 * Z
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X = inp.current_matrices[0] * inp.theta[0]
        I = torch.eye(X.shape[-1], dtype=X.dtype, device=X.device)
        M = I + X @ X
        Z = torch.linalg.solve(M, X)
        out = 2.0 * Z
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_sign("newton_variant")] = Action(
    name=_sign("newton_variant"),
    spectral_iteration=sign_newton_variant,
    baseline_spectral_iteration=sign_newton_variant_baseline,
    matrix_iteration=sign_newton_variant_matrix,
    theta_bounds=(np.zeros(1), np.array([1])),
    time=-1,
    k=1,
)


# -- Quintic -----------------------------------------------------------------
def sign_quintic(inp: ActionInput):
    s = inp.current_spectra[0]
    quad, quart = s * s, s * s * s * s
    a = inp.theta
    out = a[0] * s - a[1] * quad * s + a[2] * quart * s
    return _splice(inp.current_spectra, out)


def sign_quintic_baseline(inp: ActionInput):
    inp.theta[:] = [2, 1.5, 0.5]
    return sign_quintic(inp), inp.theta


def sign_quintic_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X = inp.current_matrices[0]
        quad = X @ X
        quart = quad @ quad
        a = inp.theta
        out = a[0] * X - (a[1] * quad - a[2] * quart) @ X
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X = inp.current_matrices[0]
        quad = X @ X
        quart = quad @ quad
        a = inp.theta
        out = a[0] * X - (a[1] * quad - a[2] * quart) @ X
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_sign("quintic")] = Action(
    name=_sign("quintic"),
    spectral_iteration=sign_quintic,
    baseline_spectral_iteration=sign_quintic_baseline,
    matrix_iteration=sign_quintic_matrix,
    theta_bounds=(np.zeros(3), np.array([10,10,10])),
    time=-1,
    k=3,
)


# -- Halley ------------------------------------------------------------------
def sign_halley(inp: ActionInput):
    s = inp.current_spectra[0]
    quad = s * s
    a = inp.theta
    out = (s * (a[0] + a[1] * quad)) / (1 + a[2] * quad)
    return _splice(inp.current_spectra, out)


def sign_halley_baseline(inp: ActionInput):
    lmax = np.max(inp.current_spectra[0], axis=-1)
    sc = 1 / np.sqrt(lmax)
    inp.theta[:] = [3,3,1]
    return sign_halley(inp), inp.theta


def sign_halley_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X = inp.current_matrices[0].clone()
        m = X.shape[0]
        d = X.shape[1]
        a = inp.theta
        I = torch.eye(d, device=X.device)
        A = torch.vstack((np.sqrt(a[2]) * X, I))
        Q, _ = torch.linalg.qr(A)
        out = a[1]/a[2] * X + 1/np.sqrt(a[2])*(a[0] - a[1]/a[2]) * Q[0:m,:] @ Q[m:, :].T
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X = inp.current_matrices[0].clone()
        m = X.shape[0]
        d = X.shape[1]
        a = inp.theta
        I = torch.eye(d, device=X.device)
        A = torch.vstack((np.sqrt(a[2]) * X, I))
        Q, _ = torch.linalg.qr(A)
        out = a[1]/a[2] * X + 1/np.sqrt(a[2])*(a[0] - a[1]/a[2]) * Q[0:m,:] @ Q[m:, :].T
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_sign("halley")] = Action(
    name=_sign("halley"),
    spectral_iteration=sign_halley,
    baseline_spectral_iteration=sign_halley_baseline,
    matrix_iteration=sign_halley_matrix,
    theta_bounds=(np.zeros(3), 10*np.ones(3)),
    time=-1,
    k=3,
)

# ----------- SQRT Actions -------------------
# %%


# ------------- SQRT Visser (uncoupled) --------------------
def sqrt_visser(inp: ActionInput) -> np.ndarray:
    """Uncoupled Visser iteration for matrix square‑root in the spectral domain.
    θ = (a0, a1)
    y_{k+1} = a0·λ + a1·y_k − a0·y_k²
    """
    y = inp.current_spectra[0]
    lam = inp.a_spectrum
    a = inp.theta
    out = a[0] * lam + a[1] * y - a[0] * y * y
    return _splice(inp.current_spectra, out)


def sqrt_visser_baseline(inp: ActionInput):
    inp.theta[:] = [0.5, 1.0]
    return sqrt_visser(inp), inp.theta


def sqrt_visser_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y, A = inp.current_matrices[0], inp.a_matrix
        a = inp.theta
        out = a[0] * A + a[1] * Y - a[0] * (Y @ Y)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y, A = inp.current_matrices[0], inp.a_matrix
        a = inp.theta
        out = a[0] * A + a[1] * Y - a[0] * (Y @ Y)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_sqrt("visser")] = Action(
    name=_sqrt("visser"),
    spectral_iteration=sqrt_visser,
    baseline_spectral_iteration=sqrt_visser_baseline,
    matrix_iteration=sqrt_visser_matrix,
    theta_bounds=(np.zeros(2), np.array([10, 10])),
    time=-1,
    k=2,
)


# -- Coupled Visser ---------------------------------------------------
def sqrt_visser_coupled(inp: ActionInput):
    y1, y2 = inp.current_spectra[:2]
    lam = inp.a_spectrum
    a = inp.theta
    out1 = a[0] * lam + a[1] * y1 - a[0] * y1 * y1
    out2 = a[0] + a[1] * y2 - a[0] * (y2 * y1)
    return _splice(inp.current_spectra, out1, out2)


def sqrt_visser_coupled_baseline(inp: ActionInput):
    inp.theta[:] = [0.5, 1.0]
    return sqrt_visser_coupled(inp), inp.theta


def sqrt_visser_coupled_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y1, Y2 = inp.current_matrices[:2]
        A = inp.a_matrix
        a = inp.theta
        I = torch.eye(A.shape[0], device=A.device)
        out1 = a[0] * A + a[1] * Y1 - a[0] * (Y1 @ Y1)
        out2 = a[0] * I + a[1] * Y2 - a[0] * (Y2 @ Y1)
        end.record()
        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y1, Y2 = inp.current_matrices[:2]       
        A = inp.a_matrix
        a = inp.theta
        I = torch.eye(A.shape[0], device=A.device)
        out1 = a[0] * A + a[1] * Y1 - a[0] * (Y1 @ Y1)
        out2 = a[0] * I + a[1] * Y2 - a[0] * (Y2 @ Y1)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out1, out2), elapsed_time


ACTIONS[_sqrt("visser_coupled")] = Action(
    name=_sqrt("visser_coupled"),
    spectral_iteration=sqrt_visser_coupled,
    baseline_spectral_iteration=sqrt_visser_coupled_baseline,
    matrix_iteration=sqrt_visser_coupled_matrix,
    theta_bounds=(np.zeros(2), np.array([0.5, 1.0])),
    time=-1,
    k=2,
    num_matrices=2,
)

# Newton ----------------------------------------------------------------
def sqrt_newton(inp: ActionInput):

    y1 = inp.current_spectra[0]
    lam = inp.a_spectrum
    a = inp.theta
    out = 0.5*(a[0]*y1 + a[1]*lam/y1)
    return _splice(inp.current_spectra, out)

def sqrt_newton_baseline(inp: ActionInput):

    inp.theta[:] = [1, 1]
    return sqrt_newton(inp), inp.theta

def sqrt_newton_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y1 = inp.current_matrices[0]
        A = inp.a_matrix
        a = inp.theta
        out = 0.5*(a[0]*Y1 + a[1]*_safe_inverse(Y1)@A)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y1 = inp.current_matrices[0]
        A = inp.a_matrix
        a = inp.theta
        out = 0.5*(a[0]*Y1 + a[1]*_safe_inverse(Y1)@A)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time

ACTIONS[_sqrt("newton")] = Action(
    name=_sqrt("newton"),
    spectral_iteration=sqrt_newton,
    baseline_spectral_iteration=sqrt_newton_baseline,
    matrix_iteration=sqrt_newton_matrix,
    theta_bounds=(np.zeros(2), np.array([10, 10])),
    time=-1,
    k=2,
)

# Coupled Newton ------------------------------------------------------
def sqrt_newton_coupled(inp: ActionInput):

    y1, y2 = inp.current_spectra[:2]
    lam = inp.a_spectrum
    a = inp.theta
    out1 = 0.5*(a[0]*y1 + a[1]*lam/y1)
    out2 = 0.5*(a[0]*lam*y2 + 1/y2)
    return _splice(inp.current_spectra, out1, out2)

def sqrt_newton_baseline_coupled(inp: ActionInput):

    inp.theta[:] = [1, 1]
    return sqrt_newton(inp), inp.theta

def sqrt_newton_matrix_coupled(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y1, Y2 = inp.current_matrices[:2]
        A = inp.a_matrix
        a = inp.theta
        out1 = 0.5*(a[0]*Y1 + a[1]*_safe_inverse(Y1)@A)
        out2 = 0.5*(a[0]*A@Y2 + a[1]*_safe_inverse(Y2))
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y1, Y2 = inp.current_matrices[:2]
        A = inp.a_matrix
        a = inp.theta
        out1 = 0.5*(a[0]*Y1 + a[1]*_safe_inverse(Y1)@A)
        out2 = 0.5*(a[0]*A@Y2 + a[1]*_safe_inverse(Y2))
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out1, out2), elapsed_time

ACTIONS[_sqrt("newton_coupled")] = Action(
    name=_sqrt("newton_coupled"),
    spectral_iteration=sqrt_newton_coupled,
    baseline_spectral_iteration=sqrt_newton_baseline_coupled,
    matrix_iteration=sqrt_newton_matrix_coupled,
    theta_bounds=(np.zeros(2), np.array([10, 10])),
    time=-1,
    k=2,
    num_matrices=2
)

# --------------------------------------------------------------------------- #
#                        INVERSE Newton–Schulz family                         #
# --------------------------------------------------------------------------- #
def inv_ns(inp: ActionInput):
    x, lam = inp.current_spectra[0], inp.a_spectrum
    a = inp.theta
    out = a[0] * x - a[1] * (x * lam * x)
    return _splice(inp.current_spectra, out)


def inv_ns_baseline(inp: ActionInput):
    inp.theta[:] = [2.0, 1.0]
    return inv_ns(inp), inp.theta


def inv_ns_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X, A = inp.current_matrices[0], inp.a_matrix
        a = inp.theta
        out = a[0] * X - a[1] * (X @ A @ X)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X, A = inp.current_matrices[0], inp.a_matrix
        a = inp.theta
        out = a[0] * X - a[1] * (X @ A @ X)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_inv("ns")] = Action(
    name=_inv("ns"),
    spectral_iteration=inv_ns,
    baseline_spectral_iteration=inv_ns_baseline,
    matrix_iteration=inv_ns_matrix,
    theta_bounds=(np.zeros(2), np.array([10, 10])),
    time=-1,
    k=2,
)


# -- Chebyshev extension -----------------------------------------------------
def inv_ns_chebyshev(inp: ActionInput):
    x, lam = inp.current_spectra[0], inp.a_spectrum
    a = inp.theta
    xl = x * lam
    out = a[0] * x + x * (-a[1] * xl + a[2] * xl * xl)
    return _splice(inp.current_spectra, out)


def inv_ns_chebyshev_baseline(inp: ActionInput):
    inp.theta[:] = [3.0, 3.0, 1.0]
    return inv_ns_chebyshev(inp), inp.theta


def inv_ns_chebyshev_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X, A = inp.current_matrices[0], inp.a_matrix
        a = inp.theta
        XA = X @ A
        Xs = XA @ XA
        out = a[0] * X + (-a[1] * XA + a[2] * Xs) @ X
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X, A = inp.current_matrices[0], inp.a_matrix
        a = inp.theta
        XA = X @ A
        Xs = XA @ XA
        out = a[0] * X + (-a[1] * XA + a[2] * Xs) @ X
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, out), elapsed_time


ACTIONS[_inv("ns_chebyshev")] = Action(
    name=_inv("ns_chebyshev"),
    spectral_iteration=inv_ns_chebyshev,
    baseline_spectral_iteration=inv_ns_chebyshev_baseline,
    matrix_iteration=inv_ns_chebyshev_matrix,
    theta_bounds=(np.zeros(3), np.array([10, 10, 10])),
    time=-1,
    k=3
)


# =========================================================================== #
#                             COUPLED  √A ITERATORS                           #
# =========================================================================== #
# -- Denman–Beavers ----------------------------------------------------------
def sqrt_db(inp: ActionInput):
    Y, Z = inp.current_spectra[:2]
    a0, a1 = inp.theta[:2]
    Y = a0 * Y
    Z = a1 * Z
    invZ = 1.0 / Z
    invY = 1.0 / Y
    Yn = 0.5 * (Y + invZ)
    Zn = 0.5 * (Z + invY)
    return _splice(inp.current_spectra, Yn, Zn)


def sqrt_db_baseline(inp: ActionInput):

    vector_spectrum_1, vector_spectrum_2 = inp.current_spectra[:2]
    n = len(vector_spectrum_1)
    log_sum = np.sum(np.log(np.abs(vector_spectrum_1)+1e-14) + np.log(np.abs(vector_spectrum_2)+1e-14))
    scale = np.exp(-0.5 / n * log_sum)
    inp.theta[:] = [scale, scale]

    return sqrt_db(inp), inp.theta


def sqrt_db_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y, Z = inp.current_matrices[:2]
        a0, a1 = inp.theta[:2]
        Y = a0 * Y
        Z = a1 * Z
        invZ = _safe_inverse(Z)
        invY = _safe_inverse(Y)
        Yn = 0.5 * (Y + invZ)
        Zn = 0.5 * (Z + invY)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y, Z = inp.current_matrices[:2]
        a0, a1 = inp.theta[:2]
        Y = a0 * Y
        Z = a1 * Z
        invZ = _safe_inverse(Z)
        invY = _safe_inverse(Y)
        Yn = 0.5 * (Y + invZ)
        Zn = 0.5 * (Z + invY)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, Yn, Zn), elapsed_time


ACTIONS[_sqrt("db")] = Action(
    name=_sqrt("db"),
    spectral_iteration=sqrt_db,
    baseline_spectral_iteration=sqrt_db_baseline,
    matrix_iteration=sqrt_db_matrix,
    theta_bounds=(np.zeros(2), np.ones(2)),
    time=-1,
    k=2,
    num_matrices=2,
)


# -- Newton–Schulz‑Variant pair ----------------------------------------------
def sqrt_nsv(inp: ActionInput):
    Y, Z = inp.current_spectra[:2]
    a0, a1 = inp.theta[:2]
    intermediate = a0 - a1 * Y * Z
    Yn = 0.5 * Y * intermediate
    Zn = 0.5 * Z * intermediate
    return _splice(inp.current_spectra, Yn, Zn)


def sqrt_nsv_baseline(inp: ActionInput):
    inp.theta[:] = [3, 1]
    return sqrt_nsv(inp), inp.theta


def sqrt_nsv_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y, Z = inp.current_matrices[:2]
        ZY = Z @ Y
        a0, a1 = inp.theta[:2]
        intermediate = a0 * torch.eye(Y.shape[-1], device=Y.device, dtype=Y.dtype) - a1 * ZY
        Yn = 0.5 * Y @ intermediate
        Zn = 0.5 * intermediate @ Z
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y, Z = inp.current_matrices[:2]
        ZY = Z @ Y
        a0, a1 = inp.theta[:2]
        intermediate = a0 * torch.eye(Y.shape[-1], device=Y.device, dtype=Y.dtype) - a1 * ZY
        Yn = 0.5 * Y @ intermediate
        Zn = 0.5 * intermediate @ Z
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, Yn, Zn), elapsed_time


ACTIONS[_sqrt("nsv")] = Action(
    name=_sqrt("nsv"),
    spectral_iteration=sqrt_nsv,
    baseline_spectral_iteration=sqrt_nsv_baseline,
    matrix_iteration=sqrt_nsv_matrix,
    theta_bounds=(np.zeros(2), 5.0 * np.ones(2)),
    time=-1,
    k=2,
    num_matrices=2,
)

# -- coupled --------------------------------------------------------------------------------
def couple(inp: ActionInput):
    Y, Z = inp.current_spectra[:2]
    A = inp.a_spectrum
    return _splice(inp.current_spectra, Y, Y/A)


def couple_baseline(inp: ActionInput):
    return couple(inp), inp.theta


def couple_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y, Z = inp.current_matrices[:2]
        A = inp.a_matrix
        Ainv = _safe_inverse(A)
        K = Y@Ainv
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y, Z = inp.current_matrices[:2]
        A = inp.a_matrix
        Ainv = _safe_inverse(A)
        K = Y@Ainv
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, Y, K), elapsed_time


ACTIONS["sqrt_couple"] = Action(
    name=_sqrt("couple"),
    spectral_iteration=couple,
    baseline_spectral_iteration=couple_baseline,
    matrix_iteration=couple_matrix,
    theta_bounds=([],[]),
    time=-1,
    k=0,
    num_matrices=2
)

# =========================================================================== #
#                             1/3 root ITERATORS                              #
# =========================================================================== #
# -- Newton ----------------------------------------------------------

def proot_newton(inp: ActionInput):

    Y = inp.current_spectra[0]
    A = inp.a_spectrum
    a0, a1 = inp.theta[:2]
    Yn = (a0*Y+a1*A/Y/Y)/3
    return _splice(inp.current_spectra, Yn)

def proot_newton_baseline(inp: ActionInput):

    inp.theta[:] = [1,2]
    return proot_newton(inp), inp.theta

def proot_newton_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        Y = inp.current_matrices[0]
        A = inp.a_matrix
        a0, a1 = inp.theta[:2]
        Yinv = _safe_inverse(Y)
        Yn = (a0*Y + a1*A@Yinv@Yinv)/3
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        Y = inp.current_matrices[0]
        A = inp.a_matrix
        a0, a1 = inp.theta[:2]
        Yinv = _safe_inverse(Y)
        Yn = (a0*Y + a1*A@Yinv@Yinv)/3
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, Yn), elapsed_time


ACTIONS[_proot("newton")] = Action(
    name=_proot("newton"),
    spectral_iteration=proot_newton,
    baseline_spectral_iteration=proot_newton_baseline,
    matrix_iteration=proot_newton_matrix,
    theta_bounds=(np.zeros(2), 1.0 * np.ones(2)),
    time=-1,
    k=2,
)

# -- Visser ----------------------------------------------------------

def proot_visser(inp: ActionInput):

    vector_spectrum_1 = inp.current_spectra[0]
    a_spectrum = inp.a_spectrum
    a0, a1 = inp.theta[:2]
    next_vector_spectrum_1 = a0*vector_spectrum_1 + a1*(a_spectrum - vector_spectrum_1*vector_spectrum_1*vector_spectrum_1)
    return _splice(inp.current_spectra, next_vector_spectrum_1)

def proot_visser_baseline(inp: ActionInput):

    inp.theta[:] = [1,1/3]
    return proot_visser(inp), inp.theta

def proot_visser_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X = inp.current_matrices[0]
        A = inp.a_matrix
        a0, a1 = inp.theta[:2]
        Xn = a0*X + a1*(A - X@X@X)
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X = inp.current_matrices[0]
        A = inp.a_matrix
        a0, a1 = inp.theta[:2]
        Xn = a0*X + a1*(A - X@X@X)
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, Xn), elapsed_time

ACTIONS[_proot("visser")] = Action(
    name=_proot("visser"),
    spectral_iteration=proot_visser,
    baseline_spectral_iteration=proot_visser_baseline,
    matrix_iteration=proot_visser_matrix,
    theta_bounds=(np.zeros(2), 1.0 * np.ones(2)),
    time=-1,
    k=2,
)

# Iannozzo iteration ----------------------------------------------
def proot_iannazzo(inp: ActionInput):

    vector_spectrum_1, vector_spectrum_2 = inp.current_spectra[:2]
    a0, a1 = inp.theta[:2]
    next_vector_spectrum_1 = vector_spectrum_1*(a0 + a1*vector_spectrum_2)/3
    next_vector_spectrum_2 = vector_spectrum_2/(((a0 + a1*vector_spectrum_2)/3)**3)
    return _splice(inp.current_spectra, next_vector_spectrum_1, next_vector_spectrum_2)

def proot_iannazzo_baseline(inp: ActionInput):

    inp.theta[:] = [2,1]
    return proot_iannazzo(inp), inp.theta

def proot_iannazzo_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X, Y = inp.current_matrices[:2]
        a0, a1 = inp.theta[:2]
        I = torch.eye(X.shape[-1], dtype=X.dtype, device=X.device)
        Yav = (a0*I + a1*Y)/3
        Yavinv = _safe_inverse(Yav)
        Xn = X@Yav
        Yn = Yavinv@Yavinv@Yavinv@Y
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X, Y = inp.current_matrices[:2]
        a0, a1 = inp.theta[:2]
        I = torch.eye(X.shape[-1], dtype=X.dtype, device=X.device)
        Yav = (a0*I + a1*Y)/3
        Yavinv = _safe_inverse(Yav)
        Xn = X@Yav
        Yn = Yavinv@Yavinv@Yavinv@Y
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, Xn, Yn), elapsed_time

ACTIONS[_proot("iannazzo")] = Action(
    name=_proot("iannazzo"),
    spectral_iteration=proot_iannazzo,
    baseline_spectral_iteration=proot_iannazzo_baseline,
    matrix_iteration=proot_iannazzo_matrix,
    theta_bounds=(np.zeros(2), 1.0 * np.ones(2)),
    time=-1,
    k=2,
    num_matrices=2
)

# Coupling -----------------------------------------------

def proot_couple(inp: ActionInput):

    vector_spectrum_1, vector_spectrum_2 = inp.current_spectra[:2]
    a_spectrum = inp.a_spectrum
    next_vector_spectrum_2 = a_spectrum/(vector_spectrum_1**3)
    return _splice(inp.current_spectra, vector_spectrum_1, next_vector_spectrum_2)

def proot_couple_baseline(inp: ActionInput):
    return proot_couple(inp), inp.theta

def proot_couple_matrix(inp: MatrixActionInput):

    if "cuda" in str(inp.current_matrices[0].device):
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        start.record()
        X = inp.current_matrices[0]
        A = inp.a_matrix
        Xinv = _safe_inverse(X)
        Yn = A@Xinv@Xinv@Xinv
        end.record()

        torch.cuda.synchronize()
        elapsed_time = start.elapsed_time(end)/1000

    else:
        start_time = time.time()
        X = inp.current_matrices[0]
        A = inp.a_matrix
        Xinv = _safe_inverse(X)
        Yn = A@Xinv@Xinv@Xinv
        elapsed_time = time.time()-start_time

    return _splice(inp.current_matrices, X, Yn), elapsed_time

ACTIONS[_proot("couple")] = Action(
    name=_proot("couple"),
    spectral_iteration=proot_couple,
    baseline_spectral_iteration=proot_couple_baseline,
    matrix_iteration=proot_couple_matrix,
    theta_bounds=([],[]),
    time=-1,
    k=0,
    num_matrices= 2
)

# swap ----------------------------------------------
def proot_swap(inp: ActionInput):

    vector_spectrum_1, vector_spectrum_2 = inp.current_spectra[:2]
    return _splice(inp.current_spectra, vector_spectrum_2, vector_spectrum_1)

def proot_swap_baseline(inp: ActionInput):
    return proot_swap(inp), inp.theta

def proot_swap_matrix(inp: MatrixActionInput):

    X, Y = inp.current_matrices[:2]
    return _splice(inp.current_matrices, Y, X), 0

ACTIONS[_proot("swap")] = Action(
    name=_proot("swap"),
    spectral_iteration=proot_swap,
    baseline_spectral_iteration=proot_swap_baseline,
    matrix_iteration=proot_swap_matrix,
    theta_bounds=([],[]),
    time=-1,
    k=0,
    num_matrices=2
)