import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class Qlearning_gen_early:
    def __init__(self, mdp, c1, c2, c3, total_episodes, beta):
        self.mdp = mdp
        self.c1 = c1
        self.c2 = c2
        self.c3 = c3
        self.total_episodes = total_episodes

        self.V_func = np.zeros((self.mdp.H + 1, self.mdp.S),dtype = np.float32)
        self.V_next = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.VL = np.zeros((self.mdp.H + 1, self.mdp.S), dtype=np.float32)
        self.VL_next = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.V_ref_func = np.zeros((self.mdp.H + 1, self.mdp.S),dtype = np.float32)
        self.Vref_next = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.V_ref_trigger = np.zeros((self.mdp.H, self.mdp.S), dtype = np.int32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)        
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i
        self.QU = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)        
        for i in range(self.mdp.H):
            self.QU[i,:,:] = self.mdp.H - i
        self.QR = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)        
        for i in range(self.mdp.H):
            self.QR[i,:,:] = self.mdp.H - i
        self.QL = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.n = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        
        self.Vref_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vref2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vadv_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vadv2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.delta = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.B = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.beta = beta

        self.regret = []
        self.raw_gap = []

    def run_episode(self):
        # Get the policy (actions for all states and steps)
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.n[step, state, action] = 1

            self.V_next[step, state, action] = self.V_func[step+1, next_state]
            self.VL_next[step, state, action] = self.VL[step+1, next_state]
            self.Vref_next[step, state, action] = self.V_ref_func[step+1, next_state]
            self.Vref_sum[step, state, action] += self.V_ref_func[step+1, next_state]
            self.Vref2_sum[step, state, action] += (self.V_ref_func[step+1, next_state])**2            
            
            # Store the received reward
            rewards[step, state, action] = reward
            state = next_state
        return rewards, state_init
    
    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions
    
    def update_Qearly(self, rewards):
        H = self.mdp.H
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    if self.n[h, s, a] == 0:
                        continue
                    else:
                        self.N[h, s, a] += 1
                        N_h_k = self.N[h, s, a]
                        step_size = (H + 1) / (H + N_h_k)
                        ucb_bonus = self.c1 * (H - h - 1) * np.sqrt(H / N_h_k)                                    
                        self.QU[h, s, a] = (1-step_size) * self.QU[h, s, a] + \
                            step_size * (rewards[h, s, a] + self.V_next[h, s, a] + ucb_bonus)
                        self.QL[h, s, a] = (1-step_size) * self.QL[h, s, a] + \
                            step_size * (rewards[h, s, a] + self.VL_next[h, s, a] - ucb_bonus)
                        
                        self.Vadv_sum[h, s, a] = (1-step_size) * self.Vadv_sum[h, s, a] + step_size * (
                        self.V_next[h, s, a] - self.Vref_next[h, s, a])
                        self.Vadv2_sum[h, s, a] = (1-step_size) * self.Vadv2_sum[h, s, a] + step_size * (
                        self.V_next[h, s, a] - self.Vref_next[h, s, a])**2
                        sigma2_vref = self.Vref2_sum[h,s,a]/self.N[h,s,a] - (self.Vref_sum[h,s,a]/self.N[h,s,a])**2
                        sigma2_vadv = self.Vadv2_sum[h,s,a] - (self.Vadv_sum[h,s,a])**2
                        if sigma2_vref < 0:
                            sigma2_vref = 1e-8
                        if sigma2_vadv < 0:
                            sigma2_vadv = 1e-8
                        Bnext = self.c2 * (np.sqrt(sigma2_vref/N_h_k) + np.sqrt(H * sigma2_vadv/N_h_k))
                        self.delta[h, s, a] =  Bnext - self.B[h, s, a]
                        self.B[h, s, a] = Bnext
                        b = self.B[h, s, a] + (1-step_size) * self.delta[h, s, a] / step_size + self.c3 * (H-h-1)**2 / N_h_k**(3/4)
                        self.QR[h, s, a] = (1-step_size) * self.QR[h, s, a] + \
                            step_size * (rewards[h, s, a] + self.V_next[h, s, a] - self.Vref_next[h, s, a] + self.Vref_sum[h, s, a]/ N_h_k + b)
                        
                        self.global_Q[h, s, a] = min([self.QU[h, s, a] , self.QR[h, s, a], self.global_Q[h,s,a]])
        self.n.fill(0)
    
    def update_reference(self, h, s):
        if self.V_ref_trigger[h,s] == 1:
            return
        if self.V_func[h,s] - self.VL[h,s] < self.beta:
            self.V_ref_trigger[h,s] = 1
            self.V_ref_func[h,s] = self.V_func[h,s]

    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()

        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h, s, :])
                self.V_ref_func[h,s] = max(self.global_Q[h, s, :])
        actions_policy = self.choose_action()
        
        for episode in range(self.total_episodes):
            run_reward, state_init = self.run_episode()
            value = self.mdp.value_gen(actions_policy)
            self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
            self.regret.append(self.regret_cum/(episode+1))
            self.raw_gap.append(best_value[state_init] - value[state_init])

            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    a = np.argmax(actions_policy[h, s])
                    if rewards[h, s, a] == 0:
                        rewards[h, s, a] =run_reward[h,s,a]

            self.update_Qearly(rewards)
            actions_policy = self.choose_action()
            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    self.V_func[h, s] = max(self.global_Q[h, s, :])
                    self.VL[h, s] = max(max(self.QL[h, s, :]), self.VL[h, s])
                    self.update_reference(h, s)
        return best_value, best_Q, value, self.global_Q, self.raw_gap