import numpy as np
import scipy
import scipy.special
import torch


class off_policy_surrogate:
    def __init__(
        self, policy,qfn,cfn,observations, affinity, eps, budget,kappa, act, N_samples=32, log_reweight=None
    ):
        self.budget=budget
        with torch.no_grad():
            self.obs_supersample = torch.cat([observations]*N_samples,0)
            self.N_samples=N_samples
            self.act = act
            self.logp, ent = policy(self.obs_supersample, self.act)
            #self.act,self.logp,_ = policy.get_action(self.obs_supersample)
            self.logp=self.logp.reshape(N_samples,-1).numpy()
            self.q_fn = qfn(self.obs_supersample,self.act).reshape(N_samples,-1).numpy()
            #self.q_fn = (self.q_fn - self.q_fn.mean())/(self.q_fn.std()+1e-5)
            self.c_fn = cfn(self.obs_supersample,self.act).reshape(N_samples,-1).numpy()
        
        self.affinity = affinity
        # likelihood of constraint
        self.g = np.log(self.affinity) - np.maximum(self.c_fn-self.budget, 0)
        self.eps = eps
        self.kappa=kappa
        if log_reweight is None:
            self.log_reweight=np.zeros_like(self.logp)
        else:
            self.log_reweight=log_reweight
    
    def __call__(self, l):
        eta, lmbda = l

        logq = self.get_q(eta,lmbda)
        log_pi = self.log_reweight+self.logp
        norm= scipy.special.logsumexp(log_pi,0)
        normalized_pi = log_pi - norm
        kappa = self.kappa
        x = self.get_x(eta,lmbda)
        k = kappa/lmbda

        #fst = np.mean(self.pi*np.exp((self.q_fn + (1-lmbda)*self.g)/eta))
        #snd = np.mean(lmbda*self.pi*self.g)
        #trd = eta*self.eps
        fst = np.mean(np.sum(np.exp(logq)*self.q_fn,0))
        gap = -kappa*np.log(lmbda)
        snd = eta*self.eps
        trd = -eta*np.mean(np.sum(np.exp(logq)*(logq - normalized_pi),0))
        fth = lmbda*(x-k)
        #print(fst,snd,trd,fth,eta,lmbda, fth/eta)
        return fst+gap+snd+trd +fth

    def get_x(self, eta, lmbda):
        logq = self.get_q(eta,lmbda)
        log_pi = self.log_reweight+self.logp
        norm= scipy.special.logsumexp(log_pi,0)
        normalized_pi = log_pi - norm

        #print(normalized_pi.shape,self.g.shape)
        g = self.g
        #print("safety old",np.sum(np.exp(self.log_reweight+normalized_pi+g)))
        #print("safety new",np.sum(np.exp(self.log_reweight+logq+g)))
        safe_q = logq
        #safe_q = safe_q- scipy.special.logsumexp(safe_q)
        safe_p = normalized_pi
        #safe_p = safe_p - scipy.special.logsumexp(safe_p)
        #print(np.exp(safe_q),np.exp(safe_p))
        x = np.mean(np.sum(np.exp(safe_q)*g,0)) - np.mean(np.sum(np.exp(safe_p)*g,0))
        #x = np.exp(scipy.special.logsumexp(logq + self.g)) - np.exp(scipy.special.logsumexp(normalized_pi + self.g) )
        return x

    def get_q(self, eta, lmbda,  log_pi=None,q_fn=None,g=None,log_reweight=None):
        if log_pi is None:
            log_pi = self.logp
        if q_fn is None:
            q_fn = self.q_fn
        if g is None:
            g=self.g
        unnormalized_q = log_pi + (q_fn  + lmbda*g)/eta
        normalizer = scipy.special.logsumexp(unnormalized_q,0)
        return unnormalized_q - normalizer
  
    def get_kl_div(self,eta, lmbda,):
        normalized_q = self.get_q(eta, lmbda)
        #logr =normalized_q - self.pi
        #return np.mean(np.exp(logr)*logr)
        norm= scipy.special.logsumexp(self.logp,0)
        normalized_pi = self.logp - norm
        kl_div = np.exp(normalized_q) * (normalized_q - normalized_pi)
        return np.mean(np.sum(kl_div,0))

    def get_constraint_diff(self,eta, lmbda):
        logq = self.get_q(eta,lmbda)
        log_pi = self.log_reweight+self.logp
        norm= scipy.special.logsumexp(log_pi,0)
        normalized_pi = log_pi - norm
        #print(normalized_pi.shape,self.g.shape)
        g = self.g
        #print("safety old",np.sum(np.exp(self.log_reweight+normalized_pi+g)))
        #print("safety new",np.sum(np.exp(self.log_reweight+logq+g)))
        safe_q = self.log_reweight+logq
        #safe_q = safe_q- scipy.special.logsumexp(safe_q)
        safe_p = self.log_reweight+normalized_pi
        #safe_p = safe_p - scipy.special.logsumexp(safe_p)
        #print(np.exp(safe_q),np.exp(safe_p))
        return np.mean(np.sum(np.exp(safe_q)*g,0)), np.mean(np.sum(np.exp(safe_p)*g,0))
    
    def expected_returns(self,eta, lmbda,):
        logq = self.get_q(eta, lmbda)
        norm= scipy.special.logsumexp(self.logp,0)
        normalized_pi = self.logp - norm
        safe_q = logq
        safe_p = normalized_pi
        return np.mean(np.sum(np.exp(safe_q)*self.q_fn,0)) , np.mean(np.sum(np.exp(safe_p)*self.q_fn,0))
