from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer


class BCTrainer(TorchTrainer):
    """
    Trainer for Behavior Cloning
    Policy is trained by maximizing log likelihood of actions in a given dataset.
    Q function is trained by SARSA
    """
    def __init__(
            self,
            env,
            policy,
            qf_beta_reg,

            discount=0.99,
            reward_scale=1.0,
            reg_const=0.01,

            policy_lr=1e-4,
            qf_lr=1e-4,
            optimizer_class=optim.Adam,

            soft_target_tau=5e-3,
            target_update_period=2,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.qf_beta_reg = qf_beta_reg
        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )
        self.qf_optimizer = optimizer_class(
            self.qf_beta_reg.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reg_const = reg_const
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.discrete = False

    def train_from_torch(self, batch):
        obs = batch['obs']
        actions = batch['actions']
        gamma_return = batch['gamma']


        """
        QF beta Loss
        """

        qf_beta_reg_pred = self.qf_beta_reg(obs, actions)

        l2_reg = 0.0
        for param in self.qf_beta_reg.parameters():
            l2_reg += torch.norm(param, p=2) ** 2

        net_error = self.qf_criterion(qf_beta_reg_pred, gamma_return)
        beta_loss = net_error + self.reg_const * l2_reg

        """
        Update networks
        """

        self.qf_optimizer.zero_grad()
        beta_loss.backward()
        self.qf_optimizer.step()

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """


            self.eval_statistics.update(create_stats_ordered_dict(
                'gamma_return',
                ptu.get_numpy(gamma_return),
            ))
            self.eval_statistics['QF gamma Loss'] = np.mean(ptu.get_numpy(net_error))
            self.eval_statistics['Beta Loss'] = np.mean(ptu.get_numpy(beta_loss))
            self.eval_statistics.update(create_stats_ordered_dict(
                'Q Predictions',
                ptu.get_numpy(qf_beta_reg_pred),
            ))

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.qf_beta_reg,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            qf_beta_reg=self.qf_beta_reg,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.qf_beta_reg = snapshot['qf_beta_reg']