import math, tqdm, json, os
import numpy as np
import argparse
from src.utils import set_gpus
from src.vmf_sampler import VMFSampler, SMAXSampler


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("-k", "--kappa", type=float, required=True,
                        help="Concentration parameter")
    parser.add_argument("-a", "--alpha", type=float, required=True,
                        help="Similarity with main embedding")
    parser.add_argument("-d", "--d", type=int, required=True,
                        help="Dimension of embeddings")
    parser.add_argument("-N", "--N", type=int, required=True,
                        help="Number of point")
    parser.add_argument("-bs", "--bs", type=int, required=True,
                        help="Batch size")
    parser.add_argument("-nt", "--n_trials", type=int, required=True,
                        help="Batch size")

    args = parser.parse_args()
    exp_path = "simulations/k=%.1f_a=%.2f_d=%d_N=%d_samples=%d/" % (args.kappa, args.alpha, args.d, args.N, args.bs*args.n_trials)
    os.makedirs(exp_path, exist_ok=True)
    params = {"k": args.kappa,
              "a": args.alpha,
              "d": args.d,
              "N": args.N,
              "samples": args.bs * args.n_trials}
    with open(exp_path+'params.json', 'w') as f:
        json.dump(params, f)
    dev = set_gpus(0.20)

    # Run Monte Carlo simulation for vMF Exploration
    vmf = VMFSampler(dev)
    exp = vmf.experiment(args.kappa, args.N, args.d, args.alpha, args.n_trials, args.bs)
    np.save(exp_path +"/data", exp)

    # Run simulation for Boltzmann Exploration

    smSampler = SMAXSampler(dev)
    exp_sm = smSampler.experiment(args.kappa, args.N, args.d, args.alpha, args.n_trials)
    np.save(exp_path + "/data_smax", exp_sm)

    smSampler = SMAXSampler(dev)
    kappa = 1.0
    N=10**6
    d=4
    alpha=0.9
    n_samples = 7680000
    exp_sm = smSampler.experiment(kappa, N, d, alpha, 1)
    exp_path = "simulations/k=%.1f_a=%.2f_d=%d_N=%d_samples=%d/" % (kappa, alpha, d, N, n_samples)
    np.save(exp_path + "/data_smax", exp_sm)