import psutil
import time
import torch
import math
from collections import deque

from open_source.rlpyt.rlpyt.runners.base import BaseRunner
from open_source.rlpyt.rlpyt.runners.minibatch_rl import MinibatchRlEval
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.utils.seed import set_seed, make_seed
from open_source.rlpyt.rlpyt.utils.logging import logger
from open_source.rlpyt.rlpyt.utils.prog_bar import ProgBarCounter
from open_source.rlpyt.rlpyt.samplers.buffer import build_samples_buffer
from open_source.rlpyt.rlpyt.samplers.collections import BatchSpec
import multiprocessing as mp
from open_source.rlpyt.rlpyt.runners.minibatch_rl_pcgrad import MinibatchRlEval_PCGrad
import pdb

class MinibatchRlEval_PCGrad_Parallel(MinibatchRlEval_PCGrad):
    ### samplers : list of samplers ###
    _eval = False

    def __init__(
            self,
            algo,
            agent,
            sampler,
            n_steps,
            seed=None,
            affinity=None,
            log_interval_steps=1e5,
            start_itr=0
    ):
        n_steps = int(n_steps)
        log_interval_steps = int(log_interval_steps)
        affinity = dict() if affinity is None else affinity
        save__init__args(locals())
        self.min_itr_learn = getattr(self.algo, 'min_itr_learn', 0)
        self.pool = mp.Pool(mp.cpu_count())

    def startup(self):
        """
        Sets hardware affinities, initializes the following: 1) sampler (which
        should initialize the agent), 2) agent device and data-parallel wrapper (if applicable),
        3) algorithm, 4) logger.
        """
        # pdb.set_trace()
        p = psutil.Process()
        try:
            if (self.affinity.get("master_cpus", None) is not None and
                    self.affinity.get("set_affinity", True)):
                p.cpu_affinity(self.affinity["master_cpus"])
            cpu_affin = p.cpu_affinity()
        except AttributeError:
            cpu_affin = "UNAVAILABLE MacOS"
        logger.log(f"Runner {getattr(self, 'rank', '')} master CPU affinity: "
            f"{cpu_affin}.")
        if self.affinity.get("master_torch_threads", None) is not None:
            torch.set_num_threads(self.affinity["master_torch_threads"])
        logger.log(f"Runner {getattr(self, 'rank', '')} master Torch threads: "
            f"{torch.get_num_threads()}.")
        if self.seed is None:
            self.seed = make_seed()
        set_seed(self.seed)
        self.rank = rank = getattr(self, "rank", 0)
        self.world_size = world_size = getattr(self, "world_size", 1)
        examples = self.sampler.initialize(
            agent=self.agent,  # Agent gets initialized in sampler.
            affinity=self.affinity,
            seed=self.seed + 1,
            bootstrap_value=getattr(self.algo, "bootstrap_value", False),
            traj_info_kwargs=self.get_traj_info_kwargs(),
            rank=rank,
            world_size=world_size,
        )
        self.itr_batch_size = self.sampler.batch_spec.size * world_size
        n_itr = self.get_n_itr()
        self.agent.to_device(self.affinity.get("cuda_idx", None))
        if world_size > 1:
            self.agent.data_parallel()
        self.algo.initialize(
            agent=self.agent,
            n_itr=n_itr,
            batch_spec=self.sampler.batch_spec,
            mid_batch_reset=self.sampler.mid_batch_reset,
            examples=examples,
            world_size=world_size,
            rank=rank,
        )
        self.initialize_logging()
        return n_itr

    def train(self, log_dir, writer):
        """
        Performs startup, evaluates the initial agent, then loops by
        alternating between ``sampler.obtain_samples()`` and
        ``algo.optimize_agent()``.  Pauses to evaluate the agent at the
        specified log interval.
        """
        n_itr = self.startup()
        ### initialize the max Return
        Return_max = float('-inf')

        with logger.prefix(f"itr #0 "):
            eval_traj_infos, eval_time = self.evaluate_agent(0)
            self.log_diagnostics(0, eval_traj_infos, eval_time)

        final_itr = n_itr
        for itr in range(self.start_itr+1, n_itr):
            # pdb.set_trace()
            logger.set_iteration(itr)
            with logger.prefix(f"itr #{itr} "):
                self.agent.sample_mode(itr)
                # pdb.set_trace()
                samples, traj_infos = self.sampler.obtain_samples(itr)
                self.stra_samples = samples
                self.agent.train_mode(itr)
                opt_info = self.algo.optimize_agent(itr, self.stra_samples)
                ### temporarily disable logging
                self.store_diagnostics(itr, traj_infos, opt_info)
                # pdb.set_trace()
                if (itr + 1) % self.log_interval_itrs == 0:
                    ### record all envs rewards ###
                    all_eval_traj_infos, all_eval_time = self.evaluate_agent_all(itr)
                    all_corner_infos = [dict['kwargs']['corner'] for dict in self.sampler.all_eval_env_kwargs_list]
                    self.write_logs(all_corner_infos, all_eval_traj_infos, writer, itr)
                    ### evaluate current corners ###
                    eval_traj_infos, eval_time = self.evaluate_agent(itr)
                    # pdb.set_trace()
                    corner_infos = [dict['kwargs']['corner'] for dict in self.sampler.eval_env_kwargs_list]
                    # self.write_logs(corner_infos, eval_traj_infos, writer, itr)
                    if eval_traj_infos[0]['Return'] > Return_max and itr >= self.min_itr_learn - 1:
                        Return_max = eval_traj_infos[0]['Return']
                        self.save_best_snapshot(itr)
                    self.log_diagnostics(itr, eval_traj_infos, eval_time)
                    min_Return = min([info['Return'] for info in eval_traj_infos])
                    max_Return = max([info['Return'] for info in eval_traj_infos])
                    if min_Return > 0:
                    # if max_Return > 0:
                    # if eval_traj_infos[0]['Return'] > 0:
                        self.save_cur_snapshot(n_itr - 1, "last")
                        # if itr > 1000:
                        #     ## save fisher and cur weights
                        #     self.consolidate(n_itr)
                        # else:
                        self.consolidate_zero(n_itr)
                        final_itr = itr
                        break
                if itr == n_itr - 1:
                    if itr > 1000:
                        self.consolidate(n_itr)
                    else:
                        self.consolidate_zero(n_itr)
        ### save replay buffer ###
        import pickle, sys
        # pdb.set_trace()
        replay_buffers = self.algo.replay_buffers
        print('saving buffers in %s' % (log_dir))
        with open(log_dir+'/buffers.pkl', 'wb') as fh:
            pickle.dump(replay_buffers, fh, protocol=4)
        print('buffers saved')
        self.shutdown()
        return final_itr

    def __getstate__(self):
        self_dict = self.__dict__.copy()
        del self_dict['pool']
        return self_dict

    def evaluate_agent(self, itr):
        """
        Record offline evaluation of agent performance, by ``sampler.evaluate_agent()``.
        """
        if itr > 0:
            self.pbar.stop()

        if itr >= self.min_itr_learn - 1 or itr == 0:
            logger.log("Evaluating agent in parallel...")
            self.agent.eval_mode(itr)  # Might be agent in sampler.
            eval_time = -time.time()
            traj_infos = self.sampler.evaluate_agent(itr)
            eval_time += time.time()
        else:
            traj_infos = []
            eval_time = 0.0
        logger.log("Evaluation runs complete.")
        return traj_infos, eval_time

    def evaluate_agent_all(self, itr):
        """
        Record offline evaluation of agent performance, by ``sampler.evaluate_agent()``.
        """
        if itr > 0:
            self.pbar.stop()

        if itr >= self.min_itr_learn - 1 or itr == 0:
            logger.log("Evaluating agent in all envs...")
            self.agent.eval_mode(itr)  # Might be agent in sampler.
            eval_time = -time.time()
            traj_infos = self.sampler.evaluate_agent_all(itr)
            eval_time += time.time()
        else:
            traj_infos = []
            eval_time = 0.0
        logger.log("Evaluation runs complete.")
        return traj_infos, eval_time

    def get_itr_snapshot(self, itr):
        """
        Returns all state needed for full checkpoint/snapshot of training run,
        including agent parameters and optimizer parameters.
        """
        return dict(
            itr=itr,
            cum_steps=itr * self.sampler.batch_size * self.world_size,
            agent_state_dict=self.agent.state_dict(),
            optimizer_state_dict=self.algo.optim_state_dict(),
        )

    def log_diagnostics(self, itr, traj_infos=None, eval_time=0, prefix='Diagnostics/'):
        """
        Write diagnostics (including stored ones) to csv via the logger.
        """
        if itr > 0:
            self.pbar.stop()
        if itr >= self.min_itr_learn - 1:
            self.save_itr_snapshot(itr)
        new_time = time.time()
        self._cum_time = new_time - self._start_time
        train_time_elapsed = new_time - self._last_time - eval_time
        new_updates = self.algo.update_counter - self._last_update_counter
        new_samples = (self.sampler.batch_size * self.world_size *
            self.log_interval_itrs)
        updates_per_second = (float('nan') if itr == 0 else
            new_updates / train_time_elapsed)
        samples_per_second = (float('nan') if itr == 0 else
            new_samples / train_time_elapsed)
        replay_ratio = (new_updates * self.algo.batch_size * self.world_size /
            new_samples)
        cum_replay_ratio = (self.algo.batch_size * self.algo.update_counter /
            ((itr + 1) * self.sampler.batch_size))  # world_size cancels.
        cum_steps = (itr + 1) * self.sampler.batch_size * self.world_size

        with logger.tabular_prefix(prefix):
            if self._eval:
                logger.record_tabular('CumTrainTime',
                    self._cum_time - self._cum_eval_time)  # Already added new eval_time.
            logger.record_tabular('Iteration', itr)
            logger.record_tabular('CumTime (s)', self._cum_time)
            logger.record_tabular('CumSteps', cum_steps)
            logger.record_tabular('CumCompletedTrajs', self._cum_completed_trajs)
            logger.record_tabular('CumUpdates', self.algo.update_counter)
            logger.record_tabular('StepsPerSecond', samples_per_second)
            logger.record_tabular('UpdatesPerSecond', updates_per_second)
            logger.record_tabular('ReplayRatio', replay_ratio)
            logger.record_tabular('CumReplayRatio', cum_replay_ratio)
        self._log_infos(traj_infos)
        # pdb.set_trace()
        logger.dump_tabular(with_prefix=False)

        self._last_time = new_time
        self._last_update_counter = self.algo.update_counter
        if itr < self.n_itr - 1:
            logger.log(f"Optimizing over {self.log_interval_itrs} iterations.")
            self.pbar = ProgBarCounter(self.log_interval_itrs)

    def write_logs(self, corner_infos, eval_traj_infos, writer, itr):
        # pdb.set_trace()
        rewards = [dict['Return'] for dict in eval_traj_infos]
        corner = []
        for info in corner_infos:
            corner.append('%s_%s_%s' % (info['process'], info['temp'], info['vdd']))
        for index, reward in enumerate(rewards):
            writer.add_scalar(corner[index] + '_reward', reward, itr)


    def shutdown(self):
        logger.log("PCGrad Parallel Training complete.")
        self.pbar.stop()
        self.sampler.shutdown()