import os
import pandas as pd
import numpy as np
from tqdm import tqdm

from experiments.runner import run
from tests.test_utils import construct_all_ones_query, construct_sign_query
from experiments.plotter import get_scaling_param, plot_scaling
from adaptive_softmax.constants import (
    SCALING_POINTS,
    NUM_TRIALS,
    SCALING_RESULT_DIR,
    ALL_ONES_QUERY,
    SIGN_QUERY,
)

def get_run_results(run_data: pd.DataFrame):
    per_arm_budget = np.mean(run_data['budget_total'] / (run_data['n']))
    return per_arm_budget

def scaling_synthetic(n, initial_d, dataset, path_dir):
    avg_per_arm_budgets = []
    std_errs = []

    for point in range(1, 2):
        curr_d = initial_d * (10 * point)
        per_arm_budgets = []
    
        for trial in range(NUM_TRIALS):
            np.random.seed(trial)
            
            # NOTE: need to construct A each time too. Otherwise, exact search
            if dataset == SIGN_QUERY:
                A, x = construct_sign_query(n, curr_d) 
                noise_bound = 1
            elif dataset == ALL_ONES_QUERY: 
                A, x = construct_all_ones_query(n, curr_d, scale=(curr_d ** -0.9))
                noise_bound = None

            path = f"{path_dir}/d={curr_d}"
            run_data = run(
                save_to=path,
                model="scaling synthetic",
                dataset=dataset,
                A=A,
                X=np.array(x, ndmin=2),  # this is for compabitility with runner.py
                multiplicative_error=0.3,
                failure_probability=0.1,
                noise_bound = noise_bound,
                use_true_sftm = True,
                use_tune = False,
                train_size = 1,
                quiet=True,
            )

            per_arm_budget = get_run_results(run_data)
            per_arm_budgets.append(per_arm_budget)

        avg_per_arm_budgets.append(np.mean(per_arm_budgets))
        std_errs.append(np.std(per_arm_budgets) / np.sqrt(NUM_TRIALS))

    run_data = {
        'd': curr_d,
        'n': n,
        'per_arm_budgets': avg_per_arm_budgets,
        'stderr': std_errs,
    }

    return run_data

        
def run_synthetic(n, init_d, curr_time=None):
    for dataset in [ALL_ONES_QUERY, SIGN_QUERY]:
        save_dir = f"{SCALING_RESULT_DIR}/{dataset}/n{n}_init_d{init_d}"
        os.makedirs(save_dir, exist_ok=True)

        if not any(os.scandir(save_dir)):
            print("running again")
            scaling_synthetic(n=n, initial_d=init_d, dataset=dataset, path_dir=save_dir)

        dimensions, budgets, naive_budgets, stds, percentages, success_rates = get_scaling_param(save_dir)
        plot_scaling(dimensions, naive_budgets, budgets, stds, percentages, success_rates, save_dir, dataset)


if __name__ == "__main__":
    run_synthetic(
        n=100, 
        init_d=1000,
        #curr_time="07:06:23", # pass in specific time here if you don't want to rerun
    )
    
  
