
import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class FedQlearning_gen_adv:
    def __init__(self, mdp, total_episodes, num_agents, 
                 is_fed = False, using_adv_min = 100, is_adv = 1, is_ber = 0):
        self.mdp = mdp
        self.total_episodes = total_episodes # total_episodes * num_agents = all episodes
        self.num_agents = num_agents
        if not is_fed:
            self.total_episodes = total_episodes * num_agents
            self.num_agents = 1
        self.V_func = np.zeros((self.mdp.H, self.mdp.S),dtype = np.float32) #estimated value function
        self.V_ref_func = np.zeros((self.mdp.H, self.mdp.S),dtype = np.float32) #used reference function
        self.trigger_times = 0 #number of round
        self.comm_episode_collection = []
        self.using_adv_min = using_adv_min
        self.is_adv = is_adv
        self.is_ber = is_ber

        self.V_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.V2_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vref_sum_all = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vref2_sum_all = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vadv_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.Vadv2_sum_stage = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.V_ref_trigger = np.zeros((self.mdp.H, self.mdp.S), dtype = np.int32)

        self.count_variance = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.int32)
        self.n_previous_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        self.n_current_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)



        self.is_fed = is_fed


        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i


        self.agent_N = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.agent_V_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_V2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vref_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vref2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vadv_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)
        self.agent_Vadv2_sum = np.zeros((num_agents, self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)


        self.regret = []
        self.raw_gap = []

    def run_episode(self, agent_id):
        # Get the policy (actions for all states and steps)
        #V_func[h,s]
        event_triggered = False
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.agent_N[agent_id, step, state, action] += 1

        #             self.V_func = np.zeros((self.mdp.H, self.mdp.S),dtype = np.float32)
        # self.V_ref_func = np.zeros((self.mdp.H, self.mdp.S),dtype = np.float32)
            
            if step < self.mdp.H - 1: #location shifting
                self.agent_V_sum[agent_id, step, state, action] += self.V_func[step, next_state]
                self.agent_V2_sum[agent_id, step, state, action] += self.V_func[step, next_state]**2
                self.agent_Vref_sum[agent_id, step, state, action] += self.V_ref_func[step, next_state]
                self.agent_Vref2_sum[agent_id, step, state, action] += (self.V_ref_func[step, next_state])**2
                self.agent_Vadv_sum[agent_id, step, state, action] += (
                    self.V_func[step, next_state] - self.V_ref_func[step, next_state])
                self.agent_Vadv2_sum[agent_id, step, state, action] += (
                    self.V_func[step, next_state] - self.V_ref_func[step, next_state])**2



            # Store the received reward
            rewards[step, state, action] = reward
            # Check if the event-triggered condition is met

            flag = self.check_sync_triggered(agent_id, step, state, action, self.is_fed)
            if flag:
                event_triggered = True
            state = next_state
        return rewards, event_triggered, state_init

    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions


    def check_sync_triggered(self, agent_id, step, state, action, is_fed):
        # Calculate the threshold for triggering the event
        #         self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        # self.n_previous_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        # self.n_current_st = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)
        # self.n_current_rd = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A),dtype = np.float32)

        previous_state_visit = self.n_previous_st[step, state, action]
        current_state_visit = self.n_current_st[step, state, action]
        threshold = 1
        if is_fed == 1 and previous_state_visit > 0:
            if current_state_visit > (1-1/self.mdp.H)*previous_state_visit:
                threshold = round(np.floor(previous_state_visit/self.num_agents/self.mdp.H))
            else:
                threshold = round(np.ceil((previous_state_visit - current_state_visit)/self.num_agents))

        # Check if the visit count exceeds the threshold
        return self.agent_N[agent_id, step, state, action] >= threshold
    
    def check_stage_triggered(self, step, state, action):
        # Calculate the threshold for triggering the event

        previous_state_visit = self.n_previous_st[step, state, action]
        current_state_visit = self.n_current_st[step, state, action]
        
        return current_state_visit >= (self.num_agents * self.mdp.H)*(previous_state_visit == 0) + (
            1+1/self.mdp.H)* previous_state_visit
    


    def aggregate_data(self, policy_k, rewards, is_fed): # after a round
        H, M = self.mdp.H, self.num_agents
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    #print(policy_k[h, s])
                    if a != np.argmax(policy_k[h, s]) or self.agent_N[:, h, s, a].sum() == 0:
                        # No update required, retain previous Q-values
                        continue
                    else:
                        self.n_current_st[h, s, a] += self.agent_N[:, h, s, a].sum()

                        self.V_sum_stage[h, s, a] += self.agent_V_sum[:,h,s,a].sum()
                        self.V2_sum_stage[h, s, a] += self.agent_V2_sum[:,h,s,a].sum()
                        self.Vref_sum_all[h, s, a] += self.agent_Vref_sum[:, h, s, a].sum()
                        self.Vref2_sum_all[h, s, a] += self.agent_Vref2_sum[:, h, s, a].sum()
                        self.Vadv_sum_stage[h, s, a] += self.agent_Vadv_sum[:, h, s, a].sum()
                        self.Vadv2_sum_stage[h, s, a] += self.agent_Vadv2_sum[:, h, s, a].sum()
                        if self.check_stage_triggered(h,s,a):
                            self.N[h,s,a] += self.n_current_st[h, s, a]
                            Q1 = rewards[h,s,a] + self.V_sum_stage[h,s,a]/self.n_current_st[h, s, a] + np.sqrt(
                                2*(H-h-1)*(H-h-1)/self.n_current_st[h, s, a])
                            sigma2_v = self.V2_sum_stage[h,s,a]/self.n_current_st[h, s, a] - (
                                self.V_sum_stage[h,s,a]/self.n_current_st[h, s, a])**2
                            if sigma2_v < 0:
                                sigma2_v = 1e-8
                            Q2 = rewards[h,s,a] + self.V_sum_stage[h,s,a]/self.n_current_st[h, s, a] + 2*np.sqrt(
                                sigma2_v/self.n_current_st[h, s, a])
                                

                            Q2 = Q2*(self.n_current_st[h, s, a] > 10) + (H-h)*(self.n_current_st[h, s, a] <= 10)

                            if not self.is_ber:

                                Q2 = H-h

                            sigma2_vref = self.Vref2_sum_all[h,s,a]/self.N[h,s,a] - (self.Vref_sum_all[h,s,a]/self.N[h,s,a])**2
                            sigma2_vadv = self.Vadv2_sum_stage[h,s,a]/self.n_current_st[h,s,a] - (
                                self.Vadv_sum_stage[h,s,a]/self.n_current_st[h,s,a])**2
                            if sigma2_vref < 0:
                                sigma2_vref = 1e-8
                            if sigma2_vadv < 0:
                                sigma2_vadv = 1e-8
                            Q3 = rewards[h,s,a] + self.Vref_sum_all[h,s,a]/self.N[h, s, a] + (
                             self.Vadv_sum_stage[h,s,a]/self.n_current_st[h,s,a]) + 2*np.sqrt(
                                sigma2_vref/self.N[h, s, a]) + 2*np.sqrt(
                                sigma2_vadv/self.n_current_st[h, s, a])
                            
                            if self.is_adv == 0:
                                Q3 = H - h
                            # print(Q1)
                            # print(Q2)
                            # print(Q3)
                            # print(self.global_Q[h,s,a])
                            #Q3 = Q3*(self.n_current_st[h, s, a] > 10) + (H-h)*(self.n_current_st[h, s, a] <= 10)
                            self.global_Q[h,s,a] = min([Q1,Q2,Q3,self.global_Q[h,s,a]])
                        
                            self.n_previous_st[h,s,a] = self.n_current_st[h, s, a]
                            self.n_current_st[h, s, a] = 0.0
                            self.V_sum_stage[h, s, a] = 0.0
                            self.V2_sum_stage[h, s, a] = 0.0
                            self.Vadv_sum_stage[h, s, a] = 0.0
                            self.Vadv2_sum_stage[h, s, a] = 0.0
        
        self.agent_N.fill(0)
        self.agent_V_sum.fill(0)
        self.agent_V2_sum.fill(0)
        self.agent_Vref_sum.fill(0)
        self.agent_Vref2_sum.fill(0)
        self.agent_Vadv_sum.fill(0)
        self.agent_Vadv2_sum.fill(0)
    
    def update_reference(self, h, s):
        if h == 0 or self.V_ref_trigger[h,s] == 1:
            return
        if self.N[h,s,:].sum() >= self.using_adv_min:
            self.V_ref_trigger[h,s] = 1
            self.V_ref_func[h-1,s] = self.V_func[h-1,s]

        
                            
    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()
        # Event-triggered termination flag
        event_triggered = False
        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H - 1):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h+1, s, :])
                self.V_ref_func[h,s] = self.V_func[h,s]

        for h in range(1,self.mdp.H):
            for s in range(self.mdp.S):
                self.update_reference(h, s)
        actions_policy = self.choose_action()
        for episode in range(self.total_episodes):
            if episode%1000 == 0:
                print(episode)
            # Run one episode for each agent
            value = self.mdp.value_gen(actions_policy)
            for agent_id in range(self.num_agents):
                agent_reward, agent_event_triggered, state_init = self.run_episode(agent_id)
                self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
                self.regret.append(self.regret_cum/(episode+1))
                self.raw_gap.append(best_value[state_init] - value[state_init])

                for h in range(self.mdp.H):
                    for s in range(self.mdp.S):
                        a = np.argmax(actions_policy[h, s])
                        if rewards[h, s, a] == 0:
                            rewards[h, s, a] = agent_reward[h, s, a]

                if agent_event_triggered:
                    event_triggered = True

            # Calculate regret

            
            #self.regret.append(best_value[initial_state] - value[initial_state])
            
            
            

            # Globally aggregate and update policy if event-triggered termination occurred
            if event_triggered:
                self.trigger_times += 1
                self.comm_episode_collection.append(episode)
                #actions_policy = self.choose_action()
#                 V_next = np.zeros(self.mdp.S)


#                 for s in range(self.mdp.S):
#                     # For each state, find the best action value at step h+1
#                     V_next[s] = np.max(self.global_Q[h+1, s])if h + 1 < self.mdp.H else 0


#                 agent_values = np.array([self.global_Q for _ in range(self.num_agents)])

                self.aggregate_data(actions_policy, rewards, is_fed = self.is_fed)
                event_triggered = False
                actions_policy = self.choose_action()
                for h in range(self.mdp.H - 1):
                    for s in range(self.mdp.S):
                        self.V_func[h,s] = max(self.global_Q[h+1, s, :])
                for h in range(1,self.mdp.H):
                    for s in range(self.mdp.S):
                        self.update_reference(h, s)
        return best_value, best_Q, value, self.global_Q