import os, argparse, torch
from pmc.config import Configurator
import optuna
import sys
import functools
sys.path.append('..')
import time

parser = argparse.ArgumentParser(description='Autonomous Diffusion Model (ADM)')
parser.add_argument(
    "--config", "-c", 
    type=str, 
    help="Path to config file"
)

parser.add_argument(
    '--optuna', 
    type=int,
    help='Number of optuna trials'
)

parser.add_argument(
    '--name',
    type=str,
    default='dummy',
    help='Experiment name. If --optuna is given, this is the study name and the experiment '
)

def main():

    # parse arguments
    args = parser.parse_args()
    # configurate and save configuration file
    cc = Configurator(args)
    os.makedirs(cc.cfg.exp_dir, exist_ok=True)
    with open(f'{cc.cfg.exp_dir}/config.yaml', 'w') as f:
        f.write(str(cc.cfg))

    if args.optuna is None:
        # regular run
        exp, model, dataloader, callbacks = cc.init_all()
        exp(model, dataloader, callbacks)
    else:
        # hyperparameter tunning 
        n_trials = args.optuna if args.optuna > 0 else None
        objective = functools.partial(objective, cc=cc)
        optimize_hyperparameters(objective, 'maximize', n_trials=n_trials, study_name=args.name)
    

def optimize_hyperparameters(
    objective,
    direction: str = 'minimize',
    n_trials: int | None = None,
    study_name: str = 'dummy',
    save_dir: str = '.'
):
    # set directory to save study
    os.makedirs(save_dir, exist_ok=True)
    db_path = os.path.join(save_dir, 'optuna.db')
    storage_url = f'sqlite:///{os.path.abspath(db_path)}'

    study = optuna.create_study(
        storage=storage_url,
        direction=direction,
        study_name=study_name,
        load_if_exists=True,
    )

    study.optimize(objective, n_trials=n_trials)

    pruned_trials = study.get_trials(deepcopy=False, states=(TrialState.PRUNED,))
    complete_trials = study.get_trials(deepcopy=False, states=(TrialState.COMPLETE,))

    print('Study statistics: ')
    print('  Number of finished trials: ', len(study.trials))
    print('  Number of pruned trials: ', len(pruned_trials))
    print('  Number of complete trials: ', len(complete_trials))

    print('Best trial:')
    best_trial = study.best_trial

    print('  Value: ', best_trial.value)

    print('  Params: ')
    for key, value in best_trial.params.items():
        print(f'    {key}: {value}')

def objective(trial: optuna.Trial, cc):
    params = {
        'model.p': trial.suggest_float('p', 0.1, 0.5, step=0.1),
        'model.b_small': trial.suggest_int('b_small', 1000, 10000, step=1000),
        'name': f'{trial.study.study_name}_{trial.number:03d}',
    }

    # constraint: b_small < b_large
    params['model.b_large'] = trial.suggest_int('b_large', params['model.b_small'], 10000, step=1000)

    cc.overwrite_cfg(params)
    exp, model, dataloader, callbacks = cc.init_all()
    start_time = time.perf_counter()
    exp(model, dataloader, callbacks)
    elapsed = time.perf_counter() - start_time

    return exp.logger.log_dict['batch0_xrecon_psnr']

if __name__ == '__main__':
    main()
