import os
import torch
import argparse
import numpy as np
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR

from utils import str2bool
from trainer import ANPTraniner
from datasets import get_dataset
from models import ANPSubsetSelect

parser = argparse.ArgumentParser()
parser.add_argument('--mode', default='train', type=str, help='training mode')
parser.add_argument('--stage', default='candidate', type=str, help='sss stage')
parser.add_argument('--run', default='0', type=str, help='experiment number:run')
parser.add_argument('--debug', default=False, type=str2bool, help='debug mode or not.')
parser.add_argument('--dataset', default='function', type=str, help='dataset to load')
parser.add_argument('--CNP_mode', default='transformer', type=str, help='cnp mode.')
parser.add_argument('--hidden_dim', default=128, type=int, help='hidden dimension')
parser.add_argument('--CNP_encoder_num_layers', default=4, type=int, help='num layers in cnp encoder')
parser.add_argument('--CNP_decoder_num_layers', default=2, type=int, help='num layers in cnp decoder')
parser.add_argument('--subset_encoder_num_layers', default=3, type=int, help='num layers in sss')
parser.add_argument('--train_with_real_mask', default=True, type=str2bool, help='use real mask for decoder at train time')
parser.add_argument('--minibatch_per_epoch', default=2000, type=int, help='number of function at each mini-batch')
parser.add_argument('--num_total_points', default=400, type=int, help='number of elements in a set')
parser.add_argument('--x_dim', default=1, type=int, help='context dimension')
parser.add_argument('--y_dim', default=1, type=int, help='target dimension')
parser.add_argument('--max_output_points', default=30, type=int, help='size of subset')
parser.add_argument('--element_jump', default=5, type=int, help='number of points to greedy select')
parser.add_argument('--epochs', default=100, type=int, help='number of epochs')
parser.add_argument('--BATCH_SIZE', default=64, type=int, help='used to sample BATCH_SIZE*minibatch_per_epoch for each mini-batch')
parser.add_argument('--batch_size', default=128, type=int, help='batch size for training')
parser.add_argument('--total_iter_valid', default=200, type=int, help='number of function to evaluate on')
parser.add_argument('--lr', default=1e-3, type=float, help='learning rate')
parser.add_argument('--reg_scale', default=0.01, type=float, help='multiplicative scale for candidate regularization')
parser.add_argument('--temperature', default=0.05, type=float, help='temperature for relaxed distributions')
parser.add_argument('--alpha', default=1e-1, type=float, help='prior sparsity level')
parser.add_argument('--thres', default=0.499, type=float, help='threshold for selected elements')
parser.add_argument('--resume', default=False, type=str2bool, help='resume training from checkpoint')
parser.add_argument('--visualize', default=False, type=str2bool, help='visualization.not used for training')
args = parser.parse_args()

if __name__ == '__main__':
    if args.debug:
        np.random.seed(111)
        torch.manual_seed(111)
        torch.cuda.manual_seed(111)
        torch.autograd.set_detect_anomaly(True)
        print('In debug mode: numpy and torch seeded')
    
    args.total_iter_train = args.minibatch_per_epoch * args.BATCH_SIZE
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'

    trainloader, validloader = get_dataset(args=args)
    
    if args.stage in ['candidate', 'autoregressive', 'random', 'sss', 'randomautoregressive']:
        model = ANPSubsetSelect(args=args)
    else:
        raise NotImplementedError()
    model = model.to(args.device)

    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    
    optimizer = Adam(model.parameters(), lr=args.lr)
    scheduler = ExponentialLR(optimizer, gamma=0.96)
    
    trainer = ANPTraniner(model=model, optimizer=optimizer, scheduler=scheduler, trainloader=trainloader, validloader=validloader, args=args)

    if args.mode == 'train':
        trainer.fit()
