import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.optim as optim


from utils import *
from networks import *
from data_loader import DatasetLoader

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='breast',
                    help="The name of the dataset")
parser.add_argument('--epochs', type=int, default=1000,
                    help="The number of epochs")
parser.add_argument('--n_runs', type=int, default=10,
                    help="The number of runs")
parser.add_argument('--n_experts', type=int, default=100,
                    help="The number of experts")

args = parser.parse_args()
dataset_name = args.dataset
n_runs = args.n_runs
num_epochs = args.epochs
num_experts = args.n_experts

print('Experiments on', dataset_name)
print('Epochs:', num_epochs, 'runs:', n_runs, 'num experts:', num_experts)

experiment_dict = {}
experiment_dict_accuracy = {}
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

epsilons_list = [0, 0.5, 1, 2, 4, 5, 10]
experiments_list = [('No LDP',0)]
for ep in epsilons_list:
    experiments_list.append(('LDP', ep))

random_state = 42
data_loader = DatasetLoader(random_state = random_state) 
models_list = []
for experiment in tqdm(experiments_list):
    train_acc = 0.0
    test_acc = 0.0
    train_loss = 0.0
    test_loss = 0.0

    # Seeds for reproductibility
    torch.manual_seed(random_state) 
    np.random.seed(random_state)
    torch.cuda.manual_seed_all(random_state)
    models = []
    for _ in range(n_runs):
        epsilon = experiment[1]
        # Load and split the dataset into training and testing sets
        X_train, X_test, y_train, y_test = data_loader.load(dataset_name)

        # Adjust input_size and output_size
        input_size = X_train.shape[1]
        output_size = 1  # Binary classification
        
        if experiment[0]=='LDP':
            model = MoEModel( input_size, output_size, num_experts=num_experts, ldp_condition=True, epsilon=epsilon).to(device)
        else:
            model = MoEModel( input_size, output_size, num_experts=num_experts, ldp_condition=False).to(device)
        
        # Convert NumPy arrays to PyTorch tensors
        X_train = torch.FloatTensor(X_train).to(device)
        y_train = torch.FloatTensor(y_train).to(device)
        X_test = torch.FloatTensor(X_test).to(device)
        y_test = torch.FloatTensor(y_test).to(device)

        criterion = ProbitLoss(reduction='none')  

        # Optimization using the Stochastic Gradient Descent
        optimizer = optim.SGD(model.parameters(), lr=0.1)
        train_losses_list, test_losses_list, train_accuracy, test_accuracy = train(model, X_train, y_train, X_test, y_test,
                                                                    criterion, optimizer,  batch_size=256, num_epochs=num_epochs)
        train_acc += train_accuracy/n_runs
        test_acc += test_accuracy/n_runs
        train_loss += train_losses_list[-1]/n_runs
        test_loss += test_losses_list[-1]/n_runs
        models.append(model)
    label = ''
    if experiment[0] == 'LDP':
        label = str(epsilon) +'-'
        
    experiment_dict[label +experiment[0] ] = [train_loss, test_loss]
    models_list.append(models)
    

# open a file, where you ant to store the models weights
file = open(dataset_name+'_models_weights.pts', 'wb')

# dump information to that file
pickle.dump(models_list, file)

# close the file
file.close()

experiment_dict = {}
experiment_dict_accuracy = {}
experiment_dict_std = {}

# We use the exact same seeds to make sure we have the same train and test sets to evaluate our models
data_loader = DatasetLoader()

for i in range(len(models_list)):
    experiment = experiments_list[i]
    models = models_list[i]
    train_accuracies = []
    test_accuracies = []
    train_losses = []
    test_losses = []
    for j in range(len(models)):
        model = models[j]
        epsilon = experiment[1]
        X_train, X_test, y_train, y_test = data_loader.load(dataset_name)
        # Convert NumPy arrays to PyTorch tensors
        X_train = torch.FloatTensor(X_train).to(device)
        y_train = torch.FloatTensor(y_train).to(device)
        X_test = torch.FloatTensor(X_test).to(device)
        y_test = torch.FloatTensor(y_test).to(device)
        m = len(X_train) # The size of the training set

        X_test, X_val, y_test, y_val = train_test_split(X_test, y_test, test_size=0.2, random_state=42)
        X_mean = X_val.mean(axis=0).reshape(1, X_val.shape[1])

        criterion = ProbitLoss(reduction='none')  # Probit loss for binary classification
        label = ''
        if experiment[0] == 'LDP':
            label = str(epsilon) +'-'
        test_loss, test_acc = validate(model, X_test, y_test, criterion)
        train_loss, train_acc = validate(model, X_train, y_train, criterion)

        train_losses.append(train_loss.item())
        test_losses.append(test_loss.item())

        rho_g = model.gating_network(X_mean).flatten()
        experts_norms = model.get_experts_norm().flatten()
        if model.ldp_condition:
            kl = rho_g.dot(experts_norms).item()
            bound =   np.exp(epsilon)*train_loss.item()  * kl/(2*m) + np.sqrt(np.log(1/0.05))/m
        else:
            bound = np.nan
    experiment_dict[label +experiment[0] ] = [np.mean(train_losses),np.mean(test_losses)]
    experiment_dict_std[label +experiment[0] ] = [np.std(train_losses),np.std(test_losses)]

    
print('Losses for experiment on', dataset_name, ':')
df_bc_loss = pd.DataFrame(experiment_dict, index=['Train loss', 'Test loss']).head()
print(df_bc_loss)

print('The associated standard deviation:')
df_bc_std = pd.DataFrame(experiment_dict_std, index=['Train loss', 'Test loss']).head()
print(df_bc_std)
