import os
import jax
import time
import argparse
import numpy as np

from optimizers import run_optex, run_benchmark, run_standard, run_line_search, tuning_mattern
from functions.synthetic import noisy_levy, noisy_ackley, noisy_sphere, noisy_rosenbrock

def parse_arguments():
    parser = argparse.ArgumentParser(description='OptEx(Synthetic) experiments')
    parser.add_argument('--func', default='noisy_levy', type=str, help='function to optimize')
    parser.add_argument('--dim', default=100000, type=int, help='input dimension')
    parser.add_argument('--std', default=0.0, type=float, help='noise std')
    parser.add_argument('--opt_name', default="sgd", type=str, help='optimizer name')
    parser.add_argument('--lr', default=0.1, type=float, help='learning rate')
    parser.add_argument('--num_epochs', default=20, type=int, help='number of epochs')
    parser.add_argument('--num_iters', default=1, type=int, help='number of iterations')
    parser.add_argument('--num_parall', default=5, type=int, help='number of parallel iterations')
    parser.add_argument('--seed', default=0, type=int, help='seed for random number generator')
    parser.add_argument('--num_runs', default=5, type=int, help='number of runs')
    args = parser.parse_args()
    
    return args

def run_specific_case(name, args, epoch_idx=0):
    key = jax.random.PRNGKey(args.seed + epoch_idx * 1234)
    x0 = jax.random.uniform(key, (args.dim, ))
    datas = [[args.std] for _ in range(args.num_parall * args.num_iters)]

    arguments = {
        "opt_name":     "optax." + args.opt_name,
        "lr":           args.lr,
        "x0":           x0, 
        "opt_state":    None,
        "num_iters":    args.num_iters, 
        "num_parall":   args.num_parall,
        "datas":        datas, 
    }
    
    num_epochs = args.num_epochs
    
    if name in ["optex", "line_search"]:
        arguments["inter_results"] = {"length_scale": 64.0}
    if name in ["line_search"]:
        arguments["inter_results"] = {}
    if name in ["standard"]:
        if args.num_iters == 1:
            num_epochs *= arguments["num_parall"]
        elif args.num_epochs == 1:
            arguments["num_iters"] *= arguments["num_parall"]
        arguments["num_parall"] = 1

    func = eval(args.func)
    results = [func(x0, std=0)]
    for i in range(num_epochs):
        x, fx, opt_state = eval("run_"+name)(func, **arguments)
        arguments.update({
            "x0":           x,
            "opt_state":    opt_state,
        })
        results.append(func(x, std=0))
        
        if name == "optex" and i % 5 == 4:
            xs, ys = arguments["inter_results"]["x_history"], arguments["inter_results"]["g_history"]
            xs, ys = np.concatenate(xs, axis=0), np.concatenate(ys, axis=0)
            
            border = 1 * args.num_parall # int(0.85 * len(xs))
            length_scale = tuning_mattern(
                xs[:-border, :], 
                ys[:-border, :], 
                xs[-border:, :], 
                ys[-border:, :], 
                choice=[0.01, 0.1, 1, 10, 100],
                effective_dim=5000
            )
            arguments["inter_results"].update({
                "length_scale": length_scale,
            })
            print("optimized length scale:", length_scale)
        
        print("Final func(x)", "%.4f" % func(x, std=0))
    return results

if __name__ == "__main__":
    args = parse_arguments()
    
    for method in ["optex", "benchmark", "line_search", "standard"]:
        print(f"============== {method.upper()} ============================")
        all_results = []
        for r in range(args.num_runs):
            print("Run ", r)
            start = time.time()
            single_results = run_specific_case(method, args, epoch_idx=r)
            print("Total time: ", time.time() - start)
            all_results.append(single_results)
        
        root = f"results/{args.func}({args.dim},{args.std})"
        if os.path.exists(root) == False:
            os.makedirs(root)
        np.save(f"{root}/{args.opt_name}({args.lr})-{args.num_epochs}x{args.num_parall}-{method}.npy", np.array(all_results))