from collections import OrderedDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
from torch.distributions.normal import Normal

import rlkit.torch.pytorch_util as ptu
from rlkit.core.eval_util import create_stats_ordered_dict
from rlkit.torch.torch_rl_algorithm import TorchTrainer


def maybe_int(s):
    try:
        return int(s)
    except:
        return s


def min_q_target(batch, policy, target_qf, qf_data, reward_scale=1.0, discount=0.99, target_q_value_type=0):
    rewards = batch['rewards']
    terminals = batch['terminals']
    next_obs = batch['next_obs']
    next_actions = batch['next_actions']

    new_next_actions, *_ = policy(
        next_obs, reparameterize=True, return_log_prob=False,
    )
    if target_q_value_type == 0:
        target_q_values = torch.min(
            target_qf(next_obs, new_next_actions),
            qf_data(next_obs, new_next_actions),
        )
    elif target_q_value_type == 1:
        target_q_values = torch.min(
            target_qf(next_obs, new_next_actions),
            qf_data(next_obs, next_actions),
        )
    else:
        possible_target_q_value_type = [0, 1]
        print(f'target_q_value_type ({target_q_value_type}) should be one of {possible_target_q_value_type}')
        raise NotImplementedError

    return reward_scale * rewards + (1. - terminals) * discount * target_q_values


def min_q_bound(batch, policy, target_qf, qf_data, reg_coef=1.0, sigma=1.0, reward_scale=1.0, discount=0.99, target_q_value_type=0):
    rewards = batch['rewards']
    terminals = batch['terminals']
    next_obs = batch['next_obs']
    next_actions = batch['next_actions']

    new_next_actions, *_ = policy(
        next_obs, reparameterize=True, return_log_prob=False,
    )
    if target_q_value_type == 0:
        dist = Normal(loc=torch.zeros_like(next_actions), scale=sigma * torch.ones_like(next_actions))
        target_q_values = torch.min(
            target_qf(next_obs, new_next_actions),
            qf_data(next_obs, next_actions) + reg_coef * torch.exp(torch.mean(dist.log_prob(new_next_actions - next_actions), dim=1, keepdim=True)),
        )
        # target_q_values = torch.max(target_q_values, qf_data(next_obs, next_actions))
    else:
        possible_target_q_value_type = [0]
        print(f'target_q_value_type ({target_q_value_type}) should be one of {possible_target_q_value_type}')
        raise NotImplementedError

    return reward_scale * rewards + (1. - terminals) * discount * target_q_values


class OneStepQRegTrainer(TorchTrainer):
    """
    Trainer for Behavior Cloning
    Policy is trained by maximizing log likelihood of actions in a given dataset.
    Q function is trained by SARSA
    """
    Possible_Q_Reg_Type = ['min_q_target', 'min_q_bound', 'q_reg', 'conservative_q_reg']

    def __init__(
            self,
            env,
            policy,
            policy_data,
            qf,
            target_qf,
            qf_data,

            kl_reg=True,
            q_reg_type_version='min_q_target-v0',
            reg_coef=1.0,
            sigma=1.0,
            n_actions=10,
            alpha=1.0,
            discount=0.99,
            reward_scale=1.0,

            policy_lr=1e-4,
            qf_lr=1e-4,
            optimizer_class=optim.Adam,

            soft_target_tau=5e-3,
            target_update_period=1,
    ):
        super().__init__()
        self.env = env
        self.policy = policy
        self.policy_data = policy_data
        self.qf = qf
        self.target_qf = target_qf
        self.qf_data = qf_data

        self.kl_reg = kl_reg
        reg_type_t = q_reg_type_version.split('-')
        self.q_reg_type = reg_type_t[0]
        self.q_reg_version = maybe_int(reg_type_t[1][1:]) if len(reg_type_t) == 2 and reg_type_t[1].startswith('v') else 0
        assert self.q_reg_type in self.Possible_Q_Reg_Type, f'self.q_reg_type ({self.q_reg_type}) should be in {self.Possible_Q_Reg_Type}'
        self.reg_coef = reg_coef
        self.sigma = sigma

        print('self.kl_reg: \t', self.kl_reg)
        print('Possible Q Reg Target: \t', self.Possible_Q_Reg_Type)
        print('self.q_reg_type: \t', self.q_reg_type)
        print('self.q_reg_version: \t', self.q_reg_version)
        if self.q_reg_type == 'min_q_bound':
            print('self.reg_coef: \t', self.reg_coef)
            print('self.sigma: \t', self.sigma)
        elif self.q_reg_type in ['q_reg', 'conservative_q_reg']:
            print('self.reg_coef: \t', self.reg_coef)

        self.n_actions = n_actions
        self.alpha = alpha

        self.soft_target_tau = soft_target_tau
        self.target_update_period = target_update_period

        self.qf_criterion = nn.MSELoss()

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(),
            lr=policy_lr,
        )

        self.qf_optimizer = optimizer_class(
            self.qf.parameters(),
            lr=qf_lr,
        )

        self.discount = discount
        self.reward_scale = reward_scale
        self.eval_statistics = OrderedDict()
        self._n_train_steps_total = 0
        self._need_to_update_eval_statistics = True

        self.discrete = False

    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_actions(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, _, _, new_obs_log_pi, *_ = network(
            obs_temp, reparameterize=False, return_log_prob=True,
        )
        if not self.discrete:
            return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    def train_from_torch(self, batch):
        obs = batch['obs']
        actions = batch['actions']
        rewards = batch['rewards']
        next_obs = batch['next_obs']
        terminals = batch['terminals']

        """
        Policy Loss
        """

        new_obs_actions, policy_mean, policy_log_std, log_pi, *_ = self.policy(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.kl_reg:
            obs_stack = torch.unsqueeze(obs, 1).repeat(1, self.n_actions, 1).reshape((-1, obs.shape[1]))
            new_obs_actions_stack, _, _, log_pi_stack, *_ = self.policy(obs_stack, reparameterize=True, return_log_prob=True,)

            log_pi = torch.mean(log_pi_stack.reshape((-1, self.n_actions)), dim=1)

            log_pi_data_stack = self.policy_data.log_prob(obs_stack, new_obs_actions_stack)
            log_pi_data = torch.mean(log_pi_data_stack.reshape((-1, self.n_actions)), dim=1)

            kl = (log_pi - log_pi_data).mean()
            policy_loss = self.alpha * kl - self.qf(obs, new_obs_actions).mean()
        else:
            policy_loss = - 1 * self.qf(obs, new_obs_actions).mean()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        """
        Q-function Loss
        """

        q_pred = self.qf(obs, actions)
        if self.q_reg_type == 'min_q_target':
            q_target = min_q_target(batch, self.policy, self.target_qf, self.qf_data,
                                    reward_scale=self.reward_scale, discount=self.discount,
                                    target_q_value_type=self.q_reg_version)
            qf_reg_loss = self.qf_criterion(q_pred, q_target)
            qf_loss = qf_reg_loss
        elif self.q_reg_type == 'min_q_bound':
            q_target = min_q_bound(batch, self.policy, self.target_qf, self.qf_data,
                                   reward_scale=self.reward_scale, discount=self.discount,
                                   target_q_value_type=self.q_reg_version, reg_coef=self.reg_coef, sigma=self.sigma)
            qf_reg_loss = self.qf_criterion(q_pred, q_target)
            qf_loss = qf_reg_loss
        elif self.q_reg_type == 'q_reg':
            new_next_actions, *_ = self.policy(next_obs, reparameterize=True, return_log_prob=False,)
            target_q_values = self.target_qf(next_obs, new_next_actions)
            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values

            new_actions, *_ = self.policy(obs)
            q_new_pred = self.qf(obs, new_actions)
            q_reg = self.qf_data(obs, new_actions) - 0.5 * self.reg_coef * torch.sum((new_actions - actions) ** 2, dim=1, keepdim=True)

            qf_reg_loss = self.qf_criterion(q_new_pred, q_reg)
            qf_loss = self.qf_criterion(q_pred, q_target) + qf_reg_loss
        elif self.q_reg_type == 'conservative_q_reg':
            new_next_actions, *_ = self.policy(next_obs, reparameterize=True, return_log_prob=False, )
            target_q_values = self.target_qf(next_obs, new_next_actions)
            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values

            q_data = self.qf_data(obs, actions)
            if self.q_reg_version == 0:
                conservative_qf_loss = torch.mean(torch.clamp(q_data - q_pred, min=0.) ** 2)
                qf_loss = self.qf_criterion(q_pred, q_target) + conservative_qf_loss
            elif self.q_reg_version == 1:
                curr_actions_tensor, *_ = self._get_policy_actions(obs, num_actions=self.n_actions, network=self.policy)
                new_curr_actions_tensor, *_ = self._get_policy_actions(next_obs, num_actions=self.n_actions, network=self.policy)
                random_actions_tensor = torch.FloatTensor(q_pred.shape[0] * self.n_actions, actions.shape[-1]).uniform_(-1, 1)  # .cuda()
                if new_curr_actions_tensor.is_cuda:
                    random_actions_tensor = random_actions_tensor.cuda()

                q_pred_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf)
                q_pred_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf)
                q_pred_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf)

                q_data_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf_data)
                q_data_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf_data)
                q_data_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf_data)

                cat_q_pred = torch.cat(
                    [q_pred_rand, q_pred_next_actions, q_pred_curr_actions], 1
                )

                cat_q_data = torch.cat(
                    [q_data_rand, q_data_next_actions, q_data_curr_actions], 1
                )

                conservative_qf_in_data_loss = torch.mean(torch.clamp(q_data - q_pred, min=0.) ** 2)
                conservative_qf_out_data_loss = torch.mean((cat_q_data - cat_q_pred) ** 2)
                qf_loss = self.qf_criterion(q_pred, q_target) + conservative_qf_in_data_loss + conservative_qf_out_data_loss
            elif self.q_reg_version == 2:
                conservative_qf_in_data_loss = torch.mean(torch.clamp(q_data - q_pred, min=0.) ** 2)
                new_actions, *_ = self.policy(obs)
                q_new_pred = self.qf(obs, new_actions)
                q_reg = self.qf_data(obs, new_actions) - 0.5 * self.reg_coef * torch.sum((new_actions - actions) ** 2, dim=1, keepdim=True)

                qf_reg_loss = self.qf_criterion(q_new_pred, q_reg)
                qf_loss = self.qf_criterion(q_pred, q_target) + conservative_qf_in_data_loss + qf_reg_loss
            else:
                qf_loss = None
        else:
            qf_reg_loss = qf_loss = None

        """
        Update networks
        """
        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        """
        Soft Updates
        """
        if self._n_train_steps_total % self.target_update_period == 0:
            ptu.soft_update_from_to(
                self.qf, self.target_qf, self.soft_target_tau
            )

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            self._need_to_update_eval_statistics = False
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics['Policy Loss'] = np.mean(ptu.get_numpy(policy_loss))
            self.eval_statistics['QF Loss'] = np.mean(ptu.get_numpy(qf_loss))
            if self.kl_reg:
                self.eval_statistics['KL'] = np.mean(ptu.get_numpy(kl))

        self._n_train_steps_total += 1

    def get_diagnostics(self):
        return self.eval_statistics

    def end_epoch(self, epoch):
        self._need_to_update_eval_statistics = True

    @property
    def networks(self):
        return [
            self.policy,
            self.policy_data,
            self.qf,
            self.target_qf,
            self.qf_data,
        ]

    def get_snapshot(self):
        return dict(
            policy=self.policy,
            policy_data=self.policy_data,
            qf=self.qf,
            target_qf=self.target_qf,
            qf_data=self.qf_data,
        )

    def set_snapshot(self, snapshot):
        self.policy = snapshot['policy']
        self.policy_data = snapshot['policy_data']
        self.qf = snapshot['qf']
        self.target_qf = snapshot['target_qf']
        self.qf_data = snapshot['qf_data']
