# Copyright (c) 2022 Aviral Kumar
# Adapted from https://raw.githubusercontent.com/aviralkumar2907/CQL/refs/heads/master/d4rl/rlkit/torch/sac/cql.py
# Modifications Copyright (c) 2025 King.com Ltd
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam


from collections import OrderedDict
from numbers import Number

from cql.dist import TanhNormal

LOG_SIG_MIN = -20
LOG_SIG_MAX = 2
MEAN_MIN = -9.0
MEAN_MAX = 9.0

def create_stats_ordered_dict(
        name,
        data,
        stat_prefix=None,
        always_show_all_stats=True,
        exclude_max_min=False,
):
    if stat_prefix is not None:
        name = "{}{}".format(stat_prefix, name)
    if isinstance(data, Number):
        return OrderedDict({name: data})

    if len(data) == 0:
        return OrderedDict()

    if isinstance(data, tuple):
        ordered_dict = OrderedDict()
        for number, d in enumerate(data):
            sub_dict = create_stats_ordered_dict(
                "{0}_{1}".format(name, number),
                d,
            )
            ordered_dict.update(sub_dict)
        return ordered_dict

    if isinstance(data, list):
        try:
            iter(data[0])
        except TypeError:
            pass
        else:
            data = np.concatenate(data)

    if (isinstance(data, np.ndarray) and data.size == 1
            and not always_show_all_stats):
        return OrderedDict({name: float(data)})

    stats = OrderedDict([
        (name + ' Mean', np.mean(data)),
        (name + ' Std', np.std(data)),
    ])
    if not exclude_max_min:
        stats[name + ' Max'] = np.max(data)
        stats[name + ' Min'] = np.min(data)
    return stats


class MLP(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_width=256, activation=nn.ReLU):
        super().__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim
        self.activation = activation

        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_width),
            self.activation(),
            nn.Linear(hidden_width, hidden_width),
            self.activation(),
            nn.Linear(hidden_width, hidden_width),
            self.activation(),
            nn.Linear(hidden_width, output_dim),
        )

    def forward(self, x):
        return self.net(x)


class CQL(nn.Module):
    def __init__(
            self,
            state_dim,
            act_dim,
            policy_lr=1e-4,
            qf_lr=1e-4,
            discount=0.95,
            reward_scale=1.0,
            min_q_version=3,
            temp=0.1,
            min_q_weight=1.0,
            max_q_backup=False,
            deterministic_backup=True,
            num_random=10,
            with_lagrange=False,
            lagrange_thresh=5.0,
            use_automatic_entropy_tuning=False,
            target_entropy=None,
            policy_bc_loss_steps=0,
            soft_target_tau=1e-5,
            device="cuda:0"
    ):
        super().__init__()

        self.state_dim = state_dim
        self.act_dim = act_dim
        self.device = device

        self.discount = discount
        self.reward_scale = reward_scale

        self.qf1 = MLP(input_dim=state_dim + act_dim, output_dim=1)
        self.qf2 = MLP(input_dim=state_dim + act_dim, output_dim=1)

        self.target_qf1 = MLP(input_dim=state_dim + act_dim, output_dim=1)
        self.target_qf2 = MLP(input_dim=state_dim + act_dim, output_dim=1)

        self.target_qf1.load_state_dict(self.qf1.state_dict())
        self.target_qf2.load_state_dict(self.qf2.state_dict())

        self.policy = MLP(input_dim=state_dim, output_dim=act_dim * 2)
        
        self.policy_optimizer = Adam(self.policy.parameters(), lr=policy_lr)
        self.qf1_optimizer = Adam(self.qf1.parameters(), lr=qf_lr)
        self.qf2_optimizer = Adam(self.qf2.parameters(), lr=qf_lr)

        # min Q
        self.temp = temp
        self.min_q_version = min_q_version
        self.min_q_weight = min_q_weight
        self.soft_target_tau = soft_target_tau

        self.softmax = torch.nn.Softmax(dim=1)
        self.softplus = torch.nn.Softplus(beta=self.temp, threshold=20)

        self.max_q_backup = max_q_backup
        self.deterministic_backup = deterministic_backup
        self.num_random = num_random
        
        self.qf_criterion = nn.MSELoss()
        self.vf_criterion = nn.MSELoss()

        self.use_automatic_entropy_tuning = use_automatic_entropy_tuning
        if self.use_automatic_entropy_tuning:
            if target_entropy:
                self.target_entropy = target_entropy
            else:
                self.target_entropy = -np.prod(self.act_dim).item()
            self.log_alpha = torch.zeros(1, requires_grad=True)
            self.alpha_optimizer = Adam([self.log_alpha], lr=policy_lr)
            
        self.with_lagrange = with_lagrange
        if self.with_lagrange:
            self.target_action_gap = lagrange_thresh
            self.log_alpha_prime = torch.zeros(1, requires_grad=True)
            self.alpha_prime_optimizer = Adam([self.log_alpha_prime], lr=qf_lr)

        self.policy_bc_loss_steps = policy_bc_loss_steps

        self._need_to_update_eval_statistics = True
        self.eval_statistics = OrderedDict()
        self._current_epoch = 0
        self._num_q_update_steps = 0
        self._num_policy_update_steps = 0
        self._n_train_steps_total = 0

    def _get_tensor_values(self, obs, actions, network=None):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(torch.cat([obs_temp, actions], dim=-1).to(self.device))
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_dist(self, obs):
        dist_params = self.policy(obs.to(torch.float32).to(self.device))
        mean = dist_params[:, 0:self.act_dim]
        mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX)

        log_std = dist_params[:, self.act_dim:]
        log_std = torch.clamp(log_std, LOG_SIG_MIN, LOG_SIG_MAX)
        std = torch.exp(log_std)

        policy_dist = torch.distributions.Normal(mean, std)

        return policy_dist, mean, log_std

    def _get_policy_actions(self, obs, num_actions, reparameterize=False):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        policy_dist, mean, log_std = self._get_policy_dist(obs_temp.to(self.device))

        if reparameterize:
            actions = policy_dist.rsample()
        else:
            actions = policy_dist.sample()

        logits = policy_dist.log_prob(actions)

        if policy_dist.__class__ == torch.distributions.Normal:
            logits = logits.sum(dim=-1, keepdim=True)

        return actions, logits.view(obs.shape[0], num_actions, 1), mean, log_std

    def train_from_torch(self, batch):
        self._current_epoch += 1
        
        rewards = batch['rewards']
        terminals = batch['terminals']
        obs = batch['observations']
        actions = batch['actions']
        next_obs = batch['next_observations']

        """
        Policy and Alpha Loss
        """
        if self._current_epoch % 2 == 0:
            """
            For the initial few epochs, try doing behaivoral cloning, if needed
            conventionally, there's not much difference in performance with having 20k 
            gradient steps here, or not having it
            """

            policy_dist, policy_mean, policy_log_std = self._get_policy_dist(obs.to(torch.float32))
            policy_log_prob = policy_dist.log_prob(actions)

            if policy_dist.__class__ == torch.distributions.Normal:
                policy_log_prob = policy_log_prob.sum(dim=-1, keepdim=True)

            policy_loss = -policy_log_prob.mean()
            log_pi = policy_log_prob
        else:
            new_obs_actions, log_pi, policy_mean, policy_log_std = self._get_policy_actions(obs, 1, reparameterize=True)

            if self.use_automatic_entropy_tuning:
                alpha_loss = -(self.log_alpha.to(log_pi.device) * (log_pi + self.target_entropy).detach()).mean()
                self.alpha_optimizer.zero_grad()
                alpha_loss.backward()
                self.alpha_optimizer.step()
                alpha = self.log_alpha.exp()
            else:
                alpha_loss = 0
                alpha = torch.tensor(0.1, dtype=torch.float)

            q_new_actions = torch.min(
                self.qf1(torch.cat([obs, new_obs_actions], dim=-1)),
                self.qf2(torch.cat([obs, new_obs_actions], dim=-1)),
            )

            policy_loss = (alpha.to(log_pi.device) * log_pi - q_new_actions).mean()

        """
        QF Loss
        """

        if self._current_epoch > self.policy_bc_loss_steps:
            q1_pred = self.qf1(torch.cat([obs, actions], dim=-1))
            q2_pred = self.qf2(torch.cat([obs, actions], dim=-1))

            new_next_actions, new_log_pi, _, _ = self._get_policy_actions(next_obs, 1, reparameterize=True)

            new_curr_actions, new_curr_log_pi, _, _ = self._get_policy_actions(obs, 1, reparameterize=True)

            if not self.max_q_backup:
                target_q_values = torch.min(
                    self.target_qf1(torch.cat([next_obs, new_next_actions], dim=-1)),
                    self.target_qf2(torch.cat([next_obs, new_next_actions], dim=-1)),
                )

                if not self.deterministic_backup:
                    target_q_values = target_q_values - alpha * new_log_pi

            if self.max_q_backup:
                """when using max q backup"""
                next_actions_temp, _, _, _ = self._get_policy_actions(next_obs, num_actions=10)
                target_qf1_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf1).max(1)[
                    0].view(-1, 1)
                target_qf2_values = self._get_tensor_values(next_obs, next_actions_temp, network=self.target_qf2).max(1)[
                    0].view(-1, 1)
                target_q_values = torch.min(target_qf1_values, target_qf2_values)

            q_target = self.reward_scale * rewards + (1. - terminals) * self.discount * target_q_values
            q_target = q_target.detach()

            qf1_loss = self.qf_criterion(q1_pred, q_target)
            qf2_loss = self.qf_criterion(q2_pred, q_target)

            # add CQL
            random_actions_tensor = torch.FloatTensor(
                q2_pred.shape[0] * self.num_random, actions.shape[-1]).uniform_(-1,1).to(q2_pred.device)
            curr_actions_tensor, curr_log_pis, _, _ = self._get_policy_actions(obs, num_actions=self.num_random)
            new_curr_actions_tensor, new_log_pis, _, _ = self._get_policy_actions(next_obs, num_actions=self.num_random)
            q1_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf1)
            q2_rand = self._get_tensor_values(obs, random_actions_tensor, network=self.qf2)
            q1_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf1)
            q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
            q1_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf1)
            q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)

            cat_q1 = torch.cat(
                [q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1
            )
            cat_q2 = torch.cat(
                [q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1
            )
            std_q1 = torch.std(cat_q1, dim=1)
            std_q2 = torch.std(cat_q2, dim=1)

            if self.min_q_version == 3:
                random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
                cat_q1 = torch.cat(
                    [q1_rand - random_density, q1_next_actions - new_log_pis.detach(),
                     q1_curr_actions - curr_log_pis.detach()], 1
                )
                cat_q2 = torch.cat(
                    [q2_rand - random_density, q2_next_actions - new_log_pis.detach(),
                     q2_curr_actions - curr_log_pis.detach()], 1
                )

            min_qf1_loss = torch.logsumexp(cat_q1 / self.temp, dim=1, ).mean() * self.min_q_weight * self.temp
            min_qf2_loss = torch.logsumexp(cat_q2 / self.temp, dim=1, ).mean() * self.min_q_weight * self.temp

            """Subtract the log likelihood of data"""
            min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.min_q_weight
            min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.min_q_weight

            if self.with_lagrange:
                alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
                min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
                min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)

                self.alpha_prime_optimizer.zero_grad()
                alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
                alpha_prime_loss.backward(retain_graph=True)
                self.alpha_prime_optimizer.step()

            qf1_loss = qf1_loss + min_qf1_loss
            qf2_loss = qf2_loss + min_qf2_loss

        """
        Update networks
        """
        self.policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=False)
        self.policy_optimizer.step()
        self._num_policy_update_steps += 1

        if self._current_epoch > self.policy_bc_loss_steps:
            self.qf1_optimizer.zero_grad()
            qf1_loss.backward(retain_graph=True)
            self.qf1_optimizer.step()

            self.qf2_optimizer.zero_grad()
            qf2_loss.backward(retain_graph=True)
            self.qf2_optimizer.step()
            self._num_q_update_steps += 1
        else:
            qf1_loss = 0
            qf2_loss = 0


        """
        Soft Updates
        """
        for target_param, param in zip(self.target_qf1.parameters(), self.qf1.parameters()):
            target_param.data = (1 - self.soft_target_tau) * target_param + self.soft_target_tau * param.data

        for target_param, param in zip(self.target_qf2.parameters(), self.qf2.parameters()):
            target_param.data = (1 - self.soft_target_tau) * target_param + self.soft_target_tau * param.data

        """
        Save some statistics for eval
        """
        if self._need_to_update_eval_statistics:
            """
            Eval should set this to None.
            This way, these statistics are only computed for one batch.
            """

            self.eval_statistics.update(create_stats_ordered_dict(
                'actions',
                actions.detach().cpu().numpy()
            ))
            self.eval_statistics.update(create_stats_ordered_dict(
                'rewards',
                rewards.detach().cpu().numpy()
            ))

            self.eval_statistics['Num Q Updates'] = self._num_q_update_steps
            self.eval_statistics['Num Policy Updates'] = self._num_policy_update_steps
            self.eval_statistics['training/Policy Loss'] = np.mean(
                policy_loss.detach().cpu().numpy()
            )
            if self._current_epoch > self.policy_bc_loss_steps:
                self.eval_statistics['training/QF1 Loss'] = np.mean(qf1_loss.detach().cpu().numpy())
                self.eval_statistics['training/min QF1 Loss'] = np.mean(min_qf1_loss.detach().cpu().numpy())
                self.eval_statistics['training/QF2 Loss'] = np.mean(qf2_loss.detach().cpu().numpy())
                self.eval_statistics['training/min QF2 Loss'] = np.mean(min_qf2_loss.detach().cpu().numpy())

                self.eval_statistics['Std QF1 values'] = np.mean(std_q1.detach().cpu().numpy())
                self.eval_statistics['Std QF2 values'] = np.mean(std_q2.detach().cpu().numpy())
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 in-distribution values',
                    q1_curr_actions.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 in-distribution values',
                    q2_curr_actions.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 random values',
                    q1_rand.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 random values',
                    q2_rand.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF1 next_actions values',
                    q1_next_actions.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'QF2 next_actions values',
                    q2_next_actions.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q1 Predictions',
                    q1_pred.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q2 Predictions',
                    q2_pred.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Q Targets',
                    q_target.detach().cpu().numpy(),
                ))

                self.eval_statistics.update(create_stats_ordered_dict(
                    'Log Pis',
                    log_pi.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy mu',
                    policy_mean.detach().cpu().numpy(),
                ))
                self.eval_statistics.update(create_stats_ordered_dict(
                    'Policy log std',
                    policy_log_std.detach().cpu().numpy(),
                ))

                if self.use_automatic_entropy_tuning:
                    self.eval_statistics['Alpha'] = alpha.item()
                    self.eval_statistics['training/Alpha Loss'] = alpha_loss.item()

                if self.with_lagrange:
                    self.eval_statistics['Alpha_prime'] = alpha_prime.item()
                    self.eval_statistics['training/min_q1_loss'] = min_qf1_loss.mean().item(0)
                    self.eval_statistics['training/min_q2_loss'] = min_qf2_loss.mean().item()
                    self.eval_statistics['threshold action gap'] = self.target_action_gap
                    self.eval_statistics['alpha prime loss'] = alpha_prime_loss.item()

        self._n_train_steps_total += 1

        return qf1_loss + qf2_loss + policy_loss, self.eval_statistics
    
    def get_diagnostics(self):
        return self.eval_statistics

