from __future__ import annotations

"""Baselines catalogue (clean version)
=====================================
Returns *pure* tuples `(action_fn, theta_vec)` so the calling code can
invoke the ACTION directly.  All strings are built with the helpers
`_sign`, `_sqrt`, `_inv` to keep the mapping consistent with the
registry defined in **actions.py** (imported below).

Public API
----------
```
baselines(name: str, T: int, args: list | None) -> list[tuple]
adaptive_baselines(name: str, A: torch.Tensor, T: int) -> list[tuple]
```
Both remain drop‑in compatible with the old version – the only visible
change is that each element is now a tuple instead of a dict.
"""

from typing import Callable, List

import numpy as np
import torch

from sorcerun.git_utils import get_repo
import sys

ROOT = get_repo().working_dir
sys.path.append(ROOT)

from make_algorithm.actions import ACTIONS, estimate_relative_times  # global registry created in actions.py
from globals import _sign, _sqrt, _inv, _proot

__all__ = ["baselines", "adaptive_baselines"]


# -----------------------------------------------------------------------------
# Individual baseline builder functions
# -----------------------------------------------------------------------------


def _baseline_sign_newton(T: int, _args: list | None) -> List[tuple]:
    return [(ACTIONS[_sign("newton")], np.array([1.0])) for _ in range(T)]


def _baseline_sign_ns(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([0.5, 1])
    return [(ACTIONS[_sign("ns")], theta) for _ in range(T)]


def _baseline_sign_scaled_newton(T: int, args: list | None) -> List[tuple]:
    if args is None or len(args) < 2:
        raise ValueError("Scaled‑Newton baseline expects [alpha, beta].")
    alpha, beta = args[:2]
    z0 = 1.0 / np.sqrt(alpha * beta)
    z1 = np.sqrt(2.0 * np.sqrt(alpha * beta) / (alpha + beta))
    out = []
    for i in range(T):
        if i == 0:
            out.append((ACTIONS[_sign("newton")], np.array([z0])))
        else:
            out.append((ACTIONS[_sign("newton")], np.array([z1])))
            z1 = 1.0 / np.sqrt(0.5 * (z1 + 1.0 / z1))
    return out


def _baseline_sign_scaled_ns(T: int, args: list | None) -> List[tuple]:
    if args is None or len(args) < 1:
        raise ValueError("Scaled‑NS baseline expects initial rho_0.")
    rho_0 = float(args[0])
    out = []
    for _ in range(T):
        a = np.sqrt(3.0 / (1.0 + rho_0 + rho_0**2))
        theta = np.array([0.5, a])
        out.append((ACTIONS[_sign("ns")], theta))
        rho_0 = 0.5 * a * rho_0 * (3.0 - a * a * rho_0 * rho_0)
    return out


def _baseline_sign_halley(T: int, args: list | None) -> List[tuple]:
    if args is None or len(args) < 1:
        raise ValueError("Halley baseline expects initial lambda_0.")
    l0 = float(args[0])
    out = []
    for _ in range(T):
        d = (4.0 * (1.0 - l0 * l0) / (l0**4)) ** (1.0 / 3.0)
        a = np.sqrt(1.0 + d) + 0.5 * np.sqrt(
            8.0 - 4.0 * d + 8.0 * (2.0 - l0 * l0) / (l0 * l0 * np.sqrt(1.0 + d))
        )
        b = (a - 1.0) ** 2 / 4.0
        c = a + b - 1.0
        out.append((ACTIONS[_sign("halley")], np.array([a, b, c])))
        l0 = l0 * (a + b * l0 * l0) / (1.0 + c * l0 * l0)
    return out


def _baseline_sign_newton_variant(T: int, args: list | None) -> List[tuple]:
    if args is None or len(args) < 2:
        raise ValueError("Newton‑variant baseline expects [alpha, beta].")
    alpha, beta = args[:2]
    z0 = 1.0 / np.sqrt(alpha * beta)
    z1 = np.sqrt((alpha + beta) / 2.0 * np.sqrt(alpha * beta))
    out = []
    for i in range(T):
        if i == 0:
            out.append((ACTIONS[_sign("newton_variant")], np.array([z0])))
        else:
            out.append((ACTIONS[_sign("newton_variant")], np.array([z1])))
            z1 = np.sqrt(0.5 * (z1 + 1.0 / z1))
    return out


def _baseline_inv_ns(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([2.0, 1.0])
    return [(ACTIONS[_inv("ns")], theta) for _ in range(T)]


def _baseline_inv_ns_cheb(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([3.0, 3.0, 1.0])
    return [(ACTIONS[_inv("ns_chebyshev")], theta) for _ in range(T)]


def _baseline_sqrt_db(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([1.0, 1.0])
    return [(ACTIONS[_sqrt("db")], theta) for _ in range(T)]


def _baseline_sqrt_nsv(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([3.0, 1.0])
    return [(ACTIONS[_sqrt("nsv")], theta) for _ in range(T)]


def _baseline_sqrt_visser(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([0.5, 1.0])
    return [(ACTIONS[_sqrt("visser")], theta) for _ in range(T)]


def _baseline_sqrt_newton(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([1.0, 1.0])  
    return [(ACTIONS[_sqrt("newton")], theta) for _ in range(T)]

def _baseline_sqrt_visser_coupled(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([0.5, 1.0])
    return [(ACTIONS[_sqrt("visser_coupled")], theta) for _ in range(T)]

def _baseline_sqrt_newton_coupled(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([1.0, 1.0])  
    return [(ACTIONS[_sqrt("newton_coupled")], theta) for _ in range(T)]

def _baseline_proot_newton(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([1.0, 2.0]) 
    return [(ACTIONS[_proot("newton")], theta) for _ in range(T)]

def _baseline_proot_visser(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([1.0, 1/3]) 
    return [(ACTIONS[_proot("visser")], theta) for _ in range(T)]

def _baseline_proot_iannazzo(T: int, _args: list | None) -> List[tuple]:
    theta = np.array([2.0, 1.0]) 
    return [(ACTIONS[_proot("iannazzo")], theta) for _ in range(T)]

# Map baseline name → builder function ------------------------------------------------------------

_BASELINE_BUILDERS: dict[str, Callable[[int, list | None], List[tuple]]] = {
    _sign("newton"): _baseline_sign_newton,
    _sign("ns"): _baseline_sign_ns,
    _sign("scaled_newton"): _baseline_sign_scaled_newton,
    _sign("scaled_ns"): _baseline_sign_scaled_ns,
    _sign("halley"): _baseline_sign_halley,
    _sign("newton_variant"): _baseline_sign_newton_variant,
    _inv("ns"): _baseline_inv_ns,
    _inv("ns_chebyshev"): _baseline_inv_ns_cheb,
    _sqrt("db"): _baseline_sqrt_db,
    _sqrt("nsv"): _baseline_sqrt_nsv,
    _sqrt("visser"): _baseline_sqrt_visser,
    _sqrt("newton"): _baseline_sqrt_newton,
    _sqrt("visser_coupled"):_baseline_sqrt_visser_coupled,
    _sqrt("newton_coupled"): _baseline_sqrt_newton_coupled,
    _proot("newton"): _baseline_proot_newton,
    _proot("visser"): _baseline_proot_visser,
    _proot("iannazzo"): _baseline_proot_iannazzo
}

DEFAULT_BASELINE_CONFIGS = {
    # ---- sign family -----------------------------------------------------
    _sign("newton"): dict(T=20, baseline_args=[]),
    _sign("scaled_newton"): dict(T=20, baseline_args=[1e-3, 1]),
    _sign("ns"): dict(T=50, baseline_args=[]),
    _sign("scaled_ns"): dict(T=50, baseline_args=[1e-3]),
    _sign("newton_variant"): dict(T=20, baseline_args=[1e-3, 1]),
    _sign("halley"): dict(T=20, baseline_args=[1e-3]),
    # ---- inverse family --------------------------------------------------
    _inv("ns"): dict(T=50, baseline_args=[]),
    _inv("ns_chebyshev"): dict(T=40, baseline_args=[]),
    # ---- sqrt family -----------------------------------------------------
    _sqrt("db"): dict(T=25, baseline_args=[]),
    _sqrt("nsv"): dict(T=25, baseline_args=[]),
    _sqrt("visser"): dict(T=150, baseline_args=[]),
    _sqrt("newton"): dict(T=25, baseline_args=[]),
    _sqrt("visser_coupled"): dict(T=150, baseline_args=[]),
    _sqrt("newton_coupled"): dict(T=25, baseline_args = []),
    # ---- proot family ----------------------------------------------------
    _proot("newton"): dict(T=50, baseline_args = []),
    _proot("visser"): dict(T=50, baseline_args = []),
    _proot("iannazzo"): dict(T=50, baseline_args = []),
}

ADAPTIVE_BASELINE_CONFIGS = {
    _inv("ns_greedy"): dict(T=50),  # adaptive baseline; no baseline_args
    _sqrt("scaled_db"): dict(T=25),
}


# -----------------------------------------------------------------------------
# Public façade
# -----------------------------------------------------------------------------
def baselines(
    baseline_name: str, T: int, size : int = 10, device: str = "cuda", modify_time = False, baseline_args: list | None = None
) -> List[tuple]:
    """Return a list of `(action_fn, theta)` tuples for *deterministic* baselines."""
    key = baseline_name.lower().strip()
    try:
        builder = _BASELINE_BUILDERS[key]
    except KeyError as exc:
        raise ValueError(f"Unknown baseline: {baseline_name}") from exc
    
    baseline = builder(T, baseline_args or [])
    if modify_time == True:

        action = baseline[0][0]
        estimate_relative_times(
                [action],
                size=size,
                device=device,
        )

        timed_baseline = []
        for i in range(T):
            timed_baseline.append((action, baseline[i][1]))
        return timed_baseline
    return baseline

# Adaptive (matrix‑dependent) baselines ---------------------------------------


def _adaptive_inv_ns_greedy(A: torch.Tensor, T: int) -> List[tuple]:
    X = A.clone()  # current iterate
    d = A.size(0)
    I = torch.eye(d, d, device=A.device, dtype=A.dtype)
    out: list[tuple] = []
    for _ in range(T):
        E = I - A @ X
        V0 = E @ E - E
        a = (-(V0.flatten() @ E.flatten()) / (V0.flatten() @ V0.flatten())).item()
        X = X + a * X @ E
        out.append((ACTIONS[_inv("ns")], np.array([a + 1.0, a])))
    return out

def _adaptive_sqrt_db(A: torch.tensor, T: int) -> List[tuple]:
   
    X = A.clone()  # current iterate
    d = A.size(0)
    I = torch.eye(d, d, device=A.device, dtype=A.dtype)
    Y = I
    out: list[tuple] = []
    for _ in range(T):
        _, logXdet = torch.linalg.slogdet(X)
        _, logYdet = torch.linalg.slogdet(Y)
        scale = -(logXdet + logYdet)/2/d
        scale = torch.exp(scale)
        X = scale*X
        Y = scale*Y
        nX = 0.5*(X + torch.linalg.inv(Y))
        nY = 0.5*(Y + torch.linalg.inv(X))
        X = nX
        Y = nY
        out.append((ACTIONS[_sqrt("db")],np.array([scale, scale])))
    return out


_ADAPTIVE_BUILDERS: dict[str, Callable[[torch.Tensor, int], List[tuple]]] = {
    "inv_ns_greedy": _adaptive_inv_ns_greedy,
    "sqrt_scaled_db": _adaptive_sqrt_db
}


def adaptive_baselines(baseline_name: str, A: torch.Tensor, T: int, size: int, device: str) -> List[tuple]:
    key = baseline_name.lower().strip()
    try:
        builder = _ADAPTIVE_BUILDERS[key]
    except KeyError as exc:
        raise ValueError(f"Unknown adaptive baseline: {baseline_name}") from exc
    baseline = builder(A, T)
    return baseline