from baselines.baselines import ZeroBaseline, LinearFeatureBaseline
from envs.ml_four_paths import FourPaths
from envs.ml_simple_bandits_env import SimpleBandits
from envs.ml_ant_movement import AntMovementEnv
from envs.ml_ant_obstaclesgen import AntObstaclesGenEnv
from envs.ml_ant_obstaclesgen_noblocks import AntObstaclesGenNoMazeEnv
from envs.taxi import Taxi
from utils.mpi_adam import MpiAdam
from torch.optim import SGD
from gym.spaces import Discrete, Box
from gym.wrappers.time_limit import TimeLimit
import gym
import re
from policies.policies import *
from algorithms.algorithms import *
from utils.utils import set_seeds, get_space_io_size
from torch.nn.functional import softplus


def initialize_algorithm(opts=None):
    # Set seeds
    set_seeds(opts.seed + opts.rank)

    if opts.continue_run is not None:
        run_data_dir = opts.run_data_dir
        run_name = opts.run_name
        rank = opts.rank

        load_all_path = opts.continue_run
        checkpoint = torch.load(opts.continue_run)
        opts = checkpoint['opts']
        opts.continue_run = load_all_path
        opts.epoch_start = int(re.search(r'\d+', os.path.basename(load_all_path)).group()) + 1

        opts.run_data_dir = run_data_dir
        opts.run_name = run_name
        opts.rank = rank

    elif opts.load_chp is not None:
        checkpoint = torch.load(opts.load_chp)
        opts.options = checkpoint['opts'].options
        opts.no_bias = checkpoint['opts'].no_bias
        if hasattr(checkpoint['opts'], 'learn_lr_inner'):
            opts.learn_lr_inner = checkpoint['opts'].learn_lr_inner
        else:
            opts.learn_lr_inner = False
        opts.hidden_sizes_base = checkpoint['opts'].hidden_sizes_base
        opts.hidden_sizes_option = checkpoint['opts'].hidden_sizes_option
        opts.hidden_sizes_termination = checkpoint['opts'].hidden_sizes_termination
        opts.hidden_sizes_subpolicy = checkpoint['opts'].hidden_sizes_subpolicy

    # Initialize env
    if "taxi" == opts.env:
        env = Taxi(exclude_envs=opts.exclude_envs, random_move_prob=opts.random_move_prob)
    elif "grid" in opts.env:
        env = FourPaths(map_name=opts.env, exclude_envs=tuple(opts.exclude_envs),
                        random_move_prob=opts.random_move_prob)
    elif "bandits" == opts.env:
        env = SimpleBandits()
    elif "ant_bandits" == opts.env:
        env = TimeLimit(env=AntMovementEnv(exclude_envs=opts.exclude_envs), max_episode_steps=600)
    elif "ant_maze" == opts.env:
        env = TimeLimit(env=AntObstaclesGenEnv(exclude_envs=opts.exclude_envs), max_episode_steps=1000)
    elif "ant_maze_noreset" == opts.env:
        env = TimeLimit(env=AntObstaclesGenEnv(exclude_envs=opts.exclude_envs, enable_resets=False), max_episode_steps=1000)
    elif "ant_maze_noblocks" == opts.env:
        env = TimeLimit(env=AntObstaclesGenNoMazeEnv(exclude_envs=opts.exclude_envs), max_episode_steps=1000)

    else:
        raise RuntimeError(f"Unknown environment specified: {opts.env}")

    # Get S(O) dim and A dim (discrete envs have different dim an io_size)
    obs_size = get_space_io_size(env.observation_space)
    act_size = get_space_io_size(env.action_space)
    if isinstance(env.action_space, Discrete):
        action_type = "discrete"
    elif isinstance(env.action_space, Box):
        action_type = "continuous"
    else:
        raise RuntimeError("Unknown action space type")

    # Initialize policy
    if opts.continue_run is not None:
        policy = create_policy(opts=opts, obs_size=obs_size, act_size=act_size, action_type=action_type,
                               checkpoint_path=opts.continue_run)
    elif opts.load_chp is not None:
        policy = create_policy(opts=opts, obs_size=obs_size, act_size=act_size, action_type=action_type,
                               checkpoint_path=opts.load_chp)
    else:
        policy = create_policy(opts=opts, obs_size=obs_size, action_type=action_type, act_size=act_size)

    # Create outer optimizer
    if opts.continue_run is not None:
        optimizer = create_optimizer(opts=opts, policy=policy, checkpoint_path=opts.continue_run)
    elif opts.load_chp is not None and opts.load_optimizer:
        optimizer = create_optimizer(opts=opts, policy=policy, checkpoint_path=opts.load_chp)
    else:
        optimizer = create_optimizer(opts=opts, policy=policy)

    # Initialize baseline
    if opts.baseline == "none":
        baseline = ZeroBaseline()
    elif opts.baseline == "linear":
        baseline = LinearFeatureBaseline()
    else:
        raise RuntimeError("Baseline {} not supported use none/linear instead".format(opts.baseline))

    # Initialize the algorithm
    algorithm = LVC(env, policy, baseline, optimizer, opts)

    # Synchronize model parameters across multiple cores
    policy.synchronize(comm=MPI.COMM_WORLD)

    # Load RNG
    if opts.continue_run is not None:
        load_rng(opts.continue_run)
    elif opts.load_chp is not None and opts.load_rng:
        load_rng(opts.load_chp)

    return algorithm, opts


def load_rng(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    pytorch_rng_states = checkpoint["pytorch_rng_states"]
    numpy_rng_states = checkpoint["numpy_rng_states"]
    if MPI is not None:
        if len(pytorch_rng_states) == MPI.COMM_WORLD.Get_size() and len(numpy_rng_states) == MPI.COMM_WORLD.Get_size():
            pytorch_rng_state = MPI.COMM_WORLD.scatter(pytorch_rng_states, root=0)
            numpy_rng_state = MPI.COMM_WORLD.scatter(numpy_rng_states, root=0)
            torch.set_rng_state(pytorch_rng_state)
            np.random.set_state(numpy_rng_state)
        else:
            raise RuntimeError(f"Len of rng_states does not match the comm size: {len(pytorch_rng_states)}, "
                               f"{len(numpy_rng_states)}, {MPI.COMM_WORLD.Get_size()}")
    else:
        torch.set_rng_state(pytorch_rng_states[0])
        np.random.set_state(numpy_rng_states[0])


def create_policy(opts, obs_size, act_size, action_type, checkpoint_path=None):
    policy = CategoricalOptionsPolicy(obs_dim=obs_size, action_dim=act_size, opts=opts, action_type=action_type,
                                      nonlinearity=torch.tanh)
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        policy.load_state_dict(checkpoint["policy_params"])
    return policy


def create_optimizer(opts, policy, checkpoint_path=None):
    if opts.learn_params == "outer":
        optimizer = MpiAdam([{"params": policy.outer_params.values(),
                              "lr": opts.lr_outer}])
    elif opts.learn_params == "all":
        optimizer = MpiAdam([{"params": policy.subpolicy_params.values(),
                              "lr": opts.lr_outer},
                             {"params": policy.base_params.values(),
                              "lr": opts.lr_outer},
                             {"params": policy.termination_params.values(),
                              "lr": opts.lr_outer},
                             {"params": policy.option_params.values(),
                              "lr": opts.lr_outer}
                             ])
    elif opts.learn_params == "inner":
        if opts.learn_lr_inner:
            param_names = policy.inner_params.keys()
            optimizer = SGD([{"params": policy.inner_params[i], "lr": softplus(policy.lr_params[i+"lr"])} for i in param_names])
        else:
            optimizer = SGD([{"params": policy.inner_params.values(),
                              "lr": opts.lr_inner}])
    elif opts.learn_params == "inner_adam":
        optimizer = MpiAdam([{"params": policy.inner_params.values(),
                              "lr": opts.lr_inner}])
    elif opts.learn_params == "outer_sgd":
        optimizer = SGD([{"params": policy.outer_params.values(),
                          "lr": opts.lr_outer}])
    else:
        raise RuntimeError("learn_params \"{}\" not supported".format(opts.learn_params))
    if checkpoint_path is not None:
        checkpoint = torch.load(checkpoint_path)
        optimizer.load_state_dict(checkpoint["optimizer_params"])
    return optimizer
