import os
from typing import List, Set

from core.model import DreamerNetwork
from collections import defaultdict


class BaseConfig(object):

    def __init__(self,
                 max_env_steps: int,
                 start_step: int,
                 valid_envs: List,
                 action_repeats: defaultdict,
                 action_repeat_set: List,
                 batch_size: int = 50,
                 policy_lr: float = 8e-5,
                 value_lr: float = 8e-5,
                 dynamics_lr: float = 1e-3,
                 tau: float = 0.005,
                 gamma: float = 0.99,
                 grad_clip_norm: float = 100,
                 save_model_freq: int = 1,
                 replay_memory_capacity: int = 1e6,
                 fixed_action_repeat=2,
                 test_interval_steps=5000,
                 test_episodes: int = 5,
                 env_itr_steps: int = 100,
                 chunk_size: int = 50):

        # training
        self.entropy_lr = 0.0003
        self._valid_envs = valid_envs
        self._action_repeats = action_repeats

        self.env_itr_steps = env_itr_steps
        self.max_env_steps = max_env_steps
        self.start_step = start_step
        self.batch_size = batch_size
        self.chunk_size = chunk_size
        self.grad_clip_norm = grad_clip_norm
        self.gamma = gamma
        self.policy_lr = policy_lr
        self.value_lr = value_lr
        self.dynamics_lr = dynamics_lr
        self.tau = tau
        self.device = 'cpu'
        self.action_repeat_mode = None
        self.seed = None
        self.save_model_freq = save_model_freq
        self.test_episodes = test_episodes
        self.free_nats = 3
        self.update_itrs = 100
        self.max_dynamics_update_itr = 500
        self.min_dynamics_update_itr = 100

        self.min_behaviour_update_itr = 20
        self.max_behaviour_update_itr = 100

        self.global_kl_beta = 0
        self.planning_horizon = 15
        self.disclam = 0.95
        self.actor_repeat_entropy_coeff = 0.01
        self.actor_entropy_coeff = 0.01
        self.action_repeat_set = sorted(list(action_repeat_set))

        self.proposal_action_sample = 100
        self.uniform_action_sample = 50

        # Root prior exploration noise.
        self.root_dirichlet_alpha = 0.25
        self.root_exploration_fraction = 0.25

        # mcts
        self.mcts_cpw = 1
        self.mcts_alpha = 0.5
        self.pb_c_init = 1.25
        self.pb_c_base = 19652
        self.num_simulations = 50

        # memory
        self.replay_memory_capacity = replay_memory_capacity

        # test
        self.test_interval_steps = test_interval_steps

        # action info
        self.fixed_action_repeat = fixed_action_repeat

        # env info
        self.env_name = None
        self.observation_space = None
        self.action_space = None

        # paths
        self.exp_path = None
        self.model_path = None
        self.with_search_best_model_path = None
        self.no_search_best_model_path = None
        self.test_data_path = None
        self.recording_path = None

    def new_game(self, seed=None):
        raise NotImplementedError

    def get_uniform_network(self):
        return DreamerNetwork(obs_size=self.observation_space.shape[0],
                              belief_size=200,
                              state_size=30,
                              hidden_size=200,
                              embedding_size=200,
                              action_space=self.action_space,
                              action_repeat_set=self.action_repeat_set)

    def get_hparams(self):
        hparams = {}
        for k, v in self.__dict__.items():
            if 'path' not in k and (v is not None):
                hparams[k] = v
        return hparams

    def visit_softmax_temperature_fn(self, num_moves=None, env_steps=None):
        if env_steps < 0.5 * self.max_env_steps:
            return 1.0
        elif env_steps < 0.75 * self.max_env_steps:
            return 0.5
        else:
            return 0.25

    def set_config(self, args):
        # env info
        assert args.env in self._valid_envs, ' Invalid env. , It should be from  {}'.format(self._valid_envs)
        self.env_name = args.env
        env = self.new_game()
        self.observation_space = env.observation_space
        self.action_space = env.action_space
        env.close()

        # training
        self.seed = args.seed
        self.update_mode = args.update_mode
        self.anneal_update_itr = args.anneal_update_itr
        self.explore_mode = args.explore_mode
        self.optimize_with_search = args.optimize_with_search
        self.uniform_action_sample = args.uniform_action_sample
        self.device = args.device
        self.actor_repeat_entropy_coeff = args.repeat_entropy_coeff
        self.automatic_entropy_tuning = args.automatic_entropy_tuning
        if args.actor_entropy_coeff is not None:
            self.actor_entropy_coeff = args.actor_entropy_coeff
        if args.action_repeat_set is not None:
            self.action_repeat_set = sorted([int(_) for _ in args.action_repeat_set.split(",")])

        # create experiment path
        self.exp_path = os.path.join(args.result_dir, args.case, args.env)

        # other parameters
        # self.exp_path = os.path.join(self.exp_path, 'dynamics_lr_{}'.format(self.dynamics_lr))
        # self.exp_path = os.path.join(self.exp_path, 'policy_lr_{}'.format(self.policy_lr))
        # self.exp_path = os.path.join(self.exp_path, 'value_lr_{}'.format(self.value_lr))
        self.exp_path = os.path.join(self.exp_path, 'act_entr_{}'.format(self.actor_entropy_coeff))
        self.exp_path = os.path.join(self.exp_path, 'act_rep_entr_{}'.format(self.actor_repeat_entropy_coeff))
        self.exp_path = os.path.join(self.exp_path, 'act_rep_set_{}'.format(self.action_repeat_set))
        self.exp_path = os.path.join(self.exp_path, 'exp_{}'.format(self.explore_mode))
        self.exp_path = os.path.join(self.exp_path, 'upd_{}'.format(self.update_mode))
        self.exp_path = os.path.join(self.exp_path, 'annl_upd_itr_{}'.format(self.anneal_update_itr))
        self.exp_path = os.path.join(self.exp_path, 'unf_sample_{}'.format(self.uniform_action_sample))
        self.exp_path = os.path.join(self.exp_path, 'prop_sample_{}'.format(self.proposal_action_sample))
        self.exp_path = os.path.join(self.exp_path,
                                     'opt_with_search' if self.optimize_with_search else
                                     'opt_no_search')
        self.exp_path = os.path.join(self.exp_path,
                                     'WITH_auto_ent_tuning' if self.automatic_entropy_tuning else
                                     'NO_auto_ent_tuning')

        # seed
        self.exp_path = os.path.join(self.exp_path, 'seed_{}'.format(self.seed))

        # model paths
        self.model_path = os.path.join(self.exp_path, 'model.p')
        self.best_model_path = {'with_search': os.path.join(self.exp_path, 'with_search_best_model.p'),
                                'no_search': os.path.join(self.exp_path, 'no_search_best_model.p'), }
        self.test_data_path = os.path.join(self.exp_path, 'test_data.p')

        # recording paths
        self.recording_path = os.path.join(self.exp_path, 'recordings')
        os.makedirs(self.recording_path, exist_ok=True)
