import torch
import copy
import numpy as np
from copy import deepcopy
from loguru import logger
from torch.functional import F

from tianshou.data import Batch

from offlinerl.algo.base import BaseAlgo
from offlinerl.utils.data import to_torch, sample
from offlinerl.utils.net.common import MLP, Net, Swish
from offlinerl.utils.net.continuous import Critic
from offlinerl.utils.net.tanhpolicy import TanhGaussianPolicy
from offlinerl.utils.exp import setup_seed

from offlinerl.algo.brac import divergences
import os
import torch
import copy
import numpy as np
from copy import deepcopy
from loguru import logger
from torch.functional import F

from tianshou.data import Batch

from offlinerl.algo.base import BaseAlgo
from offlinerl.utils.data import to_torch, sample
from offlinerl.utils.net.common import MLP, Net, Swish
from offlinerl.utils.net.continuous import Critic
from offlinerl.utils.net.tanhpolicy import TanhGaussianPolicy
from offlinerl.utils.exp import setup_seed

from offlinerl.algo.brac.divergence2 import _kl_divergence
from offlinerl.algo.brac import utils


# need to import pretrained policy


def algo_init(args):
    logger.info('Run algo_init function')

    setup_seed(args['seed'])

    if args["obs_shape"] and args["action_shape"]:
        obs_shape, action_shape = args["obs_shape"], args["action_shape"]
    elif "task" in args.keys():
        from offlinerl.utils.env import get_env_shape
        obs_shape, action_shape = get_env_shape(args['task'])
        args["obs_shape"], args["action_shape"] = obs_shape, action_shape
    else:
        raise NotImplementedError

    # transition is the dynamics model
    transition = EnsembleTransition(obs_shape, action_shape, args['hidden_layer_size'], args['transition_layers'],
                                    args['transition_init_num']).to(args['device'])
    transition_optim = torch.optim.Adam(transition.parameters(), lr=args['transition_lr'], weight_decay=0.000075)

    # actor
    net_a = Net(layer_num=args['hidden_layers'],
                state_shape=obs_shape,
                hidden_layer_size=args['hidden_layer_size'])

    actor = TanhGaussianPolicy(preprocess_net=net_a,
                               action_shape=action_shape,
                               hidden_layer_size=args['hidden_layer_size'],
                               conditioned_sigma=True).to(args['device'])

    actor_optim = torch.optim.Adam(actor.parameters(), lr=args['actor_lr'])

    net_c = Net(layer_num=args['hidden_layers'],
                state_shape=obs_shape,
                action_shape=action_shape,
                concat=True,
                hidden_layer_size=args['hidden_layer_size'])
    c_optim = torch.optim.Adam(net_c.parameters(), lr=args['c_lr'])

    net_c1 = Net(layer_num=args['hidden_layers'],
                 state_shape=obs_shape,
                 action_shape=action_shape,
                 concat=True,
                 hidden_layer_size=args['hidden_layer_size'])
    critic1 = Critic(preprocess_net=net_c1,
                     hidden_layer_size=args['hidden_layer_size'],
                     ).to(args['device'])
    critic1_optim = torch.optim.Adam(critic1.parameters(), lr=args['critic_lr'])

    net_c2 = Net(layer_num=args['hidden_layers'],
                 state_shape=obs_shape,
                 action_shape=action_shape,
                 concat=True,
                 hidden_layer_size=args['hidden_layer_size'])
    critic2 = Critic(preprocess_net=net_c2,
                     hidden_layer_size=args['hidden_layer_size'],
                     ).to(args['device'])
    critic2_optim = torch.optim.Adam(critic2.parameters(), lr=args['critic_lr'])

    if args["use_automatic_entropy_tuning"]:
        if args["target_entropy"]:
            target_entropy = args["target_entropy"]
        else:
            target_entropy = -np.prod(args["action_shape"]).item()
        log_alpha = torch.zeros(1, requires_grad=True, device=args['device'])
        alpha_optimizer = torch.optim.Adam(
            [log_alpha],
            lr=args["actor_lr"],
        )

    # if args["target_entropy"]:
    #     target_entropy = args["target_entropy"]
    # else:
    #     target_entropy = -np.prod(args["action_shape"]).item()
    # log_alpha = torch.zeros(1, requires_grad=True, device=args['device'])
    # alpha_optimizer = torch.optim.Adam(
    #     [log_alpha],
    #     lr=args["actor_lr"],
    # )

    nets = {
        "transition": {"net": transition, "opt": transition_optim},
        "actor": {"net": actor, "opt": actor_optim},
        "c": {"net": net_c, "opt": c_optim},
        "critic1": {"net": critic1, "opt": critic1_optim},
        "critic2": {"net": critic2, "opt": critic2_optim},
        "log_alpha": {"net": log_alpha, "opt": alpha_optimizer, "target_entropy": target_entropy},

    }

    if args["lagrange_thresh"] >= 0:
        target_action_gap = args["lagrange_thresh"]
        log_alpha_prime = torch.zeros(1, requires_grad=True, device=args['device'])
        alpha_prime_optimizer = torch.optim.Adam(
            [log_alpha_prime],
            lr=args["critic_lr"],
        )

        nets.update({"log_alpha_prime": {"net": log_alpha_prime, "opt": alpha_prime_optimizer}})

    return nets


'''''
    # what is log_alpha
    log_alpha = torch.zeros(1, requires_grad=True, device=args['device'])
    alpha_optimizer = torch.optim.Adam([log_alpha], lr=args["actor_lr"])

    #critic
    q1 = MLP(obs_shape + action_shape, 1, args['hidden_layer_size'], args['hidden_layers'], norm=None,
             hidden_activation='swish').to(args['device'])
    q2 = MLP(obs_shape + action_shape, 1, args['hidden_layer_size'], args['hidden_layers'], norm=None,
             hidden_activation='swish').to(args['device'])
    critic_optim = torch.optim.Adam([*q1.parameters(), *q2.parameters()], lr=args['actor_lr'])

    return {
        "transition": {"net": transition, "opt": transition_optim},
        "actor": {"net": actor, "opt": actor_optim},
        "log_alpha": {"net": log_alpha, "opt": alpha_optimizer},
        "critic": {"net": [q1, q2], "opt": critic_optim},
    }
'''''


def soft_clamp(x: torch.Tensor, _min=None, _max=None):
    # clamp tensor values while mataining the gradient
    if _max is not None:
        x = _max - F.softplus(_max - x)
    if _min is not None:
        x = _min + F.softplus(x - _min)
    return x


# transition model

class EnsembleLinear(torch.nn.Module):
    def __init__(self, in_features, out_features, ensemble_size=1):
        super().__init__()

        self.ensemble_size = ensemble_size

        self.register_parameter('weight', torch.nn.Parameter(torch.zeros(ensemble_size, in_features, out_features)))
        self.register_parameter('bias', torch.nn.Parameter(torch.zeros(ensemble_size, 1, out_features)))

        torch.nn.init.trunc_normal_(self.weight, std=1 / (2 * in_features ** 0.5))

        self.select = list(range(0, self.ensemble_size))

    def forward(self, x):
        weight = self.weight[self.select]
        bias = self.bias[self.select]

        if len(x.shape) == 2:
            x = torch.einsum('ij,bjk->bik', x, weight)
        else:
            x = torch.einsum('bij,bjk->bik', x, weight)

        x = x + bias

        return x

    def set_select(self, indexes):
        assert len(indexes) <= self.ensemble_size and max(indexes) < self.ensemble_size
        self.select = indexes


class EnsembleTransition(torch.nn.Module):
    def __init__(self, obs_dim, action_dim, hidden_features, hidden_layers, ensemble_size=1, mode='local',
                 with_reward=True):
        super().__init__()
        self.obs_dim = obs_dim
        self.mode = mode
        self.with_reward = with_reward
        self.ensemble_size = ensemble_size

        self.activation = Swish()

        module_list = []
        for i in range(hidden_layers):
            if i == 0:
                module_list.append(EnsembleLinear(obs_dim + action_dim, hidden_features, ensemble_size))
            else:
                module_list.append(EnsembleLinear(hidden_features, hidden_features, ensemble_size))
        self.backbones = torch.nn.ModuleList(module_list)

        self.output_layer = EnsembleLinear(hidden_features, 2 * (obs_dim + self.with_reward), ensemble_size)

        self.register_parameter('max_logstd',
                                torch.nn.Parameter(torch.ones(obs_dim + self.with_reward) * 1, requires_grad=True))
        self.register_parameter('min_logstd',
                                torch.nn.Parameter(torch.ones(obs_dim + self.with_reward) * -5, requires_grad=True))

    def forward(self, obs_action):
        output = obs_action
        for layer in self.backbones:
            output = self.activation(layer(output))
        mu, logstd = torch.chunk(self.output_layer(output), 2, dim=-1)
        logstd = soft_clamp(logstd, self.min_logstd, self.max_logstd)
        if self.mode == 'local':
            if self.with_reward:
                obs, reward = torch.split(mu, [self.obs_dim, 1], dim=-1)
                obs = obs + obs_action[..., :self.obs_dim]
                mu = torch.cat([obs, reward], dim=-1)
            else:
                mu = mu + obs_action[..., :self.obs_dim]
        return torch.distributions.Normal(mu, torch.exp(logstd))

    def set_select(self, indexes):
        for layer in self.backbones:
            layer.set_select(indexes)
        self.output_layer.set_select(indexes)


# data buffer
class COMBOBuffer:
    def __init__(self, buffer_size):
        self.data = None
        self.buffer_size = int(buffer_size)

    def put(self, batch_data):
        batch_data.to_torch(device='cpu')

        if self.data is None:
            self.data = batch_data
        else:
            self.data.cat_(batch_data)

        if len(self) > self.buffer_size:
            self.data = self.data[len(self) - self.buffer_size:]

    def __len__(self):
        if self.data is None: return 0
        return self.data.shape[0]

    def sample(self, batch_size):
        indexes = np.random.randint(0, len(self), size=(batch_size))
        return self.data[indexes]


class AlgoTrainer(BaseAlgo):
    def __init__(self, algo_init, args):
        super(AlgoTrainer, self).__init__(args)
        self.args = args

        # initialize transition model
        self.transition = algo_init['transition']['net']
        self.transition_optim = algo_init['transition']['opt']
        self.selected_transitions = None

        # initialize actor
        self.actor = algo_init['actor']['net']
        self.actor_optim = algo_init['actor']['opt']

        # initialize alpha
        self.log_alpha = algo_init['log_alpha']['net']
        self.log_alpha_optim = algo_init['log_alpha']['opt']

        self.c = algo_init['c']['net']
        self.c_optim = algo_init['c']['opt']

        # initialize critic and target network, double Q
        self.critic1 = algo_init["critic1"]["net"]
        self.critic1_optim = algo_init["critic1"]["opt"]
        self.critic2 = algo_init["critic2"]["net"]
        self.critic2_optim = algo_init["critic2"]["opt"]
        self.critic1_target = copy.deepcopy(self.critic1)
        self.critic2_target = copy.deepcopy(self.critic2)

        if args["use_automatic_entropy_tuning"]:
            self.log_alpha = algo_init["log_alpha"]["net"]
            self.alpha_opt = algo_init["log_alpha"]["opt"]
            self.target_entropy = algo_init["log_alpha"]["target_entropy"]

        if self.args["lagrange_thresh"] >= 0:
            self.log_alpha_prime = algo_init["log_alpha_prime"]["net"]
            self.alpha_prime_opt = algo_init["log_alpha_prime"]["opt"]

        self.critic_criterion = torch.nn.MSELoss()

        self._n_train_steps_total = 0
        self._current_epoch = 0
        # self.q1, self.q2 = algo_init['critic']['net']
        # self.target_q1 = deepcopy(self.q1)
        # self.target_q2 = deepcopy(self.q2)
        # self.critic_optim = algo_init['critic']['opt']

        self.device = args['device']
        # self._divergence_name=args['divergence_name']
        # self._divergence = divergences.get_divergence(name=self._divergence_name, c=self.c, device=self.device)
        self._divergence = _kl_divergence(self.device)
        self.beta = 1
        self.gamma = 0.5
        self._n_div_samples = 4

    # train model, then keep model, train policy
    def train(self, train_buffer, val_buffer, callback_fn):
        transition_path = os.path.join(os.path.expanduser('~'), 'OfflineRL-master', 'offlinerl_tmp',
                                       str(self.args['task']) + '.pt')
        if os.path.exists(transition_path):
            print("load transition")
            transition = self.load_transition(transition_path)
        else:
            print("train transition")
            transition = self.train_transition(train_buffer)

        model_path = os.path.join(os.path.expanduser('~'), 'OfflineRL-master', 'offlinerl_tmp',
                                  str(45) + '.pt')
        if os.path.exists(model_path):
            print("load meta_policy")
            meta_policy = self.load_model(model_path)
        else:
            print("meta_policy is not found")

        transition.requires_grad_(False)
        policy = self.train_policy(train_buffer, val_buffer, transition, meta_policy, callback_fn)

    def get_policy(self):
        return self.actor

    def get_transition(self):
        return self.transition

    # transition model training, train with train_buffer, test with valdata, select one model
    def train_transition(self, buffer):
        data_size = len(buffer)
        val_size = min(int(data_size * 0.2) + 1, 1000)
        train_size = data_size - val_size
        train_splits, val_splits = torch.utils.data.random_split(range(data_size), (train_size, val_size))
        train_buffer = buffer[train_splits.indices]
        valdata = buffer[val_splits.indices]
        batch_size = self.args['transition_batch_size']

        val_losses = [float('inf') for i in range(self.transition.ensemble_size)]

        epoch = 0
        cnt = 0
        while True:
            idxs = np.random.randint(train_buffer.shape[0], size=[self.transition.ensemble_size, train_buffer.shape[0]])
            for batch_num in range(int(np.ceil(idxs.shape[-1] / batch_size))):
                print("iteration: ", int(np.ceil(idxs.shape[-1] / batch_size)))
                print("batch_num: ", batch_num)
                batch_idxs = idxs[:, batch_num * batch_size:(batch_num + 1) * batch_size]
                batch = train_buffer[batch_idxs]
                self._train_transition(self.transition, batch, self.transition_optim)
            new_val_losses = self._eval_transition(self.transition, valdata)
            print(new_val_losses)

            # stop when testing loss is not decreasing
            change = False
            for i, new_loss, old_loss in zip(range(len(val_losses)), new_val_losses, val_losses):
                if new_loss < old_loss:
                    change = True
                    val_losses[i] = new_loss

            if change:
                cnt = 0
            else:
                cnt += 1

            if cnt >= 0:  # 5
                break

        val_losses = self._eval_transition(self.transition, valdata)
        indexes = self._select_best_indexes(val_losses, n=self.args['transition_select_num'])
        self.transition.set_select(indexes)
        self.log_transition(self.args['task'])
        return self.transition

    # train policy, iterating data collection with updated policy and policy update
    def train_policy(self, train_buffer, val_buffer, transition, meta_policy, callback_fn):
        real_batch_size = int(self.args['policy_batch_size'] * self.args['real_data_ratio'])
        model_batch_size = self.args['policy_batch_size'] - real_batch_size

        model_buffer = COMBOBuffer(self.args['buffer_size'])

        for epoch in range(self.args['max_epoch']):
            # collect data, no transition model update
            print("collect data. epoch: ", epoch)
            with torch.no_grad():
                obs = train_buffer.sample(int(self.args['data_collection_per_epoch']))['obs']
                obs = torch.tensor(obs, device=self.device)
                for t in range(self.args['horizon']):
                    print("step: ", t)
                    action = self.actor(obs).sample()
                    obs_action = torch.cat([obs, action], dim=-1)
                    next_obs_dists = transition(obs_action)
                    next_obses = next_obs_dists.sample()
                    rewards = next_obses[:, :, -1:]
                    next_obses = next_obses[:, :, :-1]

                    # model uncertainty estimation for MOPO
                    # next_obses_mode = next_obs_dists.mean[:, :, :-1]
                    # next_obs_mean = torch.mean(next_obses_mode, dim=0)
                    # diff = next_obses_mode - next_obs_mean
                    # disagreement_uncertainty = torch.max(torch.norm(diff, dim=-1, keepdim=True), dim=0)[0]
                    # aleatoric_uncertainty = torch.max(torch.norm(next_obs_dists.stddev, dim=-1, keepdim=True), dim=0)[0]
                    # uncertainty = disagreement_uncertainty if self.args[
                    #                                              'uncertainty_mode'] == 'disagreement' else aleatoric_uncertainty

                    model_indexes = np.random.randint(0, next_obses.shape[0], size=(obs.shape[0]))
                    next_obs = next_obses[model_indexes, np.arange(obs.shape[0])]
                    reward = rewards[model_indexes, np.arange(obs.shape[0])]

                    print('average reward:', reward.mean().item())
                    # print('average uncertainty:', uncertainty.mean().item())

                    # penalized_reward = reward - self.args['lam'] * uncertainty
                    dones = torch.zeros_like(reward)

                    batch_data = Batch({
                        "obs": obs.cpu(),
                        "act": action.cpu(),
                        "rew": reward.cpu(),
                        "done": dones.cpu(),
                        "obs_next": next_obs.cpu(),
                    })

                    model_buffer.put(batch_data)

                    obs = next_obs

            # update
            for _ in range(self.args['steps_per_epoch']):
                print("steps in epoch: ", _)
                batch = train_buffer.sample(real_batch_size)
                model_batch = model_buffer.sample(model_batch_size)
                batch.cat_(model_batch)
                batch.to_torch(device=self.device)

                # self._sac_update(batch)
                self._cql_update(batch, meta_policy)

            res = callback_fn(self.get_policy())

            # res['uncertainty'] = uncertainty.mean().item()
            # res['disagreement_uncertainty'] = disagreement_uncertainty.mean().item()
            # res['aleatoric_uncertainty'] = aleatoric_uncertainty.mean().item()
            res['reward'] = reward.mean().item()
            self.log_res(epoch, res)

        return self.get_policy()

    def _sac_update(self, batch_data):
        obs = batch_data['obs']
        action = batch_data['act']
        next_obs = batch_data['obs_next']
        reward = batch_data['rew']
        done = batch_data['done']

        # update critic
        obs_action = torch.cat([obs, action], dim=-1)
        _q1 = self.q1(obs_action)
        _q2 = self.q2(obs_action)

        with torch.no_grad():
            next_action_dist = self.actor(next_obs)
            next_action = next_action_dist.sample()
            log_prob = next_action_dist.log_prob(next_action).sum(dim=-1, keepdim=True)
            next_obs_action = torch.cat([next_obs, next_action], dim=-1)
            _target_q1 = self.target_q1(next_obs_action)
            _target_q2 = self.target_q2(next_obs_action)
            alpha = torch.exp(self.log_alpha)
            y = reward + self.args['discount'] * (1 - done) * (torch.min(_target_q1, _target_q2) - alpha * log_prob)

        critic_loss = ((y - _q1) ** 2).mean() + ((y - _q2) ** 2).mean()

        self.critic_optim.zero_grad()
        critic_loss.backward()
        self.critic_optim.step()

        # soft target update
        self._sync_weight(self.target_q1, self.q1, soft_target_tau=self.args['soft_target_tau'])
        self._sync_weight(self.target_q2, self.q2, soft_target_tau=self.args['soft_target_tau'])

        if self.args['learnable_alpha']:
            # update alpha
            alpha_loss = - torch.mean(self.log_alpha * (log_prob + self.args['target_entropy']).detach())

            self.log_alpha_optim.zero_grad()
            alpha_loss.backward()
            self.log_alpha_optim.step()

        # update actor
        action_dist = self.actor(obs)
        new_action = action_dist.rsample()
        action_log_prob = action_dist.log_prob(new_action)
        new_obs_action = torch.cat([obs, new_action], dim=-1)
        q = torch.min(self.q1(new_obs_action), self.q2(new_obs_action))
        actor_loss = - q.mean() + torch.exp(self.log_alpha) * action_log_prob.sum(dim=-1).mean()

        self.actor_optim.zero_grad()
        actor_loss.backward()
        self.actor_optim.step()

    def forward_cql(self, obs, reparameterize=True, return_log_prob=True):
        log_prob = None
        tanh_normal = self.actor(obs, reparameterize=reparameterize, )

        if return_log_prob:
            if reparameterize is True:
                action, pre_tanh_value = tanh_normal.rsample(
                    return_pretanh_value=True
                )
            else:
                action, pre_tanh_value = tanh_normal.sample(
                    return_pretanh_value=True
                )
            log_prob = tanh_normal.log_prob(
                action,
                pre_tanh_value=pre_tanh_value
            )
            log_prob = log_prob.sum(dim=1, keepdim=True)
        else:
            if reparameterize is True:
                action = tanh_normal.rsample()
            else:
                action = tanh_normal.sample()
        return action, log_prob

    def _get_tensor_values_cql(self, obs, actions, network):
        action_shape = actions.shape[0]
        obs_shape = obs.shape[0]
        num_repeat = int(action_shape / obs_shape)
        obs_temp = obs.unsqueeze(1).repeat(1, num_repeat, 1).view(obs.shape[0] * num_repeat, obs.shape[1])
        preds = network(obs_temp, actions)
        preds = preds.view(obs.shape[0], num_repeat, 1)
        return preds

    def _get_policy_actions_cql(self, obs, num_actions, network=None):
        obs_temp = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
        new_obs_actions, new_obs_log_pi = network(
            obs_temp, reparameterize=True, return_log_prob=True,
        )
        if not self.args["discrete"]:
            return new_obs_actions, new_obs_log_pi.view(obs.shape[0], num_actions, 1)
        else:
            return new_obs_actions

    def _cql_update(self, batch_data, meta_policy):
        obs = batch_data['obs']
        actions = batch_data['act']
        next_obs = batch_data['obs_next']
        rewards = batch_data['rew']
        terminals = batch_data['done']

        """
        Policy and Alpha Loss
        """
        new_obs_actions, log_pi = self.forward_cql(obs)

        if self.args["use_automatic_entropy_tuning"]:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
            self.alpha_opt.zero_grad()
            alpha_loss.backward()
            self.alpha_opt.step()
            alpha = self.log_alpha.exp()
        else:
            alpha_loss = 0
            alpha = 1

        # if self._current_epoch < self.args["policy_bc_steps"]:
        #     """
        #     For the initial few epochs, try doing behaivoral cloning, if needed
        #     conventionally, there's not much difference in performance with having 20k
        #     gradient steps here, or not having it
        #     """
        #     policy_log_prob = self.actor.log_prob(obs, actions)
        #     policy_loss = (alpha * log_pi - policy_log_prob).mean()
        # else:
        #     q_new_actions = torch.min(
        #         self.critic1(obs, new_obs_actions),
        #         self.critic2(obs, new_obs_actions),
        #     )
        #
        #     policy_loss = (alpha * log_pi - q_new_actions).mean()
        q_new_actions = torch.min(
            self.critic1(obs, new_obs_actions),
            self.critic2(obs, new_obs_actions),
        )
        # policy_loss = (alpha * log_pi - q_new_actions).mean()

        # evaluate divergence between pi and pi_beta without estimating pi_beta
        # div_behavior = self._divergence.dual_estimate(obs, new_obs_actions, actions, self._c_fn)

        # evaluate divergence between pi and pi_meta
        div_meta = self._kl_divergence_meta(obs, self.actor, meta_policy)

        div_beha = self._kl_divergence_beha(obs, actions, self.actor)

        # need to assign beta and gamma
        policy_loss = (alpha * log_pi + self.beta * self.gamma * div_meta + self.beta * (1-self.gamma) * div_beha- q_new_actions).mean()

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()

        """
        QF Loss
        """
        q1_pred = self.critic1(obs, actions)
        q2_pred = self.critic2(obs, actions)

        new_next_actions, new_log_pi = self.forward_cql(
            next_obs, reparameterize=True, return_log_prob=True,
        )
        new_curr_actions, new_curr_log_pi = self.forward_cql(
            obs, reparameterize=True, return_log_prob=True,
        )

        if self.args["type_q_backup"] == "max":
            target_q_values = torch.max(
                self.critic1_target(next_obs, new_next_actions),
                self.critic2_target(next_obs, new_next_actions),
            )
            target_q_values = target_q_values - alpha * new_log_pi

        elif self.args["type_q_backup"] == "min":
            target_q_values = torch.min(
                self.critic1_target(next_obs, new_next_actions),
                self.critic2_target(next_obs, new_next_actions),
            )
            target_q_values = target_q_values - alpha * new_log_pi
        elif self.args["type_q_backup"] == "medium":
            target_q1_next = self.critic1_target(next_obs, new_next_actions)
            target_q2_next = self.critic2_target(next_obs, new_next_actions)
            target_q_values = self.args["q_backup_lmbda"] * torch.min(target_q1_next, target_q2_next) \
                              + (1 - self.args["q_backup_lmbda"]) * torch.max(target_q1_next, target_q2_next)
            target_q_values = target_q_values - alpha * new_log_pi

        else:
            """when using max q backup"""
            next_actions_temp, _ = self._get_policy_actions_cql(next_obs, num_actions=10, network=self.forward_cql)
            target_qf1_values = self._get_tensor_values_cql(next_obs, next_actions_temp, network=self.critic1).max(1)[
                0].view(-1, 1)
            target_qf2_values = self._get_tensor_values_cql(next_obs, next_actions_temp, network=self.critic2).max(1)[
                0].view(-1, 1)
            target_q_values = torch.min(target_qf1_values, target_qf2_values)

        q_target = self.args["reward_scale"] * rewards + (1. - terminals) * self.args[
            "discount"] * target_q_values.detach()

        qf1_loss = self.critic_criterion(q1_pred, q_target)
        qf2_loss = self.critic_criterion(q2_pred, q_target)

        ## add CQL
        random_actions_tensor = torch.FloatTensor(q2_pred.shape[0] * self.args["num_random"],
                                                  actions.shape[-1]).uniform_(-1, 1).to(self.args["device"])
        curr_actions_tensor, curr_log_pis = self._get_policy_actions_cql(obs, num_actions=self.args["num_random"],
                                                                         network=self.forward_cql)
        new_curr_actions_tensor, new_log_pis = self._get_policy_actions_cql(next_obs,
                                                                            num_actions=self.args["num_random"],
                                                                            network=self.forward_cql)
        q1_rand = self._get_tensor_values_cql(obs, random_actions_tensor, network=self.critic1)
        q2_rand = self._get_tensor_values_cql(obs, random_actions_tensor, network=self.critic2)
        q1_curr_actions = self._get_tensor_values_cql(obs, curr_actions_tensor, network=self.critic1)
        q2_curr_actions = self._get_tensor_values_cql(obs, curr_actions_tensor, network=self.critic2)
        q1_next_actions = self._get_tensor_values_cql(obs, new_curr_actions_tensor, network=self.critic1)
        q2_next_actions = self._get_tensor_values_cql(obs, new_curr_actions_tensor, network=self.critic2)

        cat_q1 = torch.cat([q1_rand, q1_pred.unsqueeze(1), q1_next_actions, q1_curr_actions], 1)
        cat_q2 = torch.cat([q2_rand, q2_pred.unsqueeze(1), q2_next_actions, q2_curr_actions], 1)

        if self.args["min_q_version"] == 3:
            # importance sammpled version
            random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
            cat_q1 = torch.cat(
                [q1_rand - random_density, q1_next_actions - new_log_pis.detach(),
                 q1_curr_actions - curr_log_pis.detach()], 1
            )
            cat_q2 = torch.cat(
                [q2_rand - random_density, q2_next_actions - new_log_pis.detach(),
                 q2_curr_actions - curr_log_pis.detach()], 1
            )

        min_qf1_loss = torch.logsumexp(cat_q1 / self.args["temp"], dim=1, ).mean() * self.args["min_q_weight"] * \
                       self.args["temp"]
        min_qf2_loss = torch.logsumexp(cat_q2 / self.args["temp"], dim=1, ).mean() * self.args["min_q_weight"] * \
                       self.args["temp"]

        """Subtract the log likelihood of data"""
        min_qf1_loss = min_qf1_loss - q1_pred.mean() * self.args["min_q_weight"]
        min_qf2_loss = min_qf2_loss - q2_pred.mean() * self.args["min_q_weight"]

        if self.args["lagrange_thresh"] >= 0:
            alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
            min_qf1_loss = alpha_prime * (min_qf1_loss - self.args["lagrange_thresh"])
            min_qf2_loss = alpha_prime * (min_qf2_loss - self.args["lagrange_thresh"])

            self.alpha_prime_opt.zero_grad()
            alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
            alpha_prime_loss.backward(retain_graph=True)
            self.alpha_prime_opt.step()

        qf1_loss = self.args["explore"] * qf1_loss + (2 - self.args["explore"]) * min_qf1_loss
        qf2_loss = self.args["explore"] * qf2_loss + (2 - self.args["explore"]) * min_qf2_loss

        """
        Update critic networks
        """
        self.critic1_optim.zero_grad()
        qf1_loss.backward(retain_graph=True)
        self.critic1_optim.step()

        self.critic2_optim.zero_grad()
        qf2_loss.backward()
        self.critic2_optim.step()

        """
        Soft Updates target network
        """
        self._sync_weight(self.critic1_target, self.critic1, self.args["soft_target_tau"])
        self._sync_weight(self.critic2_target, self.critic2, self.args["soft_target_tau"])

        # self._n_train_steps_total += 1

    def _select_best_indexes(self, metrics, n):
        pairs = [(metric, index) for metric, index in zip(metrics, range(len(metrics)))]
        pairs = sorted(pairs, key=lambda x: x[0])
        selected_indexes = [pairs[i][1] for i in range(n)]
        return selected_indexes

    def _train_transition(self, transition, data, optim):
        data.to_torch(device=self.device)
        dist = transition(torch.cat([data['obs'], data['act']], dim=-1))
        loss = - dist.log_prob(torch.cat([data['obs_next'], data['rew']], dim=-1))
        loss = loss.mean()

        loss = loss + 0.01 * transition.max_logstd.mean() - 0.01 * transition.min_logstd.mean()

        optim.zero_grad()
        loss.backward()
        optim.step()

    def _eval_transition(self, transition, valdata):
        with torch.no_grad():
            valdata.to_torch(device=self.device)
            dist = transition(torch.cat([valdata['obs'], valdata['act']], dim=-1))
            loss = ((dist.mean - torch.cat([valdata['obs_next'], valdata['rew']], dim=-1)) ** 2).mean(dim=(1, 2))
            return list(loss.cpu().numpy())

    def _kl_divergence_meta(self, state, policy_1, policy_2):
        mean1, std1 = policy_1.forward_std(state)
        mean2, std2 = policy_2.forward_std(state)
        # Compute KL over all states
        kl_matrix = ((torch.log(std2 / std1)) + 0.5 * (std1.pow(2)
                                                       + (mean1 - mean2).pow(2)) / std2.pow(2) - 0.5)

        # Sum over action dim, average over all states
        return kl_matrix.sum(1).mean()

    def _kl_divergence_beha(self, state, action, policy_1):
        distri = policy_1.forward(state)
        kl_matrix = - distri.log_prob(action)
        #print("kl_matrix: ", kl_matrix)
        return kl_matrix.sum(1).mean()