import numpy as np
import gym
from mwrmab import initialize, evaluation

# Constant costs for all action types and reward 0 for all states but last one with 1
class constantCosts(gym.Env):
    def __init__(self, N, M, B, cost, seed, algos = ['MWRMAB', 'MWRMAB_adj', 'OPT_fair','OPT','hawkins','random','no_action']):
        
        self.N = N
        self.S = 2
        self.M = M
        self.B = B
        self.algos = algos
        self.init_seed = seed
        self.cost = cost

        self.T, self.R, self.C = self.get_experiment()

        
        self.current_states = {}
        self.rewards = {}
        np.random.seed(seed)
        start_state = np.random.choice(list(range(self.S)), size=N, replace=True)

        for algo in algos:
            self.current_states[algo] = start_state
            self.rewards[algo] = [0]*self.M

    def get_experiment(self):

        np.random.seed(self.init_seed)
        C = np.array([[0] + [self.cost]*self.M for i in range(self.N)])
        T = initialize.actionTypeTransitions_all(self.N,self.S,self.M) # Shape is N,S,M,S
        R = np.array([[0]*(self.S-1)+[1] for i in range(self.N)])

        return T, R, C

    def newT(self):
        T = initialize.actionTypeTransitions_all(self.N,self.S,self.M)
        self.T = T
        return T

    def step(self, actions, algo):
        
        current_state = evaluation.nextState(actions,self.current_states[algo], self.S, self.T)
        spent_budget = evaluation.usedBudget(self.C,actions, self.M)
        reward = evaluation.getReward(current_state, self.R)

        self.current_states[algo] = current_state
        self.rewards[algo] = reward

        return current_state, spent_budget, reward
    
    def reset(self):
        np.random.seed(self.init_seed)
        start_state = np.random.choice(list(range(self.S)), size=self.N, replace=True)

        np.random.seed(self.init_seed)
        T = initialize.actionTypeTransitions(self.N,self.S,self.M) # Shape is N,S,M,S
        self.T = T

        for algo in self.algos:
            self.current_states[algo] = start_state
            self.rewards[algo] = [0]*self.M
    


# 2 state transition probability with ordered active transitions for M action types. 
# Random integer costs between (minC, maxC), fixed N, fixed B
class orderedWorkers(gym.Env):
    def __init__(self, N, M, B, minC = None, maxC = None, cost = None, seed =1, algos = ['MWRMAB', 'MWRMAB_adj', 'OPT_fair','OPT','hawkins','random','no_action']):
        
        self.N = N
        self.S = 2
        self.M = M
        self.B = B
        self.algos = algos
        self.minC = minC
        self.maxC = maxC
        self.cost = cost
        self.init_seed = seed

        self.T, self.R, self.C = self.get_experiment()

        
        self.current_states = {}
        self.rewards = {}
        np.random.seed(seed)
        start_state = np.random.choice(list(range(self.S)), size=N, replace=True)

        for algo in algos:
            self.current_states[algo] = start_state
            self.rewards[algo] = [0]*self.M

    def get_experiment(self):

        np.random.seed(self.init_seed)
        if self.minC is not None:
            C = np.array([[0]+list(np.random.randint(self.minC, self.maxC, self.M)) for _ in range(self.N)])
        if self.cost is not None:
            C = np.array([[0] + [self.cost]*self.M for i in range(self.N)])

        T = initialize.ordered_worker_T_2state(self.N, self.M+1, always_positive_index=True)
        R = np.array([[0]*(self.S-1)+[1] for i in range(self.N)])

        return T, R, C

    def newT(self):
        T = initialize.ordered_worker_T_2state(self.N, self.M+1, always_positive_index=True)
        self.T = T
        return T

    def step(self, actions, algo):
        
        current_state = evaluation.nextState(actions,self.current_states[algo], self.S, self.T)
        spent_budget = evaluation.usedBudget(self.C,actions, self.M)
        reward = evaluation.getReward(current_state, self.R)

        self.current_states[algo] = current_state
        self.rewards[algo] = reward

        return current_state, spent_budget, reward
    
    def reset(self):
        np.random.seed(self.init_seed)
        start_state = np.random.choice(list(range(self.S)), size=self.N, replace=True)

        np.random.seed(self.init_seed)
        T = initialize.ordered_worker_T_2state(self.N, self.M+1, always_positive_index=True)
        self.T = T

        for algo in self.algos:
            self.current_states[algo] = start_state
            self.rewards[algo] = [0]*self.M

# Constant costs for all action types and reward 0 for all states but last one with 1
class decoupledCounterexample(gym.Env):
    def __init__(self, N, B, seed, algos = ['MWRMAB', 'MWRMAB_adj', 'OPT_fair','OPT','hawkins','random','no_action']):
        
        self.N = N
        self.S = 3
        self.M = 2
        self.B = B
        self.algos = algos
        self.init_seed = seed

        self.T, self.R, self.C = self.get_experiment()

        
        self.current_states = {}
        self.rewards = {}
        #start_state = np.zeros(self.N, dtype=int)
        start_state = np.random.choice(list(range(self.S)), size=N, replace=True)

        for algo in algos:
            self.current_states[algo] = start_state
            self.rewards[algo] = [0]*self.M

    def get_experiment(self):

        np.random.seed(self.init_seed)
        C = np.array([[0] + [1]*self.M for i in range(self.N)])
        T = initialize.counter_example_T(self.N) # Shape is N,S,M,S
        R = np.array([[0]*(self.S-1)+[1] for i in range(self.N)])

        return T, R, C

    def step(self, actions, algo):
        
        current_state = evaluation.nextState(actions,self.current_states[algo], self.S, self.T)
        spent_budget = evaluation.usedBudget(self.C,actions, self.M)
        reward = evaluation.getReward(current_state, self.R)

        self.current_states[algo] = current_state
        self.rewards[algo] = reward

        return current_state, spent_budget, reward
    
    def reset(self):
        start_state = np.zeros(self.N, dtype=int)

        for algo in self.algos:
            self.current_states[algo] = start_state
            self.rewards[algo] = [0]*self.M