import numpy as np
import pandas as pd
from tqdm import trange
from Agent import Agent
import argparse
from util import ratio_figure

def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("-runs", type=int, default=1,
                        help='the number of experiment runs')
    parser.add_argument("-fixed_budget", type=int, default=2000,
                        help='simulation budget per experiment')
    parser.add_argument("-experiment_id", type=int, default=2,
                        help='experiment id')
    parser.add_argument("-rho", type=float, default=0.5,
                        help='the correlation coefficient')
    parser.add_argument("-policy", type=str, default="TS",
                        help='EA, AOP, BC, TS, CORSA')
    parser.add_argument("-n0", type=int, default=10,
                        help='number of initial budget')
    parser.add_argument("-PCS", type=int, default=1,
                        help='whether to estimate PCS')
    parser.add_argument("-delta", type=float, default=0.1,
                        help='the threshold on probability of correct selection')
    parser.add_argument("-draw_ratio", type=int, default=0,
                        help='whether to draw the allocation ratio')
    args = parser.parse_args()
    return args

def simulate(runs, budget, Agent, draw_ratio):

    best_action_counts = np.zeros((runs, budget))
    sample_complexity = np.zeros(runs)

    for r in trange(runs):
        Agent.reset()
        total_sample = 0
        while not Agent.stop():
            total_sample += 1
            next_alternative = Agent.sample() # key element1
            Agent.step(next_alternative)
            selected_alternative = Agent.select()
            best_action_counts[r, total_sample-1] = (selected_alternative == Agent.best_alternative).astype(int)
        sample_complexity[r] = total_sample
    PCS = best_action_counts.mean(axis=0)
    Error_bar = np.sqrt((PCS * (1 - PCS))/runs)
    Nums = Agent.alternative_count
    if draw_ratio:
        ratio_figure(Agent.ratio_hist, Agent.policy)
    return PCS, Error_bar, Nums, sample_complexity

def main(args):

    runs = args.runs
    budget = args.fixed_budget
    experiment_id = args.experiment_id
    rho = args.rho
    policy = args.policy
    n0 = args.n0
    PCS = args.PCS
    delta = args.delta
    draw_ratio = args.draw_ratio

    Alg = Agent(experiment_id, rho, policy, n0, PCS, budget, delta)
    PCS_value, Error_bar, counts, sample_complexity = simulate(runs, budget, Alg, draw_ratio)

    results0 = {"Complexity": sample_complexity}
    results1 = {"PCS": PCS_value.ravel(),"Error":Error_bar.ravel()}
    results2 = {"Nums": counts.ravel()}
    if PCS:
        pd.DataFrame(results1).to_csv(f'Results/EXP{experiment_id}_{rho}_{policy}_{delta}_PCS.csv')
    pd.DataFrame(results0).to_csv(f'Results/EXP{experiment_id}_{rho}_{policy}_{delta}_Complexity.csv')
    pd.DataFrame(results2).to_csv(f'Results/EXP{experiment_id}_{rho}_{policy}_{delta}_Allocations.csv')

if __name__ == '__main__':
    args = parse()
    main(args)



