import os
import torch
import argparse
import numpy as np
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

from utils import str2bool
from tester import ANPValidator
from datasets import get_dataset
from models import ANPSubsetSelect

parser = argparse.ArgumentParser()
parser.add_argument('--root', default='', type=str, help='dataset path')
parser.add_argument('--mode', default='train', type=str, help='training mode')
parser.add_argument('--stage', default='sss', type=str, help='sss stage')
parser.add_argument('--run', default='0', type=str, help='experiment number:run')
parser.add_argument('--debug', default=False, type=str2bool, help='debug mode or not.')
parser.add_argument('--dataset', default='celeba', type=str, help='dataset to load')
parser.add_argument('--CNP_mode', default='transformer', type=str, help='cnp mode.')
parser.add_argument('--hidden_dim', default=128, type=int, help='hidden dimension')
parser.add_argument('--CNP_encoder_num_layers', default=4, type=int, help='num layers in cnp encoder')
parser.add_argument('--CNP_decoder_num_layers', default=2, type=int, help='num layers in cnp decoder')
parser.add_argument('--subset_encoder_num_layers', default=3, type=int, help='num layers in sss')
parser.add_argument('--train_with_real_mask', default=True, type=str2bool, help='use real mask for decoder at train time')
parser.add_argument('--subset_sizes', default=[40, 60, 80, 100], nargs='+', help='subset sizes to evaluate on')
parser.add_argument('--size', default=[32, 32], nargs='+', help='size of image')
parser.add_argument('--eval_per_run', default=1, type=int, help='context dimension')
parser.add_argument('--x_dim', default=2, type=int, help='context dimension')
parser.add_argument('--y_dim', default=3, type=int, help='target dimension')
parser.add_argument('--max_output_points', default=100, type=int, help='size of subset')
parser.add_argument('--element_jump', default=10, type=int, help='number of points to greedy select')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--reg_scale', default=0.01, type=float, help='multiplicative scale for candidate regularization')
parser.add_argument('--temperature', default=0.05, type=float, help='temperature for relaxed distributions')
parser.add_argument('--alpha', default=1e-1, type=float, help='prior sparsity level')
parser.add_argument('--thres', default=0.499, type=float, help='threshold for selected elements')
parser.add_argument('--resume', default=False, type=str2bool, help='resume training from checkpoint')
parser.add_argument('--visualize', default=False, type=str2bool, help='visualization.not used for training')
args = parser.parse_args()

if __name__ == '__main__':
    np.random.seed(111)
    torch.manual_seed(111)
    torch.cuda.manual_seed(111)
    print('Seeded numpy and torch')

    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    _, validloader = get_dataset(args=args)
    
    if args.stage in ['candidate', 'autoregressive', 'random', 'sss', 'randomautoregressive']:
        model = ANPSubsetSelect(args=args)
    else:
        raise NotImplementedError()
    model = model.to(args.device)

    tester = ANPValidator(model=model, validloader=validloader, args=args)
    
    if args.visualize:
        tester.visualize_celeba()
    else:
        results = tester.validate()
            
        for subset_size in results.keys():
            valid_stats = results[subset_size]
            print('Subset size: {} '.format(subset_size))
            for key in valid_stats:
                value = valid_stats[key]
                print('\t\t{:<20} : {:.4f} +/- {:.4f}'.format(key, value[0] / np.prod(validloader.dataset.size), value[1] / np.prod(validloader.dataset.size)))
            print('\n')
