import numpy as np
import math
import TwoPlayerGame
import pickle

class RobsutQlearning():
	def __init__(self,q1, q2, game_model, lr, gamma):
		self.gamma = gamma
		self.lr = lr
		self.state_noised = self.get_noised_state()
		self.q1 = q1
		self.q2 = q2
		self.policy = self.get_policy(self.q1, self.q2)
		self.reward_func = game_model.reward_func
		self.transition_prob = game_model.transition_prob
		self.v = self.get_v(self.q1, self.q2)


	def get_noised_state(self):
		state_noised = np.array([[0]*2]*2)
		for adv in range(2):
		    if adv == 0:
		        state_noised[1][adv] = 0
		        state_noised[0][adv] = 1
		    if adv == 1:
		        state_noised[1][adv] = 1
		        state_noised[0][adv] = 0
		return state_noised

	def get_policy(self, q1, q2):
	    pi1 = np.array([[0.0]*2]*2)
	    pi2 = np.array([[0.0]*2]*2)
	    rho1 = np.array([[0.0]*2]*2)
	    rho2 = np.array([[0.0]*2]*2)
	    for state in range(2):
	        for agent1 in range(2):
	            for agent2 in range(2):
	                for adv1 in range(2):
	                    for adv2 in range(2):
	                        pi1[state][agent1] = math.e**q1[state][agent1][:][:][:].sum()/(math.e**q1[state][agent1][:][:][:].sum() + math.e**q1[state][1-agent1][:][:][:].sum())
	                        pi2[state][agent2] = math.e**q2[state][:][agent2][:][:].sum()/(math.e**q2[state][:][agent2][:][:].sum() + math.e**q1[state][:][1-agent2][:][:].sum())
	                        rho1[state][adv1] = math.e**q1[state][:][:][adv1][:].sum()/(math.e**q1[state][:][:][adv1][:].sum() + math.e**q1[state][:][:][1-adv1][:].sum())
	                        rho2[state][adv2] = math.e**q2[state][:][:][:][adv2].sum()/(math.e**q2[state][:][:][:][adv2].sum() + math.e**q2[state][:][:][:][1-adv2].sum())
	    return pi1, pi2, rho1, rho2

	def get_total_sum(self, q1, q2, next_state):
	    temp1 = 0
	    temp2 = 0
	    for agent1 in range(2):
	        for agent2 in range(2):
	            for adv1 in range(2):
	                for adv2 in range(2):
	                    next_state_noised_1 = self.state_noised[next_state][adv1]
	                    next_state_noised_2 = self.state_noised[next_state][adv2]
	                    pi1, pi2, rho1, rho2 = self.get_policy(q1, q2)
	                    temp1 += q1[next_state][agent1][agent2][adv1][adv2] * pi1[next_state_noised_1][agent1] * pi2[next_state_noised_2][agent2] * rho1[next_state][adv1] * rho2[next_state][adv2]
	                    temp2 += q2[next_state][agent1][agent2][adv1][adv2] * pi1[next_state_noised_1][agent1] * pi2[next_state_noised_2][agent2] * rho1[next_state][adv1] * rho2[next_state][adv2]
	    return temp1, temp2

	def update(self, q1, q2):
	    q1_new = np.array([[[[[0.0]*2]*2]*2]*2]*2)
	    q2_new = np.array([[[[[0.0]*2]*2]*2]*2]*2)
	    for state in range(2):
	        for agent1 in range(2):
	            for agent2 in range(2):
	                for adv1 in range(2):
	                    for adv2 in range(2):
	                        next_s0 = self.transition_prob[state][agent1][agent2][adv1][adv2][0]
	                        next_s1 = self.transition_prob[state][agent1][agent2][adv1][adv2][1]
	                        if next_s1 >= next_s0:
	                            next_state = 1
	                        else:
	                            next_state = 0
	                        total_sum1, total_sum2 = self.get_total_sum(q1, q2, next_state)
	                        q1_new[state][agent1][agent2][adv1][adv2] = (1-self.lr)*q1[state][agent1][agent2][adv1][adv2] + self.lr*(self.reward_func[state][agent1][agent2][adv1][adv2] + self.gamma*total_sum1)
	                        q2_new[state][agent1][agent2][adv1][adv2] = (1-self.lr)*q2[state][agent1][agent2][adv1][adv2] + self.lr*(self.reward_func[state][agent1][agent2][adv1][adv2] + self.gamma*total_sum2)
	    return q1_new, q2_new

	def train(self, step):
		for i in range(step):
			self.q1, self.q2 = self.update(self.q1, self.q2)
	
	def get_v(self, q1, q2):
	    temp1_0, temp2_0 = self.get_total_sum(q1, q2, 0)
	    temp1_1, temp2_1 = self.get_total_sum(q1, q2, 1)
	    return np.array([[temp1_0, temp1_1], [temp2_0, temp2_1]])

	def test_func(self, pi1, pi2, rho1, rho2, step):
	    def test_performance(pi1, pi2, rho1, rho2, s=0):
	        b1 = np.random.choice([0,1], p=rho1[s])
	        b2 = np.random.choice([0,1], p=rho2[s])
	        s_noise_1 = self.state_noised[s][b1]
	        s_noise_2 = self.state_noised[s][b2]
	        a1 = np.random.choice([0,1], p=pi1[s_noise_1])
	        a2 = np.random.choice([0,1], p=pi2[s_noise_2])
	        r = self.reward_func[s][a1][a2][b1][b2]
	        if self.transition_prob[s][a1][a2][b1][b2][0] > self.transition_prob[s][a1][a2][b1][b2][1]:
	            new_s = 0
	        else:
	            new_s = 1
	        return (s, a1, a2, b1, b2, r, new_s)
	    s0 = 0
	    r_sum = 0
	    s, a1, a2, b1, b2, r, new_s = test_performance(pi1, pi2, rho1, rho2, s0)
	    for i in range(step-1):
	        r_sum += r
	        s, a1, a2, b1, b2, r, new_s = test_performance(pi1, pi2, rho1, rho2, new_s)
	    return r_sum

if __name__ == "__main__":
	game_model = TwoPlayerGame.two_player_game()
	robsut_q = RobsutQlearning(q1=np.array([[[[[0]*2]*2]*2]*2]*2), q2=np.array([[[[[0]*2]*2]*2]*2]*2), game_model=game_model, lr=0.1, gamma=0.99)
	
	v_list = []
	q1_list = []
	q2_list = []
	for i in range(800):
		robsut_q.train(25)
		v_list.append(robsut_q.v)
		q1_list.append(robsut_q.q1)
		q2_list.append(robsut_q.q2)
		print("----------- " + str(i) + " epsides-----------")


	pi1, pi2, rho1, rho2 = robsut_q.policy
	print(robsut_q.test_func(pi1, pi2, rho1, rho2, 1000))
	values = [robsut_q.q1, robsut_q.q2, robsut_q.v]

	with open('q_list1.pickle', 'wb') as f:
	    pickle.dump({'v': q1_list}, f)
	f.close()

	with open('q_list2.pickle', 'wb') as f:
	    pickle.dump({'v': q2_list}, f)
	f.close()

	with open('q_list1.pickle','rb') as f:
		q_list1 = pickle.load(f)
	f.close
	print(q_list1)

	with open('q_list2.pickle','rb') as f:
		q_list2 = pickle.load(f)
	f.close
	print(q_list2)

