from .base_critic import BaseCritic
import torch
import torch.optim as optim
from torch import nn
from copy import deepcopy
import torch.nn.functional as F
import numpy as np
import math

from  infreastructure import pytorch_util as ptu

#SAC中的alpha大，鼓励探索，alpha小的时候，使得policy更稳定。所以将alpha作为参数训练比较好。
class CQLCritic(BaseCritic):
    def __init__(self,hparams,optimizer_class,config,q1,q2,q1_target,q2_target,v_net,beta,**kwargs):
        super().__init__(**kwargs)
        self.ob_dim = hparams['ob_dim']
        self.ac_dim = hparams['ac_dim']
        self.arch = hparams['q_arch']
        self.orthogonal_init = hparams['orthogonal_init']
        self.q_lr = hparams['q_lr']
        '''
        self.q1_net = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, self.arch, self.orthogonal_init,config.dropout_rate)
        self.q2_net = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, self.arch, self.orthogonal_init,config.dropout_rate)
        self.q1_net_target = deepcopy(self.q1_net)
        self.q2_net_target = deepcopy(self.q2_net)
        '''
        self.q1_net = q1
        self.q2_net = q2
        self.q1_net_target = q1_target
        self.q2_net_target = q2_target
        self.config = config
        #self.policy = policy
        self.q1_net.to(hparams['device'])
        self.q2_net.to(hparams['device'])
        self.q1_net_target.to(hparams['device'])
        self.q2_net_target.to(hparams['device'])
        self.q_optimizer = optimizer_class(list(self.q1_net.parameters()) + list(self.q2_net.parameters()), self.q_lr)

        self.v_net = v_net#ptu.FullConnecteNetwork(self.ob_dim, 1, self.arch, self.orthogonal_init)
        self.v_net.to(hparams['device'])
        self.expectile = self.config.expectile
        self.v_optimizer = optimizer_class(self.v_net.parameters(), self.config.v_lr)

        self.beta = beta

        '''
        self.q1_net_0 = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, self.arch, self.orthogonal_init)
        self.q2_net_0 = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, self.arch, self.orthogonal_init)
        self.q1_net_0_target = deepcopy(self.q1_net_0)
        self.q2_net_0_target = deepcopy(self.q2_net_0)
        self.q1_net_0.to(hparams['device'])
        self.q2_net_0.to(hparams['device'])
        self.q1_net_0_target.to(hparams['device'])
        self.q2_net_0_target.to(hparams['device'])
        self.q_optimizer_0 = optimizer_class(list(self.q1_net_0.parameters()) + list(self.q2_net_0.parameters()),
                                             self.q_lr)

        self.v_net_0 = ptu.FullConnecteNetwork(self.ob_dim, 1, self.arch, self.orthogonal_init)
        self.v_net_0.to(hparams['device'])
        self.v_optimizer_0 = optimizer_class(self.v_net_0.parameters(), self.config.v_lr)
        
        
        self.zeta = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, '64', self.orthogonal_init)
        self.nu = ptu.FullConnectedQFunction(self.ob_dim, self.ac_dim, '64', self.orthogonal_init)
        self.zeta.to(hparams['device'])
        self.nu.to(hparams['device'])

        self.weight_opt = optimizer_class(list(self.zeta.parameters()) + list(self.nu.parameters()),self.q_lr)
        self.p = config.p
        self.gamma = config.gamma
        '''

        if self.config.use_automatic_entropy_tuning:
            self.log_alpha = ptu.Scalar(0.0)
            self.alpha_optimizer = optimizer_class(
                self.log_alpha.parameters(),
                lr=self.config.policy_lr,
            )
        else:
            self.log_alpha = None

        self.update_target_network(1.0)

    #on-policy的不需要重新计算new_ac和log_pi，直接使用policy交互然后critic计算Q值V值即可。而off-policy和offline的需要计算new_ac和log_pi，因为采样的ac不是当前policy提出的ac
    def update(self, ob_no, ac_na, next_ob_no, re_n, terminal_n,new_ac,log_pi):
        #zeta_loss = self.update_weight(ob_no, ac_na, next_ob_no, init_obs)
        #new_ac, log_pi = self.policy.get_action(ob_no)
        #new_ac, log_pi = self.policy.get_action(ob_no,deterministic=True)

        if self.config.use_automatic_entropy_tuning:
            alpha_loss = -(self.log_alpha() * (log_pi + self.config.target_entropy).detach()).mean()
            alpha = self.log_alpha().exp() * self.config.alpha_multiplier
        else:
            alpha_loss = ob_no.new_tensor(0.0)
            alpha = ob_no.new_tensor(self.config.alpha_multiplier)

        with torch.no_grad():
            q_target_values = torch.min(self.q1_net_target(ob_no, ac_na), self.q2_net_target(ob_no, ac_na))

        v_values = self.v_net(ob_no).squeeze()
        v_loss = self.loss(q_target_values - v_values, self.expectile).mean()
        self.v_optimizer.zero_grad()
        v_loss.backward()
        self.v_optimizer.step()

        if self.beta != 0:
            random_q1, random_q2 = self.q1_net(ob_no, new_ac), self.q2_net(ob_no, new_ac)
            random_q = torch.min(random_q1, random_q2)

        with torch.no_grad():
            next_v_values = self.v_net(next_ob_no).squeeze()
            q_target = re_n + (1 - terminal_n) * self.config.discount * next_v_values
        q1_values = self.q1_net(ob_no, ac_na)
        q2_values = self.q2_net(ob_no, ac_na)

        q1_loss = F.mse_loss(q1_values, q_target)
        q2_loss = F.mse_loss(q2_values, q_target)
        if self.beta == 0:
            q_loss = q1_loss + q2_loss
        else:
            q1_ood_loss = random_q1.mean() * self.beta#(random_q1 - v_values.detach()).mean() * self.beta#random_q1.mean() * self.beta
            q2_ood_loss = random_q2.mean() * self.beta#(random_q2 - v_values.detach()).mean() * self.beta#random_q2.mean() * self.beta
            q_loss = q1_loss + q2_loss + q1_ood_loss + q2_ood_loss

        self.q_optimizer.zero_grad()
        q_loss.backward()
        self.q_optimizer.step()

        q_values = torch.min(q1_values, q2_values)


        '''
        with torch.no_grad():
            q_target_values_0 = torch.min(self.q1_net_0_target(ob_no, ac_na), self.q2_net_0_target(ob_no, ac_na))

        v_values_0 = self.v_net_0(ob_no).squeeze()
        v_loss_0 = self.loss(q_target_values_0 - v_values_0, self.expectile).mean()
        self.v_optimizer_0.zero_grad()
        v_loss_0.backward()
        self.v_optimizer_0.step()

        new_ac_0, log_pi_0 = self.policy.get_action(ob_no)
        random_q1_0,random_q2_0 = self.q1_net_0(ob_no,new_ac_0),self.q2_net_0(ob_no,new_ac_0)
        random_q_0 = torch.min(random_q1_0,random_q2_0)

        with torch.no_grad():
            next_v_values_0 = self.v_net_0(next_ob_no).squeeze()
            q_target_0 = re_n + (1 - terminal_n) * self.config.discount * next_v_values_0
        q1_values_0 = self.q1_net_0(ob_no, ac_na)
        q2_values_0 = self.q2_net_0(ob_no, ac_na)
        q1_loss_0 = F.mse_loss(q1_values_0, q_target_0)
        q2_loss_0 = F.mse_loss(q2_values_0, q_target_0)
        q1_ood_loss_0 = (random_q1_0 - v_values_0.detach()).mean() * 1
        q2_ood_loss_0 = (random_q2_0 - v_values_0.detach()).mean() * 1
        q_loss_0 = q1_loss_0 + q2_loss_0 + q1_ood_loss_0 + q2_ood_loss_0

        self.q_optimizer_0.zero_grad()
        q_loss_0.backward()
        self.q_optimizer_0.step()

        
        q_values_0 = torch.min(q1_values_0,q2_values_0)

        a = random_q - v_values
        a_0 = random_q_0 - v_values_0
        
        a_d = a_0.max() - a.max()
        
        
        init_ac, _ = self.policy.get_action(init_obs)
        zeta = self.zeta(ob_no, ac_na)
        f_zeta = zeta ** self.p / self.p
        B_nu = self.nu(ob_no, ac_na) - self.gamma * self.nu(next_ob_no, next_ac.detach())
        zeta_loss = - torch.sum(
            (B_nu * zeta - f_zeta) - (1 - self.gamma) * self.nu(init_obs, init_ac.detach()))
        nu_loss = -zeta_loss
        '''



        if self.config.use_automatic_entropy_tuning:
            self.alpha_optimizer.zero_grad()
            alpha_loss.backward()
            self.alpha_optimizer.step()



        #with torch.no_grad():
            #q_values = torch.min(self.q1_net_target(ob_no, ac_na), self.q2_net_target(ob_no, ac_na))
            #v_value = self.v_net(ob_no).squeeze()
        #policy_loss = self.policy.update(v_value, q_values, ob_no, ac_na)


        #new_ac, log_pi = self.policy.get_action(ob_no)
        #q_new_values = torch.min(self.q1_net_target(ob_no,new_ac),self.q2_net_target(ob_no,new_ac))
        #policy_loss = self.policy.update(alpha,q_new_values,log_pi)


        log = dict(
            #log_pi=log_pi.mean().item(),
            #policy_loss=policy_loss,
            q1_loss=q1_loss.item(),
            q2_loss=q2_loss.item(),
            alpha_loss=alpha_loss.item(),
            alpha=alpha.item(),
            average_q1=q1_values.mean().item(),
            average_q2=q2_values.mean().item(),
            average_reward=q_values.mean().item(),
            max_reward=q_target_values.max().item(),
            min_reward=q_target_values.min().item(),
            v_values=v_values.mean().item(),
            r=re_n[0].item(),
            v0_values=v_values[0].item(),
            next_v0_values=next_v_values[0].item(),

            q0_values=q_values[0].item(),
            #a_max=a.max().item(),
            #a_0_max=a_0.max().item(),
            #a_d_max=a_d.item(),
            #beta0_v0=v_values_0[0].item(),
            #beta0_random_q1_0=random_q1_0[0].item(),
            #beta0_random_q2_0=random_q2_0[0].item(),
            #beta0_q0_values=q_values_0[0].item(),
        )
        '''
        if self.beta != 0:
            log['rand_q_max'] = random_q.max().item(),
            log['rand_q_min'] = random_q.min().item(),
            log['rand_q_mean'] = random_q.mean().item(),
            log['random_q1_0'] = random_q1[0].item(),
            log['random_q2_0'] = random_q2[0].item(),
        '''
        return log,alpha

    '''
        def update_weight(self,ob_n,ac_n,next_ob,init_obs):
        init_ac,_ = self.policy.get_action(init_obs)
        next_ac,_ = self.policy.get_action(next_ob)
        f_zeta = self.zeta(ob_n,ac_n) ** self.p / self.p
        B_nu = self.nu(ob_n,ac_n) - self.gamma * self.nu(next_ob,next_ac)
        zeta_loss = - torch.sum((B_nu * self.zeta(ob_n,ac_n) - f_zeta) - (1 - self.gamma) * self.nu(init_obs,init_ac))
        nu_loss = -zeta_loss
        loss = zeta_loss + nu_loss
        self.weight_opt.zero_grad()
        loss.backward()
        self.weight_opt.step()
        return zeta_loss
    '''

    def update_target_network(self, soft_target_update_rate):
        ptu.soft_target_updata(self.q1_net, self.q1_net_target, soft_target_update_rate)
        ptu.soft_target_updata(self.q2_net, self.q2_net_target, soft_target_update_rate)
        #ptu.soft_target_updata(self.q1_net_0, self.q1_net_0_target, soft_target_update_rate)
        #ptu.soft_target_updata(self.q2_net_0, self.q2_net_0_target, soft_target_update_rate)

    def loss(self,value,expectile=0.8):
        weight = torch.where(value > 0 ,expectile,(1 - expectile))
        return weight * (value ** 2)

    def change_beta(self,beta):
        self.beta = beta

    def get_q(self,obs,acs):
        return torch.min(self.q1_net(obs,acs),self.q2_net(obs,acs))