import argparse
import numpy as np
import time
from tqdm import tqdm

import Algorithm
import Bandits


def parse_args():
    """
    Specifies command line arguments for the program.
    """
    parser = argparse.ArgumentParser(description='Best arm identification')

    parser.add_argument('--seed', default=1, type=int,
                        help='Seed for random number generators')
    # default best-arm options
    parser.add_argument('--n', default=10, type=int,
                        help='number of total arms')
    parser.add_argument('--arm_dist', default='normal',
                        help='number of total arms')
    parser.add_argument('--delta', default=1e-200, type=float,
                        help='1 - target confidence')
    parser.add_argument('--epsilon', default=0.01, type=float,
                        help='epsilon')
    parser.add_argument('--method', default='Opt-BBAI',
                        help='method')
    parser.add_argument('--num_sim', default=100, type=int,
                        help='number of total simulation')
    return parser.parse_args()


def main():
    args = parse_args()
    np.random.seed(args.seed)
    np.set_printoptions(threshold=np.inf)
    num_arms = args.n
    delta = args.delta
    epsilon = args.epsilon
    print('epsilon:', epsilon)
    alpha = 1.001
    L1 = int(np.ceil(2 * np.log(num_arms) * np.sqrt(np.log(1 / delta))))
    print('L1:', L1)
    L2 = int(np.ceil(80 * np.log(1 / delta) * np.log(np.log(1 / delta)) / num_arms))
    print('L2:', L2)
    L3 = int(np.ceil(np.log(1 / delta) * np.log(1 / delta)))
    print('L3:', L3)
    if args.arm_dist == 'uniform':
        # uniform
        arm_means = np.random.uniform(0.2, 0.4, num_arms)
        arm_means[0] = 0.5
        print('arm_means_uniform:', arm_means)
    elif args.arm_dist == 'normal':
        # normal
        arm_means = np.random.normal(0.2, 0.2, num_arms)
        for i in range(num_arms):
            mean = arm_means[i]
            while mean > 0.4 or mean < 0:
                mean = np.random.normal(0.2, 0.2)
            arm_means[i] = mean
        arm_means[0] = 0.6
        print('arm_means_normal:', arm_means)
    delta = delta_list[k]
    time_list = []
    sample_list = []
    batch_list = []
    num_error = 0
    num_condition = 0
    for i in tqdm(range(args.num_sim)):
        if args.method == 'TrackandStop':
            start_time = time.time()
            # these parameters can also be set individually for each arm
            sim = Bandits.Simulator(num_arms=num_arms, arm_means=arm_means)
            total_sample, mu, best_arm = Algorithm.TrackandStop(delta, alpha, sim)
            sample_list.append(total_sample)
            if best_arm != np.argmax(arm_means) + 1:
                num_error += 1
            end_time = time.time()
            time_list.append(end_time - start_time)
        elif args.method == 'Tri-BBAI':
            start_time = time.time()
            # these parameters can also be set individually for each arm
            sim = Bandits.Simulator(num_arms=num_arms, arm_means=arm_means)
            total_sample, mu, best_arm, condition, batch = Algorithm.Tri_BBAI(epsilon, delta, L1, L2, L3, alpha,
                                                                              sim)
            sample_list.append(total_sample)
            batch_list.append(batch)
            if best_arm != np.argmax(arm_means) + 1:
                num_error += 1
            if condition:
                num_condition += 1
            end_time = time.time()
            time_list.append(end_time - start_time)
        elif args.method == 'Opt-BBAI':
            start_time = time.time()
            # these parameters can also be set individually for each arm
            sim = Bandits.Simulator(num_arms=num_arms, arm_means=arm_means)
            total_sample, mu, best_arm, condition, batch = Algorithm.Opt_BBAI_New(epsilon, delta, L1, L2, L3, alpha,
                                                                                  sim)
            sample_list.append(total_sample)
            batch_list.append(batch)
            if best_arm != np.argmax(arm_means) + 1:
                num_error += 1
            if condition:
                num_condition += 1
            end_time = time.time()
            time_list.append(end_time - start_time)
        elif args.method == 'Top-k':
            start_time = time.time()
            # these parameters can also be set individually for each arm
            sim = Bandits.Simulator(num_arms=num_arms, arm_means=arm_means)
            total_sample, mu, best_arm = Algorithm.Top1DeltaEliminate(delta, sim)
            sample_list.append(total_sample)
            if best_arm != np.argmax(arm_means) + 1:
                num_error += 1
            end_time = time.time()
            time_list.append(end_time - start_time)
        elif args.method == 'EGE':
            start_time = time.time()
            # these parameters can also be set individually for each arm
            sim = Bandits.Simulator(num_arms=num_arms, arm_means=arm_means)
            total_sample, mu, best_arm = Algorithm.ExponentialGapElimination(delta, sim)
            sample_list.append(total_sample)
            if best_arm != np.argmax(arm_means) + 1:
                num_error += 1
            end_time = time.time()
            time_list.append(end_time - start_time)
    print('error_rate:', num_error / args.num_sim)
    print('meet_condition_rate:', num_condition / args.num_sim)
    print('Runtime Mean:', np.mean(time_list))
    print('Runtime Std:', np.std(time_list))
    print('Sample Mean:', np.mean(sample_list))
    print('Sample Std:', np.std(sample_list))
    print('Batch Mean:', np.mean(batch_list))
    print('Batch Std:', np.std(batch_list))


if __name__ == '__main__':
    main()
