import os
from queue import PriorityQueue
import numpy as np
import pickle
import torch


class AgentBase():
    def __init__(self, venv, CONFIG, CONFIG_ENV):
        self.venv = venv
        self.config = CONFIG
        self.rng = np.random.default_rng(seed=CONFIG.SEED)
        self.device = CONFIG.DEVICE
        self.image_device = CONFIG.IMAGE_DEVICE

        # Mode
        self.eval = CONFIG.EVAL

        # Vectorized envs
        self.n_envs = CONFIG.NUM_CPUS
        self.use_append = CONFIG_ENV.USE_APPEND
        self.max_train_steps = CONFIG_ENV.MAX_TRAIN_STEPS
        self.max_eval_steps = CONFIG_ENV.MAX_EVAL_STEPS
        self.env_step_cnt = [0 for _ in range(self.n_envs)]

        # Arch
        self.num_state_stack = CONFIG.NUM_STATE_STACK
        self.num_action_stack = CONFIG.NUM_ACTION_STACK
        self.append_dim = self.venv.state_dim*self.num_state_stack + self.venv.action_dim*self.num_action_stack

        # Save
        self.out_folder = CONFIG.OUT_FOLDER
        self.save_top_k = CONFIG.SAVE_TOP_K
        self.pq_top_k = PriorityQueue()
        self.save_metric = CONFIG.SAVE_METRIC

        # Save loss and eval info, key is step number
        self.loss_record = {}
        self.eval_record = {}

        # Load tasks
        dataset = CONFIG_ENV.DATASET
        print("= Loading tasks from", dataset)
        with open(dataset, 'rb') as f:
            self.task_all = pickle.load(f)
        self.num_task = len(self.task_all)
        print(self.num_task, "tasks are loaded")

        # Load test tasks
        if hasattr(CONFIG_ENV, 'DATASET_TEST'):
            dataset = CONFIG_ENV.DATASET_TEST
            print("= Loading tasks from", dataset)
            with open(dataset, 'rb') as f:
                self.task_test_all = pickle.load(f)
            print(len(self.task_test_all), "tasks are loaded")

        # Load ood tasks
        if hasattr(CONFIG_ENV, 'DATASET_OOD'):
            dataset = CONFIG_ENV.DATASET_OOD
            print("= Loading tasks from", dataset)
            with open(dataset, 'rb') as f:
                self.task_ood_all = pickle.load(f)
            print(len(self.task_ood_all), "tasks are loaded")

        # Language
        self.use_lang = CONFIG_ENV.USE_LANG
        if self.use_lang:
            self.num_lang_per_env = CONFIG_ENV.NUM_LANG_PER_ENV
            self.append_dim += CONFIG_ENV.LANG_DIM
            dataset = CONFIG_ENV.DATASET_LANG
            print("= Loading language from", dataset)
            with open(dataset, 'rb') as f:
                self.task_lang_all = pickle.load(f)

            # Move data to gpu
            for key, _ in self.task_lang_all.items():
                self.task_lang_all[key] = self.task_lang_all[key].to(self.device)

        # Set starting step
        if CONFIG.CURRENT_STEP is None:
            self.cnt_step = 0
        else:
            self.cnt_step = CONFIG.CURRENT_STEP
            print("starting from {:d} steps".format(self.cnt_step))


    def set_train_mode(self):
        self.eval_mode = False
        self.max_env_step = self.max_train_steps


    def set_eval_mode(self):
        self.num_eval_success = 0  # for calculating expected success rate
        self.num_eval_safe = 0  # for calculating expected safety rate
        self.eval_reward_cumulative = [0 for _ in range(self.n_envs)
                                       ]  # for calculating cumulative reward
        self.eval_reward_best = [0 for _ in range(self.n_envs)]
        self.eval_reward_cumulative_all = 0
        self.eval_reward_best_all = 0
        self.env_step_cnt = [0 for _ in range(self.n_envs)]

        self.eval_mode = True
        self.max_env_step = self.max_eval_steps


    # === Venv ===
    def step(self, action):
        return self.venv.step(action)


    def get_append(self, task_all, task_ids, _prev_s=None, _prev_a=None):
        cur_tasks = [task_all[id] for id in task_ids]
        if self.use_append:
            append_all = torch.empty((self.n_envs, 0)).to(self.device)
            if self.use_lang:
                # append_all = torch.vstack([task['lang'][lang_ind] for lang_ind, task in zip(self.lang_ind_all, cur_tasks)]).float().to(self.device)
                append_all = torch.vstack([self.task_lang_all[task['name']][lang_ind] for lang_ind, task in zip(self.lang_ind_all, cur_tasks)]).float().to(self.device)

            # Get state
            if self.num_state_stack > 0:
                prev_s = torch.from_numpy(np.vstack([np.hstack(prev) for prev in _prev_s])).float().to(self.device)
                append_all = torch.hstack((append_all, prev_s))

            # Get action
            if self.num_action_stack > 0:
                prev_a = torch.from_numpy(np.vstack([np.hstack(prev) for prev in _prev_a])).float().to(self.device)
                append_all = torch.hstack((append_all, prev_a))
        else:
            append_all = None
        return append_all


    def reset_sim(self):
        self.venv.env_method('close_pb')


    def reset_env_all(self, task_ids=None, task_all=None, verbose=False):
        if not task_all: task_all = self.task_all
        num_task = len(task_all)
        if task_ids is None:
            task_ids = self.rng.integers(low=0, 
                                         high=num_task, 
                                         size=self.n_envs)
        if len(task_ids) == 1:
            task_ids = task_ids*self.n_envs
        tasks = [task_all[id] for id in task_ids]
        s = self.venv.reset(tasks)
        if verbose:
            for index in range(self.n_envs):
                print("<-- Reset environment {} with task {}:".format(
                    index, task_ids[index]))
        self.env_step_cnt = [0 for _ in range(self.n_envs)]
        return s, task_ids


    def reset_env(self, env_ind, task_id=None, task_all=None, verbose=False):
        if not task_all: task_all = self.task_all
        num_task = len(task_all)
        if task_id is None:
            task_id = self.rng.integers(low=0, 
                                        high=num_task)
        s = self.venv.reset_one(env_ind, task_all[task_id])
        if verbose:
            print("<-- Reset environment {} with task {}:".format(
                env_ind, task_id))
        self.env_step_cnt[env_ind] = 0
        return s, task_id


    # === Models ===
    def save(self, metric=None, force_save=False):
        assert metric is not None or force_save, \
            "should provide metric of force save"
        save_current = False
        if force_save:
            save_current = True
        elif self.pq_top_k.qsize() < self.save_top_k:
            self.pq_top_k.put((metric, self.cnt_step))
            save_current = True
        elif metric > self.pq_top_k.queue[0][0]:  # overwrite
            # Remove old one
            _, step_remove = self.pq_top_k.get()
            for module, module_folder in zip(self.module_all,
                                             self.module_folder_all):
                module.remove(int(step_remove), module_folder)
            self.pq_top_k.put((metric, self.cnt_step))
            save_current = True

        if save_current:
            print()
            print('Saving current model...')
            for module, module_folder in zip(self.module_all,
                                             self.module_folder_all):
                module.save(self.cnt_step, module_folder)
            print(self.pq_top_k.queue)

    def restore(self, step, logs_path, agent_type, actor_path=None):
        """Restore the weights of the neural network.

        Args:
            step (int): #updates trained.
            logs_path (str): the path of the directory, under this folder there
                should be critic/ and agent/ folders.
        """
        model_folder = path_c = os.path.join(logs_path, agent_type)
        path_c = os.path.join(model_folder, 'critic',
                              'critic-{}.pth'.format(step))
        if actor_path is not None:
            path_a = actor_path
        else:
            path_a = os.path.join(model_folder, 'actor',
                                  'actor-{}.pth'.format(step))
        self.learner.critic.load_state_dict(
            torch.load(path_c, map_location=self.device))
        self.learner.critic_target.load_state_dict(
            torch.load(path_c, map_location=self.device))
        self.learner.actor.load_state_dict(
            torch.load(path_a, map_location=self.device))
        print('  <= Restore {} with {} updates from {}.'.format(
            agent_type, step, model_folder))
