import numpy as np
import tqdm
from utils import *



class RMAB():
    """
    Restless Multi-Armed Bandit
    """
    def __init__(self, N, d, alpha, gamma):
        self.N = N
        self.d = d
        self.alpha = alpha
        self.gamma = gamma
        self.Transition_0, self.Transition_1 = self.generate_Transition(d)
        self.Reward_0, self.Reward_1 = self.generate_Reward(d)
        self.m_star = None
        self.global_state_agent = np.zeros(self.N)
        self.global_action_agent = np.zeros(self.N)
        self.reset()
    

    def reset(self):
        assert self.N % self.d == 0, "N must be divisible by d"
        self.config = np.array([self.N/self.d]*self.d)
        par = self.N // self.d
        for i in range(self.d):
            self.global_state_agent[i*par:(i+1)*par] = i

        return self.config


    def generate_Transition(self, d):
        assert d==3, "Only d=3 is supported"
        Transition_0 = np.eye(self.d)
        Transition_1 = np.array([[0.3, 0.3, 0.4], [0.25, 0.25, 0.5], [0.3, 0.4, 0.3]])


        return Transition_0, Transition_1


    def generate_Reward(self, d):
        assert d==3, "Only d=3 is supported"
        # Reward_0 = np.array([0.8, 0.4, 0.3])
        # Reward_1 = np.array([0.6, 0.7, 0.2])
        Reward_0 = np.array([0, 0, 0])
        Reward_1 = np.array([6, 7, 2])
        # Reward_0 = np.random.rand(d)*10
        # Reward_1 = np.random.rand(d)*10

        return Reward_0, Reward_1
    

    def step(self, priority):
        """
        Step the environment
        """
        next_config = np.zeros(self.d)
        alpha = self.alpha * self.N
        current_frac = 0
        for state in priority:
            current_frac += self.config[state]
            if current_frac <= alpha:
                for _ in range(int(self.config[state])):
                    next_config[np.random.choice(self.d, p=self.Transition_1[state])] += 1
            elif current_frac > alpha and current_frac <= self.config[state] + alpha:
                residual_frac = alpha - (current_frac - self.config[state])
                for _ in range(int(residual_frac)):
                    next_config[np.random.choice(self.d, p=self.Transition_1[state])] += 1
                for _ in range(int(self.config[state] - residual_frac)):
                    next_config[np.random.choice(self.d, p=self.Transition_0[state])] += 1
            elif current_frac > self.config[state] + alpha:
                for _ in range(int(self.config[state])):
                    next_config[np.random.choice(self.d, p=self.Transition_0[state])] += 1
        
        assert current_frac == self.N, "Fractional allocation error"
        self.config = next_config
        
        return self.config    
    

    def step_by_agent(self, priority):
        """
        Step each agent
        """
        policy = get_policy(self.config/self.N, priority, self.alpha)
        self.global_action_agent = np.zeros(self.N)

        for i in range(self.d):
            agents_in_state_i = np.where(self.global_state_agent == i)[0]
            num_agents_in_state_i = len(agents_in_state_i)

            if num_agents_in_state_i > 0:
                num_to_activate = int(policy[i] * num_agents_in_state_i) 
                activated_agents = np.random.choice(agents_in_state_i, num_to_activate, replace=False)  # Randomly choose agents to activate
                self.global_action_agent[activated_agents] = 1  

        cur_global_state = self.global_state_agent.copy().astype(int)
        cur_global_action = self.global_action_agent.copy().astype(int)

        cur_global_reward = np.zeros(self.N)
        for i in range(self.N):
            if cur_global_action[i] == 1:
                cur_global_reward[i] = self.Reward_1[cur_global_state[i]]
            else:
                cur_global_reward[i] = self.Reward_0[cur_global_state[i]]

        for i in range(self.N):
            if cur_global_action[i] == 1:
                self.global_state_agent[i] = np.random.choice(self.d, p=self.Transition_1[cur_global_state[i]])
            else:
                self.global_state_agent[i] = np.random.choice(self.d, p=self.Transition_0[cur_global_state[i]])
        
        # Update config
        for i in range(self.d):
            self.config[i] = np.sum(self.global_state_agent == i)

        return cur_global_state, cur_global_action, cur_global_reward

    


    def find_m_star(self, priority, n_iter=100):
        """
        Find the asymptotic configuration
        """
        def step_expected(config, priority):
            """
            Step the environment in expectation
            """
            current_frac = 0
            next_config = np.zeros(self.d)
            for state in priority:
                current_frac += config[state]
                if current_frac <= self.alpha:
                    next_config += self.Transition_1[state] * config[state]
                elif current_frac > self.alpha and current_frac <= config[state] + self.alpha:
                    residual_frac = self.alpha - (current_frac - config[state])
                    next_config += self.Transition_1[state] * residual_frac
                    next_config += self.Transition_0[state] * (config[state] - residual_frac)
                elif current_frac > config[state] + self.alpha:
                    next_config += self.Transition_0[state] * config[state]
            return next_config

        m = np.array([1/self.d]*self.d)
        for _ in range(n_iter):
            m = step_expected(m, priority)
        assert np.allclose(np.sum(m), 1), f"Fractional allocation error {np.sum(m)}"
        print(f"m*: {m} under Index Priority: {priority}")
        self.m_star = m

        return m
    

    def reset_to_m_star(self):
        """
        Reset the configuration to m_star
        """
        if self.m_star is None:
            self.find_m_star(np.arange(self.d))

        self.config = np.floor(self.m_star * self.N)
        self.config[-1] += self.N - np.sum(self.config)
        
        return self.config
    

    def Value_Learning_MC(self, priority, trajectory_length=30, num_sims=1000):
        """
        Value Learning Monte Carlo
        """
        Values = []

        import concurrent.futures

        with concurrent.futures.ProcessPoolExecutor() as executor:
            futures = [executor.submit(MC_once, self, priority, trajectory_length) for _ in range(num_sims)]
            for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=num_sims, desc="Value Learning MC"):
                Values.append(future.result())

        mean_value = np.mean(Values)
        std_value = np.std(Values)

        print(f"Value Learning MC: {mean_value} +- {std_value/np.sqrt(num_sims)}")


        return mean_value, std_value/np.sqrt(num_sims)


class Circular_Env(RMAB):
    """
    Circular Environment
    """
    def __init__(self, N, d, alpha, gamma):
        assert d== 4, "Only d=4 is supported"
        super().__init__(N, d, alpha, gamma)
        self.Transition_0, self.Transition_1 = self.generate_Transition(d)
        self.Reward_0, self.Reward_1 = self.generate_Reward(d)


    def reset(self):
        assert self.N % 6 == 0, "N must be divisible by 6"
        self.config = np.zeros(self.d)
        self.config[0] = self.N / 6
        self.config[1] = self.N / 3
        self.config[2] = self.N / 2
        self.config[3] = 0
        
        par_0 = int(self.N / 6)
        par_1 = int(self.N / 3)
        par_2 = int(self.N / 2)
        
        self.global_state_agent[:par_0] = 0
        self.global_state_agent[par_0:par_0+par_1] = 1
        self.global_state_agent[par_0+par_1:par_0+par_1+par_2] = 2
        
        return self.config
    
    def generate_Transition(self, d):
        assert d==4, "Only d=4 is supported"
        Transition_0 = np.array([[0.5, 0, 0, 0.5], [0.5, 0.5, 0, 0], [0, 0.5, 0.5, 0], [0, 0, 0.5, 0.5]])
        Transition_1 = np.array([[0.5, 0.5, 0, 0], [0, 0.5, 0.5, 0], [0, 0, 0.5, 0.5], [0.5, 0, 0, 0.5]])

        return Transition_0, Transition_1
    
    def generate_Reward(self, d):
        assert d==4, "Only d=4 is supported"
        Reward_0 = np.array([-1, 0, 0, 1])
        Reward_1 = np.array([-1, 0, 0, 1])

        return Reward_0, Reward_1



