from typing import Optional, Tuple, List, Dict, Callable, Union

import numpy as np
import os

def build_X_paths(
    ns: List[int],
    d: int,
    num_trials: int,
    data_dir: str,
) -> List[str]:
    return [
        os.path.join(data_dir, f"X_input_n{n}_d{d}_num_trials{num_trials}.npz")
        for n in ns
    ]

def build_gsw_assign_paths(
    ns: List[int],
    d: int,
    num_trials: int,
    ws: List[np.ndarray],
    phis: List[float],
    data_dir: str,
    balance: bool = False,
    low_rank_approx: bool = False,
) -> List[List[Tuple[np.ndarray, float, str]]]:
    all_paths = []
    for n in ns:
        per_n = []
        for w in ws:
            if balance and np.allclose(w, 0.0):
                continue
            for phi in phis:
                w_tag = "-".join(str(float(x)) for x in w)
                if low_rank_approx:
                    gsw_assign_file = os.path.join(
                        data_dir,
                        f"gsw_assignment_lowrank_n{n}_d{d}_w{w_tag}_phi{phi}_num_trials{num_trials}.npz",
                    )
                elif balance:
                    gsw_assign_file = os.path.join(
                        data_dir,
                        f"gsw_assignment_balance_n{n}_d{d}_w{w_tag}_phi{phi}_num_trials{num_trials}.npz",
                    )
                else:
                    gsw_assign_file = os.path.join(
                        data_dir,
                        f"gsw_assignment_n{n}_d{d}_w{w_tag}_phi{phi}_num_trials{num_trials}.npz",
                    )
                per_n.append((w, phi, gsw_assign_file))
        all_paths.append(per_n)
    return all_paths

def build_classic_assign_path(
    n: int,
    d: int,
    num_trials: int,
    data_dir: str,
) -> str:
    return os.path.join(
        data_dir, f"classic_assignment_n{n}_d{d}_num_trials{num_trials}.npz"
    )


def make_X2_from_X(X: np.ndarray) -> np.ndarray:
    """Return matrix whose rows are vec(x_i ⊗ x_i) for each row x_i of X."""
    return np.vstack([np.kron(xi, xi) for xi in X])


def make_X3_from_X(X: np.ndarray) -> np.ndarray:
    """Return matrix whose rows are vec(x_i ⊗ x_i ⊗ x_i) for each row x_i of X."""
    return np.vstack([np.kron(xi, np.kron(xi, xi)) for xi in X])


def make_X4_from_X(X: np.ndarray) -> np.ndarray:
    """Return matrix whose rows are vec(x_i ⊗ x_i ⊗ x_i ⊗ x_i)."""
    return np.vstack([np.kron(np.kron(xi, xi), np.kron(xi, xi)) for xi in X])


def make_X5_from_X(X: np.ndarray) -> np.ndarray:
    """Return matrix whose rows are vec(x_i ⊗ x_i ⊗ x_i ⊗ x_i ⊗ x_i)."""
    return np.vstack([np.kron(np.kron(np.kron(xi, xi), xi), np.kron(xi, xi)) for xi in X])
