import numpy as np
from infreastructure import pytorch_util as ptu

class SACPolicy(object):
    def __init__(self,param,optimizer_class,policy,dropout=None):
        self.ob_dim = param['ob_dim']
        self.ac_dim = param['ac_dim']
        self.arch = param['policy_arch']
        self.log_std_multiplier = param['policy_log_std_multiplier']
        self.log_std_offset = param['policy_log_std_offset']
        self.orthogonal_init = param['orthogonal_init']
        self.po_rl = param['po_rl']
        #self._policy = ptu.TanhGaussianPolicy(self.ob_dim,self.ac_dim,self.arch,self.log_std_multiplier,
                                             #self.log_std_offset,self.orthogonal_init,dropout)
        self._policy = policy
        self.policy_optimizer = optimizer_class(self._policy.parameters(),self.po_rl)
        self._policy.to(param['device'])

    def get_action(self,ob,deterministic=False,repeat=None):
        action,log_pi = self._policy(ob,deterministic,repeat)
        return action,log_pi

    def get_policy(self):
        return self._policy

    def get_policy_log(self,ob,ac):
        return self._policy.log_prob(ob,ac)

    #SAC的策略更新，loss使用的是KL散度，计算Π和exp(q-log z)的距离。使策略尽可能接近exp(q-log z)  log z是常数，去掉不影响梯度。
    def update(self,alpha,q,log_pi):
        loss = (alpha * log_pi - q).mean()
        self.policy_optimizer.zero_grad()
        loss.backward(retain_graph=True)
        self.policy_optimizer.step()
        return  loss

    def save(self,path):
         self._policy.save(path)

    def load(self,path):
        self._policy.load(path)