import numpy as np
from Env import FiniteStateFiniteActionMDP
import matplotlib.pyplot as plt

class Qlearning_gen:
    def __init__(self, mdp, c, total_episodes):
        self.mdp = mdp
        self.c = c
        self.total_episodes = total_episodes

        self.V_func = np.zeros((self.mdp.H+1, self.mdp.S),dtype = np.float32)
        self.V_next = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.float32)

        self.global_Q = np.full((self.mdp.H, self.mdp.S, self.mdp.A), self.mdp.H, dtype=np.float32)
        for i in range(self.mdp.H):
            self.global_Q[i,:,:] = self.mdp.H - i

        self.N = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        self.n = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A), dtype=np.int32)
        
        self.regret = []
        self.raw_gap = []

    def run_episode(self):
        # Get the policy (actions for all states and steps)
        actions_policy = self.choose_action()
        state = self.mdp.reset()
        state_init = state
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))  # To store rewards for each state-step pair

        for step in range(self.mdp.H):
            # Select the action based on the agent's policy
            action = np.argmax(actions_policy[step, state])

            next_state, reward = self.mdp.step(action)

            # Increment visit count for the current state-action pair
            self.n[step, state, action] = 1

            self.V_next[step, state, action] = self.V_func[step+1, next_state]
            
            # Store the received reward
            rewards[step, state, action] = reward
            state = next_state
        return rewards, state_init
    
    def choose_action(self):
        actions = np.zeros([self.mdp.H, self.mdp.S, self.mdp.A])

        for step in range(self.mdp.H):
            for state in range(self.mdp.S):
                best_action = np.argmax(self.global_Q[step, state])
                actions[step, state, best_action] = 1

        return actions
    
    def update_Q(self, rewards):
        H = self.mdp.H
        for h in range(H):
            for s in range(self.mdp.S):
                for a in range(self.mdp.A):
                    if self.n[h, s, a] == 0:
                        continue
                    else:
                        self.N[h, s, a] += 1
                        N_h_k = self.N[h, s, a]
                        step_size = (H + 1) / (H + N_h_k)
                        ucb_bonus = self.c * (H - h - 1) * np.sqrt(H / N_h_k)
                        self.global_Q[h, s, a] = (1-step_size) * self.global_Q[h, s, a] + \
                            step_size * (rewards[h, s, a] + self.V_next[h, s, a] + ucb_bonus)
        self.n.fill(0)

    def learn(self):
        # cummulative regret per-agent
        self.regret_cum = 0
        best_value , best_policy, best_Q = self.mdp.best_gen()

        # Initialize a structure to store rewards (deterministic reward)
        rewards = np.zeros((self.mdp.H, self.mdp.S, self.mdp.A))
        for h in range(self.mdp.H):
            for s in range(self.mdp.S):
                self.V_func[h,s] = max(self.global_Q[h, s, :])
        actions_policy = self.choose_action()
        
        for episode in range(self.total_episodes):
            run_reward, state_init = self.run_episode()
            value = self.mdp.value_gen(actions_policy)
            self.regret_cum = self.regret_cum + best_value[state_init] - value[state_init]
            self.regret.append(self.regret_cum/(episode+1))
            self.raw_gap.append(best_value[state_init] - value[state_init])

            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    a = np.argmax(actions_policy[h, s])
                    if rewards[h, s, a] == 0:
                        rewards[h, s, a] =run_reward[h,s,a]

            self.update_Q(rewards)
            actions_policy = self.choose_action()
            for h in range(self.mdp.H):
                for s in range(self.mdp.S):
                    self.V_func[h,s] = max(self.global_Q[h, s, :])
        return best_value, best_Q, value, self.global_Q, self.raw_gap