import os
import numpy as np
from tqdm import tqdm
import torch
import pickle
from collections import deque

from .agent_off import AgentMetaStepOff
from algorithm.sac_meta import SAC_meta
from util.pickle import save_obj, load_obj


class AgentSACReptile(AgentMetaStepOff):
    def __init__(self, venv,
                        CONFIG,
                        CONFIG_ARCH,
                        CONFIG_ENV,
                        verbose=True):
        """
        Args:
            CONFIG (Class object): hyper-parameter configuration.
            verbose (bool, optional): print info or not. Defaults to True.
        """
        super(AgentSACReptile, self).__init__(venv, CONFIG, CONFIG_ENV)
        # if CONFIG_ENV.USE_LANG:
        #     self.num_lang_update = CONFIG.NUM_LANG_UPDATE
        #     self.lang_update_freq = CONFIG.LANG_UPDATE_FREQ
        self.num_t = CONFIG.NUM_T

        # Choose checkpoint if adapt
        if CONFIG.EVAL:
            train_details = os.path.join(self.out_folder, 'train_details')
            train_details = torch.load(train_details)
            steps = list(train_details['eval_record'].keys())
            train_reward_all = []
            for step in steps:
                loss = train_details['loss_record'][step]
                train_reward_all += [loss[-1]]
            best_step_ind = np.argmax(train_reward_all)
            best_step = steps[best_step_ind]
            # best_step = 1800

            actor_path = os.path.join(self.out_folder, 
                                      'model',
                                      'actor',
                                      'actor-'+str(best_step)+'.pth')
            critic_path = os.path.join(self.out_folder, 
                                      'model',
                                      'critic',
                                      'critic-'+str(best_step)+'.pth')
            CONFIG_ARCH.ACTOR_PATH = actor_path
            CONFIG_ARCH.CRITIC_PATH = critic_path

        print("= Constructing policy agent")
        CONFIG_ARCH.APPEND_DIM['critic'] = self.append_dim
        CONFIG_ARCH.APPEND_DIM['actor'] = self.append_dim
        self.learner = SAC_meta(CONFIG)
        self.learner.build_network(CONFIG_ARCH,
                                    build_optimizer=True,
                                    verbose=verbose)

        # alias
        self.module_all = [self.learner]
        model_folder = os.path.join(self.out_folder, 'model')
        os.makedirs(model_folder, exist_ok=True)
        self.module_folder_all = [model_folder]


    def evaluate(self):
        """Adapt"""
        record = []
        if self.num_eval_task >= len(self.task_test_all):
            self.eval_task_id_all = range(len(self.task_test_all))
        else:
            # self.eval_task_id_all = self.rng.integers(low=0,
            #                                 high=len(self.task_test_all), 
            #                                 size=self.num_eval_task)
            self.eval_task_id_all = np.linspace(start=0,
                                            stop=len(self.task_test_all), 
                                            num=self.num_eval_task,
                                            endpoint=False,
                                            dtype=int)
        print('Adapting to ', self.eval_task_id_all)
        for task_id in tqdm(self.eval_task_id_all, leave=False):
            pre_stat, post_stat, train_reward = self.adapt(
                                                optimizer_state=None, 
                                                task_id=task_id,
                                                task_all=self.task_test_all)
            record += [(pre_stat, post_stat, train_reward)]
        save_obj(record, os.path.join(self.out_folder, 'adapt'))


    def learn(self):
        train_reward_running = deque([0 for _ in range(10)], maxlen=10)

        for self.cnt_step in tqdm(range(self.num_itr), leave=False):
            self.set_train_mode()

            # Adapt - assume only one task at a time
            self.adapt_task_id_all = self.rng.integers(low=0, 
                                                    high=self.num_task, 
                                                    size=self.num_adapt_task)

            # Initializer inner optimizer state
            optimizer_state = None 

            # Record entrpy initial, final, and average among inner update
            entropy_all = np.zeros((3))  # avg, init, final
            loss_q_all = np.zeros((3))
            loss_pi_all = np.zeros((3))

            # Loop thru task
            train_reward_epoch_all = []

            # Assume one task!
            for ind, task_id in enumerate(self.adapt_task_id_all):

                # Clone - essentially reset critic/actor to meta critic/actor in sac_mini
                self.learner.clone(optimizer_state)

                # Reset online buffer
                self.memory.reset_online()

                for t_ind in range(self.num_t):

                    # Loop thru adapt steps
                    for adapt_step in range(self.num_adapt_step):

                        # Sample if 1st
                        if t_ind == 0:
                            train_reward = self.run_task(task_id, 
                                                        self.num_adapt_episode,
                                                        task_all=self.task_all)

                        # Adapt - do not update language module at adaptation
                        loss = np.zeros(4)
                        for timer in range(self.num_adapt_update):
                            batch = self.unpack_batch(self.sample_batch(online_weight=self.adapt_online_weight))
                            loss_q, loss_pi, loss_entropy, loss_alpha = self.learner.update(batch, timer,
                                                update_period=self.update_period,
                                                detach_lang=True)
                            loss_q_all[0] += loss_q
                            loss_pi_all[0] += loss_pi
                            entropy_all[0] += loss_entropy
                            if adapt_step == 0:
                                loss_q_all[1] += loss_q
                                loss_pi_all[1] += loss_pi
                                entropy_all[1] += loss_entropy
                            if adapt_step == self.num_adapt_step - 1:
                                loss_q_all[2] += loss_q
                                loss_pi_all[2] += loss_pi
                                entropy_all[2] += loss_entropy

                    # Record reward from last adapt step
                    if t_ind == 0:
                        train_reward_epoch_all += [train_reward]

                    # Store gradient in meta
                    self.learner.store_grad()

                    # Fill meta buffer with online samples
                    self.memory.fill_meta()

                    ################### Meta Update ###################

                    # Update rest with stored gradient
                    self.learner.meta_update(num_task=self.num_adapt_task, 
                                            timer=self.cnt_step)

            ################### Record ###################
            loss_q_all = loss_q_all/(self.num_adapt_update*self.num_adapt_task*self.num_t)
            loss_pi_all = loss_pi_all/(self.num_adapt_update*self.num_adapt_task*self.num_t)
            entropy_all = entropy_all/(self.num_adapt_update*self.num_adapt_task*self.num_t)
            loss_q_all[0] = loss_q_all[0]/self.num_adapt_step    # avg
            loss_pi_all[0] = loss_pi_all[0]/self.num_adapt_step
            entropy_all[0] = entropy_all[0]/self.num_adapt_step
            train_reward_epoch = np.mean(train_reward_epoch_all)
            self.loss_record[self.cnt_step] = (loss_q_all, loss_pi_all, 
                                         entropy_all, train_reward_epoch)

            # Running avg
            train_reward_running.appendleft(train_reward_epoch)
            train_reward_running_avg = np.mean(train_reward_running)

            ################## Eval ###################
            if self.cnt_step > 0 and self.cnt_step % self.check_opt_freq == 0:
                test_pre = np.zeros((3))
                test_post = np.zeros((3))

                # Test envs
                if self.num_eval_task >= len(self.task_test_all):
                    self.eval_task_id_all = range(len(self.task_test_all))
                else:
                    self.eval_task_id_all = self.rng.integers(low=0,
                                                high=len(self.task_test_all), 
                                                size=self.num_eval_task)
                for task_id in tqdm(self.eval_task_id_all, leave=False):
                    pre_stat, post_stat, _ = self.adapt(optimizer_state, 
                                                    task_id=task_id,
                                                    task_all=self.task_test_all)
                    test_pre += pre_stat / self.num_eval_task
                    test_post += post_stat / self.num_eval_task

                # Eval stat
                self.eval_record[self.cnt_step]= (
                                                test_pre, test_post, 
                                                )
                print('Test - Pre: ', test_pre)
                print('Test - Post: ', test_post)

                # Saving (meta) model - using post stat
                if self.save_metric == 'success':
                    self.save(metric=train_reward_running_avg)
                elif self.save_metric == 'cum_reward':
                    self.save(metric=train_reward_running_avg)
                elif self.save_metric == 'best_reward':
                    self.save(metric=train_reward_running_avg)
                else:
                    raise NotImplementedError

                # Save training details
                torch.save(
                    {
                        'loss_record': self.loss_record,
                        'eval_record': self.eval_record,
                    }, os.path.join(self.out_folder, 'train_details'))

                # Reset simulation
                self.reset_sim()

        ################### Done ###################
        self.save(force_save=True)


    def adapt(self, optimizer_state, task_id, task_all):
        self.learner.clone(optimizer_state)
        self.memory.reset_online()

        # Test before adapt - do not re-use optimizer state
        self.set_eval_mode()
        pre_stat = self.run_task(task_id, 
                                 self.num_episode_per_eval,
                                 task_all=task_all)

        # Loop thru adapt steps
        self.set_train_mode()
        train_reward_all = []
        for _ in range(self.num_adapt_step):

            # Sample    # TODO: not always sample?
            train_reward = self.run_task(task_id, 
                                        self.num_adapt_episode,
                                        task_all=task_all)
            train_reward_all += [train_reward]

            # Adapt - do not update language module at adaptation
            for timer in range(self.num_adapt_update):
                batch = self.unpack_batch(
                        self.sample_batch(
                            online_weight=self.adapt_online_weight
                        ))
                self.learner.update(batch, timer, 
                                    update_period=self.update_period,
                                    detach_lang=True)

        # Test - do not re-use optimizer state
        self.set_eval_mode()
        post_stat = self.run_task(task_id, 
                                self.num_episode_per_eval,
                                task_all=task_all)
        return pre_stat, post_stat, train_reward_all


    def run_task(self, task_id, num_episode, task_all):
        """Interact with one task"""
        task_ids = [task_id for _ in range(self.n_envs)]

        s, _ = self.reset_env_all(task_ids=task_ids, task_all=task_all)
        _prev_s = [deque([[-1 for _ in range(self.venv.state_dim)] for _ in range(self.num_state_stack)], maxlen=self.num_state_stack) for _ in range(self.n_envs)]
        _prev_a = [deque([[-1 for _ in range(self.venv.action_dim)] for _ in range(self.num_action_stack)], maxlen=self.num_action_stack) for _ in range(self.n_envs)]
        num_episode_run = 0

        # Initialize language
        if self.use_lang:
            self.lang_ind_all = self.rng.integers(0, self.num_lang_per_env, size=self.n_envs) 

        r_train = torch.tensor([])
        while num_episode_run < num_episode:

            # Interact
            append_all = self.get_append(task_all, task_ids, _prev_s, _prev_a)

            # Select action
            with torch.no_grad():
                a_all = self.forward(s, append=append_all, latent=None)

            # Apply action - update heading
            s_all, r_all, done_all, info_all = self.step(a_all)

            # Record state
            for env_ind, _prev_s_env in enumerate(_prev_s):
                _prev_s_env.appendleft(info_all[env_ind]['s'])

            # Record prev_action
            for env_ind, _prev_a_env in enumerate(_prev_a):
                _prev_a_env.appendleft(a_all[env_ind].clone().cpu().numpy())

            # Get new append
            append_nxt_all = self.get_append(task_all, task_ids, _prev_s, _prev_a)

            # Check all envs
            for env_ind, (s_, r, done, info) in enumerate(
                    zip(s_all, r_all, done_all, info_all)):

                # Save append
                if append_all is not None:
                    info['append'] = append_all[env_ind].unsqueeze(0)
                    info['append_nxt'] = append_nxt_all[env_ind].unsqueeze(0)

                # Store the transition in memory
                action = a_all[env_ind].unsqueeze(0).clone()
                if not self.eval_mode:
                    self.store_transition(
                        s[env_ind].unsqueeze(0).to(self.image_device), 
                        action, r, s_.unsqueeze(0).to(self.image_device), 
                        done, info)

                # Increment step count for the env
                self.env_step_cnt[env_ind] += 1

                # Check reward
                if self.eval_mode:
                    self.eval_reward_cumulative[env_ind] += r.item()
                    self.eval_reward_best[env_ind] = max(self.eval_reward_best[env_ind], r.item())
                else:
                    r_train = torch.hstack((r_train, r_all.view(-1)))

                # Check done for particular env
                if done or self.env_step_cnt[env_ind] > self.max_env_step:
                    if self.eval_mode:
                        self.num_eval_success += (self.eval_reward_best[env_ind] > self.success_threshold)
                        self.eval_reward_cumulative_all += self.eval_reward_cumulative[env_ind]
                        self.eval_reward_best_all += self.eval_reward_best[env_ind]
                        self.eval_reward_cumulative[env_ind] = 0
                        self.eval_reward_best[env_ind] = 0

                    # Reset - assume same task - only change task after meta update and reset_all, although not optimal design
                    s_one, _ = self.reset_env(env_ind, 
                                            task_id=task_id, 
                                            task_all=task_all,
                                            verbose=False)
                    s_all[env_ind] = s_one
                    _prev_s[env_ind] = deque([[-1 for _ in range(self.venv.state_dim)] for _ in range(self.num_state_stack)], maxlen=self.num_state_stack)
                    _prev_a[env_ind] = deque([[-1 for _ in range(self.venv.action_dim)] for _ in range(self.num_action_stack)], maxlen=self.num_action_stack)
                    if self.use_lang:
                        self.lang_ind_all[env_ind] = self.rng.integers(0, self.num_lang_per_env) 

                    # Count
                    num_episode_run += 1

            # Update "prev" states
            s = s_all

        # Get eval stats
        if self.eval_mode:
            eval_success = self.num_eval_success / num_episode_run
            eval_reward_cumulative = self.eval_reward_cumulative_all / num_episode_run
            eval_reward_best = self.eval_reward_best_all / num_episode_run
            return np.array([eval_success, eval_reward_cumulative, eval_reward_best])
        else:
            return float(torch.mean(r_train)*self.max_env_step)
