
import torch
from collections import namedtuple

from open_source.rlpyt.rlpyt.algos.qpg.ddpg import DDPG
from open_source.rlpyt.rlpyt.utils.quick_args import save__init__args
from open_source.rlpyt.rlpyt.utils.logging import logger
from open_source.rlpyt.rlpyt.replays.non_sequence.uniform import (UniformReplayBuffer,
    AsyncUniformReplayBuffer)
from open_source.rlpyt.rlpyt.replays.non_sequence.time_limit import (TlUniformReplayBuffer,
    AsyncTlUniformReplayBuffer)
from open_source.rlpyt.rlpyt.utils.collections import namedarraytuple
from open_source.rlpyt.rlpyt.utils.tensor import valid_mean
from open_source.rlpyt.rlpyt.algos.utils import valid_from_done

from torch.autograd import Variable
import pdb

OptInfo = namedtuple("OptInfo",
    ["muLoss", "qLoss", "muGradNorm", "qGradNorm"])
SamplesToBuffer = namedarraytuple("SamplesToBuffer",
    ["observation", "action", "reward", "done", "timeout"])


class DDPG_EWC(DDPG):
    """Deep deterministic policy gradient algorithm, training from a replay
    buffer."""

    opt_info_fields = tuple(f for f in OptInfo._fields)  # copy

    def __init__(
            self,
            discount=0.99,
            batch_size=64,
            min_steps_learn=int(1e4),
            replay_size=int(1e6),
            replay_ratio=64,  # data_consumption / data_generation
            target_update_tau=0.01,
            target_update_interval=1,
            policy_update_interval=1,
            learning_rate=1e-4,
            q_learning_rate=1e-3,
            OptimCls=torch.optim.Adam,
            optim_kwargs=None,
            initial_optim_state_dict=None,
            clip_grad_norm=1e8,
            q_target_clip=1e6,
            n_step_return=1,
            updates_per_sync=1,  # For async mode only.
            bootstrap_timelimit=True,
            ReplayBufferCls=None,
            replay_buffer_loaded=None,
            fisher_lambda=400,
            fisher_batch_size=100,
            use_ewc = True,
            ):
        """Saves input arguments."""
        if optim_kwargs is None:
            optim_kwargs = dict()

        self.replay_buffer_loaded = replay_buffer_loaded
        self.fisher_lambda = fisher_lambda
        self.fisher_batch_size = fisher_batch_size
        self._batch_size = batch_size
        del batch_size  # Property.
        save__init__args(locals())

    def calculate_fisher(self):
        mu_fisher_matrix = {}
        q_fisher_matrix = {}
        ### calculate for policy fisher matrix and q model fisher matrix
        for n, p in self.agent.model.named_parameters():
            n = n.replace('.', '_')
            mu_fisher_matrix[n] = p.clone().detach().fill_(0)

        for n, p in self.agent.q_model.named_parameters():
            n = n.replace('.', '_')
            q_fisher_matrix[n] = p.clone().detach().fill_(0)

        for _ in range(self.fisher_batch_size):
            samples_from_replay = self.replay_buffer.sample_batch(1)
            if self.mid_batch_reset and not self.agent.recurrent:
                valid = torch.ones_like(samples_from_replay.done, dtype=torch.float)
            else:
                valid = valid_from_done(samples_from_replay.done)
            if self.bootstrap_timelimit:
                # To avoid non-use of bootstrap when environment is 'done' due to
                # time-limit, turn off training on these samples.
                valid *= (1 - samples_from_replay.timeout_n.float())

            self.agent.model.eval()
            self.mu_optimizer.zero_grad()
            mu_loss = self.mu_loss(samples_from_replay, valid)
            mu_loss.backward()

            for n, p in self.agent.model.named_parameters():
                n = n.replace('.', '_')
                mu_fisher_matrix[n].data += p.grad.data ** 2 / self.fisher_batch_size

            self.agent.q_model.eval()
            self.q_optimizer.zero_grad()
            q_loss = self.q_loss(samples_from_replay, valid)
            q_loss.backward()

            for n, p in self.agent.q_model.named_parameters():
                n = n.replace('.', '_')
                q_fisher_matrix[n].data += p.grad.data ** 2 / self.fisher_batch_size

        mu_fisher_matrix = {n : p for n, p in mu_fisher_matrix.items()}
        q_fisher_matrix = {n : p for n, p in q_fisher_matrix.items()}
        return mu_fisher_matrix, q_fisher_matrix

    def get_fisher(self):
        return self.calculate_fisher()

    def get_mu_parameters(self):
        return self.agent.model.named_parameters()

    def mu_penalty(self, cuda=False):
        try:
            losses = []
            # pdb.set_trace()
            for n, p in self.agent.model.named_parameters():
                n = n.replace('.', '_')
                mean = getattr(self.agent.model, '{}_mu_mean'.format(n))
                mu_fisher = getattr(self.agent.model, '{}_mu_fisher'.format(n))

                mean = Variable(mean)
                mu_fisher = Variable(mu_fisher)

                losses.append((mu_fisher * (p - mean)**2).sum())
                # print("load mu fisher!")
            return (self.fisher_lambda/2)*sum(losses).squeeze()
        except AttributeError:
            # ewc loss is 0 if there's no consolidated paramters
            return(
                Variable(torch.zeros(1).cuda()).squeeze() if cuda else
                Variable(torch.zeros(1)).squeeze()
            )

    def q_penalty(self, cuda=False):
        try:
            losses = []
            # pdb.set_trace()
            for n, p in self.agent.q_model.named_parameters():
                n = n.replace('.', '_')
                mean = getattr(self.agent.q_model, '{}_q_mean'.format(n))
                q_fisher = getattr(self.agent.q_model, '{}_q_fisher'.format(n))

                mean = Variable(mean)
                q_fisher = Variable(q_fisher)

                losses.append((q_fisher * (p - mean)**2).sum())
                # print("load q fisher!")
            return (self.fisher_lambda/2)*sum(losses).squeeze()
        except AttributeError:
            # ewc loss is 0 if there's no consolidated paramters
            return(
                Variable(torch.zeros(1).cuda()).squeeze() if cuda else
                Variable(torch.zeros(1)).squeeze()
            )

    def optimize_agent(self, itr, samples=None, sampler_itr=None):
        """
        Extracts the needed fields from input samples and stores them in the
        replay buffer.  Then samples from the replay buffer to train the agent
        by gradient updates (with the number of updates determined by replay
        ratio, sampler batch size, and training batch size).
        """
        itr = itr if sampler_itr is None else sampler_itr  # Async uses sampler_itr.
        if samples is not None:
            samples_to_buffer = self.samples_to_buffer(samples)
            self.replay_buffer.append_samples(samples_to_buffer)
        opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields))))
        if itr < self.min_itr_learn:
            return opt_info
        for _ in range(self.updates_per_optimize):
            samples_from_replay = self.replay_buffer.sample_batch(self.batch_size)
            # pdb.set_trace()
            if self.mid_batch_reset and not self.agent.recurrent:
                valid = torch.ones_like(samples_from_replay.done, dtype=torch.float)
            else:
                valid = valid_from_done(samples_from_replay.done)
            if self.bootstrap_timelimit:
                # To avoid non-use of bootstrap when environment is 'done' due to
                # time-limit, turn off training on these samples.
                valid *= (1 - samples_from_replay.timeout_n.float())
            self.q_optimizer.zero_grad()
            q_loss = self.q_loss(samples_from_replay, valid)
            if self.use_ewc:
                # add ewc penalty
                q_loss += self.q_penalty()
            q_loss.backward()
            q_grad_norm = torch.nn.utils.clip_grad_norm_(
                self.agent.q_parameters(), self.clip_grad_norm)
            self.q_optimizer.step()
            opt_info.qLoss.append(q_loss.item())
            opt_info.qGradNorm.append(torch.tensor(q_grad_norm).item())  # backwards compatible
            self.update_counter += 1
            if self.update_counter % self.policy_update_interval == 0:
                self.mu_optimizer.zero_grad()
                mu_loss = self.mu_loss(samples_from_replay, valid)
                if self.use_ewc:
                    # add ewc penalty
                    mu_loss += self.mu_penalty()
                mu_loss.backward()
                ### debug
                # pdb.set_trace()
                for n, p in self.agent.model.named_parameters():
                    cur_grad = p.grad
                #     pdb.set_trace()
                mu_grad_norm = torch.nn.utils.clip_grad_norm_(
                    self.agent.mu_parameters(), self.clip_grad_norm)
                self.mu_optimizer.step()
                opt_info.muLoss.append(mu_loss.item())
                opt_info.muGradNorm.append(torch.tensor(mu_grad_norm).item())  # backwards compatible
            if self.update_counter % self.target_update_interval == 0:
                self.agent.update_target(self.target_update_tau)
        return opt_info

    def consolidate(self, mu_fisher_matrix, q_fisher_matrix):
        # pdb.set_trace()
        for n, p in self.agent.model.named_parameters():
            n = n.replace('.', '_')
            self.agent.model.register_buffer('{}_mu_mean'.format(n), p.data.clone())
            self.agent.model.register_buffer('{}_mu_fisher'.format(n), mu_fisher_matrix[n].data.clone())

        for n, p in self.agent.q_model.named_parameters():
            n = n.replace('.', '_')
            self.agent.q_model.register_buffer('{}_q_mean'.format(n), p.data.clone())
            self.agent.q_model.register_buffer('{}_q_fisher'.format(n), q_fisher_matrix[n].data.clone())

    def consolidate_zero(self):
        # pdb.set_trace()
        for n, p in self.agent.model.named_parameters():
            n = n.replace('.', '_')
            self.agent.model.register_buffer('{}_mu_mean'.format(n), p.data.clone())
            self.agent.model.register_buffer('{}_mu_fisher'.format(n), torch.zeros(p.size()))

        for n, p in self.agent.q_model.named_parameters():
            n = n.replace('.', '_')
            self.agent.q_model.register_buffer('{}_q_mean'.format(n), p.data.clone())
            self.agent.q_model.register_buffer('{}_q_fisher'.format(n), torch.zeros(p.size()))


