import numpy as np
import pandas as pd
from typing import Dict, List
import multiprocessing as mp
from functools import partial
import argparse
import time

def parse_args():
    parser = argparse.ArgumentParser(description='Riemannian optimization with multiprocessing')
    parser.add_argument('--seed', type=int, default=42, help='Global RNG seed (default: 42)')
    parser.add_argument('--d', type=int, default=16, help='Manifold dimension (default: 16)')
    parser.add_argument('--n_iters', type=int, default=1000, help='Number of iterations (default: 1000)')
    parser.add_argument('--lr', type=float, default=0.001, help='Learning rate (default: 0.001)')
    parser.add_argument('--mu', type=float, default=0.0001, help='Smoothing radius for ZO finite difference (default: 0.0001)')
    parser.add_argument('--m', type=int, default=16, help='Number of directions averaged per step (default: 16)')
    parser.add_argument('--n_runs_zo', type=int, default=5, help='Number of independent ZO runs (default: 5)')
    parser.add_argument('--n_processes', type=int, default=5, help='Number of parallel processes (default: 5)')
    parser.add_argument('--sigma', type=float, default=1.0, help='Noise level (default: 1.0)')
    return parser.parse_args()

args = parse_args()

SEED = args.seed
d = args.d
n_iters = args.n_iters
lr = args.lr
mu = args.mu
m = args.m
n_runs_zo = args.n_runs_zo
n_processes = args.n_processes
sigma = args.sigma

FILE_NAME = f"loss_DIM_{d}_NITERS_{n_iters}_LR_{lr}_MU_{mu}_ZOOBATCH_{m}_NRUNS_{n_runs_zo}_PROCESSES_{n_processes}.csv"

def setup_problem(d: int, seed: int, cond: float = 1e4):
    np.random.seed(seed)
    X = np.random.randn(d, d)
    Q, _ = np.linalg.qr(X)
    lam_min = 1.0
    lam_max = cond * lam_min
    evals = np.geomspace(lam_min, lam_max, d)
    A = Q @ np.diag(evals) @ Q.T
    M = np.random.randn(d, d)
    B = M.T @ M + d * np.eye(d)

    return A, B


def f(x: np.ndarray, B: np.ndarray) -> float:
    return 0.5 * x.T @ B @ x

def grad_riemannian(x: np.ndarray, A: np.ndarray, B: np.ndarray) -> np.ndarray:
    return np.linalg.solve(A, B @ x)

def sample_rejection(A: np.ndarray, n_samples: int) -> np.ndarray:
    d = A.shape[0]
    eigvals, eigvecs = np.linalg.eigh(A)
    L = eigvecs @ np.diag(1.0 / np.sqrt(eigvals))
    lam_max = eigvals.max()
    A2 = A @ A
    out = np.empty((n_samples, d))
    k = 0
    while k < n_samples:
        z = np.random.normal(size=d)
        v = L @ (z / np.linalg.norm(z))
        if np.random.rand() < np.sqrt(v @ (A2 @ v) / lam_max):
            out[k] = v
            k += 1
    return out

def sample_direct(A: np.ndarray, n_samples: int) -> np.ndarray:
    z = np.random.normal(size=(n_samples, A.shape[0]))
    norms = np.sqrt(np.einsum("ij,ij->i", z @ A, z))
    return z / norms[:, None]

def first_order(x0: np.ndarray, A: np.ndarray, B: np.ndarray, sigma: float) -> List[float]:
    x = x0.copy()
    vals = [f(x, B)]
    for _ in range(n_iters):
        grad = np.linalg.solve(A, B @ x)
        Xi = np.random.randn(*x.shape) * sigma
        x -= lr * (grad + Xi)
        vals.append(f(x, B))
    return vals
 



def zeroth_order(
        x0: np.ndarray,
        A: np.ndarray,
        B: np.ndarray,
        sampler,
        run_seed: int,
        sigma: float
    ) -> List[float]:
    np.random.seed(run_seed)
    x = x0.copy()
    vals = [f(x, B)]
    for _ in range(n_iters):
        grad_est = np.zeros_like(x)
        vs = sampler(A, m)
        for v in vs:
            fd = (f(x + mu * v, B) - f(x - mu * v, B)) / (2 * mu)
            grad_est += fd * v
        grad_est /= m
        Xi = np.random.randn(*x.shape) * sigma
        x -= lr * (grad_est + Xi)
        vals.append(f(x, B))
        if _ % 1000 == 0:
            print(f"Iteration {_} completed by process {mp.current_process().name}")
    return vals

def run_optimization(run_id: int, x0: np.ndarray, A: np.ndarray, B: np.ndarray, sigma: float) -> Dict[str, List[float]]:
    results = {}

    direct_seed = SEED + 100 + run_id
    reject_seed = SEED + 200 + run_id
    results[f"zeroth_order_direct_{run_id}"] = zeroth_order(x0, A, B, sample_direct, direct_seed, sigma)
    results[f"zeroth_order_rejection_{run_id}"] = zeroth_order(x0, A, B, sample_rejection, reject_seed, sigma)
    return results

def main():
    start_time = time.time()
    A, B = setup_problem(d, SEED)
    x0 = np.random.randn(d)
    results = {"first_order": first_order(x0, A, B, sigma)}
    run_ids = range(1, n_runs_zo + 1)
    with mp.Pool(processes=n_processes) as pool:
        worker_func = partial(run_optimization, x0=x0, A=A, B=B, sigma=sigma)
        parallel_results = pool.map(worker_func, run_ids)
    for res in parallel_results:
        results.update(res)
    df = pd.DataFrame(results)
    df.to_csv(FILE_NAME, index=False)
    print(f"Optimization completed and results saved to {FILE_NAME}!")
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"Total runtime: {elapsed_time:.2f} seconds ({elapsed_time/60:.2f} minutes)")

if __name__ == "__main__":
    main() 