import numpy as np
import os
import torch
from bandit import *
from alg_neural import *

from util import *

if __name__ == '__main__':
    N = 10
    T = 10000
    C = T
    model = "CNN" # CNN or MLP
    node = 100
    shuffle = True
    # dataset = ["letter", "adult","covertype"]
    # dataset = ["mushroom","shuttle","isolet"]
    # dataset = ["fashion", "mnist"]
    dataset = ["mnist"]
    # dataset = ["mushroom"]
    info = "CNN_0001"
    # info = ""

    filename = f"{info}_{N}_{T}"
    base_dir = os.path.join(".", f"Results/", filename)
    output_file = f"./Log/{filename}.txt"
    os.makedirs("./Log", exist_ok=True)

    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
    os.environ["CUDA_VISIBLE_DEVICES"]= "0" # Set the GPU 2 to use
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    algorithms = [
        (EpsilonGreedy, {"mode":model, "device":device, "node":node, "batch_size":32, "perturbation_std":0.0}, "epsilon"),
        (DeepFL, {"mode":model, "device":device, "node":node, "batch_size":32, "perturbation_std":1.0, 'type': 0}, "DeepFPL"),
        (NeuralTS, {"mode":model, "device":device, "node":node, "batch_size":32, "perturbation_std":0.0, "style":"UCB"}, "NeuralUCB"),
        (NeuralTS, {"mode":model, "device":device, "node":node, "batch_size":32, "perturbation_std":0.0, "style":"TS"}, "NeuralTS"),
        (DeepFP, {"mode":model, "device":device, "node":node, "batch_size":32, "perturbation_std":1.0}, "DeepFP"),
    ]

    environments = []
    for data in dataset:
        print("================== running dataset", data, "==================")
        envs = []
        for episode in range(N):
            envs.append(NeuralBandit(data, shuffle=shuffle, seed=episode, device=device, mode=model))
        # print("data dim: ", envs[0].data_d, ", K: ", envs[0].K.item())
        print("data dim: ", envs[0].d, ", K: ", envs[0].K.item())
        environments.append([data, envs[0].K, envs[0].d])
        
        res_dir = os.path.join(base_dir, data)
        os.makedirs(res_dir, exist_ok=True)

        for alg_def in algorithms:
            alg_class, alg_params, alg_name = alg_def[0], alg_def[1], alg_def[-1]        

            fname = os.path.join(res_dir, alg_name)        
            if os.path.exists(fname):
                print('File exists. Will load saved file. Moving on to the next algorithm')
            else:
                regret, _ = evaluate(alg_class, alg_params, envs, T, device)
                cum_regret = regret.cumsum(axis=0)
                np.savetxt(fname, cum_regret, delimiter=",")
        print("\n\n")

plot_results(T, filename, environments, [alg[-1] for alg in algorithms])
