from BaseAgent import BaseAgent
import numpy as np
import copy

class CoupledQ(BaseAgent):
    def __init__(self, env,config):

        super().__init__(env)

        self.env = env
        self.gamma = config["gamma"]
        self.beta = config["beta"]
        self.alpha = config["alpha"]

        self.weights_u = None
        self.weights_v = None

        self.perv_weights_u = None
        self.prev_weights_v = None

        self.init_weight(self.env.env_name)

    def primal_weight(self):
        return self.weights_v

    def dual_weight(self):
        return self.weights_u

    def baird_weight(self):

        self.weights_u = np.ones(self.num_features)
        self.weights_v = np.ones(self.num_features)
#np.random.uniform(0,10,self.num_features)

        self.weights_v[self.env.SEVENTH_STATE] = 10
        #self.weights_v[self.env.SEVENTH_STATE] = 10

    def theta_two_theta(self):
        self.weights_u = np.ones(self.num_features)
        self.weights_v = np.ones(self.num_features)

    def init_weight(self, weight_initializer_key):
        init_type = {"Baird": self.baird_weight, "ThetaTwoTheta": self.theta_two_theta}
        init_type[weight_initializer_key]()

    def update_u(self, state, next_state, action, reward, done_mask):

        q_sa_u = self.action_value(state, action, self.prev_weights_v)
        conditiond_error = self.features[action, state] * q_sa_u - self.prev_weights_u

        self.weights_u = self.prev_weights_u + self.alpha * (conditiond_error)

    def update_v(self, state, next_state, action, reward, done_mask):

        q_sa_v = self.action_value(state, action, self.prev_weights_v)
        next_action = self.greedy_policy(next_state, self.prev_weights_u)
        next_q_sa_u = self.action_value(next_state, next_action, self.prev_weights_u)

        td_error = reward + done_mask * self.gamma * next_q_sa_u - q_sa_v

        semi_gradient = self.features[action, state]

        self.weights_v = self.prev_weights_v + self.beta * (td_error) * semi_gradient

    def update(self, state, next_state, action, reward, done_mask):
        
        self.prev_weights_u = copy.deepcopy(self.weights_u)
        self.prev_weights_v = copy.deepcopy(self.weights_v)

        self.update_u(state, next_state, action, reward, done_mask)
        self.update_v(state, next_state, action, reward, done_mask)
