import numpy as np
from Env import Env
import math


class SoftAQLearner:
    def __init__(self, epsilon=1.0, gamma=0.95, learningRate=1.0, parameter=1):
        self.learningRate = learningRate
        self.epsilon = epsilon
        self.gamma = gamma
        self.parameter =parameter
        self.init_Q_table()

    def init_Q_table(self):
        self.Q1 = np.random.normal(0, 0.01, size=(Env().world_size, Env().world_size, 4))
        self.Q2 = np.random.normal(0, 0.01, size=(Env().world_size, Env().world_size, 4))
        self.Count_S_A_1 = np.zeros((Env().world_size, Env().world_size, 4))
        self.Count_S_A_2 = np.zeros((Env().world_size, Env().world_size, 4))
        self.Count_S = np.zeros((Env().world_size, Env().world_size))

    def explore(self, state):
        self.Count_S[state[0]][state[1]] += 1
        epsilon_temp = self.epsilon / np.power(self.Count_S[state[0]][state[1]], 0.5)
        if np.random.random() >= epsilon_temp:
            Q3 = [(self.Q1[state[0]][state[1]][i] + self.Q2[state[0]][state[1]][i]) / 2.0 for i in range(4)]
            action = np.argmax(Q3[:])
        else:
            action = np.random.choice(4)
        return action

    def softmax1(self, state, action, reward, next_state):
        positive = abs(reward + self.gamma * self.Q1[next_state[0]][next_state[1]][
                    np.argmax(self.Q1[next_state[0]][next_state[1]][:])] - self.Q1[state[0]][state[1]][action])
        negative = abs(reward + self.gamma * self.Q2[next_state[0]][next_state[1]][
                    np.argmax(self.Q1[next_state[0]][next_state[1]][:])] - self.Q1[state[0]][state[1]][action])
        try:
            z_exp = [math.exp(bias * self.parameter) for bias in [positive, negative]]
            sum_z_exp = sum(z_exp)
            choose = np.random.choice([False, True], p=[i / sum_z_exp for i in z_exp])
        except Exception as e:
            if positive > negative:
                return False
            else:
                return True
        return choose

    def softmax2(self, state, action, reward, next_state):
        positive = abs(reward + self.gamma * self.Q2[next_state[0]][next_state[1]][
                    np.argmax(self.Q2[next_state[0]][next_state[1]][:])] - self.Q2[state[0]][state[1]][action])
        negative = abs(reward + self.gamma *self.Q1[next_state[0]][next_state[1]][
                    np.argmax(self.Q2[next_state[0]][next_state[1]][:])] - self.Q2[state[0]][state[1]][action])
        try:
            z_exp = [math.exp(bias * self.parameter) for bias in [positive, negative]]
            sum_z_exp = sum(z_exp)
            choose = np.random.choice([False, True], p=[i / sum_z_exp for i in z_exp])
        except Exception as e:
            if positive > negative:
                return False
            else:
                return True
        return choose

    def learning(self, state, action, reward, next_state, done):
        Y = reward
        if np.random.random() >= 0.5:
            self.Count_S_A_1[state[0]][state[1]][action] += 1
            lr_1 = self.learningRate / np.power(self.Count_S_A_1[state[0]][state[1]][action], 1.0)
            if not done:
                if self.softmax1(state, action, reward, next_state):
                    Y += self.gamma * self.Q1[next_state[0]][next_state[1]][
                        np.argmax(self.Q1[next_state[0]][next_state[1]][:])]
                else:
                    Y += self.gamma * self.Q2[next_state[0]][next_state[1]][
                        np.argmax(self.Q1[next_state[0]][next_state[1]][:])]
            self.Q1[state[0]][state[1]][action] += lr_1 * (Y - self.Q1[state[0]][state[1]][action])
        else:
            self.Count_S_A_2[state[0]][state[1]][action] += 1
            lr_2 = self.learningRate / np.power(self.Count_S_A_2[state[0]][state[1]][action], 1.0)
            if not done:
                if self.softmax2(state, action, reward, next_state):
                    Y += self.gamma * self.Q2[next_state[0]][next_state[1]][
                        np.argmax(self.Q2[next_state[0]][next_state[1]][:])]
                else:
                    Y += self.gamma * self.Q1[next_state[0]][next_state[1]][
                        np.argmax(self.Q2[next_state[0]][next_state[1]][:])]
            self.Q2[state[0]][state[1]][action] += lr_2 * (Y - self.Q2[state[0]][state[1]][action])
