'''
Generate results for Colored MNIST dataset using empirical risk minimization (ERM) and deceptive risk minimization (DRM).
'''


####################################################################################################    
num_seeds = 10 # Number of seeds to run

# Parameters
train_len = 2000 # Number of training examples
test_len = 1000 # Number of test examples
train_batch_size = 64 # Batch size for training

detect_len = 5000 # Number of detection examples
detect_batch_size = 1000 # Number of detection examples to load in each batch
num_detect_sets = 3 # Number of detection sets

martingale_penalty = 5000 # Martingale penalty

num_epochs_erm = 2 # Number of epochs for empirical risk minimization
num_training_epochs = 3 # Number of training epochs

temperature = 1.0 # Temperature for softmax
softrank_regularization_type = "l2"
softrank_regularization_factor = 0.1 # Regularization factor for Softrank
learning_rate = 0.005 # Learning rate
####################################################################################################    

####################################################################################################    
import os, sys, shutil
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import numpy as np

# Choose device
import torch
from torchvision import transforms
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

# Colored MNIST dataset
from colored_mnist_dataset import *

# Define neural network
from convnet import *

# Training code
from core.train import *
import random
####################################################################################################

####################################################################################################  
# Set seeds
torch.manual_seed(0)
np.random.seed(0)
random.seed(0)

####################################################################################################

# Store results
train_success_all_erm = np.zeros((num_seeds, 1))
test_success_all_erm = np.zeros((num_seeds, 1))
train_success_all_drm = np.zeros((num_seeds, 1))
test_success_all_drm = np.zeros((num_seeds, 1))

for seed in range(num_seeds):
    ################################################################################################
    print("Seed:", seed+1, "/", num_seeds)
    print(" ")
    
    # Delete existing dataset
    shutil.rmtree('./data') if os.path.exists('./data') else None

    # Define train and test loaders
    train_loader = torch.utils.data.DataLoader(
    ColoredMNIST(root='./data', env='train',
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))
                    ]), train_len=train_len, detection_len=detect_len),
    batch_size=train_batch_size, shuffle=True, **kwargs)
    
    # Test loader
    test_loader = torch.utils.data.DataLoader(
    ColoredMNIST(root='./data', env='test', transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))
    ]), train_len=train_len, detection_len=detect_len),
    batch_size=1000, shuffle=True, **kwargs)

    # Define batch detection set loader (loads detection set in batches)    
    detect_dataset = ColoredMNIST(root='./data', env='detect',
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307, 0.1307, 0.), (0.3081, 0.3081, 0.3081))
                    ]), train_len=train_len, detection_len=detect_len)
    detect_loader_sampler = torch.utils.data.RandomSampler(detect_dataset, replacement=False, num_samples=num_detect_sets*detect_batch_size)
    detect_loader_batch = torch.utils.data.DataLoader(detect_dataset, sampler=detect_loader_sampler, batch_size=detect_batch_size, **kwargs)

    
    # Run training with empirical risk minimization (ERM)
    model = ConvNet().to(device)
    train_success_erm, test_success_erm = train_and_test_erm(train_loader, test_loader, model, num_training_epochs=num_training_epochs, lr=learning_rate)

    # Run training with deceptive risk minimization (DRM)
    model = ConvNet().to(device)
    train_success_drm, test_success_drm = train_and_test_drm_batch(train_loader, test_loader, detect_loader_batch, model, martingale_penalty=martingale_penalty, temperature=temperature, softrank_regularization_factor=softrank_regularization_factor, num_training_epochs=num_training_epochs, lr=learning_rate, softrank_regularization_type=softrank_regularization_type, num_epochs_erm=num_epochs_erm, num_detect_sets=num_detect_sets, detect_batch_size=detect_batch_size)
    
    # Store results
    train_success_all_erm[seed] = train_success_erm
    test_success_all_erm[seed] = test_success_erm
    train_success_all_drm[seed] = train_success_drm
    test_success_all_drm[seed] = test_success_drm
    
    ################################################################################################
    
# Print results
print('Train success (ERM):', np.mean(train_success_all_erm), '+-', np.std(train_success_all_erm))
print('Test success (ERM):', np.mean(test_success_all_erm), '+-', np.std(test_success_all_erm))
print('Train success (DRM):', np.mean(train_success_all_drm), '+-', np.std(train_success_all_drm))
print('Test success (DRM):', np.mean(test_success_all_drm), '+-', np.std(test_success_all_drm))

# Save results
np.savez('results.npz',
         train_success_all_erm=train_success_all_erm,
         test_success_all_erm=test_success_all_erm,
         train_success_all_drm=train_success_all_drm,
         test_success_all_drm=test_success_all_drm)


    
    
