import os
import torch
import argparse
import numpy as np
import torch.nn as nn

from utils import str2bool
from tester import ANPValidator
from datasets import get_dataset
from models import ANPSubsetSelect

parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='train', type=str, help='training mode')
parser.add_argument('--stage', default='candidate', type=str, help='sss stage')
parser.add_argument('--run', default='0', type=str, help='experiment number:run')
parser.add_argument('--debug', default=False, type=str2bool, help='debug mode or not.')
parser.add_argument('--dataset', default='function', type=str, help='dataset to load')
parser.add_argument('--CNP_mode', default='transformer', type=str, help='cnp mode.')
parser.add_argument('--hidden_dim', default=128, type=int, help='hidden dimension')
parser.add_argument('--CNP_encoder_num_layers', default=4, type=int, help='num layers in cnp encoder')
parser.add_argument('--CNP_decoder_num_layers', default=2, type=int, help='num layers in cnp decoder')
parser.add_argument('--subset_encoder_num_layers', default=3, type=int, help='num layers in sss')
parser.add_argument('--train_with_real_mask', default=True, type=str2bool, help='use real mask for decoder at train time')
parser.add_argument('--minibatch_per_epoch', default=2000, type=int, help='number of function at each mini-batch')
parser.add_argument('--num_total_points', default=400, type=int, help='number of elements in a set')
parser.add_argument('--x_dim', default=1, type=int, help='context dimension')
parser.add_argument('--y_dim', default=1, type=int, help='target dimension')
parser.add_argument('--max_output_points', default=30, type=int, help='size of subset')
parser.add_argument('--subset_sizes', default=[15], nargs='+', help='subset sizes to evaluate on')
parser.add_argument('--element_jump', default=5, type=int, help='number of points to greedy select')
parser.add_argument('--eval_per_run', default=20, type=int, help='number of times to evaluate the same model')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--BATCH_SIZE', default=64, type=int, help='used to sample BATCH_SIZE*minibatch_per_epoch for each mini-batch')
parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
parser.add_argument('--total_iter_valid', default=200, type=int, help='number of function to evaluate on')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--reg_scale', default=0.01, type=float, help='multiplicative scale for candidate regularization')
parser.add_argument('--temperature', default=0.05, type=float, help='temperature for relaxed distributions')
parser.add_argument('--alpha', default=1e-1, type=float, help='prior sparsity level')
parser.add_argument('--thres', default=0.499, type=float, help='threshold for selected elements')
parser.add_argument('--resume', default=False, type=str2bool, help='resume training from checkpoint')
parser.add_argument('--visualize', default=False, type=str2bool, help='resume training from checkpoint')
args = parser.parse_args()

if __name__ == '__main__':
    np.random.seed(111)
    torch.manual_seed(111)
    torch.cuda.manual_seed(111)
    print('Seeded numpy and torch')

    args.total_iter_train = args.minibatch_per_epoch * args.BATCH_SIZE
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    _, validloader = get_dataset(args=args)
    
    if args.stage in ['candidate', 'autoregressive', 'random', 'sss', 'randomautoregressive']:
        model = ANPSubsetSelect(args=args)
    else:
        raise NotImplementedError()
    model = model.to(args.device)
    
    tester = ANPValidator(model=model, validloader=validloader, args=args)
    
    if args.visualize:
        tester.visualize_function()
    else:
        results = tester.validate()
            
        for subset_size in results.keys():
            valid_stats = results[subset_size]
            print('Subset size: {} '.format(subset_size))
            for key in valid_stats:
                value = valid_stats[key]
                print('\t\t{:<20} : {:.4f} +/- {:.4f}'.format(key, value[0] / validloader.dataset.num_total_points, value[1] / validloader.dataset.num_total_points ))
            print('\n')
