from collections import OrderedDict
from copy import deepcopy

from ml_collections import ConfigDict

import numpy as np
import torch
import torch.optim as optim
from torch import nn as nn
import torch.nn.functional as F

from .model import Scalar, soft_target_update
from .utils import prefix_metrics


class ConservativeSAC(object):

    @staticmethod
    def get_default_config(updates=None):
        config = ConfigDict()
        config.discount = 0.99
        config.alpha_multiplier = 1.0
        config.use_automatic_entropy_tuning = True
        config.backup_entropy = False
        config.target_entropy = 0.0
        config.policy_lr = 1e-4
        config.qf_lr = 3e-4
        config.optimizer_type = 'adam'
        config.soft_target_update_rate = 5e-3
        config.target_update_period = 1
        config.use_cql = True
        config.cql_n_actions = 10
        config.cql_importance_sample = True
        config.cql_lagrange = True
        config.cql_target_action_gap = 0.2
        config.cql_temp = 1.0
        config.cql_min_q_weight = 5.0
        config.cql_max_target_backup = True
        config.cql_clip_diff_min = -200.
        config.cql_clip_diff_max = np.inf
        
        config.osd_alpha = 0.1
        config.lamb_scale = 1.0
        config.v_l2_reg = 1e-4
        config.osd_lower = 0.1
        config.osd_higher = 10.0
        config.beta = 1e-3

        if updates is not None:
            config.update(ConfigDict(updates).copy_and_resolve_references())
        return config

    def __init__(self, config, policy, qf1, qf2, target_qf1, target_qf2, nu_network):
        self.config = ConservativeSAC.get_default_config(config)
        self.policy = policy
        self.qf1 = qf1
        self.qf2 = qf2
        self.target_qf1 = target_qf1
        self.target_qf2 = target_qf2
        self.nu_network= nu_network
        self.osd_alpha = self.config.osd_alpha
        self.lamb_scale = self.config.lamb_scale
        self.beta = self.config.beta

        optimizer_class = {
            'adam': torch.optim.Adam,
            'sgd': torch.optim.SGD,
        }[self.config.optimizer_type]

        self.policy_optimizer = optimizer_class(
            self.policy.parameters(), self.config.policy_lr,
        )
        self.qf_optimizer = optimizer_class(
            list(self.qf1.parameters()) + list(self.qf2.parameters()), self.config.qf_lr
        )

        self.nu_optimizer = optimizer_class(
            self.nu_network.parameters(),  self.config.qf_lr
        )

        if self.config.use_automatic_entropy_tuning:
            self.log_alpha = Scalar(0.0)
            self.alpha_optimizer = optimizer_class(
                self.log_alpha.parameters(),
                lr=self.config.policy_lr,
            )
        else:
            self.log_alpha = None
            
        # self.lam_v= torch.zeros(1, requires_grad=True)
        # self.lam_v_optimizer = optimizer_class(
        #     [self.lam_v],
        #     lr=self.config.qf_lr)
            
        self.lam_v= Scalar(0.0)
        self.lam_v_optimizer = optimizer_class(
                self.lam_v.parameters(),
                lr=self.config.qf_lr,
            )
        

        
  

        if self.config.cql_lagrange:
            self.log_alpha_prime = Scalar(1.0)
            self.alpha_prime_optimizer = optimizer_class(
                self.log_alpha_prime.parameters(),
                lr=self.config.qf_lr,
            )

        self.update_target_network(1.0)
        self._total_steps = 0
        
        self._f_fn = lambda x: torch.where(x < 1, x * (torch.log(x + 1e-10) - 1) + 1, 0.5 * (x - 1) ** 2)
        self.zero = torch.zeros(1)
        self._f_prime_inv_fn = lambda x: torch.where(x < 0, torch.exp(torch.minimum(x, self.zero.to(x.device))), x + 1)
        self._g_fn = lambda x: torch.where(x < 0, torch.exp(torch.minimum(x, self.zero.to(x.device))) * (torch.minimum(x, self.zero) - 1) + 1, 0.5 * x ** 2)
        self._r_fn = lambda x: self._f_prime_inv_fn(x)
        self._log_r_fn = lambda x: torch.where(x < 0, x, torch.log(torch.maximum(x, self.zero.to(x.device)) + 1))

    def update_target_network(self, soft_target_update_rate):
        soft_target_update(self.qf1, self.target_qf1, soft_target_update_rate)
        soft_target_update(self.qf2, self.target_qf2, soft_target_update_rate)

    def orthogonal_regularization(self, network):
        reg = 0
        # for layer in network.layers:
        #   if isinstance(layer, tf.keras.layers.Dense):
        #     prod = tf.matmul(tf.transpose(layer.kernel), layer.kernel)
        #     reg += tf.reduce_sum(tf.math.square(prod * (1 - tf.eye(prod.shape[0]))))
        for k,v in network.named_parameters():
            # print(k,v.shape)
            if 'weight' in k:
                # print(k,v.shape)
                prod = torch.mm(v, v.T)
                reg += (torch.square(prod * (1 - torch.eye(prod.shape[0],device=v.device)))).sum()
        #print('reg',reg.shape)
        
        return reg

        
    def train_osd(self, batch):
        self._total_steps += 1
        observations = batch['observations']
        actions = batch['actions']
        rewards = (batch['rewards']+5)/100.
        next_observations = batch['next_observations']
        dones = batch['dones']      
        init_observations =  batch['init_observations']

        initial_v_values = self.nu_network(init_observations)
        v_values = self.nu_network(observations)
        next_v_values = self.nu_network(next_observations)
        e_v = rewards + (1 - dones) * self.config.discount * next_v_values - v_values
        self.lam_v = self.lam_v.to(dones.device)
        self.zero= self.zero.to(dones.device)

        preactivation_v = (e_v - self.lamb_scale * self.lam_v()) / self.osd_alpha
        w_v = self._r_fn(preactivation_v)
        f_w_v = self._g_fn(preactivation_v)        
        
        v_loss0 = ((1 - self.config.discount) * initial_v_values).mean()
        v_loss1 = - self.osd_alpha * f_w_v.mean()

        v_loss2 = (w_v * (e_v - self.lam_v())).mean()
        v_loss3 = self.lam_v()
        v_loss4 = torch.square(e_v).mean()

        v_loss = v_loss0 + v_loss1 + v_loss2 + v_loss3 + self.beta *v_loss4
        v_l2_norm = self.orthogonal_regularization(self.nu_network)
        v_loss = v_loss + self.config.v_l2_reg * v_l2_norm
        
        lamb_v_loss = (- self.osd_alpha * f_w_v.detach() + w_v.detach() * (e_v.detach() - self.config.lamb_scale * self.lam_v()) + self.beta * torch.square(e_v) + self.lam_v()).mean()
        # lamb_v_loss = (- self.osd_alpha * f_w_v + w_v * (e_v - self.config.lamb_scale * self.lam_v()) + self.lam_v()).mean()
        
        self.nu_optimizer.zero_grad()
        v_loss.backward(retain_graph=True)
        self.nu_optimizer.step()


        self.lam_v_optimizer.zero_grad()
        lamb_v_loss.backward(retain_graph=True)
        self.lam_v_optimizer.step()

        
        metrics = dict(
            v_loss0=v_loss0.item(),
            v_loss1=v_loss1.item(),
            v_loss2=v_loss2.item(),
            v_loss3=v_loss3.item(),
            v_loss=v_loss.item(),
            v_l2_norm=v_l2_norm.item(),
            lamb_v_loss=lamb_v_loss.item(),
            lam_v=self.lam_v().item(),
            w_v_mean=w_v.mean().item(),
            w_v_min=w_v.min().item(),
            w_v_max=w_v.max().item(),
            total_steps=self.total_steps,
        )
        return metrics

    def train(self, batch, bc=False):
        self._total_steps += 1

        observations = batch['observations']
        actions = batch['actions']
        rewards = batch['rewards']
        next_observations = batch['next_observations']
        dones = batch['dones']

        """OSD weight"""
        # random_matrix = torch.rand_like(rewards)
        # weights= torch.mul(random_matrix, 9.9) + 0.1
        
        v_values = self.nu_network(observations)
        next_v_values = self.nu_network(next_observations)
        e_v = rewards + (1 - dones) * self.config.discount * next_v_values - v_values
        preactivation_v = (e_v -  self.lam_v()) / self.osd_alpha
        w_v = self._r_fn(preactivation_v)
        medium = w_v.quantile(q=0.5)
        weights = torch.clamp(w_v - medium + 1., self.config.osd_lower, self.config.osd_higher).detach()
        # weights = torch.clamp(w_v, self.config.osd_lower, self.config.osd_higher).detach()
        
        new_actions, log_pi = self.policy(observations)

        if self.config.use_automatic_entropy_tuning:
            alpha_loss = -((self.log_alpha() * (log_pi + self.config.target_entropy).detach()) * weights).mean()/weights.mean()
            alpha = self.log_alpha().exp() * self.config.alpha_multiplier
        else:
            alpha_loss = observations.new_tensor(0.0)
            alpha = observations.new_tensor(self.config.alpha_multiplier)
            
            


        """ Policy loss """
        if bc:
            log_probs = self.policy.log_prob(observations, actions)
            policy_loss = (alpha*log_pi - log_probs).mean()
        else:
            q_new_actions = torch.min(
                self.qf1(observations, new_actions),
                self.qf2(observations, new_actions),
            )
            policy_loss = ((alpha*log_pi - q_new_actions) * weights).mean()/weights.mean()

        """ Q function loss """
        q1_pred = self.qf1(observations, actions)
        q2_pred = self.qf2(observations, actions)

        if self.config.cql_max_target_backup:
            new_next_actions, next_log_pi = self.policy(next_observations, repeat=self.config.cql_n_actions)
            target_q_values, max_target_indices = torch.max(
                torch.min(
                    self.target_qf1(next_observations, new_next_actions),
                    self.target_qf2(next_observations, new_next_actions),
                ),
                dim=-1
            )
            next_log_pi = torch.gather(next_log_pi, -1, max_target_indices.unsqueeze(-1)).squeeze(-1)
        else:
            new_next_actions, next_log_pi = self.policy(next_observations)
            target_q_values = torch.min(
                self.target_qf1(next_observations, new_next_actions),
                self.target_qf2(next_observations, new_next_actions),
            )

        if self.config.backup_entropy:
            target_q_values = target_q_values - alpha * next_log_pi

        td_target = rewards + (1. - dones) * self.config.discount * target_q_values
        qf1_loss =  ((q1_pred - td_target).pow(2) * weights).mean()/(weights.mean())
        qf2_loss = ((q2_pred - td_target).pow(2) * weights).mean()/(weights.mean())


        ### CQL
        if not self.config.use_cql:
            qf_loss = qf1_loss + qf2_loss
        else:
            batch_size = actions.shape[0]
            action_dim = actions.shape[-1]
            cql_random_actions = actions.new_empty((batch_size, self.config.cql_n_actions, action_dim), requires_grad=False).uniform_(-1, 1)
            cql_current_actions, cql_current_log_pis = self.policy(observations, repeat=self.config.cql_n_actions)
            cql_next_actions, cql_next_log_pis = self.policy(next_observations, repeat=self.config.cql_n_actions)
            cql_current_actions, cql_current_log_pis = cql_current_actions.detach(), cql_current_log_pis.detach()
            cql_next_actions, cql_next_log_pis = cql_next_actions.detach(), cql_next_log_pis.detach()

            cql_q1_rand = self.qf1(observations, cql_random_actions)
            cql_q2_rand = self.qf2(observations, cql_random_actions)
            cql_q1_current_actions = self.qf1(observations, cql_current_actions)
            cql_q2_current_actions = self.qf2(observations, cql_current_actions)
            cql_q1_next_actions = self.qf1(observations, cql_next_actions)
            cql_q2_next_actions = self.qf2(observations, cql_next_actions)

            cql_cat_q1 = torch.cat(
                [cql_q1_rand, torch.unsqueeze(q1_pred, 1), cql_q1_next_actions, cql_q1_current_actions], dim=1
            )
            cql_cat_q2 = torch.cat(
                [cql_q2_rand, torch.unsqueeze(q2_pred, 1), cql_q2_next_actions, cql_q2_current_actions], dim=1
            )
            cql_std_q1 = torch.std(cql_cat_q1, dim=1)
            cql_std_q2 = torch.std(cql_cat_q2, dim=1)

            if self.config.cql_importance_sample:
                random_density = np.log(0.5 ** action_dim)
                cql_cat_q1 = torch.cat(
                    [cql_q1_rand - random_density,
                     cql_q1_next_actions - cql_next_log_pis.detach(),
                     cql_q1_current_actions - cql_current_log_pis.detach()],
                    dim=1
                )
                cql_cat_q2 = torch.cat(
                    [cql_q2_rand - random_density,
                     cql_q2_next_actions - cql_next_log_pis.detach(),
                     cql_q2_current_actions - cql_current_log_pis.detach()],
                    dim=1
                )

            cql_qf1_ood = torch.logsumexp(cql_cat_q1 / self.config.cql_temp, dim=1) * self.config.cql_temp
            cql_qf2_ood = torch.logsumexp(cql_cat_q2 / self.config.cql_temp, dim=1) * self.config.cql_temp

            """Subtract the log likelihood of data"""
            cql_qf1_diff = (torch.clamp(
                cql_qf1_ood - q1_pred,
                self.config.cql_clip_diff_min,
                self.config.cql_clip_diff_max,
            ) * weights).mean()/weights.mean()
            cql_qf2_diff = (torch.clamp(
                cql_qf2_ood - q2_pred,
                self.config.cql_clip_diff_min,
                self.config.cql_clip_diff_max,
            )*  weights).mean()/weights.mean()

            if self.config.cql_lagrange:
                alpha_prime = torch.clamp(torch.exp(self.log_alpha_prime()), min=0.0, max=1000000.0)
                cql_min_qf1_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf1_diff - self.config.cql_target_action_gap)
                cql_min_qf2_loss = alpha_prime * self.config.cql_min_q_weight * (cql_qf2_diff - self.config.cql_target_action_gap)

                self.alpha_prime_optimizer.zero_grad()
                alpha_prime_loss = (-cql_min_qf1_loss - cql_min_qf2_loss)*0.5
                alpha_prime_loss.backward(retain_graph=True)
                self.alpha_prime_optimizer.step()
            else:
                cql_min_qf1_loss = cql_qf1_diff * self.config.cql_min_q_weight
                cql_min_qf2_loss = cql_qf2_diff * self.config.cql_min_q_weight
                alpha_prime_loss = observations.new_tensor(0.0)
                alpha_prime = observations.new_tensor(0.0)


            qf_loss = qf1_loss + qf2_loss + cql_min_qf1_loss + cql_min_qf2_loss


        if self.config.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()

        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        self.qf_optimizer.zero_grad()
        qf_loss.backward()
        self.qf_optimizer.step()

        if self.total_steps % self.config.target_update_period == 0:
            self.update_target_network(
                self.config.soft_target_update_rate
            )


        metrics = dict(
            log_pi=log_pi.mean().item(),
            policy_loss=policy_loss.item(),
            qf1_loss=qf1_loss.item(),
            qf2_loss=qf2_loss.item(),
            alpha_loss=alpha_loss.item(),
            alpha=alpha.item(),
            average_qf1=q1_pred.mean().item(),
            average_qf2=q2_pred.mean().item(),
            average_target_q=target_q_values.mean().item(),
            total_steps=self.total_steps,
        )

        if self.config.use_cql:
            metrics.update(prefix_metrics(dict(
                cql_std_q1=cql_std_q1.mean().item(),
                cql_std_q2=cql_std_q2.mean().item(),
                cql_q1_rand=cql_q1_rand.mean().item(),
                cql_q2_rand=cql_q2_rand.mean().item(),
                cql_min_qf1_loss=cql_min_qf1_loss.mean().item(),
                cql_min_qf2_loss=cql_min_qf2_loss.mean().item(),
                cql_qf1_diff=cql_qf1_diff.mean().item(),
                cql_qf2_diff=cql_qf2_diff.mean().item(),
                cql_q1_current_actions=cql_q1_current_actions.mean().item(),
                cql_q2_current_actions=cql_q2_current_actions.mean().item(),
                cql_q1_next_actions=cql_q1_next_actions.mean().item(),
                cql_q2_next_actions=cql_q2_next_actions.mean().item(),
                alpha_prime_loss=alpha_prime_loss.item(),
                alpha_prime=alpha_prime.item(),
            ), 'cql'))

        return metrics

    def torch_to_device(self, device):
        for module in self.modules:
            module.to(device)
        self.lam_v = self.lam_v.to(device)


    @property
    def modules(self):
        modules = [self.policy, self.qf1, self.qf2, self.target_qf1, self.target_qf2, self.nu_network, self.lam_v]
        if self.config.use_automatic_entropy_tuning:
            modules.append(self.log_alpha)
        if self.config.cql_lagrange:
            modules.append(self.log_alpha_prime)
        # modules.append(self.lamb_scale)
        # modules.append(self.osd_alpha)
        return modules

    @property
    def total_steps(self):
        return self._total_steps
