import sys
import os
from collections import defaultdict

# Get the absolute path of the directory containing the notebook
# This assumes your notebook's current working directory IS the 'notebook' folder
notebook_dir = os.getcwd() # Or specify the absolute path if needed

# Get the absolute path of the parent directory ('your_project_root')
parent_dir = os.path.dirname(notebook_dir)
# Or use: parent_dir = os.path.abspath(os.path.join(notebook_dir, '..'))

# Add the parent directory to sys.path if it's not already there
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

# Option A: Import specific functions
from util.OPT_utilities import objectiveFcn, grad_desc, coor_desc, ssd, ssd_ls_temp, ssd_bt_temp, ssd_hbt, ssd_sag, spsa, nesterov_grad_desc
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import argparse

def truncate_to_matrix(array_list):
    # Find the minimum length among the arrays
    min_length = min(arr.shape[0] for arr in array_list)
    
    # Truncate each array to the minimum length and stack them into a matrix
    truncated_matrix = np.vstack([arr[:min_length] for arr in array_list])
    
    return truncated_matrix

def f_lr(x, lbd, r):
    """Worst function in the world by Nesterov 2013"""
    if r > len(x):
        raise ValueError('r must be less than or equal to the length of x')
    sums = (x[0]**2 + sum((x[i] - x[i+1])**2 for i in range(0, r-1)) 
            + x[r-1]**2)/2 - x[0]
    return lbd * sums/4 + lbd * r/(8*(r+1))

def parse_parameters():
    parser = argparse.ArgumentParser(description='SSD Optimization Parameters')
    parser.add_argument('--lmba', type=float, default=20.0, help='Lambda value')
    parser.add_argument('--d', type=int, default=1000, help='Problem Dimension')
    parser.add_argument('--r1', type=int, default=100, help='HF Dimension')
    parser.add_argument('--r2', type=int, default=2, help='LF Dimension')
    parser.add_argument('--ell', type=int, default=10, help='Subspace Dimension')
    parser.add_argument('--epochs', type=int, default=10, help='Number of Epochs')
    parser.add_argument('--line_iter', type=int, default=10, 
                        help='Maximal Number of Line Search Iterations')
    parser.add_argument('--L0', type=float, default=1.0, 
                        help='Initial Learning Rate for Line Search')
    parser.add_argument('--c', type=float, default=0.9, help='Armijo Shrinking Factor')
    parser.add_argument('--num_trials', type=int, default=3, help='Number of Trials')
    return parser.parse_args()

def main():
    # Parse the parameters
    args = parse_parameters()
    lmda = args.lmba
    d = args.d
    r1 = args.r1
    r2 = args.r2
    ell = args.ell
    num_iterations = args.epochs
    linesearch_iter = args.line_iter
    L0 = args.L0
    c = args.c
    num_trials = args.num_trials
    learning_rate = 1 / lmda
    learning_rate_ssd = learning_rate * ell / d

    # Initialize the worst function
    f    = lambda x : f_lr(x, lmda, r1)
    f_LF = lambda x : f_lr(x, lmda, r2)
    x0  = np.zeros(d)

    # Assign function classes
    # High-fidelity objective function
    obj = objectiveFcn(f,label='Low-rank Function')
    # Low-fidelity objective function
    obj_lowFi= objectiveFcn(f_LF)

    # Define methods and their parameters
    methods = [
        ('gd', grad_desc, {'learning_rate': learning_rate, 'num_iterations': num_iterations}),
        ('ngd', nesterov_grad_desc, {'learning_rate': learning_rate, 'num_iterations': num_iterations}),
        ('cd', coor_desc, {'learning_rate': learning_rate, 'num_iterations': num_iterations / 2}),
        ('ssd', ssd, {'ell': ell, 'learning_rate': learning_rate_ssd, 'num_iterations': num_iterations * d / ell}),
        ('spsa', spsa, {'num_iterations': num_iterations * d}),
        ('rgfm', ssd, {'ell': 1, 'learning_rate': learning_rate_ssd, 'num_iterations': num_iterations * d}),
        ('ssd_bf', ssd_bt_temp, {'ell': ell, 'obj_lowFi': obj_lowFi, 'c': c, 'num_iterations': num_iterations * d / ell,
                                'linesearch_iter': linesearch_iter, 'L0': L0}),
        ('ssd_hf', ssd_hbt, {'ell': ell, 'c': c, 'num_iterations': num_iterations * d / ell,
                            'linesearch_iter': linesearch_iter, 'L0': L0}),
        ('ssd_sag', ssd_sag, {'ell': ell, 'learning_rate': learning_rate_ssd, 'num_iterations': num_iterations * d / ell}),
    ]

    # Run methods
    res = defaultdict(list)
    for _ in tqdm(range(num_trials)):
        for name, method, params in methods:
            _ = method(x0, obj, **params)
            res[name].append(obj.returnHistory())
            res[f'{name}_time'].append(np.asarray(obj.timHistory))

    # Collect data and compute mean/std
    for k, v in res.items():
        res[k] = truncate_to_matrix(v)
    bf_ratio = linesearch_iter * r2 / ((ell + 1) * r1)

    save_path = f'results/worst/worst-d{d}-rH{r1}-rL{r2}-lmda{lmda}-ell{ell}-c{c}.npz'
    print(f'Saved results to {save_path}')
    np.savez(save_path, res=res, bf_ratio=bf_ratio)
    print('Done!')
    
if __name__ == '__main__':
    if not os.path.exists('results/worst'):
        os.makedirs('results/worst')
    main()
