import argparse
import os
import random
import sys
from datetime import datetime

import numpy as np

from algorithms.CLUB import CLUB
from algorithms.LinUCB import LinUCB, LinUCB_Ind
from algorithms.SCLUB import SCLUB
from algorithms.UniCorn import UniCorn
from algorithms.UniSCorn import UniSCorn
from env import Arm, User
from env.Environment import Environment
from utils.utils import set_random_seed

seeds_set = [2756048, 675510, 807110, 927, 218, 495, 515, 452]

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='')
    # in_folder: the folder containing input files
    parser.add_argument('--in_folder', dest='in_folder', default='input_data/synthetic', help='folder with input files')
    # out_folder: the folder to output
    parser.add_argument('--out_folder', dest='out_folder', default="output_data", help='the folder to output')
    # d: the dimension of the arm
    parser.add_argument('--d', dest='d', type=int, default=50, help='dimension of the arm')
    # user_num: the number of users
    parser.add_argument('--user_num', dest='user_num', type=int, default=50, help='user_num')
    # item_num: the number of items
    parser.add_argument('--item_num', dest='item_num', type=int, default=100, help='item_num')
    # horizons: T
    parser.add_argument('--horizons', dest='horizons', type=int, default=10000, help='horizons')
    # algorithms: the algorithms to run UniSCorn UniCorn CLUB SCLUB LinUCB-One LinUCB-Ind
    parser.add_argument('--algorithms', dest='algorithms', default="LinUCB-Ind", help='algorithms name')
    # seedIndex: the index of random seed
    parser.add_argument('--seedindex', dest='seedindex', type=int, default=1, help='seedIndex')
    # thread_num: the number of threads
    parser.add_argument('--thread_num', dest='thread_num', type=int, default=1, help='thread_num')
    args = parser.parse_args()

    set_random_seed(seeds_set[args.seedindex])
    # generate environment
    assert args.in_folder, "Please input the input folder"
    AM = Arm.ArmManager(args.in_folder)
    AM.loadArms()
    print(f'[main] Finish loading arms: {AM.n_arms}')
    UM = User.UserManager(args.in_folder)
    UM.loadUser()
    print(f'[main] Finish loading users: {UM.n_user}')
    environment = Environment(d=args.d, num_users=args.user_num, num_items=args.item_num, arms=AM.arms, users=UM.users, type="Stochastic")

    # check which algorithms to run
    assert args.algorithms, "Please input the algorithms name"
    algorithms_name = args.algorithms.split(' ')
    algorithms = {}
    if "SCLUB" in algorithms_name:
        algorithms["SCLUB"] = SCLUB(args.user_num, args.d, args.horizons)
    if "CLUB" in algorithms_name:
        algorithms["CLUB"] = CLUB(args.user_num, args.d, args.horizons)
    if "LinUCB-One" in algorithms_name:
        algorithms["LinUCB-One"] = LinUCB(args.user_num, args.d, args.horizons)
    if "LinUCB-Ind" in algorithms_name:
        algorithms["LinUCB-Ind"] = LinUCB_Ind(args.user_num, args.d, args.horizons)
    if "UniCorn" in algorithms_name:
        algorithms["UniCLUB"] = UniCorn(args.user_num, args.d, args.horizons)
    if "UniSCorn" in algorithms_name:
        algorithms["UniSCLUB"] = UniSCorn(args.user_num, args.d, args.horizons)

    # check the number of algorithms
    assert len(algorithms) == len(algorithms_name), '[main] Wrong algorithms setup'
    print(f'[main] Finish setting up algorithms: {list(algorithms.keys())} at {datetime.now().strftime("%m-%d %H:%M")}')

    # make sure the output directory exists
    root_dir = os.path.dirname(os.path.abspath(__file__))
    output_dir = os.path.join(root_dir, args.out_folder)
    print(f'[main] Output directory: {output_dir}')
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    fn_suffix = f'_users{args.user_num}_items{args.item_num}_d{args.d}_T{args.horizons}'
    # run the algorithms
    for name, algo in algorithms.items():
        #set random seed
        set_random_seed(seeds_set[args.seedindex])

        print(f'[main] Start running algorithm {name} at {datetime.now().strftime("%m-%d %H:%M")}')
        results = algo.run(environment)
        print(f'[main] Finish running algorithm {name} at {datetime.now().strftime("%m-%d %H:%M")}')
        print(results["regret"][-1])
        np.savez(f'{args.out_folder}/{name}' + fn_suffix + '_seed_' + f'{args.seedindex}' + '.npz', **results)
        print(f'[main] Finish saving results of algorithm {name} at {datetime.now().strftime("%m-%d %H:%M")}')

    print(f'[main] Finish running {len(algorithms)} algorithms at {datetime.now().strftime("%m-%d %H:%M")}')
