'''
Generate results for Simple dataset using empirical risk minimization (ERM) and deceptive risk minimization (DRM).
'''

####################################################################################################    
num_seeds = 10 # Number of seeds to run

# Parameters
train_len = 2000 # Number of training examples
detect_len = 1000 # Number of detection examples
test_len = 1000 # Number of test examples

train_batch_size = 64 # Batch size for training
detect_batch_size = 1000 # Number of detection examples to load in each batch

martingale_penalty = 500.0 # Martingale penalty
temperature = 1.0 # Temperature for softmax
softrank_regularization_factor = 0.001 # Regularization factor for Softrank

num_training_epochs = 2 # Number of training epochs


####################################################################################################    

####################################################################################################    
import os, sys, shutil
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

# Choose device
import torch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# Simple dataset
from simple_dataset import *

# Define neural network
from net import *

# Training code
from core.train import *
####################################################################################################

####################################################################################################

# Store results
train_success_all_erm = np.zeros((num_seeds, 1))
test_success_all_erm = np.zeros((num_seeds, 1))
train_success_all_drm = np.zeros((num_seeds, 1))
test_success_all_drm = np.zeros((num_seeds, 1))

for seed in range(num_seeds):
    ################################################################################################
    print("Seed:", seed+1, "/", num_seeds)
    print(" ")
    
    # Delete existing dataset
    shutil.rmtree('./data') if os.path.exists('./data') else None

    # Define train and test loaders
    train_loader = torch.utils.data.DataLoader(SimpleExample(root='./data', env='train', train_len=train_len, detect_len=detect_len, test_len=test_len), batch_size=train_batch_size, shuffle=True, **kwargs)
    test_loader = torch.utils.data.DataLoader(SimpleExample(root='./data', env='test', train_len=train_len, detect_len=detect_len, test_len=test_len), batch_size=1000, shuffle=True, **kwargs)

    # Define batch detection set loader (loads detection set in batches)
    detect_loader_batch = torch.utils.data.DataLoader(SimpleExample(root='./data', env='detect', train_len=train_len, detect_len=detect_len, test_len=test_len), batch_size=detect_batch_size, shuffle=False, **kwargs)
    
    # Run training with empirical risk minimization (ERM)
    model = Net().to(device)
    train_success_erm, test_success_erm = train_and_test_erm(train_loader, test_loader, model)
    
    # Run training with deceptive risk minimization (DRM)
    model = Net().to(device)
    train_success_drm, test_success_drm = train_and_test_drm(train_loader, test_loader, detect_loader_batch, model, martingale_penalty=martingale_penalty, temperature=temperature, softrank_regularization_factor=softrank_regularization_factor, num_training_epochs=num_training_epochs)

    
    # Store results
    train_success_all_erm[seed] = train_success_erm
    test_success_all_erm[seed] = test_success_erm
    train_success_all_drm[seed] = train_success_drm
    test_success_all_drm[seed] = test_success_drm
    
    ################################################################################################
    
# Print results
print('Train success (ERM):', np.mean(train_success_all_erm), '+-', np.std(train_success_all_erm))
print('Test success (ERM):', np.mean(test_success_all_erm), '+-', np.std(test_success_all_erm))
print('Train success (DRM):', np.mean(train_success_all_drm), '+-', np.std(train_success_all_drm))
print('Test success (DRM):', np.mean(test_success_all_drm), '+-', np.std(test_success_all_drm))

# Save results
np.savez('results.npz',
         train_success_all_erm=train_success_all_erm,
         test_success_all_erm=test_success_all_erm,
         train_success_all_drm=train_success_all_drm,
         test_success_all_drm=test_success_all_drm)
####################################################################################################

    
    
