import time
import pickle
import pandas as pd
import numpy as np
from tqdm import tqdm
from mwrmab import environments
from mwrmab.algos import decoupled_methods, whittle_binary_search
from mwrmab.baselines import combinatorial_mdp, hawkins_methods

def get_actions(env, algorithm, gamma=0.95, per_arm_indexes=None, per_arm_adjusted_indexes=None, comb_mdp=None, comb_mdp_fair=None):

    actions = np.zeros(env.N,dtype=int)

    if algorithm == 'S1S2':
        indexes = np.array([np.array([per_arm_adjusted_indexes[n,env.current_states[algorithm][n],a] for n in range(env.N)]) for a in range(env.M)])
        actions = decoupled_methods.jointGreedyAllocation(indexes, env.C, env.N, env.B, env.M)
        
    if algorithm == 'MWRMAB_bt':
        indexes = np.array([np.array([per_arm_indexes[n,env.current_states[algorithm][n],a] for n in range(env.N)]) for a in range(env.M)])
        actions = decoupled_methods.greedyAllocation(indexes, env.C, env.N, env.B, env.M, breakties=True)

    if algorithm == 'MWRMAB_adj_bt':
        indexes = np.array([np.array([per_arm_adjusted_indexes[n,env.current_states[algorithm][n],a] for n in range(env.N)]) for a in range(env.M)])
        actions = decoupled_methods.greedyAllocation(indexes, env.C, env.N, env.B, env.M, breakties=True)

    if algorithm == 'MWRMAB':
        indexes = np.array([np.array([per_arm_indexes[n,env.current_states[algorithm][n],a] for n in range(env.N)]) for a in range(env.M)])
        actions = decoupled_methods.greedyAllocation(indexes, env.C, env.N, env.B, env.M, breakties=False)

    if algorithm == 'MWRMAB_adj':
        indexes = np.array([np.array([per_arm_adjusted_indexes[n,env.current_states[algorithm][n],a] for n in range(env.N)]) for a in range(env.M)])
        actions = decoupled_methods.greedyAllocation(indexes, env.C, env.N, env.B, env.M, breakties=False)

    if algorithm == 'OPT':
        actions = comb_mdp.get_action(env.current_states[algorithm])

    if algorithm == 'OPT_fair':
        actions = comb_mdp_fair.get_action(env.current_states[algorithm])
    
    if algorithm == 'hawkins':
        actions = hawkins_methods.get_hawkins_actions(env.N, env.T, env.R, env.C, env.B, 0, env.current_states[algorithm], gamma)

    if algorithm == 'random':
        actions = decoupled_methods.randomUntilBudget(env.C, env.B, env.M, env.N)

    if algorithm == 'no_actions':
        actions = np.zeros(env.N,dtype=int)

    return actions

def saveMeanResults(exp, rewards, usedBudget, path):
    means = []
    std = []
    for algo in rewards.keys():
        means.append(rewards[algo].mean(axis=1).mean())
        std.append(rewards[algo].mean(axis=1).std())

    fraction = []
    fraction_std = []
    for algo in usedBudget.keys():
        x = ((usedBudget[algo].max(axis=2) - usedBudget[algo].min(axis=2))<=exp.env.C.max()).mean(axis=1)
        fraction.append(x.mean())
        fraction_std.append(x.std())

    df = pd.DataFrame([means,std,fraction,fraction_std],
    index=['reward','std_reward','fraction_fair','fraction_fair_std'],columns=list(rewards.keys())).T
    df.to_csv(path)


class Experiment:
    def __init__(self, N, M=None, B=None, seed=None, cost=None, minC=None, maxC=None , data='constantCosts', EPOCHS=50, STEPS=50, changeT = False, algos = ['MWRMAB', 'MWRMAB_adj', 'OPT_fair','OPT','hawkins','random','no_action'],index_lb=-1,index_ub=1,gamma=0.95,tolerance=1e-4):

        self.index_lb = index_lb
        self.index_ub = index_ub
        self.gamma = gamma
        self.tolerance = tolerance
        self.changeT = changeT
        self.algos = algos
        self.EPOCHS = EPOCHS
        self.STEPS = STEPS
        self.data = data

        if data == 'constantCosts':
            self.env = environments.constantCosts(N, M, B, cost, seed, algos=algos)
        if data == 'orderedWorkers':
            self.env = environments.orderedWorkers(N, M, B, minC=minC, maxC=maxC, cost=cost, seed=seed, algos = algos)
        if data == 'decoupledCounterexample':
            self.env = environments.decoupledCounterexample(N, B, seed, algos=algos)

        self.time={}
        self.pretime={}
        for algo in algos:
            self.time[algo] = 0
            self.pretime[algo] = 0
    

    def run(self, saveResults=True):
        rewards = {}
        actions = {}
        usedBudget = {}
        runtime = {}
        for algo in self.algos:
            rewards[algo] = np.ndarray(shape=(self.EPOCHS,self.STEPS))
            actions[algo] = np.ndarray(shape=(self.EPOCHS,self.STEPS,self.env.N))
            usedBudget[algo] = np.ndarray(shape=(self.EPOCHS,self.STEPS,self.env.M))
            runtime[algo] = np.zeros(shape=(self.EPOCHS,int(self.STEPS+1)))

        for epoch in tqdm(range(self.EPOCHS)):            
            if self.changeT:
                self.env.newT()

            #Compute per arm BS indexes (N,S,M)
            if 'MWRMAB' in self.algos:
                start_time = time.time()
                per_arm_indexes = whittle_binary_search.all_per_arm_indexes(self.env.N,self.env.S,self.env.M,self.env.T,self.env.R,self.env.C,index_lb=self.index_lb,index_ub=self.index_ub,gamma=self.gamma,tolerance=self.tolerance)
                rtime = (time.time() - start_time)
                self.time['MWRMAB'] += rtime
                self.pretime['MWRMAB'] += rtime
                runtime['MWRMAB'][epoch,0] += rtime
            if 'MWRMAB_bt' in self.algos:
                self.pretime['MWRMAB_bt'] += rtime
                self.time['MWRMAB_bt'] += rtime
                runtime['MWRMAB_bt'][epoch,0] += rtime

            # pre-compute the adjusted indexes too
            if 'MWRMAB_adj' in self.algos:
                start_time = time.time()
                per_arm_adjusted_indexes = whittle_binary_search.all_per_arm_adjusted_indexes(self.env.N,self.env.S,self.env.M,self.env.T,self.env.R,self.env.C,per_arm_indexes,index_lb=self.index_lb,index_ub=self.index_ub,gamma=self.gamma,tolerance=self.tolerance)
                rtime = (time.time() - start_time)
                self.time['MWRMAB_adj'] += rtime
                self.pretime['MWRMAB_adj'] += rtime
                runtime['MWRMAB_adj'][epoch,0] += rtime
            if 'MWRMAB_adj_bt' in self.algos:
                self.pretime['MWRMAB_adj_bt'] += rtime
                self.time['MWRMAB_adj_bt'] += rtime
                runtime['MWRMAB_adj_bt'][epoch,0] += rtime
            #print("time",self.time['MWRMAB_adj'],time.time()-start_time)


            if 'OPT_fair' in self.algos:
                # make a new CombMDP object for each T
                start_time = time.time()
                comb_mdp_fair = combinatorial_mdp.CombMDP()
                x = comb_mdp_fair.make_mdp(self.env.T, self.env.R, self.env.C, self.env.B, LB=0, fairness_epsilon=None)
                comb_mdp_fair.value_iteration(self.gamma)
                rtime = (time.time() - start_time)
                self.time['OPT_fair'] += rtime
                runtime['OPT_fair'][epoch,0] += rtime

            else:
                comb_mdp_fair = None

            if 'OPT' in self.algos:
            # make a new CombMDP object for each T
                start_time = time.time()
                comb_mdp = combinatorial_mdp.CombMDP()
                x = comb_mdp.make_mdp(self.env.T, self.env.R, self.env.C, self.env.B, LB=0, fairness_epsilon=99999999)
                comb_mdp.value_iteration(self.gamma)
                rtime = (time.time() - start_time)
                self.time['OPT'] += rtime
                runtime['OPT'][epoch,0] += rtime
            else:
                comb_mdp = None

            for i in range(self.STEPS):
                saved_random_state = np.random.get_state()
                for algo in self.algos:
                    start_time = time.time()
                    np.random.set_state(saved_random_state)
                    actions_algo = get_actions(self.env, algo, gamma=self.gamma, 
                                                    per_arm_indexes=per_arm_indexes, 
                                                    per_arm_adjusted_indexes=per_arm_adjusted_indexes, comb_mdp=comb_mdp, 
                                                    comb_mdp_fair=comb_mdp_fair,
                                               )
                    self.time[algo] += (time.time() - start_time)
                    runtime[algo][epoch,i+1] = runtime[algo][epoch,i] + (time.time() - start_time)
                    current_state, spent_budget, reward = self.env.step(actions_algo, algo)
                    rewards[algo][epoch,i] = reward
                    actions[algo][epoch,i] = actions_algo
                    usedBudget[algo][epoch,i] = spent_budget
        
        if saveResults:
            if self.data == 'orderedWorkers':
                path_cost = f'_minC{self.env.minC}_maxC{self.env.maxC-1}' if self.env.minC is not None else f'C{self.env.cost}'
            if self.data == 'constantCosts':
                path_cost = f'_C{self.env.cost}'
            if self.data == 'decoupledCounterexample':
                path_cost = ''
            filename = f'{self.data}_N{self.env.N}_B{int(self.env.B)}_M{self.env.M}{path_cost}'
            path = f'./logs/meanResults/{self.data}/{filename}.csv'
            saveMeanResults(self, rewards, usedBudget, path)

            with open(f'./logs/runtimes/{self.data}/{filename}.pkl', 'wb') as f:
                pickle.dump(runtime, f)

        return rewards, actions, usedBudget, runtime

        
