# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin
from collections import defaultdict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributions as D

from rsl_rl.modules import ActorCritic
from rsl_rl.storage import RolloutStorage
from rsl_rl.algorithms import PPO


def weight_init(m):
    """Custom weight init for Conv2D and Linear layers."""
    if isinstance(m, nn.Linear):
        nn.init.orthogonal_(m.weight.data)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)
    elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        gain = nn.init.calculate_gain('relu')
        nn.init.orthogonal_(m.weight.data, gain)
        if hasattr(m.bias, 'data'):
            m.bias.data.fill_(0.0)


# s_t, a_t -> z_t+1
# s_t, a_t, s_t+1 -> z_t+1
# z_t+1 -> s_t+1
class LBS_net(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()

        self.pri_forward_net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim))

        self.pos_forward_net = nn.Sequential(
            nn.Linear(2 * obs_dim + action_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim))

        self.reconstruction_net = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, obs_dim))

        self.apply(weight_init)

    def forward(self, obs, action, next_obs):
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        pri_z = self.pri_forward_net(torch.cat([obs, action], dim=-1))
        pos_z = self.pos_forward_net(torch.cat([obs, action, next_obs], dim=-1))

        reco_s = self.reconstruction_net(pos_z)

        pri_z = D.Independent(D.Normal(pri_z, 1.0), 1)
        pos_z = D.Independent(D.Normal(pos_z, 1.0), 1)
        reco_s = D.Independent(D.Normal(reco_s, 1.0), 1)

        kl_div = D.kl_divergence(pos_z, pri_z)

        reco_error = -reco_s.log_prob(next_obs).mean()
        kl_error = kl_div.mean()

        return kl_error, reco_error, kl_div.detach()


class LBS_PRED(nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_dim):
        super().__init__()

        self.pred_net = nn.Sequential(
            nn.Linear(obs_dim + action_dim, hidden_dim), nn.ReLU(),
            nn.Linear(hidden_dim, 1))

        self.apply(weight_init)

    def forward(self, obs, action, next_obs):
        assert obs.shape[0] == next_obs.shape[0]
        assert obs.shape[0] == action.shape[0]

        pred_kl = self.pred_net(torch.cat([obs, action], dim=-1))
        pred_kl = D.Independent(D.Normal(pred_kl, 1.0), 1)

        return pred_kl


class LBS(PPO):
    def __init__(self,
                 actor_critic,
                 num_learning_epochs=1,
                 num_mini_batches=1,
                 clip_param=0.2,
                 gamma=0.998,
                 lam=0.95,
                 value_loss_coef=1.0,
                 entropy_coef=0.0,
                 learning_rate=1e-3,
                 max_grad_norm=1.0,
                 use_clipped_value_loss=True,
                 clip_min_std=1e-15,  # clip the policy.std if it supports, check update()
                 optimizer_class_name="Adam",
                 schedule="fixed",
                 desired_kl=0.01,
                 device='cpu',
                 reward_free=True,
                 env=None,
                 ):
        super().__init__(actor_critic, num_learning_epochs, num_mini_batches,
                         clip_param, gamma, lam, value_loss_coef, entropy_coef,
                         learning_rate, max_grad_norm, use_clipped_value_loss,
                         clip_min_std, optimizer_class_name, schedule,
                         desired_kl, device)
        self.reward_free = reward_free

        if env.num_privileged_obs is not None:
            self.num_obs = env.num_privileged_obs
        else:
            self.num_obs = env.num_obs

        self.num_actions = env.num_actions

        self.lbs = LBS_net(self.num_obs, self.num_actions,
                           hidden_dim=128).to(self.device)
        # optimizers
        self.lbs_opt = torch.optim.Adam(self.lbs.parameters(), lr=1e-4)

        self.lbs_pred = LBS_PRED(self.num_obs, self.num_actions,
                                 hidden_dim=128).to(self.device)
        self.lbs_pred_opt = torch.optim.Adam(self.lbs_pred.parameters(), lr=1e-4)

        self.lbs.train()
        self.lbs_pred.train()

    def process_env_step(self, next_obs, rewards, dones, infos, task_ids):
        # print('lalala')
        # print(rewards.shape)
        # print(next_obs.shape)
        # print(self.transition.actions.shape)
        if self.reward_free:
            kl_pred = self.lbs_pred(self.transition.observations,
                                    self.transition.actions,
                                    next_obs)

            reward = kl_pred.mean
            self.transition.rewards = reward.reshape(-1).clone()
        else:
            self.transition.rewards = rewards.clone()
        self.transition.dones = dones
        self.transition.next_observations = next_obs
        self.transition.task_ids = task_ids
        # Bootstrapping on time outs
        if 'time_outs' in infos:
            self.transition.rewards += self.gamma * torch.squeeze(
                self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)

        # Record the transition
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.actor_critic.reset(dones)

    def update_lbs(self, minibatch):
        obs = minibatch.obs
        masks = minibatch.masks
        obs = obs.transpose(1, 0)[masks.transpose(1, 0)].view(-1, masks.shape[0],
                                                              obs.shape[-1]).transpose(1, 0)

        action = minibatch.actions
        next_obs = minibatch.next_obs
        kl_error, reco_error, kl_div = self.lbs(obs, action, next_obs)

        lbs_loss = kl_error.mean() + reco_error.mean()

        self.lbs_opt.zero_grad(set_to_none=True)
        lbs_loss.backward()
        nn.utils.clip_grad_norm_(self.lbs.parameters(), self.max_grad_norm)
        self.lbs_opt.step()

        kl_pred = self.lbs_pred(obs, action, next_obs)
        lbs_pred_loss = -kl_pred.log_prob(kl_div.reshape(kl_div.shape[0],
                                                         kl_div.shape[1],
                                                         1).detach()).mean()
        self.lbs_pred_opt.zero_grad(set_to_none=True)
        lbs_pred_loss.backward()
        nn.utils.clip_grad_norm_(self.lbs_pred.parameters(), self.max_grad_norm)
        self.lbs_pred_opt.step()

        return lbs_loss, lbs_pred_loss

    def update(self, current_learning_iteration):
        self.current_learning_iteration = current_learning_iteration
        mean_losses = defaultdict(lambda: 0.)
        average_stats = defaultdict(lambda: 0.)
        if self.actor_critic.is_recurrent:
            generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        for minibatch in generator:
            # print('o', minibatch.obs.shape)
            # print('a', minibatch.actions.shape)
            # print('next_o', minibatch.next_obs.shape)
            # print('m', minibatch.masks.shape)
            lbs_loss, lbs_pred_loss = self.update_lbs(minibatch)
            mean_losses["lbs_loss"] = lbs_loss
            mean_losses["lbs_pred_loss"] = lbs_pred_loss

            losses, _, stats = self.compute_losses(minibatch)

            loss = 0.
            for k, v in losses.items():
                loss += getattr(self, k + "_coef", 1.) * v
                mean_losses[k] = mean_losses[k] + v.detach()
            mean_losses["total_loss"] = mean_losses["total_loss"] + loss.detach()
            for k, v in stats.items():
                average_stats[k] = average_stats[k] + v.detach()

            # Gradient step
            self.optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
            self.optimizer.step()

        num_updates = self.num_learning_epochs * self.num_mini_batches
        for k in mean_losses.keys():
            mean_losses[k] = mean_losses[k] / num_updates
        for k in average_stats.keys():
            average_stats[k] = average_stats[k] / num_updates
        self.storage.clear()
        if hasattr(self.actor_critic, "clip_std"):
            self.actor_critic.clip_std(min=self.clip_min_std)

        return mean_losses, average_stats

    def compute_losses(self, minibatch):
        self.actor_critic.act(minibatch.obs, masks=minibatch.masks, hidden_states=minibatch.hid_states[0])
        actions_log_prob_batch = self.actor_critic.get_actions_log_prob(minibatch.actions)
        value_batch = self.actor_critic.evaluate(minibatch.critic_obs, masks=minibatch.masks,
                                                 hidden_states=minibatch.hid_states[1])
        mu_batch = self.actor_critic.action_mean
        sigma_batch = self.actor_critic.action_std
        try:
            entropy_batch = self.actor_critic.entropy
        except:
            entropy_batch = None

        # KL
        if self.desired_kl != None and self.schedule == 'adaptive':
            with torch.inference_mode():
                kl = torch.sum(
                    torch.log(sigma_batch / minibatch.old_sigma + 1.e-5) + (
                                torch.square(minibatch.old_sigma) + torch.square(minibatch.old_mu - mu_batch)) / (
                                2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
                kl_mean = torch.mean(kl)

                if kl_mean > self.desired_kl * 2.0:
                    self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                    self.learning_rate = min(1e-2, self.learning_rate * 1.5)

                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.learning_rate

        # Surrogate loss
        ratio = torch.exp(actions_log_prob_batch - torch.squeeze(minibatch.old_actions_log_prob))
        surrogate = -torch.squeeze(minibatch.advantages) * ratio
        surrogate_clipped = -torch.squeeze(minibatch.advantages) * torch.clamp(ratio, 1.0 - self.clip_param,
                                                                               1.0 + self.clip_param)
        surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

        # Value function loss
        if self.use_clipped_value_loss:
            value_clipped = minibatch.values + (value_batch - minibatch.values).clamp(-self.clip_param,
                                                                                      self.clip_param)
            value_losses = (value_batch - minibatch.returns).pow(2)
            value_losses_clipped = (value_clipped - minibatch.returns).pow(2)
            value_loss = torch.max(value_losses, value_losses_clipped).mean()
        else:
            value_loss = (minibatch.returns - value_batch).pow(2).mean()

        return_ = dict(
            surrogate_loss=surrogate_loss,
            value_loss=value_loss,
        )
        if entropy_batch is not None:
            return_["entropy"] = - entropy_batch.mean()

        inter_vars = dict(
            ratio=ratio,
            surrogate=surrogate,
            surrogate_clipped=surrogate_clipped,
        )
        if self.desired_kl != None and self.schedule == 'adaptive':
            inter_vars["kl"] = kl
        if self.use_clipped_value_loss:
            inter_vars["value_clipped"] = value_clipped
        return return_, inter_vars, dict()
