import numpy as np
import torch
from tqdm import tqdm

import sys
sys.path.append('..')
from DRKSD.experiments import experiment_G1D, experiment_INTRACTABLE, experiment_MNIST, experiment_RBM


if __name__ == "__main__":

    n_steps = 1000
    n_experiments = 100

    pi_hat_model_list = ["boosting", "logistic"]
    beta_hat_model_list = ["DRF", "CME"]
    experiment_type_list = ["G1D", "INTRACTABLE"]

    
    
    pi_hat_model_list = ["RF"]
    beta_hat_model_list = ["CME"]
    experiment_type_list = ["INTRACTABLE"]
    n_list = np.array([500])
    pi_hat_model_list = ["logistic", "boosting"]
    beta_hat_model_list = ["CME", "1NN"]
    experiment_type_list = ["G1D"]
    n_list = np.arange(200, 850, 50)
    

    for experiment_type in experiment_type_list:
        for pi_hat_model in pi_hat_model_list:
            for beta_hat_model in beta_hat_model_list:
                for n in n_list:

                    print(experiment_type + "/" + pi_hat_model + "_" + beta_hat_model + "_" + str(n))
                
                    if experiment_type == "G1D":
                        print("Experiment type G1D!")
                        experiment = experiment_G1D(n=n, n_steps=n_steps, pi_hat_model=pi_hat_model, beta_hat_model=beta_hat_model, 
                                                    mean=0., sigma=1, coef_ax=.5, stdx=1)
                        len_params = 3
                    elif experiment_type == "INTRACTABLE":
                        print("Experiment type INTRACTABLE!")
                        experiment = experiment_INTRACTABLE(n=n, n_steps=n_steps, pi_hat_model=pi_hat_model, beta_hat_model=beta_hat_model,
                                                            coef = torch.ones(1))
                        len_params = 6
                   
                    if experiment_type == "MNIST":
                        print("Experiment type MNIST!")
                        n_layer = 10
                        print("n_layer: ", n_layer)
                        for digit in range(10):
                            print("Training for digit: ", digit)
                            
                            experiment = experiment_MNIST(n=n, n_steps=n_steps, pi_hat_model=pi_hat_model, beta_hat_model=beta_hat_model,
                                                    digit=digit, n_layer=n_layer)
                            len_params = 3*n_layer

                            T_reshaped = np.zeros((n_experiments, len_params))
                            for random_seed in tqdm(range(n_experiments)):
                                T_reshaped[random_seed, :] = experiment(random_seed) 
                            
                            with open("../data/" + experiment_type + "/" + pi_hat_model + "_" + beta_hat_model + "_" + str(n) + "digit" + str(digit) + "n_layer" + str(n_layer) + ".npy", 'wb') as f:
                                np.save(f, T_reshaped)
                    elif experiment_type == "RBM":
                            print("Experiment type RBM!")
                            dvisible = 2
                            dhidden = 1
                            option = 3
                            experiment = experiment_RBM(n=n, n_burn=1000, pi_hat_model=pi_hat_model, beta_hat_model=beta_hat_model,
                                                            dvisible = dvisible, dhidden = dhidden)
                            values = experiment(0)
                            with open("../data/" + experiment_type + "/" + pi_hat_model + "_" + beta_hat_model + "_" + str(n) + "option" + str(option) + ".npy", 'wb') as f:
                                np.save(f, values)
                    else:
                        T_reshaped = np.zeros((n_experiments, len_params))
                        for random_seed in tqdm(range(n_experiments)):
                            T_reshaped[random_seed, :] = experiment(random_seed) 
                        
                        with open("../data/" + experiment_type + "/" + pi_hat_model + "_" + beta_hat_model + "_" + str(n) + ".npy", 'wb') as f:
                            np.save(f, T_reshaped)

