import optuna
from omegaconf import DictConfig, OmegaConf
from rich.pretty import pprint
import subprocess
from colorama import Fore, Style
import math

import datetime

import stoix.systems.ppo.ff_ppo_outer_parallel_seeds

assert optuna.__version__ == "3.6.1"  # must be the same version for loading / saving

import hydra
import joblib
import neptune

import numpy as np


# Define a function to save the study to a file using pickle
def save_study(study, filename):
    with open(filename, "wb") as f:
        joblib.dump(study, f)


# Define a function to load the study from a file using pickle
def load_study(filename):
    with open(filename, "rb") as f:
        return joblib.load(f)


def save_study_to_neptune(study, project_name, filename, optuna_neptune_dir):
    # Save the study to a file
    save_study(study, filename)
    # Initialize Neptune project - we do this here to not leave an async process running
    project = neptune.init_project(project=project_name)
    # Upload the study to Neptune
    key = f"{optuna_neptune_dir}/{filename}"
    project[key].upload(filename)
    # Forces a sync
    project.stop()


def load_study_from_neptune(project_name, filename, optuna_neptune_dir):
    project = neptune.init_project(project=project_name, mode="read-only")
    key = f"{optuna_neptune_dir}/{filename}"
    project[key].download()
    project.wait()
    print(f"Downloaded! {filename}")
    study = load_study(f"{filename}.pkl")
    return study


def optuna_objective(trial: optuna.Trial, cfg: DictConfig) -> float:
    """Optuna objective function.

    Args:
        trial: Optuna trial object.
        cfg: DictConfig containing the configuration for the experiment.

    Returns:
        float: Evaluation performance of the system."""

    assert not (cfg.system.run_outer_ppo and cfg.sweep_name == 'baseline'), 'Cannot run outer PPO with baseline sweep'
    assert not (not cfg.system.run_outer_ppo and cfg.sweep_name not in ['baseline', 'baseline_ext', 'baseline_ablate_lr']), 'Cannot run normal PPO with non-baseline sweep'

    if cfg.scramble_seeds:
        cfg.parallel_seeds = np.random.choice(np.arange(10000), size=(len(cfg.parallel_seeds),), replace=False).tolist()

    trial.set_user_attr("parallel_seeds", cfg.parallel_seeds)

    if cfg.sweep_name == 'baseline': 

        if 'sokoban' in cfg.env.scenario.task_name:
            max_num_envs = 8
        elif 'pacman' in cfg.env.scenario.task_name:
            max_num_envs = 9
        else:
            max_num_envs = 10

        # Set up the hyperparameters to optimize
        cfg.arch.total_num_envs = 2 ** trial.suggest_int("arch.total_num_envs", 6, max_num_envs)
        cfg.system.rollout_length = 2 ** trial.suggest_int("system.rollout_length", 2, 8)
        cfg.system.actor_lr = trial.suggest_float("system.actor_lr", 1e-5, 1e-3, log=True)
        cfg.system.critic_lr = trial.suggest_float("system.critic_lr", 1e-5, 1e-3, log=True)
        cfg.system.epochs = trial.suggest_int("system.epochs", 1, 16)
        cfg.system.num_minibatches = 2 ** trial.suggest_int("system.num_minibatches", 0, 6)
        cfg.system.gamma = trial.suggest_float("system.gamma", 0.9, 1.0)
        cfg.system.gae_lambda = trial.suggest_float("system.gae_lambda", 0.0, 1.0)
        cfg.system.max_grad_norm = trial.suggest_float("system.max_grad_norm", 0.1, 5.0)
        cfg.system.reward_scaling = trial.suggest_float("system.reward_scaling", 1e-1, 100, log=True)
        cfg.system.clip_eps = trial.suggest_float("system.clip_eps", 0.1, 0.5)
    
    elif cfg.sweep_name == 'outer_lr':

        print('adding constant outer lr') 
        cfg.system.outer_optimizer.learning_rate.peak_value = float(trial.suggest_float("system.outer_optimizer.learning_rate.peak_value", 0.0, 2.0))

    elif cfg.sweep_name == 'bias_init':
    
        print('adding bias init parameters') # 2 params
        cfg.system.free_step_momentum = float(trial.suggest_float("system.free_step_momentum", 0.0, 0.9))
        cfg.system.free_step_learning_rate.peak_value = float(trial.suggest_float("system.free_step_learning_rate.peak_value", 0.0, 1.0))

    elif cfg.sweep_name in ['nest', 'hb']:
    
        print('adding momenutm') # 2 param
        cfg.system.outer_optimizer.momentum = float(trial.suggest_float("system.outer_optimizer.momentum", 0.0, 0.9))
        cfg.system.outer_optimizer.learning_rate.peak_value = float(trial.suggest_float("system.outer_optimizer.learning_rate.peak_value", 0.1, 1.0))

        cfg.system.outer_optimizer.nesterov = True if 'nest' in cfg.sweep_name else False

    if cfg.system.run_outer_ppo:
        print('running outer PPO')
    
    ppo_experiment = stoix.systems.ppo.ff_ppo_outer_parallel_seeds.run_experiment

    eval_performance = ppo_experiment(cfg)

    print(f"{Fore.CYAN}{Style.BRIGHT}PPO experiment completed{Style.RESET_ALL}")

    return eval_performance


def perform_sweep(cfg: DictConfig):

    search_space = {}

    if not cfg.sweep_name == 'baseline':

        cfg.study_config.checkpoint_interval = 1

        if 'outer_lr' in cfg.sweep_name:

            search_space["system.outer_optimizer.learning_rate.peak_value"] = np.arange(0.0, 4.0 + 0.1, 0.1)

            seed_offset = 10

        elif cfg.sweep_name == 'bias_init':
    
            search_space["system.free_step_momentum"] = np.arange(0.0, 0.9 + 0.1, 0.1)
            search_space["system.free_step_learning_rate.peak_value"] = np.arange(0.0, 1.0 + 0.1, 0.1)

            seed_offset = 20

        elif cfg.sweep_name == 'nest':
    
            search_space["system.outer_optimizer.learning_rate.peak_value"] = np.arange(0.1, 1.0 + 0.1, 0.1)
            search_space["system.outer_optimizer.momentum"] = np.arange(0.0, 0.9 + 0.1, 0.1)

            seed_offset = 30

        print(search_space)

        n_trials = None
        for value in search_space.values():
            n_trials = len(value) if n_trials is None else n_trials * len(value)
        cfg.study_config.n_trials = n_trials

        print('n_trials', n_trials)

        cfg.study_config.sampler.seed += seed_offset # avoids running grid on same seeds

    assert (
        cfg.study_config.n_trials % cfg.study_config.checkpoint_interval == 0
    ), "n_trials must be divisible by checkpoint_interval"

    # FINDING WHICH STUDY FILES ALREADY EXIST IN NEPTUNE
    project = neptune.init_project(project=cfg.logger.kwargs.neptune_project, mode="read-only")

    # iterate over the number of checkpoints
    latest_study_file = None
    for i in range(0, cfg.study_config.n_trials, cfg.study_config.checkpoint_interval):
        filename = f"{cfg.study_config.study_name}_{i+cfg.study_config.checkpoint_interval}.pkl"
        try:
            # try and load the file from neptune
            project[f"{cfg.optuna_neptune_dir}/{filename}"].download()
            project.wait()
            latest_study_file = filename

        except neptune.exceptions.MissingFieldException:
            pass

        if latest_study_file is not None and not cfg.resume:
            # if we're not resuming and the file is found, raise an error - do not want to overwrite
            raise ValueError(
                f"NOT RESUMING AND {cfg.optuna_neptune_dir}/{filename} ALREADY EXISTS IN NEPTUNE!!"
            )
        elif latest_study_file is None and cfg.resume:
            raise (
                ValueError(
                    f"RESUMING BUT {cfg.optuna_neptune_dir}/{filename} DOES NOT EXIST IN NEPTUNE!!"
                )
            )
    project.stop()

    # STUDY CREATION OR LOADING
    if not cfg.resume:
        
        # create new study

        np.random.seed(cfg.study_config.sampler.seed) # for scrambled seeds

        if cfg.sweep_name == 'baseline':

            sampler = optuna.samplers.TPESampler(
                n_startup_trials=cfg.study_config.sampler.n_startup_trials,
                seed=cfg.study_config.sampler.seed,
                multivariate=cfg.study_config.sampler.multivariate,
            )

        else:
        
            sampler = optuna.samplers.GridSampler(
                search_space=search_space
            )

        pruner = (
            optuna.pruners.NopPruner()
        )  # this isn't really necessary but just to be explicit that we're not pruning
        study = optuna.create_study(
            sampler=sampler,
            pruner=pruner,
            direction=cfg.study_config.direction,
            study_name=cfg.study_config.study_name,
        )
        start_trial = 0
        print('Creating new study')
    else:
        # load study from file
        study = load_study(f"{latest_study_file}.pkl")
        start_trial = len(study.trials)
        print("Resuming from trial:", start_trial)
        assert (
            start_trial % cfg.study_config.checkpoint_interval == 0
        ), "start_trial must be divisible by checkpoint_interval"

    if cfg.sweep_name != 'baseline':

        # load the baseline study 
        base_study = load_study_from_neptune(cfg.base_project, f'ff_ppo_{cfg.env.scenario.task_name}_{cfg.base_sweep_name}_{cfg.base_trial_num}.pkl', 'optuna_study_pickles')

        base_params = base_study.best_params
        print(f'Using best trial from baseline: {base_study.best_trial.number} with value {base_study.best_trial.value} and params {base_study.best_trial.params}')

        # Set up the hyperparameters to optimize
        cfg.arch.total_num_envs = 2 ** base_params["arch.total_num_envs"]
        cfg.system.rollout_length = 2 ** base_params["system.rollout_length"]
        cfg.system.actor_lr = base_params["system.actor_lr"] # scaled by optuna objective
        cfg.system.critic_lr = base_params["system.critic_lr"] # scaled by optuna objective
        cfg.system.epochs = base_params["system.epochs"] # overwritten by optuna objective
        cfg.system.num_minibatches = 2 ** base_params["system.num_minibatches"]
        cfg.system.gamma = base_params["system.gamma"]
        cfg.system.gae_lambda = base_params["system.gae_lambda"]
        cfg.system.max_grad_norm = base_params["system.max_grad_norm"]
        cfg.system.reward_scaling = base_params["system.reward_scaling"]
        cfg.system.clip_eps = base_params["system.clip_eps"] # overwritten by optuna objective

    # lambda function to pass the system name and env name to the optuna objective
    objective = lambda trial: optuna_objective(trial, cfg)

    for i in range(start_trial, cfg.study_config.n_trials, cfg.study_config.checkpoint_interval):

        # Optimize the study
        study.optimize(objective, n_trials=cfg.study_config.checkpoint_interval)

        # Generate filename
        filename = f"{cfg.study_config.study_name}_{i + cfg.study_config.checkpoint_interval}.pkl"

        print(f"saving {filename}")

        # Save the study to neptune
        if not cfg.debug_dont_save:
            save_study_to_neptune(study, cfg.logger.kwargs.neptune_project, filename, cfg.optuna_neptune_dir)


@hydra.main(
    config_path="./configs",
    config_name="sweep.yaml",
    version_base="1.2",
)
def sweep_hydra_entry_point(cfg: DictConfig) -> None:
    """Experiment entry point."""
    # Allow dynamic attributes.
    OmegaConf.set_struct(cfg, False)
    del cfg.arch.seed
    perform_sweep(cfg)


if __name__ == "__main__":
    sweep_hydra_entry_point()
