from environment import Environment
from algorithm.MUCB import MUCB
from algorithm.CUSUM_UCB import CusumUCB
from algorithm.GLR_UCB import GLRUCB
from algorithm.DUCB import DUCB
from algorithm.DTS import DTS
from algorithm.Master import Master
from algorithm.META import Meta
from algorithm.AdSwitch import AdSwitch
from algorithm.ArmSwitch import ArmSwitch

HORIZON = 20000
REPETITIONS = 100
NB_SEGS = 5
NB_OF_INSTANCES = 1

def select_alg(
            repetitions = REPETITIONS,
            nb_arms = int,
            nb_break_points = NB_SEGS,
            horizon = int,
            env = Environment,
            alg = str, 
            diminishing = False,
            skip = False,
            arg = {},
            path = str
        ):
    match alg:
        case "MUCB":
            alpha = str_to_int(arg.get("alpha"))
            w = str_to_int(arg.get("w"))
            b = str_to_int(arg.get("b"))
            gamma = str_to_int(arg.get("gamma"))
            klUCB = str_to_bool(arg.get("klUCB"))
            skip_uncertainty = str_to_int(arg.get("skip_uncertainty"))
            algorithm = MUCB(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                klUCB = klUCB,
                path = path,
                diminishing = diminishing,
                skip = skip,
                alpha = alpha,
                gamma = gamma,
                w = w,
                b = b,
                skip_uncertainty = skip_uncertainty                
            )
        case "CUSUM_UCB":
            alpha = str_to_int(arg.get("alpha"))
            h = str_to_int(arg.get("h"))
            epsilon = str_to_int(arg.get("epsilon"))
            m = str_to_int(arg.get("M"))
            gamma = str_to_int(arg.get("gamma"))
            klUCB = str_to_bool(arg.get("klUCB"))
            skip_uncertainty = str_to_int(arg.get("skip_uncertainty"))
            algorithm = CusumUCB(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path,
                klUCB = klUCB,
                diminishing = diminishing,
                skip = skip,
                alpha = alpha,
                gamma = gamma,
                h = h, 
                epsilon = epsilon,
                m = m, 
                skip_uncertainty = skip_uncertainty                
            )
        case "GLR_UCB":
            alpha = str_to_int(arg.get("alpha"))
            gamma = str_to_int(arg.get("gamma"))
            klUCB = str_to_bool(arg.get("klUCB"))
            nb_break_points_known = str_to_bool(arg.get("nb_break_points_known"))
            skip_uncertainty = str_to_int(arg.get("skip_uncertainty"))
            algorithm = GLRUCB(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path,
                klUCB = klUCB,
                diminishing = diminishing,
                skip = skip,
                alpha = alpha,
                gamma = gamma,
                nb_break_points_known = nb_break_points_known,
                skip_uncertainty = skip_uncertainty                
            )   
        case "DUCB":
            gamma = str_to_int(arg.get("gamma"))
            # klUCB = str_to_int(arg.get("klUCB")) 
            klUCB = True                      
            algorithm = DUCB(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path,  
                klUCB = klUCB, 
                gamma = gamma              
            )   
        case "DTS": 
            gamma = str_to_int(arg.get("gamma"))
            algorithm = DTS(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path,   
                gamma = gamma              
            )         
        case "master":
            algorithm = Master(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path                
            )
        case "meta":
            algorithm = Meta(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path                
            )         
        case "AdSwitch":
            algorithm = AdSwitch(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path                 
            )  
        case "ArmSwitch":
            algorithm = ArmSwitch(
                repetitions,
                nb_arms,
                nb_break_points,
                horizon,
                env,
                path = path                 
            )    
            
    return algorithm

def str_to_int(string = None):
    if string != None:
        if string.find(".") != -1:    
            string = float(string)
        else: 
            string = int(string)
    return string 

def str_to_bool(string = None):
    if string != None:
        string = bool(string)
    return string 