#!/usr/bin/env python3

import argparse
import sys
import os
import numpy as np
import matplotlib.pyplot as plt

sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
from main import STESimulator

plt.rcParams['font.family'] ='sans-serif'
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.major.width'] = 1.0
plt.rcParams['ytick.major.width'] = 1.0
plt.rcParams['axes.linewidth'] = 1.0
plt.rcParams['axes.xmargin'] = 0.01
plt.rcParams['axes.ymargin'] = 0.01
plt.rcParams["legend.fancybox"] = False
plt.rcParams["legend.framealpha"] = 0
plt.rcParams["legend.edgecolor"] = 'none'
plt.rcParams["mathtext.fontset"] = "stix"


def parse_arguments():
    parser = argparse.ArgumentParser(description='STE vs ODE analysis with omega=1.0, b=3')
    
    parser.add_argument('--d', type=int, default=400, help='dimension (default: 400)')
    parser.add_argument('--tau', type=float, default=1500, help='max tau value (default: 1500)')
    parser.add_argument('--eta', type=float, default=0.005, help='learning rate (default: 0.005)')
    parser.add_argument('--lam', type=float, default=1.0, help='regularization parameter (default: 1.0)')
    parser.add_argument('--sigma', type=float, default=0.01, help='noise standard deviation (default: 0.01)')
    parser.add_argument('--seed', type=int, default=0, help='random seed (default: 0)')
    parser.add_argument('--log_every', type=int, default=100, help='log interval (default: 100)')
    
    parser.add_argument('--num_runs', type=int, default=5, help='number of simulation runs (default: 5)')
    parser.add_argument('--substeps', type=int, default=12, help='ODE RK4 substeps (default: 12)')
    parser.add_argument('--k_step_tau', type=float, default=125, help='error bar display interval in tau units (default: 125)')
    
    parser.add_argument('--output_dir', type=str, default='fig', help='output directory (default: fig)')
    parser.add_argument('--figsize', type=float, nargs=2, default=[6.4, 4.8], help='figure size (default: [6.4, 4.8])')
    
    args = parser.parse_args()
    
    args.T = int(args.tau * args.d)
    args.k_step = int(args.k_step_tau * args.d / args.log_every)
    
    return args


def run_experiment(args):
    print("Starting experiment with omega=1.0, b=3...")
    
    omega = 1.0
    b = 3
    
    params = {
        'd': args.d,
        'T': args.T,
        'eta': args.eta,
        'lam': args.lam,
        'sigma': args.sigma,
        'seed': args.seed,
        'log_every': args.log_every,
        'quant_method': 'step',
        'snapshot_every': 5000,
        'omega': omega,
        'b': b
    }
    
    sim = STESimulator(**params)
    histories, ode = sim.run_histories(
        num_runs=args.num_runs,
        substeps=args.substeps,
        shared_init=True,
        ode_once=True
    )
    
    d = params['d']
    steps = histories[0]['steps']
    tau = steps / float(d)
    
    Hagg = STESimulator.aggregate_histories(histories, ('epsilon_g',))
    ste_mean = Hagg['epsilon_g'].mean(axis=1)
    ste_std = Hagg['epsilon_g'].std(axis=1)
    
    Oagg = STESimulator.aggregate_odes(ode, ('epsilon_g',))
    ode_eps = Oagg['epsilon_g'][:, 0]
    
    return tau, ste_mean, ste_std, ode_eps


def visualize_results(tau, ste_mean, ste_std, ode_eps, args):
    print("Visualizing results...")
    
    fig, ax = plt.subplots(figsize=args.figsize)
    
    ax.plot(tau, ode_eps, color='tab:blue', linewidth=2.0, label='ODE')
    
    idx = np.arange(0, len(tau), max(int(args.k_step), 1))
    ax.errorbar(tau[idx], ste_mean[idx], yerr=ste_std[idx],
                capsize=3, fmt='o',
                ecolor='tab:orange', ms=7, mfc='None', mec='tab:orange', color='tab:orange',
                label='STE')
    
    ax.grid(True, linestyle='--', alpha=0.7)
    
    ax.set_xlabel(r'$\tau$', fontsize=32)
    ax.set_ylabel(r'$\varepsilon_{g}$', fontsize=32)
    
    ax.minorticks_on()
    ax.tick_params(axis='both', which='major', direction='in', length=5, width=1.2, top=True, right=True, labelsize=16)
    ax.tick_params(axis='both', which='minor', direction='in', length=2, width=0.8, top=True, right=True)
    
    y_data = list(ode_eps) + list(ste_mean)
    y_min = min(y_data) - 0.05
    ax.set_ylim(y_min, 1.0)
    
    ax.legend(loc='best', fontsize=14)
    
    plt.tight_layout(pad=0.2)
    
    os.makedirs(args.output_dir, exist_ok=True)
    base_path = os.path.join(args.output_dir, f"ste_ode_comparison_omega1.0_b3_tau{tau[-1]:.0f}_d{args.d}")
    pdf_path = base_path + '.pdf'
    png_path = base_path + '.png'
    
    plt.savefig(pdf_path, bbox_inches='tight')
    plt.savefig(png_path, dpi=300, bbox_inches='tight')
    
    print(f"Figures saved:")
    print(f"  PDF: {pdf_path}")
    print(f"  PNG: {png_path}")
    
    plt.close()


def main():
    args = parse_arguments()
    
    print("=== STE vs ODE comparison experiment with omega=1.0, b=3 ===")
    print(f"Parameters: d={args.d}, T={args.T}, eta={args.eta}, lam={args.lam}, sigma={args.sigma}")
    print(f"Number of runs: {args.num_runs}")
    print(f"Output directory: {args.output_dir}")
    print()
    
    tau, ste_mean, ste_std, ode_eps = run_experiment(args)
    
    visualize_results(tau, ste_mean, ste_std, ode_eps, args)
    
    print("\n=== Experiment completed ===")


if __name__ == '__main__':
    main()
