from typing import Optional, Tuple, List, Dict, Callable, Union

# import matplotlib.pyplot as plt
import numpy as np
import os
import time
import gswalk_poly_weights as gs_weights
import gswalk_low_rank as gs_low_rank
from kv import CompleteRand, Bernoulli, BlockOrthant, PairwiseMatch, RerandMR
import kv
import test_helper as helper

# F0_FUNCTIONS: Dict[str, Callable[[np.ndarray], float]] = {
#     "linear": lambda xx: xx[0] - xx[1],
#     "quadratic": lambda xx: xx[0] - xx[1] + xx[0] ** 2 + xx[1] ** 2 - 2*xx[0]*xx[1],
#     "cubic": lambda xx: xx[0] - xx[1] + xx[0] ** 2 + xx[1] ** 2 - 2*xx[0]*xx[1] + xx[0] ** 3 - xx[1] ** 3 - 3*xx[0]**2*xx[1] + 3*xx[0]*xx[1]**2,
#     "sinusoidal": lambda xx: np.sin((np.pi/3) * (1 + xx[0] - 2*xx[1])) - 6 * np.sin(np.pi * (xx[0]/3 + xx[1]/4)) + 6 * np.sin(np.pi * (xx[0]/3 + xx[1]/6)),
# }

C = np.array([1, 2, 3, 4, 5], dtype=float)
F0_FUNCTIONS: Dict[str, Callable[[np.ndarray], float]] = {
    "linear": lambda xx: float(np.dot(C, xx[:5])),
    "quadratic": lambda xx: float(np.dot(C, xx[:5]) ** 2),
    "cubic": lambda xx: float(np.dot(C, xx[:5]) ** 3),
    "sinusoidal": lambda xx: np.sin((np.pi/3) * (1 + xx[0] - 2*xx[1])) - 6 * np.sin(np.pi * (xx[0]/3 + xx[1]/4)) + 6 * np.sin(np.pi * (xx[0]/3 + xx[1]/6)),
}


# Base order to plot designs (only shown if present in results).
DESIGN_ORDER = [
    "comprand",
    "blocking",
    "pairmatch",
    "rerandom",
    # "gswalk_kernel_lin",
    # "gswalk_kernel_quad",
    # "gswalk_kernel_gaus",
    # "gswalk_kernel_exp",
    "lin_PSOD",
    "quad_PSOD",
    "gaus_PSOD",
    "exp_PSOD",
    "gaus_MSOD",
    "exp_MSOD",
    "gswalk_poly_1",
    "gswalk_poly_2",
]

def make_X(n: int, d: int, rng: Optional[np.random.Generator] = None) -> np.ndarray:
    """Generate feature matrix with rows normalized to have ||x_i|| <= 1."""
    if rng is None:
        gen = np.random.default_rng()
    elif isinstance(rng, np.random.Generator):
        gen = rng
    else:
        gen = np.random.default_rng(rng)
    X = gen.standard_normal((n, d))
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    norms = np.maximum(norms, 1.0)
    return X / norms


def generate_X_dataset(
    num_datasets: int = 1000,
    n: int = 50,
    d: int = 2,
    seed: int = 0,
    out_path: str = "ridge_input_n50_d2_1000.npz",
):
    """
    Generate many X matrices (shape n x d) and save to a compressed .npz file.
    """
    rng = np.random.default_rng(seed)
    X_all = np.empty((num_datasets, n, d), dtype=float)
    for i in range(num_datasets):
        X_all[i] = make_X(n, d, rng)
    np.savez_compressed(out_path, X=X_all, n=n, d=d, seed=seed)
    print(f"Saved {num_datasets} X matrices to {out_path}")
    return out_path


def generate_gsw_assignments_from_X_npz(
    input_path: str,
    phi: float = 0.5,
    w: Optional[np.ndarray] = None,
    gsw_its: int = 100,
    out_path: str = "gsw_assignments.npz",
    balance: bool = False,
    exp: bool = False,
    low_rank_approx: bool = False
) -> str:
    """
    Load X matrices from an .npz, run GSwalk with fixed phi and w, and save assignments.

    Saves a compressed .npz with fields:
      assignments: shape (num_X, gsw_its, n)
      w, phi, gsw_its, n, d, source
    """
    if w is None:
        w = np.array([0.0]) if exp else np.array([1.0, 0.5])
    data = np.load(input_path)
    X_all = data["X"]
    num_X, n, d = X_all.shape
    k = len(w)

    assignments = np.empty((num_X, gsw_its, n), dtype=int)

    progress_step = max(1, num_X // 5)
    runtimes = np.empty(num_X, dtype=float)
    for i_x, X in enumerate(X_all):
        
        V = X.T  # shape (d, n)

        start_time = time.time()
        if low_rank_approx:
            gsw_assignments = gs_low_rank.GSwalk_poly_low_rank_many(
                V, 
                k, 
                d, # use a rank dk matrix to approximate the Gram
                phi, 
                gsw_its, 
                weights=w, 
                balance=balance, 
                exp=exp,
            )
        else:
            gsw_assignments = gs_weights.GSwalk_poly_aug_many(
                V,
                k,
                phi,
                gsw_its,
                weights=w,
                balance=balance,
                exp=exp,
            )
        runtimes[i_x] = time.time() - start_time
        assignments[i_x] = np.array(gsw_assignments, dtype=int)


    np.savez_compressed(
        out_path,
        assignments=assignments,
        # runtimes=runtimes,
        w=w,
        phi=phi,
        gsw_its=gsw_its,
        n=n,
        d=d,
        source=input_path,
    )

    print(
        f"GSwalk runtimes: mean={runtimes.mean():.6g}s, std={runtimes.std(ddof=0):.6g}s"
    )
    print(f"Saved GSwalk assignments to {out_path}")
    return out_path

def generate_classic_assignments_from_X_npz(
    input_path: str,
    its: int = 1000,
    out_path: str = "classic_assignments.npz",
) -> str:
    """
    Load X matrices and save classic-design assignments in one file.

    Saves a compressed .npz with fields:
      comprand_assignments: shape (num_X, its, n)
      bernoulli_assignments: shape (num_X, its, n)
      blocking_assignments: shape (num_X, its, n)
      pairmatch_assignments: shape (num_X, its, n)
      rerandom_assignments: shape (num_X, its, n)
      classic_its, n, d, source
    """
    data = np.load(input_path)
    X_all = data["X"]
    num_X, n, d = X_all.shape

    comprand_assignments = np.empty((num_X, its, n), dtype=int)
    bernoulli_assignments = np.empty((num_X, its, n), dtype=int)
    blocking_assignments = np.empty((num_X, its, n), dtype=int)
    pairmatch_assignments = np.empty((num_X, its, n), dtype=int)
    rerandom_assignments = np.empty((num_X, its, n), dtype=int)

    for i_x, X in enumerate(X_all):
        comprand_assignments[i_x] = np.array(CompleteRand(n, its), dtype=int)
        bernoulli_assignments[i_x] = np.array(Bernoulli(n, its), dtype=int)
        blocking_assignments[i_x] = np.array(BlockOrthant(X, its), dtype=int)
        pairmatch_assignments[i_x] = np.array(PairwiseMatch(X, its), dtype=int)
        rerandom_assignments[i_x] = np.array(RerandMR(X, its, 0.01), dtype=int)

    np.savez_compressed(
        out_path,
        comprand_assignments=comprand_assignments,
        bernoulli_assignments=bernoulli_assignments,
        blocking_assignments=blocking_assignments,
        pairmatch_assignments=pairmatch_assignments,
        rerandom_assignments=rerandom_assignments,
        classic_its=its,
        n=n,
        d=d,
        source=input_path,
    )
    print(f"Saved classic assignments to {out_path}")
    return out_path

def load_x_and_assignments(
    X_npz_path: str,
    assignments_npz_paths: Union[str, List[str]],
    classic_assignments_npz_path: str,
) -> Tuple[np.ndarray, int, int, List[Tuple[str, np.ndarray]], Dict[str, np.ndarray]]:
    X_data = np.load(X_npz_path)
    X_all = X_data["X"]
    num_X, n, _ = X_all.shape

    if isinstance(assignments_npz_paths, str):
        assignments_npz_paths = [assignments_npz_paths]
    if not assignments_npz_paths:
        raise ValueError("At least one assignments file is required.")
    if not classic_assignments_npz_path:
        raise ValueError("classic_assignments_npz_path is required.")

    gsw_entries = []
    for path in assignments_npz_paths:
        A_data = np.load(path)
        assignments = A_data["assignments"]  # shape (num_X, gsw_its, n)
        if assignments.shape[0] != num_X or assignments.shape[-1] != n:
            raise ValueError(f"Mismatch between X matrices and assignments in {path}.")
        w = A_data["w"]
        phi = A_data["phi"]
        if os.path.basename(path).startswith("gsw_assignment_balance"):
            gsw_design = f"gswalk_poly_balance_w{w}_phi{phi}"
        elif os.path.basename(path).startswith("gsw_assignment_lowrank"):
            gsw_design = f"gswalk_poly_lowrank_w{w}_phi{phi}"
        else:
            gsw_design = f"gswalk_poly_w{w}_phi{phi}"
        gsw_entries.append((gsw_design, assignments))

    classic_data = np.load(classic_assignments_npz_path)
    classic_assignments = {
        "comprand": classic_data["comprand_assignments"],
        "bernoulli": classic_data["bernoulli_assignments"],
        "blocking": classic_data["blocking_assignments"],
        "pairmatch": classic_data["pairmatch_assignments"],
        "rerandom": classic_data["rerandom_assignments"],
    }
    for design_name, assigns in classic_assignments.items():
        if assigns.shape[0] != num_X or assigns.shape[-1] != n:
            raise ValueError(
                f"Mismatch between X matrices and {design_name} assignments."
            )

    return X_all, n, num_X, gsw_entries, classic_assignments

def evaluate_cov_from_saved(
    X_npz_path: str,
    assignments_npz_paths: Union[str, List[str]],
    classic_assignments_npz_path: str,
) -> Dict[str, float]:
    """
    Load X matrices, classic-design assignments, and GSwalk assignments, then
    average covariance matrix operator norm over all X.
    """
    
    X_all, n, num_X, gsw_entries, classic_assignments = load_x_and_assignments(
        X_npz_path=X_npz_path,
        assignments_npz_paths=assignments_npz_paths,
        classic_assignments_npz_path=classic_assignments_npz_path,
    )

    design_names = list(classic_assignments.keys()) + [
        name for name, _ in gsw_entries
    ]
    gram_power_exponents = list(range(1, 6))
    covu_sums = {f"{design}_U": 0.0 for design in design_names}
    covgram_sums = {
        f"{design}_gram_p{p}": 0.0 for design in design_names for p in gram_power_exponents
    }

    for i_x, X in enumerate(X_all):
        def op_norm_xtux(
            assignments_i: np.ndarray,
        ) -> Tuple[float, List[float]]:
            U = np.einsum("bi,bj->ij", assignments_i, assignments_i) / float(
                assignments_i.shape[0]
            )
            evals, evecs = np.linalg.eigh(U)
            U_sqrt = (evecs * np.sqrt(np.maximum(evals, 0.0))) @ evecs.T

            # form the Gram matrix X X^t
            gram = X @ X.T  # n-by-n matrix
            powers = np.array(gram_power_exponents, dtype=float)
            gram_powers = gram[..., None] ** powers
            gram_eigs = []
            for i_p in range(gram_powers.shape[-1]):
                gram_op = U_sqrt @ gram_powers[..., i_p] @ U_sqrt
                gram_eigs.append(float(np.linalg.eigvalsh(gram_op)[-1]))

            return (
                float(evals[-1]),
                gram_eigs,
            )

        for design_name, assignments in classic_assignments.items():
            op_u, gram_eigs = op_norm_xtux(assignments[i_x])
            covu_sums[f"{design_name}_U"] += op_u
            for p, gram_eig in zip(gram_power_exponents, gram_eigs):
                covgram_sums[f"{design_name}_gram_p{p}"] += gram_eig

        for gsw_design, gsw_assignments in gsw_entries:
            op_u, gram_eigs = op_norm_xtux(gsw_assignments[i_x])
            covu_sums[f"{gsw_design}_U"] += op_u
            for p, gram_eig in zip(gram_power_exponents, gram_eigs):
                covgram_sums[f"{gsw_design}_gram_p{p}"] += gram_eig

    report = {k: v / float(num_X) for k, v in covu_sums.items()}
    report.update({k: v / float(num_X) for k, v in covgram_sums.items()})
    return report

def evaluate_gsw_condvar_from_saved(
    X_npz_path: str,
    assignments_npz_paths: Union[str, List[str]],
    classic_assignments_npz_path: str,
    f0_funcs: Optional[Dict[str, Callable[[np.ndarray], float]]] = None,
    sigma: float = 0.0,   # the standard deviation of residuals
) -> Dict[str, Dict[str, float]]:
    """
    Load X matrices, classic-design assignments, and GSwalk assignments, then
    average conditional variances for each f0 across all X.
    """
    if f0_funcs is None:
        f0_funcs = F0_FUNCTIONS
    X_all, n, num_X, gsw_entries, classic_assignments = load_x_and_assignments(
        X_npz_path=X_npz_path,
        assignments_npz_paths=assignments_npz_paths,
        classic_assignments_npz_path=classic_assignments_npz_path,
    )

    design_names = list(classic_assignments.keys()) + [
        name for name, _ in gsw_entries
    ]
    condvar_sums = {
        fname: {design: 0.0 for design in design_names} for fname in f0_funcs
    }

    for i_x, X in enumerate(X_all):
       
        for fname, f0 in f0_funcs.items():
            y0 = np.array([f0(xi) for xi in X])
            # y0 = y0 - y0.mean() + sigma*np.random.randn(n)
            y0 = y0 + sigma*np.random.randn(n)
            for design_name, assignments in classic_assignments.items():
                us = assignments[i_x]
                stats = us @ y0 / n
                condvar_sums[fname][design_name] += float((stats ** 2).mean())
            
            for gsw_design, gsw_assignments in gsw_entries:
                gsw_assigns_x = gsw_assignments[i_x]
                stats = gsw_assigns_x @ y0 / n
                condvar_sums[fname][gsw_design] += float((stats ** 2).mean())

    report = {
        fname: {k: v / float(num_X) for k, v in sums.items()}
        for fname, sums in condvar_sums.items()
    }
    # for fname, design_vals in report.items():
    #     print(f"f0={fname}")
    #     for design_name in sorted(design_vals.keys()):
    #         print(f"  {design_name}: {design_vals[design_name]:.6g}")
    return report

def test(
    n: int,
    d: int,
    num_trials: int,
    phi: float,
    ws: List[np.ndarray],
    gsw_its: int = 1000,
    sigma: float = 0.1,
    data_dir: str = "data",
) -> Dict[str, Dict[str, float]]:
    os.makedirs(data_dir, exist_ok=True)
    X_file = helper.build_X_paths([n], d, num_trials, data_dir)[0]
    generate_X_dataset(
        num_datasets=num_trials,
        n=n,
        d=d,
        seed=0,
        out_path=X_file,
    )

    gsw_assign_paths = []
    gsw_assign_entries = helper.build_gsw_assign_paths(
        [n], d, num_trials, ws, [phi], data_dir
    )[0]
    for w, phi_val, gsw_assign_file in gsw_assign_entries:
        print(f"generate GSW assignment with weights {w} and phi {phi_val}")
        generate_gsw_assignments_from_X_npz(
            input_path=X_file,
            w=np.array(w, dtype=float),
            phi=phi_val,
            gsw_its=gsw_its,
            out_path=gsw_assign_file,
        )
        gsw_assign_paths.append(gsw_assign_file)

    classic_assign_file = helper.build_classic_assign_path(
        n, d, num_trials, data_dir
    )
    generate_classic_assignments_from_X_npz(
        input_path=X_file,
        its=gsw_its,
        out_path=classic_assign_file,
    )


def evaluate(
    n: int,
    d: int,
    num_trials: int,
    phi: float,
    ws: List[np.ndarray],
    gsw_its: int = 1000,
    sigma: float = 0.1,
    data_dir: str = "data",
) -> Dict[str, Dict[str, float]]:
    os.makedirs(data_dir, exist_ok=True)
    X_file = helper.build_X_paths([n], d, num_trials, data_dir)[0]

    gsw_assign_paths = [
        p
        for _, _, p in helper.build_gsw_assign_paths(
            [n], d, num_trials, ws, [phi], data_dir
        )[0]
    ]
    classic_assign_file = helper.build_classic_assign_path(
        n, d, num_trials, data_dir
    )

    cov = evaluate_cov_from_saved(
        X_file,
        assignments_npz_paths=gsw_assign_paths,
        classic_assignments_npz_path=classic_assign_file,
    )

    results = evaluate_gsw_condvar_from_saved(
        X_file,
        assignments_npz_paths=gsw_assign_paths,
        classic_assignments_npz_path=classic_assign_file,
        sigma=sigma,
    )

    return results
