import optuna
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint
import subprocess
from colorama import Fore, Style

import datetime

import stoix.systems.ppo.ff_ppo_outer_parallel_seeds

assert optuna.__version__ == "3.6.1"  # must be the same version for loading / saving

import hydra
import joblib
import neptune

import numpy as np


# Define a function to load the study from a file using pickle
def load_study(filename):
    with open(filename, "rb") as f:
        return joblib.load(f)


def load_study_from_neptune(project_name, filename, optuna_neptune_dir):
    project = neptune.init_project(project=project_name, mode="read-only")
    key = f"{optuna_neptune_dir}/{filename}"
    project[key].download()
    project.wait()
    print(f"Downloaded! {filename}")
    study = load_study(f"{filename}.pkl")
    return study


def eval_sweep(cfg: DictConfig):

    assert not cfg.num_seeds % 4 
    cfg.eval_seeds = np.arange(cfg.base_seed, cfg.base_seed + cfg.num_seeds).tolist()

    if cfg.sweep_name == 'outer_lr_best':
        cfg.trial_num = 41
    elif cfg.sweep_name == 'bias_init_best':
        cfg.trial_num = 110
    elif cfg.sweep_name == 'nest_best':
        cfg.trial_num = 100

    # load the study 
    study = load_study_from_neptune('grids', f'ff_ppo_{cfg.env.scenario.task_name}_{cfg.sweep_name[:-5]}_{cfg.trial_num}.pkl', 'optuna_study_pickles')
    params = study.best_params
    print(f'Using sweep params: {study.best_trial.number} with value {study.best_trial.value} and params {study.best_trial.params}')

    baseline_study = load_study_from_neptune('baseline', f'ff_ppo_{cfg.env.scenario.task_name}_baseline_500.pkl', 'optuna_study_pickles')
    baseline_params = baseline_study.best_params
    print(f'Using baseline params: {baseline_study.best_trial.number} with value {baseline_study.best_trial.value} and params {baseline_study.best_trial.params}')
    params = {**params, **baseline_params}
    
    # common for all
    cfg.arch.total_num_envs = 2 ** params["arch.total_num_envs"]
    cfg.system.rollout_length = 2 ** params["system.rollout_length"]
    cfg.system.actor_lr = params["system.actor_lr"]
    cfg.system.critic_lr = params["system.critic_lr"]
    cfg.system.epochs = params["system.epochs"]
    cfg.system.num_minibatches = 2 ** params["system.num_minibatches"]
    cfg.system.gamma = params["system.gamma"]
    cfg.system.gae_lambda = params["system.gae_lambda"]
    cfg.system.max_grad_norm = params["system.max_grad_norm"]
    cfg.system.reward_scaling = params["system.reward_scaling"]
    cfg.system.clip_eps = params["system.clip_eps"]

    if cfg.sweep_name == 'outer_lr_best':

        print('adding constant outer lr') # 1 param
        cfg.system.outer_optimizer.learning_rate.peak_value = params["system.outer_optimizer.learning_rate.peak_value"]
    
    elif cfg.sweep_name == 'bias_init_best':
    
        print('adding bias init parameters') # 2 params
        cfg.system.free_step_momentum = params["system.free_step_momentum"]
        cfg.system.free_step_learning_rate.peak_value = params["system.free_step_learning_rate.peak_value"]

    elif cfg.sweep_name == 'nest_best':
    
        cfg.system.outer_optimizer.learning_rate.peak_value = params["system.outer_optimizer.learning_rate.peak_value"]
        cfg.system.outer_optimizer.momentum = params["system.outer_optimizer.momentum"]

    cfg.system.system_name = cfg.sweep_name
    cfg.logger.kwargs.neptune_tag = [cfg.sweep_name]

    ppo_experiment = stoix.systems.ppo.ff_ppo_outer_parallel_seeds.run_experiment

    for i in range(cfg.num_seeds // 4):
        temp_cfg = cfg.copy()
        temp_cfg.parallel_seeds = cfg.eval_seeds[i * 4:i * 4 + 4]
        _ = ppo_experiment(temp_cfg)

    return

def eval_all_sweep(cfg):

    for sweep_name in ['outer_lr_best', 'bias_init_best', 'nest_best']:

        temp_cfg = cfg.copy()

        temp_cfg.sweep_name = sweep_name
        temp_cfg.system.run_outer_ppo = True
        
        temp_cfg.base_project = '/grids'

        eval_sweep(temp_cfg)

@hydra.main(
    config_path="./configs",
    config_name="eval.yaml",
    version_base="1.2",
)
def sweep_hydra_entry_point(cfg: DictConfig) -> None:
    """Experiment entry point."""
    # Allow dynamic attributes.
    OmegaConf.set_struct(cfg, False)
    del cfg.arch.seed
    eval_all_sweep(cfg)

if __name__ == "__main__":
    sweep_hydra_entry_point()
