import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class Qlearning_gen_adv:
    def __init__(self, mdp, total_episodes, c1, c2, c3, using_adv_min):
        self.mdp = mdp
        self.c1 = c1
        self.c2 = c2
        self.c3 = c3
        self.total_episodes = total_episodes
        self.Nswitch = 0

        self.V_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32) #estimated value function
        self.V_ref_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32) #used reference function
        self.using_adv_min = using_adv_min

        self.V_ref_trigger = np.zeros((self.mdp.H, self.mdp.S), dtype = np.int32)

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.int32)
        self.n_previous_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.n_current_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i

        self.V_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.V2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vref_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vref2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vadv_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.Vadv2_sum = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.n_switch = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))


        self.regret = []
        self.globalcost = []

    def run_episode(self):
        # Get the policy (actions for all states and steps)
        #V_func[h,s]
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.n_current_st[step, state, action] += 1
            self.V_sum[step, state, action] += self.V_func[step+1, next_state]
            self.V2_sum[step, state, action] += self.V_func[step+1, next_state]**2
            self.Vref_sum[step, state, action] += self.V_ref_func[step+1, next_state]
            self.Vref2_sum[step, state, action] += (self.V_ref_func[step+1, next_state])**2
            self.Vadv_sum[step, state, action] += (
                self.V_func[step+1, next_state] - self.V_ref_func[step+1, next_state])
            self.Vadv2_sum[step, state, action] += (
                self.V_func[step+1, next_state] - self.V_ref_func[step+1, next_state])**2

            # Store the received reward
            rewards[step, state, action] = reward
            # Check if the event-triggered condition is met
            state = next_state
            
        return rewards, state_init

    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions
    
    def check_stage_triggered(self, step, state, action):
        # Calculate the threshold for triggering the event

        previous_state_visit = self.n_previous_st[step, state, action]
        current_state_visit = self.n_current_st[step, state, action]
        
        return current_state_visit >= self.mdp.H*(previous_state_visit == 0) + int((
            1+1/self.mdp.H)* previous_state_visit)
        
    def update_Qadv(self, rewards):
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):                    
                    if self.check_stage_triggered(h,s,a):
                        self.N[h,s,a] += self.n_current_st[h, s, a]
                        Q1 = rewards[h,s,a] + self.V_sum[h,s,a]/self.n_current_st[h, s, a] + self.c1 * np.sqrt(
                            (self.mdp.H-h-1)**2/self.n_current_st[h, s, a])

                        sigma2_vref = self.Vref2_sum[h,s,a]/self.N[h,s,a] - (self.Vref_sum[h,s,a]/self.N[h,s,a])**2
                        sigma2_vadv = self.Vadv2_sum[h,s,a]/self.n_current_st[h,s,a] - (
                            self.Vadv_sum[h,s,a]/self.n_current_st[h,s,a])**2
                        if sigma2_vref < 0:
                            sigma2_vref = 1e-8
                        if sigma2_vadv < 0:
                            sigma2_vadv = 1e-8
                        Q2 = rewards[h,s,a] + self.Vref_sum[h,s,a]/self.N[h, s, a] + (
                            self.Vadv_sum[h,s,a]/self.n_current_st[h,s,a]) + self.c2*np.sqrt(
                            sigma2_vref/self.N[h, s, a]) + self.c2*np.sqrt(
                            sigma2_vadv/self.n_current_st[h, s, a]) + self.c3 * self.mdp.H* (1/self.n_current_st[h,s,a]**(3/4)
                                                                                 +1/self.N[h,s,a]**(3/4))
                        
                        self.global_Q[h,s,a] = min([Q1 , Q2, self.global_Q[h,s,a]])
                    
                        self.n_previous_st[h,s,a] = self.n_current_st[h, s, a]
                        self.n_current_st[h, s, a] = 0.0
                        self.V_sum[h, s, a] = 0.0
                        self.V2_sum[h, s, a] = 0.0
                        self.Vadv_sum[h, s, a] = 0.0
                        self.Vadv2_sum[h, s, a] = 0.0        
    
    def update_reference(self, h, s):
        if self.V_ref_trigger[h,s] == 1:
            return
        if self.N[h,s,:].sum() >= self.using_adv_min:
            self.V_ref_trigger[h,s] = 1
            self.V_ref_func[h,s] = self.V_func[h,s]
            
    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()
        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h, s, :])
                self.V_ref_func[h,s] = self.V_func[h,s]

        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.update_reference(h, s)
        actions_policy = self.choose_action()
        self.n_switch = actions_policy
        # print(actions_policy)
        for episode in range(self.total_episodes):            
            # Run one episode for each agent
            value = self.mdp.value_gen(actions_policy)
            run_reward, state_init = self.run_episode()
            self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
            self.regret.append(self.regret_cum)

            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    a = np.argmax(actions_policy[h, s])
                    if rewards[h, s, a] == 0:
                        rewards[h, s, a] = run_reward[h, s, a]

            self.update_Qadv(rewards)
            actions_policy = self.choose_action()
            # print(actions_policy)
            if not np.array_equal(self.n_switch, actions_policy):
                self.Nswitch += 1
            self.globalcost.append(self.Nswitch)
            self.n_switch = actions_policy
            
            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    self.V_func[h,s] = max(self.global_Q[h, s, :])
                    self.update_reference(h, s)
        return best_Q, self.global_Q