import numpy as np
import pandas as pd
from typing import Dict, List
import multiprocessing as mp
from functools import partial
import argparse
import time

# ----------------------------
#  Parse command line arguments
# ----------------------------
def parse_args():
    parser = argparse.ArgumentParser(description='Riemannian optimization with multiprocessing - Logistic Loss')
    parser.add_argument('--seed', type=int, default=42, help='Global RNG seed (default: 42)')
    parser.add_argument('--d', type=int, default=16, help='Manifold dimension (default: 16)')
    parser.add_argument('--n_data', type=int, default=100, help='Number of data points (default: 100)')
    parser.add_argument('--n_iters', type=int, default=1000, help='Number of iterations (default: 1000)')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
    parser.add_argument('--mu', type=float, default=0.0001, help='Smoothing radius for ZO finite difference (default: 0.0001)')
    parser.add_argument('--m', type=int, default=16, help='Number of directions averaged per step (default: 16)')
    parser.add_argument('--n_runs_zo', type=int, default=5, help='Number of independent ZO runs (default: 5)')
    parser.add_argument('--n_processes', type=int, default=2, help='Number of parallel processes (default: 2)')
    parser.add_argument('--sigma', type=float, default=1.0, help='Noise level (default: 1.0)')
    parser.add_argument('--lam', type=float, default=0.1, help='Regularization parameter (default: 0.1)')
    return parser.parse_args()

# Parse arguments
args = parse_args()

# ----------------------------
#  Hyper‑parameters (from command line or defaults)
# ----------------------------
SEED = args.seed
d = args.d
n_data = args.n_data
n_iters = args.n_iters
lr = args.lr
mu = args.mu
m = args.m
n_runs_zo = args.n_runs_zo
n_processes = args.n_processes
sigma = args.sigma
lam = args.lam

FILE_NAME = f"logistic_loss_DIM_{d}_NDATA_{n_data}_NITERS_{n_iters}_LR_{lr}_MU_{mu}_ZOOBATCH_{m}_NRUNS_{n_runs_zo}_PROCESSES_{n_processes}.csv"

# ----------------------------
#  Data generation and problem setup
# ----------------------------
def generate_data(d: int, n_data: int, seed: int):
    """Generate synthetic classification data"""
    np.random.seed(seed)
    
    # Generate random data points
    X = np.random.randn(n_data, d)
    
    # Generate a true weight vector
    w_true = np.random.randn(d)
    w_true = w_true / np.linalg.norm(w_true)  # normalize
    
    # Generate labels with some noise
    logits = X @ w_true
    probs = 1 / (1 + np.exp(-logits))
    y = (np.random.rand(n_data) < probs).astype(float) * 2 - 1  # {-1, +1} labels
    
    return X, y, w_true

def setup_problem(d: int, n_data: int, seed: int, cond: float = 1e4):
    """Setup optimization problem with extreme SPD metric and logistic loss data"""
    np.random.seed(seed)
    
    # 1) Create extreme SPD metric A (same as before)
    X_metric = np.random.randn(d, d)
    Q, _ = np.linalg.qr(X_metric)
    
    # 2) Prescribe eigenvalues spanning [1, cond]
    lam_min = 1.0
    lam_max = cond * lam_min
    evals = np.geomspace(lam_min, lam_max, d)
    
    # 3) Build the extreme SPD metric A
    A = Q @ np.diag(evals) @ Q.T
    
    # 4) Generate classification data
    X, y, w_true = generate_data(d, n_data, seed)
    
    # 5) Create regularization matrix B (same as before)
    M = np.random.randn(d, d)
    B = M.T @ M + d * np.eye(d)
    
    return A, B, X, y, w_true

# ----------------------------
#  Loss function and gradient
# ----------------------------
def f(x: np.ndarray, X: np.ndarray, y: np.ndarray, B: np.ndarray, lam: float) -> float:
    """Logistic loss with L2 regularization"""
    # Logistic loss term
    logits = X @ x
    # Use stable computation: log(1 + exp(-|z|)) + max(z, 0) for log(1 + exp(z))
    # For y in {-1, 1}, loss = log(1 + exp(-y*logits))
    neg_y_logits = -y * logits
    logistic_loss = np.mean(np.log(1 + np.exp(-np.abs(neg_y_logits))) + np.maximum(neg_y_logits, 0))
    
    # Regularization term
    reg_term = 0.5 * lam * x.T @ B @ x
    
    return logistic_loss + reg_term

def grad_euclidean(x: np.ndarray, X: np.ndarray, y: np.ndarray, B: np.ndarray, lam: float) -> np.ndarray:
    """Euclidean gradient of the logistic loss"""
    # Logistic loss gradient
    logits = X @ x
    probs = 1 / (1 + np.exp(-logits))
    logistic_grad = X.T @ (probs - (y + 1) / 2) / len(y)  # Convert {-1,1} to {0,1}
    
    # Regularization gradient
    reg_grad = lam * B @ x
    
    return logistic_grad + reg_grad

def grad_riemannian(x: np.ndarray, A: np.ndarray, X: np.ndarray, y: np.ndarray, B: np.ndarray, lam: float) -> np.ndarray:
    """Riemannian gradient: A^{-1} * euclidean_gradient"""
    euclidean_grad = grad_euclidean(x, X, y, B, lam)
    return np.linalg.solve(A, euclidean_grad)

# ----------------------------
#  Sampling functions (same as before)
# ----------------------------
def sample_rejection(A: np.ndarray, n_samples: int) -> np.ndarray:
    d = A.shape[0]
    eigvals, eigvecs = np.linalg.eigh(A)
    L = eigvecs @ np.diag(1.0 / np.sqrt(eigvals))
    lam_max = eigvals.max()
    A2 = A @ A
    out = np.empty((n_samples, d))
    k = 0
    while k < n_samples:
        z = np.random.normal(size=d)
        v = L @ (z / np.linalg.norm(z))
        if np.random.rand() < np.sqrt(v @ (A2 @ v) / lam_max):
            out[k] = v
            k += 1
    return out

def sample_direct(A: np.ndarray, n_samples: int) -> np.ndarray:
    z = np.random.normal(size=(n_samples, A.shape[0]))
    norms = np.sqrt(np.einsum("ij,ij->i", z @ A, z))
    return z / norms[:, None]

# ----------------------------
#  Optimizers
# ----------------------------
def first_order(x0: np.ndarray, A: np.ndarray, X: np.ndarray, y: np.ndarray, B: np.ndarray, lam: float, sigma: float) -> List[float]:
    x = x0.copy()
    vals = [f(x, X, y, B, lam)]
    for i in range(n_iters):
        # 1) Compute clean Riemannian gradient
        grad = grad_riemannian(x, A, X, y, B, lam)
        # 2) Add noise Xi ~ N(0, I*σ²)
        Xi = np.random.randn(*x.shape) * sigma
        # 3) SGD update
        x -= lr * (grad + Xi)
        vals.append(f(x, X, y, B, lam))
        if i % 100 == 0:
            print(f"First-order iteration {i}, loss: {vals[-1]:.4f}")
    return vals

def zeroth_order(
        x0: np.ndarray,
        A: np.ndarray,
        X: np.ndarray,
        y: np.ndarray,
        B: np.ndarray,
        lam: float,
        sampler,
        run_seed: int,
        sigma: float
    ) -> List[float]:
    np.random.seed(run_seed)
    x = x0.copy()
    vals = [f(x, X, y, B, lam)]
    for i in range(n_iters):
        # 1) Estimate Riemannian gradient using sampler
        grad_est = np.zeros_like(x)
        vs = sampler(A, m)
        for v in vs:
            fd = (f(x + mu * v, X, y, B, lam) - f(x - mu * v, X, y, B, lam)) / (2 * mu)
            grad_est += fd * v
        grad_est /= m

        # 2) Add noise Xi ~ N(0, σ^2 I)
        Xi = np.random.randn(*x.shape) * sigma

        # 3) Update
        x -= lr * (grad_est + Xi)
        vals.append(f(x, X, y, B, lam))
        if i % 100 == 0:
            print(f"Zeroth-order iteration {i} completed by process {mp.current_process().name}, loss: {vals[-1]:.4f}")
    return vals

# ----------------------------
#  Worker function for multiprocessing
# ----------------------------
def run_optimization(run_id: int, x0: np.ndarray, A: np.ndarray, X: np.ndarray, y: np.ndarray, B: np.ndarray, lam: float, sigma: float) -> Dict[str, List[float]]:
    results = {}
    
    # Use different seeds for direct and rejection sampling
    direct_seed = SEED + 100 + run_id
    reject_seed = SEED + 200 + run_id
    
    print(f"Starting run {run_id} with direct sampling...")
    results[f"zeroth_order_direct_{run_id}"] = zeroth_order(x0, A, X, y, B, lam, sample_direct, direct_seed, sigma)
    
    print(f"Starting run {run_id} with rejection sampling...")
    results[f"zeroth_order_rejection_{run_id}"] = zeroth_order(x0, A, X, y, B, lam, sample_rejection, reject_seed, sigma)
    
    return results

def main():
    # Record start time
    start_time = time.time()
    
    # Setup the optimization problem
    print(f"Setting up problem with d={d}, n_data={n_data}...")
    A, B, X, y, w_true = setup_problem(d, n_data, SEED)
    x0 = np.random.randn(d) * 0.1  # smaller initial point
    
    print(f"Initial loss: {f(x0, X, y, B, lam):.4f}")
    print(f"True weight norm: {np.linalg.norm(w_true):.4f}")
    
    # Run first-order optimization
    print("Running first-order optimization...")
    results = {"first_order": first_order(x0, A, X, y, B, lam, sigma)}
    
    # Prepare arguments for parallel processing
    run_ids = range(1, n_runs_zo + 1)
    
    # Create a pool of workers
    print(f"Starting {n_runs_zo} zeroth-order runs with {n_processes} processes...")
    with mp.Pool(processes=n_processes) as pool:
        # Create partial function with fixed arguments
        worker_func = partial(run_optimization, x0=x0, A=A, X=X, y=y, B=B, lam=lam, sigma=sigma)
        # Map the worker function to run_ids
        parallel_results = pool.map(worker_func, run_ids)
    
    # Combine all results
    for res in parallel_results:
        results.update(res)
    
    # Create DataFrame and save results
    df = pd.DataFrame(results)
    df.to_csv(FILE_NAME, index=False)
    print(f"Optimization completed and results saved to {FILE_NAME}!")
    
    # Print final losses
    print("\nFinal losses:")
    for key, vals in results.items():
        print(f"{key}: {vals[-1]:.4f}")
    
    # Record end time and print elapsed time
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total runtime: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")

if __name__ == "__main__":
    # Test command:
    # python multi-pro-logistic.py --d 8 --n_data 50 --n_iters 1000000 --n_runs_zo 16 --n_processes 8 --lr 0.00001
    # python multi-pro-logistic.py --d 16 --n_data 50 --n_iters 1000000 --n_runs_zo 16 --n_processes 8 --lr 0.000001


    main() 