
import psutil
import time
import torch
import math
from collections import deque

from open_source.rlpyt.rlpyt.runners.base import BaseRunner
from open_source.rlpyt.rlpyt.runners.minibatch_rl import MinibatchRlEval
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.utils.seed import set_seed, make_seed
from open_source.rlpyt.rlpyt.utils.logging import logger
from open_source.rlpyt.rlpyt.utils.prog_bar import ProgBarCounter
from open_source.rlpyt.rlpyt.samplers.buffer import build_samples_buffer
from open_source.rlpyt.rlpyt.samplers.collections import BatchSpec
import pdb

class MinibatchRlEval_PCGrad(MinibatchRlEval):
    ### samplers : list of samplers ###
    _eval = False

    def __init__(
            self,
            algo,
            agent,
            samplers,
            n_steps,
            seed=None,
            affinity=None,
            log_interval_steps=1e5,
            start_itr=0
    ):
        n_steps = int(n_steps)
        log_interval_steps = int(log_interval_steps)
        affinity = dict() if affinity is None else affinity
        save__init__args(locals())
        self.min_itr_learn = getattr(self.algo, 'min_itr_learn', 0)

    def startup(self):
        """
        Sets hardware affinities, initializes the following: 1) sampler (which
        should initialize the agent), 2) agent device and data-parallel wrapper (if applicable),
        3) algorithm, 4) logger.
        """
        # pdb.set_trace()
        p = psutil.Process()
        try:
            if (self.affinity.get("master_cpus", None) is not None and
                    self.affinity.get("set_affinity", True)):
                p.cpu_affinity(self.affinity["master_cpus"])
            cpu_affin = p.cpu_affinity()
        except AttributeError:
            cpu_affin = "UNAVAILABLE MacOS"
        logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
            f"{cpu_affin}.")
        if self.affinity.get("master_torch_threads", None) is not None:
            torch.set_num_threads(self.affinity["master_torch_threads"])
        logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
            f"{torch.get_num_threads()}.")
        if self.seed is None:
            self.seed = make_seed()
        set_seed(self.seed)
        self.rank = rank = getattr(self, "rank", 0)
        self.world_size = world_size = getattr(self, "world_size", 1)
        for id, sampler in self.samplers.items():
            examples = sampler.initialize(
                agent=self.agent,  # Agent gets initialized in sampler.
                affinity=self.affinity,
                seed=self.seed + 1,
                bootstrap_value=getattr(self.algo, "bootstrap_value", False),
                traj_info_kwargs=self.get_traj_info_kwargs(),
                rank=rank,
                world_size=world_size,
            )
        ### initialize with the last sampler ###
        self.itr_batch_size = self.samplers[id].batch_spec.size * world_size
        n_itr = self.get_n_itr()
        self.agent.to_device(self.affinity.get("cuda_idx", None))
        if world_size > 1:
            self.agent.data_parallel()
        ### initialize multiple buffers ###
        self.algo.initialize(
            agent=self.agent,
            n_itr=n_itr,
            batch_spec=self.samplers[id].batch_spec,
            mid_batch_reset=self.samplers[id].mid_batch_reset,
            examples=examples,
            world_size=world_size,
            rank=rank,
        )
        ### initialize a stratified sample batch ###
        env = self.samplers[id].EnvCls(**self.samplers[id].env_kwargs)
        stra_batch_spec = BatchSpec(T=self.samplers[id].batch_spec.T,
                                    B=len(self.samplers) * self.samplers[id].batch_spec.B)
        stra_samples_pyt, stra_samples_np, examples = build_samples_buffer(self.agent, env,
            stra_batch_spec, bootstrap_value=False, agent_shared=False,
            env_shared=False, subprocess=False)
        self.stra_samples = stra_samples_pyt
        self.stra_samples_np = stra_samples_np
        # pdb.set_trace()
        self.initialize_logging()
        return n_itr

    def train(self, log_dir):
        """
        Performs startup, evaluates the initial agent, then loops by
        alternating between ``sampler.obtain_samples()`` and
        ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
        specified log interval.
        """
        n_itr = self.startup()
        ### initialize the max Return
        Return_max = float('-inf')

        with logger.prefix(f"itr #0 "):
            eval_traj_infos, eval_time = self.evaluate_agent(0)
            self.log_diagnostics(0, eval_traj_infos, eval_time)

        final_itr = n_itr
        for itr in range(self.start_itr+1, n_itr):
            # pdb.set_trace()
            logger.set_iteration(itr)
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)
                # pdb.set_trace()
                i = 0
                for id, sampler in self.samplers.items():
                    samples, traj_infos = sampler.obtain_samples(itr)
                    self.stra_samples[0, i*sampler.batch_spec.B : (i+1)*sampler.batch_spec.B] = samples[0, 0 : sampler.batch_spec.B]
                    i += 1
                # pdb.set_trace()
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, self.stra_samples)
                ### temporarily disable logging
                self.store_diagnostics(itr, traj_infos, opt_info)
                # pdb.set_trace()
                if (itr + 1) % self.log_interval_itrs == 0:
                    eval_traj_infos, eval_time = self.evaluate_agent(itr)
                    if eval_traj_infos[0]['Return'] > Return_max and itr >= self.min_itr_learn - 1:
                        Return_max = eval_traj_infos[0]['Return']
                        self.save_best_snapshot(itr)
                    self.log_diagnostics(itr, eval_traj_infos, eval_time)
                    min_Return = min([info['Return'] for info in eval_traj_infos])
                    max_Return = max([info['Return'] for info in eval_traj_infos])
                    if min_Return > 0:
                    # if max_Return > 0:
                    # if eval_traj_infos[0]['Return'] > 0:
                        self.save_cur_snapshot(n_itr - 1, "last")
                        # if itr > 1000:
                        #     ## save fisher and cur weights
                        #     self.consolidate(n_itr)
                        # else:
                        self.consolidate_zero(n_itr)
                        final_itr = itr
                        break
                if itr == n_itr - 1:
                    if itr > 1000:
                        self.consolidate(n_itr)
                    else:
                        self.consolidate_zero(n_itr)
        ### save replay buffer ###
        import pickle, sys
        # pdb.set_trace()
        replay_buffers = self.algo.replay_buffers
        print('saving buffers in %s' % (log_dir))
        with open(log_dir+'/buffers.pkl', 'wb') as fh:
            pickle.dump(replay_buffers, fh, protocol=4)
        print('buffers saved')
        self.shutdown()
        return final_itr

    def shutdown(self):
        logger.log("Training complete.")
        self.pbar.stop()
        for id, sampler in self.samplers.items():
            sampler.shutdown()

    ### temperorily evaluate agent in no.1 env ###
    def evaluate_agent(self, itr):
        """
        Record offline evaluation of agent performance, by ``sampler.evaluate_agent()``.
        """
        if itr > 0:
            self.pbar.stop()

        if itr >= self.min_itr_learn - 1 or itr == 0:
            logger.log("Evaluating agent...")
            self.agent.eval_mode(itr)  # Might be agent in sampler.
            eval_time = -time.time()
            ### traj_infos is a list. 1 info for each sampler ###
            traj_infos = []
            for i, sampler in self.samplers.items():
                traj_info = sampler.evaluate_agent(itr)
                traj_infos += traj_info
            eval_time += time.time()
        else:
            traj_infos = []
            eval_time = 0.0
        logger.log("Evaluation runs complete.")
        return traj_infos, eval_time

    def save_cur_snapshot(self, itr, prefix):
        """
        Calls the logger to save training checkpoint/snapshot (logger itself
        may or may not save, depending on mode selected).
        """
        logger.log("saving %s snapshot..." % (prefix))
        params = self.get_itr_snapshot(itr)
        logger.save_cur_params(prefix, params)
        logger.log("saved")

    def get_itr_snapshot(self, itr):
        """
        Returns all state needed for full checkpoint/snapshot of training run,
        including agent parameters and optimizer parameters.
        """
        return dict(
            itr=itr,
            # cum_steps=itr * self.samplers[0].batch_size * self.world_size,
            cum_steps=itr * next(iter(self.samplers.values())).batch_size * self.world_size,
            agent_state_dict=self.agent.state_dict(),
            optimizer_state_dict=self.algo.optim_state_dict(),
        )

    def save_best_snapshot(self, itr):
        """
        Calls the logger to save training checkpoint/snapshot (logger itself
        may or may not save, depending on mode selected).
        """
        logger.log("saving best snapshot...")
        params = self.get_itr_snapshot(itr)
        logger.save_best_params(itr, params)
        logger.log("saved")

    def consolidate(self, n_itr):
        print(
            'Estimating diagonals of the fisher information matrix...',
            flush=True, end='',
        )
        mu_fisher, q_fisher = self.algo.calculate_fisher()
        self.algo.consolidate(mu_fisher_matrix=mu_fisher, q_fisher_matrix=q_fisher)
        self.save_cur_snapshot(n_itr - 1, "last")
        print(' Consolidated!')

    def consolidate_zero(self, n_itr):
        self.algo.consolidate_zero()
        self.save_cur_snapshot(n_itr - 1, "last")
        print('Not Consolidated!')

 ### temporarily disable the logging ###
    def log_diagnostics(self, itr, eval_traj_infos, eval_time, prefix='Diagnostics/'):
        if not eval_traj_infos:
            logger.log("WARNING: had no complete trajectories in eval.")
        steps_in_eval = sum([info["Length"] for info in eval_traj_infos])
        with logger.tabular_prefix(prefix):
            logger.record_tabular('StepsInEval', steps_in_eval)
            logger.record_tabular('TrajsInEval', len(eval_traj_infos))
            self._cum_eval_time += eval_time
            logger.record_tabular('CumEvalTime', self._cum_eval_time)
        # super().log_diagnostics(itr, eval_traj_infos, eval_time, prefix=prefix)

        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (next(iter(self.samplers.values())).batch_size * self.world_size *
                       self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else
                              new_updates / train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else
                              new_samples / train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
                        new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
                            ((itr + 1) * next(iter(self.samplers.values())).batch_size))  # world_size cancels.
        cum_steps = (itr + 1) * next(iter(self.samplers.values())).batch_size * self.world_size

        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular('CumTrainTime',
                                      self._cum_time - self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_infos(eval_traj_infos)
        # pdb.set_trace()
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)