
import torch
import numpy as np
from collections import namedtuple

from open_source.rlpyt.rlpyt.algos.qpg.ddpg import DDPG
from open_source.rlpyt.rlpyt.algos.qpg.ddpg_ewc import DDPG_EWC
from open_source.rlpyt.rlpyt.algos.qpg.pcgrad import PCGrad
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.utils.logging import logger
from open_source.rlpyt.rlpyt.replays.non_sequence.uniform import (UniformReplayBuffer,
    AsyncUniformReplayBuffer)
from open_source.rlpyt.rlpyt.replays.non_sequence.time_limit import (TlUniformReplayBuffer,
    AsyncTlUniformReplayBuffer)
from open_source.rlpyt.rlpyt.utils.collections import namedarraytuple
from open_source.rlpyt.rlpyt.utils.tensor import valid_mean
from open_source.rlpyt.rlpyt.algos.utils import valid_from_done
from open_source.rlpyt.rlpyt.samplers.collections import (Samples, AgentSamples, AgentSamplesBsv,
    EnvSamples)
from open_source.rlpyt.rlpyt.replays.non_sequence.time_limit import SamplesFromReplayTL
from open_source.rlpyt.rlpyt.agents.base import AgentInputs
from open_source.rlpyt.rlpyt.utils.collections import namedarraytuple
from open_source.rlpyt.rlpyt.utils.buffer import torchify_buffer
import pdb


OptInfo = namedtuple("OptInfo",
    ["muLoss", "qLoss", "muGradNorm", "qGradNorm"])
SamplesToBuffer = namedarraytuple("SamplesToBuffer",
    ["observation", "action", "reward", "done", "timeout", "taskID"])



class DDPG_PCGrad_Parallel(DDPG):
    """Deep deterministic policy gradient algorithm, training from a replay
    buffer."""

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            all_task_id,
            cur_task_id,
            discount=0.99,
            batch_size=64,
            min_steps_learn=int(1e4),
            replay_size=int(1e6),
            replay_ratio=64,  # data_consumption / data_generation
            target_update_tau=0.01,
            target_update_interval=1,
            policy_update_interval=1,
            learning_rate=1e-4,
            q_learning_rate=1e-3,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,
            clip_grad_norm=1e8,
            q_target_clip=1e6,
            n_step_return=1,
            updates_per_sync=1,  # For async mode only.
            bootstrap_timelimit=True,
            ReplayBufferCls=None,
            replay_buffer_loaded=None,
            num_buffers=1,
            fisher_lambda=400,
            fisher_batch_size=100,
            use_ewc=True,
            ):
        """Saves input arguments."""
        if optim_kwargs is None:
            optim_kwargs = dict()

        self.replay_buffer_loaded = replay_buffer_loaded
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def initialize(self, agent, n_itr, batch_spec, mid_batch_reset, examples,
                  world_size=1, rank=0):
        """Stores input arguments and initializes replay buffer and optimizer.
        Use in non-async runners.  Computes number of gradient updates per
        optimization iteration as `(replay_ratio * sampler-batch-size /
        training-batch_size)`."""
        self.agent = agent
        self.n_itr = n_itr
        self.mid_batch_reset = mid_batch_reset
        # self.sampler_bs = sampler_bs = batch_spec.size
        self.sampler_bs = sampler_bs = 1
        self.updates_per_optimize = max(1, round(self.replay_ratio * sampler_bs /
            self.batch_size))
        logger.log(f"From sampler batch size {sampler_bs}, training "
            f"batch size {self.batch_size}, and replay ratio "
            f"{self.replay_ratio}, computed {self.updates_per_optimize} "
            f"updates per iteration.")
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        # Agent give min itr learn.?
        self.used_buffer = []
        self.initialize_replay_buffer(examples, batch_spec)
        self.optim_initialize(rank)
        self.decodeID = ['tt', 'ff', 'ss', 'sf', 'fs']

    def async_initialize(self, agent, sampler_n_itr, batch_spec, mid_batch_reset,
            examples, world_size=1):
        """Used in async runner only; returns replay buffer allocated in shared
        memory, does not instantiate optimizer. """
        self.agent = agent
        self.n_itr = sampler_n_itr
        self.initialize_replay_buffer(examples, batch_spec, async_=True)
        self.mid_batch_reset = mid_batch_reset
        self.sampler_bs = sampler_bs = batch_spec.size
        self.updates_per_optimize = self.updates_per_sync
        self.min_itr_learn = int(self.min_steps_learn // sampler_bs)
        return self.replay_buffer

    def optim_initialize(self, rank=0):
        """Called in initilize or by async runner after forking sampler."""
        self.rank = rank
        # self.mu_optimizer = self.OptimCls(self.agent.mu_parameters(),
        #     lr=self.learning_rate, **self.optim_kwargs)
        self.mu_optimizer = PCGrad(self.OptimCls(self.agent.mu_parameters(),
            lr=self.learning_rate, **self.optim_kwargs))
        # self.q_optimizer = self.OptimCls(self.agent.q_parameters(),
        #     lr=self.q_learning_rate, **self.optim_kwargs)
        self.q_optimizer = PCGrad(self.OptimCls(self.agent.q_parameters(),
              lr=self.q_learning_rate, **self.optim_kwargs))
        if self.initial_optim_state_dict is not None:
            self.q_optimizer.load_state_dict(self.initial_optim_state_dict["q"])
            self.mu_optimizer.load_state_dict(self.initial_optim_state_dict["mu"])

    def initialize_replay_buffer(self, examples, batch_spec, async_=False):
        """
        Allocates replay buffer using examples and with the fields in `SamplesToBuffer`
        namedarraytuple.
        """
        example_to_buffer = SamplesToBuffer(
            observation=examples["observation"],
            action=examples["action"],
            reward=examples["reward"],
            done=examples["done"],
            timeout=getattr(examples["env_info"], "timeout", None),
            taskID=getattr(examples["env_info"], "taskID", 0)
        )
        replay_kwargs = dict(
            example=example_to_buffer,
            size=self.replay_size,
            # B=batch_spec.B,
            B=1,
            n_step_return=self.n_step_return,
        )
        if not self.bootstrap_timelimit:
            ReplayCls = AsyncUniformReplayBuffer if async_ else UniformReplayBuffer
        else:
            ReplayCls = AsyncTlUniformReplayBuffer if async_ else TlUniformReplayBuffer
        if self.ReplayBufferCls is not None:
            ReplayCls = self.ReplayBufferCls
            logger.log(f"WARNING: ignoring internal selection logic and using"
                f" input replay buffer class: {ReplayCls} -- compatibility not"
                " guaranteed.")
        # self.replay_buffers = []
        # for _ in range(num_buffer):
        #     replay_buffer = ReplayCls(**replay_kwargs)
        #     self.replay_buffers.append(replay_buffer)
        self.replay_buffers = dict.fromkeys(self.all_task_id)
        for id in self.replay_buffers:
            self.replay_buffers[id] = ReplayCls(**replay_kwargs)
        ### load the replay buffer if there is one
        if self.replay_buffer_loaded is not None:
            self.replay_buffers = self.replay_buffer_loaded
        for id, replay_buffer in self.replay_buffers.items():
            if replay_buffer.t > 0:
                self.used_buffer.append(id)

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        """
        Extracts the needed fields from input samples and stores them in the 
        replay buffer.  Then samples from the replay buffer to train the agent
        by gradient updates (with the number of updates determined by replay
        ratio, sampler batch size, and training batch size).
        """
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            ### split the samples before going into replay buffers ###
            # het_samples = np.empty((self.num_buffers,), dtype = Samples)
            # B = int(samples.env.env_info.taskID.shape[1] / self.num_buffers)
            het_samples = {}
            # B = int(samples.env.env_info.taskID.shape[1] / len(self.cur_task_id))
            B = 1
            for index in range(len(self.cur_task_id)):
                cur_sample = samples[:, index*B : (index+1)*B]
                task_code = cur_sample.env.env_info.taskID[0][0]
                process = self.decodeID[int(task_code[0])]
                temp = str(int(task_code[1]))
                vdd = str(float(task_code[2]))
                task_id = '%s_%s_%s' % (process, temp, vdd)
                if task_id not in het_samples and task_id in self.cur_task_id:
                    het_samples[task_id] = cur_sample
            # pdb.set_trace()
            for id in het_samples:
                cur_sample = het_samples[id]
                if cur_sample is not None:
                    ### skip those used buffer when randomly initializing ###
                    if itr < self.min_itr_learn and id in self.used_buffer:
                        continue
                    samples_to_buffer = self.samples_to_buffer(cur_sample)
                    self.replay_buffers[id].append_samples(samples_to_buffer)

        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        ### extract dimensions of obs and action ###
        dim_obs = samples_to_buffer.observation.shape[2]
        dim_action = samples_to_buffer.action.shape[2]
        stra_samples_from_replay = self.initialize_samples_from_buffer(self.num_buffers*self.batch_size,
                                                                        dim_obs, dim_action)
        for _ in range(self.updates_per_optimize):
            # samples_from_replays = []
            samples_from_replays = dict.fromkeys(self.cur_task_id)
            # for i, replay_buffer in enumerate(self.replay_buffers):
            i = 0
            for id, replay_buffer in self.replay_buffers.items():
                ### not any taskID here, only PVT info ###
                if id not in self.cur_task_id:
                    continue
                samples_from_replay = replay_buffer.sample_batch(self.batch_size)
                stra_samples_from_replay[i*self.batch_size:(i+1)*self.batch_size] = samples_from_replay
                # samples_from_replays.append(samples_from_replay)
                samples_from_replays[id] = samples_from_replay
                i += 1
            # pdb.set_trace()
            if self.mid_batch_reset and not self.agent.recurrent:
                valid = torch.ones_like(samples_from_replay.done, dtype=torch.float)
                valid_q = torch.ones_like(stra_samples_from_replay.done, dtype=torch.float)
            else:
                valid = valid_from_done(samples_from_replay.done)
                valid_q = valid_from_done(stra_samples_from_replay.done)
            if self.bootstrap_timelimit:
                # To avoid non-use of bootstrap when environment is 'done' due to
                # time-limit, turn off training on these samples.
                valid *= (1 - samples_from_replay.timeout_n.float())
                valid_q *= (1 - stra_samples_from_replay.timeout_n.float())
            self.q_optimizer.zero_grad()
            # q_loss = self.q_loss(stra_samples_from_replay, valid_q)
            # q_loss.backward()
            q_losses = []
            for id in samples_from_replays:
                q_loss = self.q_loss(samples_from_replays[id], valid)
                q_losses.append(q_loss)
            self.q_optimizer.pc_backward(q_losses)
            # q_grad_norm = torch.nn.utils.clip_grad_norm_(
            #     self.agent.q_parameters(), self.clip_grad_norm)
            self.q_optimizer.step()
            # opt_info.qLoss.append(q_loss.item())
            # opt_info.qGradNorm.append(torch.tensor(q_grad_norm).item())  # backwards compatible
            self.update_counter += 1
            if self.update_counter % self.policy_update_interval == 0:
                self.mu_optimizer.zero_grad()
                mu_losses = []
                for id in samples_from_replays:
                    mu_loss = self.mu_loss(samples_from_replays[id], valid)
                    mu_losses.append(mu_loss)
                # pdb.set_trace()
                self.mu_optimizer.pc_backward(mu_losses)
                # mu_grad_norm = torch.nn.utils.clip_grad_norm_(
                #     self.agent.mu_parameters(), self.clip_grad_norm)
                self.mu_optimizer.step()
                # opt_info.muLoss.append(mu_loss.item())
                # opt_info.muGradNorm.append(torch.tensor(mu_grad_norm).item())  # backwards compatible
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)
        return opt_info

    def samples_to_buffer(self, samples):
        """Defines how to add data from sampler into the replay buffer. Called
        in optimize_agent() if samples are provided to that method."""
        return SamplesToBuffer(
            observation=samples.env.observation,
            action=samples.agent.action,
            reward=samples.env.reward,
            done=samples.env.done,
            timeout=getattr(samples.env.env_info, "timeout", None),
            taskID=getattr(samples.env.env_info, "taskID", 0)
        )

    def encode_taskID(self, samples_to_buffer):
        pass

    def initialize_samples_from_buffer(self, B, dim_obs, dim_action):
        batch = SamplesFromReplayTL(
            agent_inputs=AgentInputs(
                observation=torch.zeros([B, dim_obs]),
                prev_action=torch.zeros([B, dim_action]),
                prev_reward=torch.zeros([B]),
            ),
            action=torch.zeros([B, dim_action]),
            return_=torch.zeros([B]),
            done=torch.zeros([B], dtype=torch.bool),
            done_n=torch.zeros([B], dtype=torch.bool),
            target_inputs=AgentInputs(
                observation=torch.zeros([B, dim_obs]),
                prev_action=torch.zeros([B, dim_action]),
                prev_reward=torch.zeros([B]),
            ),
            timeout=torch.zeros([B], dtype=torch.bool),
            timeout_n=torch.zeros([B], dtype=torch.bool),
            )

        return torchify_buffer(batch)

    def consolidate(self, mu_fisher_matrix, q_fisher_matrix):
        # pdb.set_trace()
        for n, p in self.agent.model.named_parameters():
            n = n.replace('.', '_')
            self.agent.model.register_buffer('{}_mu_mean'.format(n), p.data.clone())
            self.agent.model.register_buffer('{}_mu_fisher'.format(n), mu_fisher_matrix[n].data.clone())

        for n, p in self.agent.q_model.named_parameters():
            n = n.replace('.', '_')
            self.agent.q_model.register_buffer('{}_q_mean'.format(n), p.data.clone())
            self.agent.q_model.register_buffer('{}_q_fisher'.format(n), q_fisher_matrix[n].data.clone())

    def consolidate_zero(self):
        # pdb.set_trace()
        for n, p in self.agent.model.named_parameters():
            n = n.replace('.', '_')
            self.agent.model.register_buffer('{}_mu_mean'.format(n), p.data.clone())
            self.agent.model.register_buffer('{}_mu_fisher'.format(n), torch.zeros(p.size()))

        for n, p in self.agent.q_model.named_parameters():
            n = n.replace('.', '_')
            self.agent.q_model.register_buffer('{}_q_mean'.format(n), p.data.clone())
            self.agent.q_model.register_buffer('{}_q_fisher'.format(n), torch.zeros(p.size()))

    def calculate_fisher(self):
        mu_fisher_matrix = {}
        q_fisher_matrix = {}
        ### calculate for policy fisher matrix and q model fisher matrix
        for n, p in self.agent.model.named_parameters():
            n = n.replace('.', '_')
            mu_fisher_matrix[n] = p.clone().detach().fill_(0)

        for n, p in self.agent.q_model.named_parameters():
            n = n.replace('.', '_')
            q_fisher_matrix[n] = p.clone().detach().fill_(0)

        for _ in range(self.fisher_batch_size):
            samples_from_replay = next(iter(self.replay_buffers.values())).sample_batch(1)
            if self.mid_batch_reset and not self.agent.recurrent:
                valid = torch.ones_like(samples_from_replay.done, dtype=torch.float)
            else:
                valid = valid_from_done(samples_from_replay.done)
            if self.bootstrap_timelimit:
                # To avoid non-use of bootstrap when environment is 'done' due to
                # time-limit, turn off training on these samples.
                valid *= (1 - samples_from_replay.timeout_n.float())

            self.agent.model.eval()
            self.mu_optimizer.zero_grad()
            mu_loss = self.mu_loss(samples_from_replay, valid)
            mu_loss.backward()

            for n, p in self.agent.model.named_parameters():
                n = n.replace('.', '_')
                mu_fisher_matrix[n].data += p.grad.data ** 2 / self.fisher_batch_size

            self.agent.q_model.eval()
            self.q_optimizer.zero_grad()
            q_loss = self.q_loss(samples_from_replay, valid)
            q_loss.backward()

            for n, p in self.agent.q_model.named_parameters():
                n = n.replace('.', '_')
                q_fisher_matrix[n].data += p.grad.data ** 2 / self.fisher_batch_size

        mu_fisher_matrix = {n : p for n, p in mu_fisher_matrix.items()}
        q_fisher_matrix = {n : p for n, p in q_fisher_matrix.items()}
        return mu_fisher_matrix, q_fisher_matrix

    def get_fisher(self):
        return self.calculate_fisher()

    def get_mu_parameters(self):
        return self.agent.model.named_parameters()

    def mu_penalty(self, cuda=False):
        try:
            losses = []
            # pdb.set_trace()
            for n, p in self.agent.model.named_parameters():
                n = n.replace('.', '_')
                mean = getattr(self.agent.model, '{}_mu_mean'.format(n))
                mu_fisher = getattr(self.agent.model, '{}_mu_fisher'.format(n))

                mean = Variable(mean)
                mu_fisher = Variable(mu_fisher)

                losses.append((mu_fisher * (p - mean)**2).sum())
                # print("load mu fisher!")
            return (self.fisher_lambda/2)*sum(losses).squeeze()
        except AttributeError:
            # ewc loss is 0 if there's no consolidated paramters
            return(
                Variable(torch.zeros(1).cuda()).squeeze() if cuda else
                Variable(torch.zeros(1)).squeeze()
            )

    def q_penalty(self, cuda=False):
        try:
            losses = []
            # pdb.set_trace()
            for n, p in self.agent.q_model.named_parameters():
                n = n.replace('.', '_')
                mean = getattr(self.agent.q_model, '{}_q_mean'.format(n))
                q_fisher = getattr(self.agent.q_model, '{}_q_fisher'.format(n))

                mean = Variable(mean)
                q_fisher = Variable(q_fisher)

                losses.append((q_fisher * (p - mean)**2).sum())
                # print("load q fisher!")
            return (self.fisher_lambda/2)*sum(losses).squeeze()
        except AttributeError:
            # ewc loss is 0 if there's no consolidated paramters
            return(
                Variable(torch.zeros(1).cuda()).squeeze() if cuda else
                Variable(torch.zeros(1)).squeeze()
            )
