import argparse
from datetime import datetime
import numpy as np
import os
import torch
from mpi4py import MPI
from utils import logger
from utils.initalization import initialize_algorithm
from utils.utils import GradUpdateType
import shutil


def main(opts=None):
    algorithm, opts = initialize_algorithm(opts)
    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        opts.base_dir = f"{opts.run_data_dir}/{run_name(opts)}"
        if os.path.exists(f"{opts.base_dir}/logs"):
            shutil.rmtree(f"{opts.base_dir}/logs")
        os.makedirs(f"{opts.base_dir}/logs", exist_ok=False)
        data = [opts.base_dir for _ in range(MPI.COMM_WORLD.Get_size())]
    else:
        data = None
    if MPI is not None:
        opts.base_dir = MPI.COMM_WORLD.scatter(data, root=0)

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        logger.configure(dir=f"{opts.base_dir}/logs")
        logger.info("\nRun parameters:")
        for key, val in vars(opts).items():
            logger.info(str(key) + ": " + str(val))
        processes = 1 if MPI is None else MPI.COMM_WORLD.Get_size()
        logger.info(f"Running {processes} processes")
    else:
        logger.configure(dir=f"{opts.base_dir}/logs", format_strs=[])

    # with torch.autograd.detect_anomaly():
    run_algorithm(opts, algorithm)


def run_algorithm(opts, algorithm):
    epoch_digits = len(str(opts.epochs))
    # torch.autograd.set_detect_anomaly(True)
    # for epoch in range(2):
    for epoch in range(opts.epoch_start, opts.epoch_start + opts.epochs):
        # Plot initial epoch
        if epoch == opts.epoch_start:
            algorithm.plot_epoch(name_prefix=opts.base_dir + "/plots/"
                                             + "epoch{:0{prec}d}/".format(epoch-1, prec=epoch_digits),
                                 plot_updates=False)
        algorithm.epoch()

        # Average data across processes and log
        if MPI is not None:
            logger.merge_thread_data(comm=MPI.COMM_WORLD)

        logger.log(f"\nEpoch {epoch}:")
        logger.dumpkvs()
        logger.increment_global_step()

        # Plotting
        if epoch % opts.plot_freq == 0:
            algorithm.plot_epoch(name_prefix=opts.base_dir + "/plots/"
                                             + "epoch{:0{prec}d}/".format(epoch, prec=epoch_digits),
                                                                          plot_updates=False)

        # Checkpoint
        if (epoch % opts.chp_freq == 0) or (epoch == opts.epochs):
            if MPI is not None:
                pytorch_rng_states = MPI.COMM_WORLD.gather(torch.get_rng_state(), root=0)
                numpy_rng_states = MPI.COMM_WORLD.gather(np.random.get_state(), root=0)
            else:
                pytorch_rng_states = [torch.get_rng_state()]
                numpy_rng_states = [np.random.get_state()]
            if opts.rank == 0:
                state_dicts = {
                    "epoch": epoch,
                    "policy_params": algorithm.policy.state_dict(),
                    "optimizer_params": algorithm.policy_optimizer.state_dict(),
                    "opts": opts,
                    "pytorch_rng_states": pytorch_rng_states,
                    "numpy_rng_states": numpy_rng_states
                }
                save_dir = opts.base_dir +"/checkpoints/"
                os.makedirs(save_dir, exist_ok=True)
                torch.save(
                    state_dicts,
                    save_dir + "epoch{:0{prec}d}.tar".format(epoch, prec=epoch_digits)
                )


def run_name(opts):
    return f"{opts.run_name}"
    # timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S')
    # if opts.run_name == "":
    #     return f"{opts.grad_update_type}_{opts.env}_{opts.baseline}_{opts.seed}_{timestamp}"
    # else:
    #     return f"{opts.run_name}_{opts.grad_update_type}_{opts.env}_{opts.baseline}_{opts.seed}_{timestamp}"


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Master Thesis experiments program")

    # Algorithm
    parser.add_argument('--epochs', type=int, default=200, help='Number of algorithm epochs (outer updates)')
    parser.add_argument('--envs_per_process', type=int, default=15,
                        help='Total number of envs = num_processes x envs_per_process')
    parser.add_argument('--episodes', type=int, default=20,
                        help='How many episodes should be sampled per meta-update/update')
    parser.add_argument('--lookaheads', type=int, default=1,
                        help='How many gradient steps to optimize for (best performance after lookaheads updates)')
    parser.add_argument('--grad_update_type', type=GradUpdateType.argparse, default=GradUpdateType.META,
                        choices=list(GradUpdateType), help='Which gradient update to use for training')
    parser.add_argument('--maml', action='store_true', help='Enable maml with 1 option')

    # Environment options
    parser.add_argument('--env', type=str, default="taxi",
                        choices=("taxi", "grid_large", "grid_small", "grid_medium", "bandits",
                                 "ant_bandits", "ant_maze", "ant_maze_noreset", "ant_maze_noblocks"))
    parser.add_argument('--exclude_envs', type=int, nargs='*', default=[], help="Environments that should be excluded")
    parser.add_argument('--fixed_env', type=int, default=-1,
                        help='Used for task specific methods and adaptation to fix task')
    parser.add_argument('--random_move_prob', type=float, default=0.0, help='Random move prob in taxi/grid environment')

    # RL settings
    parser.add_argument('--gae_discount', type=float, default=0.98, help='GAE lambda parameter for advantage estimator')
    parser.add_argument('--dice_discount', type=float, default=0, help='Past influence discount for loaded dice')
    parser.add_argument('--return_discount', type=float, default=0.95, help='Discount factor for returns')
    parser.add_argument('--baseline', type=str, default="none", choices=("none", "linear"),
                        help='Baseline to be used for loss calculation')
    parser.add_argument('--entropy_reg', type=float, default=0, help='Entropy regularization coefficient')

    # Optimizers/Learning rates
    parser.add_argument('--lr_outer', type=float, default=1e-1, help='Outer learning rate')
    parser.add_argument('--lr_inner', type=float, default=10.0, help='Inner learning rate')
    parser.add_argument('--learn_params', type=str, default="outer", choices=("outer", "all", "inner",
                                                                              "inner_adam", "outer_sgd"))
    parser.add_argument('--learn_lr_inner', action='store_true', help='Learn learning rate for inner steps')


    # TODO write help description for this

    # Policy options
    parser.add_argument('--options', type=int, default=4, help='Number of options')
    parser.add_argument('--termination_prior', type=float, default=-1, help='Prior termination probability')
    parser.add_argument('--fixed_std', type=float, default=0, help='Standard deviation for continuous policy')

    parser.add_argument('--hidden_sizes_base', type=int, nargs='*', default=[], help='Layer sizes for shared encoder')
    parser.add_argument('--hidden_sizes_option', type=int, nargs='*',  default=[], help='Layer sizes for options')
    parser.add_argument('--hidden_sizes_termination', type=int, nargs='*',  default=[],
                        help='Layer sizes for terminations')
    parser.add_argument('--hidden_sizes_subpolicy', type=int, nargs='*', default=[], help='Layer sizes for subpolicy')
    parser.add_argument('--no_bias', action='store_true', help='Do not use bias for policy networks')

    parser.add_argument('--temperature_options', type=float, default=1, help='Temp for softmax policy over options')
    parser.add_argument('--temperature_terminations', type=float, default=1, help='Temp for sigmoid terminations')
    parser.add_argument('--max_option_prob', type=float, default=1, help='Max probability of selecting and option')
    # parser.add_argument('--max_termination_prob', type=float, default=1, help='Max probability of selecting and option')
    # parser.add_argument('--min_termination_prob', type=float, default=0, help='Max probability of selecting and option')

    # Seeding
    parser.add_argument('--seed', type=int, default=1234, help='Random seed to use')

    # Logging, plotting, saving
    parser.add_argument('--run_data_dir', type=str, default='../run_data', help='Path to directory with run data')
    parser.add_argument('--run_name', type=str, default='test', help='Name to identify the run')
    parser.add_argument('--plot_freq', type=int, default=20, help='Plotting frequency (in episodes)')
    parser.add_argument('--chp_freq', type=int, default=20, help='Checkpoint frequency (in episodes)')
    parser.add_argument('--log_level', type=int, default=1, help='Specifies how detailed the logs should be 0-2')
    parser.add_argument('--save_trajs', action='store_true', help='Use pickle to save trajectories (for heatmaps')

    # Load model
    parser.add_argument('--load_chp', type=str, nargs='?', help='Path to checkpoint')
    parser.add_argument('--load_rng', action='store_true', help='Load RNG states')
    parser.add_argument('--continue_run', type=str, nargs='?', help='Continue run')
    parser.add_argument('--load_optimizer', type=str, nargs='?', help='Load optimizer state')

    opts = parser.parse_args()

    if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
        opts.rank = 0
    else:
        opts.rank = MPI.COMM_WORLD.Get_rank()
    opts.epoch_start = 1
    main(opts)


