import numpy as np
import pandas as pd
from tqdm import trange
from Agent import Agent
import argparse
from util import ratio_figure

def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("-runs", type=int, default=10,
                        help='the number of macro replications')
    parser.add_argument("-experiment_id", type=int, default=2,
                        help='experiment id')
    parser.add_argument("-n0", type=int, default=10,
                        help='the number of initial budget')
    parser.add_argument("-PCS", type=int, default=0,
                        help='whether to estimate PCS')
    parser.add_argument("-delta", type=float, default=0.1,
                        help='the threshold on probability of correct selection')
    parser.add_argument("-sampling_rule", type=str, default="SEQSR",
                        help='ASR, USR, ESR, SEQSR, FWSR')
    parser.add_argument("-obsve_budget", type=int, default=200000,
                        help='number of observation collected')
    parser.add_argument("-draw_ratio", type=int, default=0,
                        help='whether to draw the allocation ratio')
    args = parser.parse_args()
    return args

def simulate(runs, obsve_budget, Agent, draw_ratio):

    best_action_counts = np.zeros((runs, obsve_budget, Agent.s))
    sample_complexity = np.zeros(runs)

    for r in trange(runs):
        Agent.reset()
        total_sample = 0
        while not Agent.stop():
            total_sample += 1
            next_task, next_arm, next_cons = Agent.sample()
            Agent.step(next_task, next_arm, next_cons)
            selected_arm = Agent.select()
            if total_sample < obsve_budget:
                best_action_counts[r, total_sample, :] = (selected_arm == Agent.opt_solution).astype(int)
        sample_complexity[r] = total_sample
    PCS = np.all(best_action_counts == 1, axis=2).astype(int).mean(axis=0) # fixed budget setting
    Error = np.sqrt(np.all(best_action_counts == 1, axis=2).astype(int).var(axis=0, ddof=1) / runs)
    if draw_ratio:
        ratio_figure(Agent.ratio_hist)
    results1 = {"PCS": PCS.ravel(), "Error": Error}
    results2 = {"Complexity": sample_complexity}
    return results1, results2

def main(args):
    runs = args.runs
    experiment_id = args.experiment_id
    n0 = args.n0
    PCS = args.PCS
    delta = args.delta
    sampling_rule = args.sampling_rule
    obsve_budget = args.obsve_budget
    draw_ratio = args.draw_ratio

    #simulation
    Alg = Agent(experiment_id, n0, PCS, delta, obsve_budget, sampling_rule)
    _, results2 = simulate(runs, obsve_budget, Alg, draw_ratio)
    pd.DataFrame(results2).to_csv(f'Results/EXP{experiment_id}_{sampling_rule}_n0_{n0}_PCS_{PCS}_delta_{delta}_sample_complexity.csv')

if __name__ == '__main__':
    args = parse()
    main(args)