import numpy as np
import argparse
import os
from env_list import env_list
from environment import Environment, EnvInM
from run_alg import run

HORIZON = 20000
REPETITIONS = 100
PLOT_DIR = "C:/Users/USER/Code/diminishing-exploration/plot"
NB_SEGS = 5
NB_OF_INSTANCES = 1

def main(variant):

    ################### Argument setting ###################
    alg                 = variant.get('alg')
    diminishing         = variant.get('diminishing')
    skip                = variant.get('skip')
    alg_arg             = variant.get('alg_arg')
    experiment          = variant.get('experiment')
    horizon             = variant.get('horizon')
    repetitions         = variant.get('repetitions')
    envId               = variant.get('envId')
    nb_of_instances     = variant.get('nb_of_instances')
    rand_instance       = variant.get('rand_instance')
    ID                  = variant.get('ID')
    seed                = variant.get('seed')
    if seed != None:
        np.random.seed(seed)

    alg_arg = parse_config(alg_arg)
    
    print(alg_arg)

    ################### Create the folder ###################
    if os.path.isdir(PLOT_DIR):
        print("{} is already a directory here...".format(PLOT_DIR))
    elif os.path.isfile(PLOT_DIR):
        raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(PLOT_DIR))
    else:
        os.mkdir(PLOT_DIR)
        print(f"Creat the directory {PLOT_DIR}")
    
    print('===================================')
    print(f'Experiment scaling in {experiment}')

    if rand_instance: 
        print(f'Number of random instance: {nb_of_instances}')
    else:
        print(f'envId: {envId}')

    print(f'Repetitions: {repetitions}')
    if experiment != 'T':
        print(f'Horizon: {horizon}')

    subfolder = f"/rand_instance" if rand_instance else f"/normal"
    plot_dir = CreateFolder(PLOT_DIR, subfolder)


    env_mean = env_list(horizon, envId,  rand_instance=rand_instance)
    nb_arms = len(env_mean["params"]["listOfMeans"][0])
    if experiment != 'K':
        print(f"Number of arms: {nb_arms}")
    nb_break_point = len(env_mean["params"]["changePoints"])
    if experiment != 'M':
        print(f"Number of change points: {nb_break_point}")

    print(env_mean["params"]["listOfMeans"])


    ################### The type of experiment ###################

    # Scaling in t
    if experiment == 't':
        subfolder = f"/K{nb_arms}_T{horizon}_N{repetitions}_M{nb_break_point}_envId{envId}_seed{seed}"
        plot_dir = CreateFolder(plot_dir, subfolder)
        cfg_env = {
            'repetitions': repetitions,
            'nb_arms': nb_arms,
            'nb_break_points': nb_break_point,
            'horizon': horizon
        }        
        env_samples = Environment(cfg_env, env_mean)
        dict_index_list = [0]
        env_samples_dict = {0 : env_samples}      
    # Scaling in M  
    elif experiment == 'M':
        subfolder = f"/M_scaling_K{nb_arms}_T{horizon}_N{repetitions}_envId{envId}_seed{seed}"
        plot_dir = CreateFolder(plot_dir, subfolder)
        cfg_env = {
            'repetitions': repetitions,
            'nb_arms': nb_arms,
            'nb_break_points': nb_break_point,
            'horizon': horizon
        }
        env_in_M = EnvInM(cfg_env)
        dict_index_list, env_samples_dict = env_in_M(env_mean)
    # Scaling in T
    elif experiment == 'T':
        subfolder = f"/T_scaling_K{nb_arms}_N{repetitions}_M{nb_break_point}_envId{envId}_seed{seed}"
        plot_dir = CreateFolder(plot_dir, subfolder)
        cfg_env = {
            'repetitions': repetitions,
            'nb_arms': nb_arms,
            'nb_break_points': nb_break_point,
        }
        print('...')
    # Scaling in K
    elif experiment == 'K':
        subfolder = f"/K_scaling_T{horizon}_N{repetitions}_M{nb_break_point}_envId{envId}_seed{seed}"
        plot_dir = CreateFolder(plot_dir, subfolder)
        cfg_env = {
            'repetitions': repetitions,
            'nb_break_points': nb_break_point,
            'horizon': horizon
        }
        print('...')
    cfg = {
        'alg' : alg,
        'diminishing' : diminishing,
        'skip' : skip,
        'alg_arg' : alg_arg,
        'repetitions': repetitions,
        'nb_arms': nb_arms,
        'nb_break_points': nb_break_point,
        'horizon': horizon,
        'dict_index_list' : dict_index_list,
        'env_samples_dict' : env_samples_dict,
        'nb_of_instances' : nb_of_instances,
        'experiment' : experiment,
        'plot_dir' : plot_dir,
    }
    run(cfg)
    
def CreateFolder(Plot_dir = PLOT_DIR, subfolder = ""):
    plot_dir = Plot_dir + subfolder
    if os.path.isdir(plot_dir):
        print("{} is already a directory here...".format(plot_dir))
    elif os.path.isfile(plot_dir):
        raise ValueError("[ERROR] {} is a file, cannot use it as a directory !".format(plot_dir))
    else:
        os.mkdir(plot_dir)
    print("Using sub folder = '{}' and plotting in '{}'...".format(subfolder, plot_dir))
    return plot_dir

def parse_config(config_str):
    config = {}
    key_value_pairs = config_str.split(',')
    for pair in key_value_pairs:
        key, value = pair.split("=")
        config[key] = value
    return config

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--alg', type=str, default="MUCB")
    parser.add_argument('--diminishing', type=bool, default=False)
    parser.add_argument('--skip', type=bool, default=False)
    parser.add_argument('--alg_arg', type=str, default="w=200")
    parser.add_argument('--experiment', type=str, default='t')
    parser.add_argument('--horizon', type=int, default=HORIZON)  
    parser.add_argument('--repetitions', type=int, default=REPETITIONS)
    parser.add_argument('--envId', type=int, default=0)
    parser.add_argument('--nb_of_instances', type=int, default=NB_OF_INSTANCES)
    parser.add_argument('--rand_instance', type=bool, default=False)
    parser.add_argument('--ID', type=int, default=None)
    parser.add_argument('--seed', type=int, default=10)

    args = parser.parse_args()

    main(variant=vars(args))
    