from BaseAgent import BaseAgent
import numpy as np
import copy


class GreedyGQ(BaseAgent):
    
    def __init__(self,env,config):
        
        super().__init__(env)
        
        self.env = env

        self.gamma = config['gamma']
        self.beta = config["beta"]
        self.alpha = config["alpha"]
        self.eta = config['eta']
        
        self.updates =0             
        self.weights_theta = None
        self.weights_w = None

        self.prev_weights_theta = None
        self.prev_weights_w = None
        
        self.init_weight(self.env.env_name)

    def primal_weight(self):
        return self.weights_theta
    
    def dual_weight(self):
        return self.weights_w
        
        
    def init_weight(self,weight_initializer):
        init_type = {'Baird':self.baird_weight,'ThetaTwoTheta':self.theta_two_theta}
        init_type[weight_initializer]()
        
    def theta_two_theta(self):
        self.weights_theta = np.ones(self.num_features)
        self.weights_w = np.ones(self.num_features)
        
    def baird_weight(self):   
        
        self.weights_theta = np.ones(self.num_features)
        self.weights_theta[self.env.SEVENTH_STATE]= 10
        self.weights_w = np.zeros(self.num_features)
   
    def td_error(self,state,action,next_state,reward,done_mask):


        phi = self.features[action,state]
        q_sa_theta = self.action_value(state,action,self.prev_weights_theta)
        next_action = self.greedy_policy(next_state,self.prev_weights_theta)
        next_q_sa_theta = self.action_value(next_state,next_action,self.prev_weights_theta)

        td_error = reward + done_mask*self.gamma * next_q_sa_theta - q_sa_theta   
        
        return td_error

    def update_theta(self,state,action,next_state,reward,done_mask):
        
        td_error = self.td_error(state,action,next_state,reward,done_mask)
        phi=  self.features[action,state]
        
        next_action = self.greedy_policy(next_state,self.prev_weights_theta)
        next_phi = self.features[next_action,next_state]
        phi_w = np.einsum('i,i->',phi,self.prev_weights_w) 
        
        gradient = td_error * phi - done_mask * self.gamma * phi_w * next_phi
        
        #self.writer.add_scalar('theta update semi-gradient',max(abs(gradient)),self.updates)
        #self.writer.add_scalar('lr',self.lr_alpha.lr,self.updates)
        
        self.weights_theta = self.prev_weights_theta + self.alpha * gradient
        
        
    def update_w(self,state,action,next_state,reward,done_mask):
        
        phi = self.features[action,state]
        q_sa_theta = self.action_value(state,action,self.prev_weights_theta)
        next_action = self.greedy_policy(next_state,self.prev_weights_theta)
        next_q_sa_theta = self.action_value(next_state,next_action,self.prev_weights_theta)

        td_error = reward + done_mask*self.gamma * next_q_sa_theta - q_sa_theta   
        
        phi_w = np.einsum('i,i->',phi,self.prev_weights_w)
        
        self.weights_w = self.prev_weights_w + self.beta*(td_error- self.eta*phi_w)*phi
        
    def update(self,state,next_state,action,reward,done_mask):
        self.prev_weights_theta = copy.deepcopy(self.weights_theta)
        self.prev_weights_w= copy.deepcopy(self.weights_w)
        
        self.update_theta(state,action,next_state,reward,done_mask)
        self.update_w(state,action,next_state,reward,done_mask)

        self.updates += 1