import os
import numpy as np
import torch
from collections import deque
from tqdm import tqdm

from .agent_off import AgentBaseStepOff
from algorithm.sac_mini import SAC_mini
from util.pickle import save_obj


class AgentSAC(AgentBaseStepOff):
    def __init__(self,
                 venv,
                 CONFIG,
                 CONFIG_ARCH,
                 CONFIG_ENV,
                 verbose=True):
        """
        Args:
            CONFIG (Class object): hyper-parameter configuration.
            verbose (bool, optional): print info or not. Defaults to True.
        """
        super(AgentSAC, self).__init__(venv, CONFIG, CONFIG_ENV)

        print("= Constructing policy agent")
        CONFIG_ARCH.APPEND_DIM['critic'] = self.append_dim
        CONFIG_ARCH.APPEND_DIM['actor'] = self.append_dim
        self.learner = SAC_mini(CONFIG)
        self.learner.build_network(CONFIG_ARCH,
                                    build_optimizer=not self.eval,
                                    verbose=verbose)

        # alias
        self.module_all = [self.learner]
        model_folder = os.path.join(self.out_folder, 'model')
        os.makedirs(model_folder, exist_ok=True)
        self.module_folder_all = [model_folder]


    def evaluate(self):
        """Adapt - """
        num_eval_task = 45
        num_episode = 20
        num_adapt = 10  #!
        num_step = num_episode*self.max_train_steps
        record = []
        if num_eval_task >= len(self.task_test_all):
            self.eval_task_id_all = range(len(self.task_test_all))
        else:
            # self.eval_task_id_all = self.rng.integers(low=0,
            #                                 high=len(self.task_test_all), 
            #                                 size=num_eval_task)
            self.eval_task_id_all = np.linspace(start=0,
                                            stop=len(self.task_test_all), 
                                            num=num_eval_task,
                                            endpoint=False,
                                            dtype=int)
        print('Adapting to ', self.eval_task_id_all)

        # For all test objects
        for task_id in tqdm(self.eval_task_id_all, leave=False):
            task_reward = []

            # Reset replay buffer and optimizer
            self.set_adapt_mode()

            # Adapt multiple iterations
            for _ in range(num_adapt):

                # Collect
                _, reward = self.run_steps(task_all=[self.task_test_all[task_id]],
                                            num_steps=num_step)
                task_reward += [reward]

                # Update
                for timer in range(self.num_update_per_opt):
                    batch = self.unpack_batch(self.sample_batch())
                    self.learner.update(batch, timer, update_period=self.update_period)

            # Save
            record += [(task_reward,)]
        save_obj(record, os.path.join(self.out_folder, 'adapt'))


    def learn(self):
        
        # Run initial steps
        self.set_train_mode()
        if self.min_step_b4_opt > 0:
            self.cnt_step, _ = self.run_steps(task_all=self.task_all,
                                           num_steps=self.min_step_b4_opt)

        # Run rest of steps while optimizing policy
        cnt_opt = 0
        while self.cnt_step < self.max_sample_steps:
            print(self.cnt_step, end='\r')

            # Eval
            if self.eval_mode:
                # Train tasks
                num_episodes_run = self.run_steps(task_all=self.task_all,
                                        num_episodes=self.num_episode_per_eval)
                train_stat = self.finish_eval(num_episodes_run)
                print('Train: ', train_stat)

                # Test tasks
                self.set_eval_mode()
                num_episodes_run = self.run_steps(task_all=self.task_test_all,
                                        num_episodes=self.num_episode_per_eval)
                test_stat = self.finish_eval(num_episodes_run)
                print('Test: ', test_stat)

                # Log
                self.eval_record[self.cnt_step]= (train_stat, test_stat)

                # Saving model
                if self.save_metric == 'success':
                    best_path = self.save(metric=train_stat[0])
                elif self.save_metric == 'cum_reward':
                    best_path = self.save(metric=train_stat[1])
                elif self.save_metric == 'best_reward':
                    best_path = self.save(metric=train_stat[2])
                else:
                    raise NotImplementedError

                # Switch to training
                self.set_train_mode()

                # Reset simulation to clear pb cache
                self.reset_sim()

            # Train
            else:
                cnt_new, r_train = self.run_steps(task_all=self.task_all,
                                                        num_steps=self.opt_freq)
                self.cnt_step += cnt_new

                # Update critic/actor
                loss = np.zeros(4)
                for timer in range(self.num_update_per_opt):
                    batch = self.unpack_batch(self.sample_batch())
                    loss_tp = self.learner.update(
                        batch, timer, update_period=self.update_period)
                    for i, l in enumerate(loss_tp):
                        loss[i] += l
                loss /= self.num_update_per_opt

                # Record: loss_q, loss_pi, loss_entropy, loss_alpha
                self.loss_record[self.cnt_step] = loss

                # Count number of optimization
                cnt_opt += 1

                # Clear GPU cache
                torch.cuda.empty_cache()

                ################### Eval ###################
                if cnt_opt % self.check_opt_freq == 0:
                    self.set_eval_mode()

        ################### Done ###################
        best_path = self.save(force_save=True)
        return best_path


    def run_steps(self, task_all,
                        num_steps=None, 
                        num_episodes=None,
                        return_info=False):
        cnt = 0
        if self.eval_mode:
            cnt_target = num_episodes
        else:
            cnt_target = num_steps

        # Reset
        s, task_ids = self.reset_env_all(task_all=task_all)
        _prev_s = [deque([[-1 for _ in range(self.venv.state_dim)] for _ in range(self.num_state_stack)], maxlen=self.num_state_stack) for _ in range(self.n_envs)]
        _prev_a = [deque([[-1 for _ in range(self.venv.action_dim)] for _ in range(self.num_action_stack)], maxlen=self.num_action_stack) for _ in range(self.n_envs)]

        # Initialize language
        if self.use_lang:
            self.lang_ind_all = self.rng.integers(0, self.num_lang_per_env, size=self.n_envs)

        # Run
        info_episode = []
        r_train = torch.tensor([])
        while cnt < cnt_target:

            # Interact
            append_all = self.get_append(task_all, task_ids, _prev_s, _prev_a)

            # Select action
            with torch.no_grad():
                a_all = self.forward(s, append=append_all, latent=None)

            # Apply action - update heading
            s_all, r_all, done_all, info_all = self.step(a_all)

            # Record state
            for env_ind, _prev_s_env in enumerate(_prev_s):
                _prev_s_env.appendleft(info_all[env_ind]['s'])

            # Record prev_action
            for env_ind, _prev_a_env in enumerate(_prev_a):
                _prev_a_env.appendleft(a_all[env_ind].clone().cpu().numpy())

            # Get new append
            append_nxt_all = self.get_append(task_all, task_ids, _prev_s, _prev_a)

            # Check all envs
            for env_ind, (s_, r, done, info) in enumerate(
                    zip(s_all, r_all, done_all, info_all)):

                # Save append
                if append_all is not None:
                    info['append'] = append_all[env_ind].unsqueeze(0)
                    info['append_nxt'] = append_nxt_all[env_ind].unsqueeze(0)

                # Store the transition in memory
                action = a_all[env_ind].unsqueeze(0).clone()
                if not self.eval_mode:
                    self.store_transition(
                        s[env_ind].unsqueeze(0).to(self.image_device), 
                        action, r,
                        s_.unsqueeze(0).to(self.image_device), done, info)

                # Check reward
                if self.eval_mode:
                    self.eval_reward_cumulative[env_ind] += r.item()
                    self.eval_reward_best[env_ind] = max(self.eval_reward_best[env_ind], r.item())

                # Increment step count for the env
                self.env_step_cnt[env_ind] += 1

                # Check done for particular env
                if done or self.env_step_cnt[env_ind] > self.max_env_step:
                    if self.eval_mode:
                        self.num_eval_success += (self.eval_reward_best[env_ind] > self.success_threshold)
                        self.eval_reward_cumulative_all += self.eval_reward_cumulative[env_ind]
                        self.eval_reward_best_all += self.eval_reward_best[env_ind]
                        self.eval_reward_cumulative[env_ind] = 0
                        self.eval_reward_best[env_ind] = 0
                        # Record info of the episode
                        info_episode += [info]

                        # Count for eval mode
                        cnt += 1

                    # Reset
                    s_one, task_id_one = self.reset_env(env_ind, 
                                                        task_all=task_all, 
                                                        verbose=False)

                    # Update
                    s_all[env_ind] = s_one
                    task_ids[env_ind] = task_id_one
                    _prev_s[env_ind] = deque([[-1 for _ in range(self.venv.state_dim)] for _ in range(self.num_state_stack)], maxlen=self.num_state_stack)
                    _prev_a[env_ind] = deque([[-1 for _ in range(self.venv.action_dim)] for _ in range(self.num_action_stack)], maxlen=self.num_action_stack)
                    if self.use_lang:
                        self.lang_ind_all[env_ind] = self.rng.integers(0, self.num_lang_per_env) 

            r_train = torch.hstack((r_train, r_all.view(-1).cpu()))

            # Count for train mode
            if not self.eval_mode:
                cnt += self.n_envs

                # Update gamma, lr etc.
                for _ in range(self.n_envs):
                    self.learner.update_hyper_param()

            # Update "prev" states
            s = s_all

        if self.eval_mode:
            return cnt
        else:
            return cnt, torch.mean(r_train)*self.max_env_step


    def finish_eval(self, num_episodes_run):
        eval_success = self.num_eval_success / num_episodes_run
        eval_reward_cumulative = self.eval_reward_cumulative_all / num_episodes_run
        eval_reward_best = self.eval_reward_best_all / num_episodes_run
        return np.array([eval_success, eval_reward_cumulative, eval_reward_best])
