import argparse
import sys
sys.path.append('./')
# Import model util
from utils.model import Model
# Import global configs
import global_config as gc

import torch
import numpy as np

# Get args
parser = argparse.ArgumentParser(
    description='Train a model on the given features/pairs.')

parser.add_argument(
    '--a',
    dest='alpha',
    type=float,
    help='Alpha value.',
    default=.05)
parser.add_argument(
    '--x',
    dest='experiment',
    type=str,
    help='Available experiments: {}.'.format(
        gc.available_experiments),
    default="features_ae")
parser.add_argument(
    '--d1',
    dest='domain1',
    type=str,
    help='Domain 1',
    default='amazon')
parser.add_argument(
    '--d2',
    dest='domain2',
    type=str,
    help='Domain 2.',
    default='amazon')
parser.add_argument(
    '--f1',
    dest='features1',
    type=str,
    help='Features 1.',
    default='CaffeNet4096')
parser.add_argument(
    '--f2',
    dest='features2',
    type=str,
    help='Features 2.',
    default='CaffeNet4096')
parser.add_argument(
    '--lr',
    dest='learning_rate',
    type=float,
    help='Learning Rate',
    default=5e-5)
parser.add_argument(
    '--z',
    dest='latent_dim',
    type=int,
    help='Dimensions of latent variable z.',
    default=256)
parser.add_argument(
    '--e',
    dest='epochs',
    type=int,
    help='Epochs for coupled AE training. In iterations through the full dataset specified in aeconfig/dataset_params',
    default=200)
parser.add_argument(
    '--b1',
    dest='batch_size1',
    type=int,
    help='Batch size for Domain 1',
    default=128)
parser.add_argument(
    '--b2',
    dest='batch_size2',
    type=int,
    help='Batch size for Domain 2',
    default=64)
parser.add_argument(
    '--pc',
    dest='per_class',
    type=int,
    help='Supervision per class',
    default=3)
parser.add_argument(
    '--ths',
    dest='threshold',
    type=int,
    help='Threshold for supervision per batch',
    default=1)
parser.add_argument(
    '--hid1',
    dest='hidden1',
    type=int,
    help='Size of the hidden layer AE 1',
    default=512)
parser.add_argument(
    '--hid2',
    dest='hidden2',
    type=int,
    help='Size of the hidden layer AE 2',
    default=2048)
parser.add_argument(
    '--sample',
    dest='sample',
    type=int,
    help='Fraction of samples to use per class (0=all data, else number per class)',
    default=0)

parser.add_argument(
    '--supervision',
    dest='supervision',
    type=bool,
    help='If to use a classifier in source domain',
    default=True)

parser.add_argument(
    '--use_target',
    dest='use_target',
    type=bool,
    help='If to use target supervision',
    default=False)

parser.add_argument(
    '--r',
    dest='restore',
    type=bool,
    help='Should a model with the given parameters be restored',
    default=False)
parser.add_argument(
    '--torch_seed',
    dest='torch_seed',
    type=int,
    help='Torch seed',
    default=0)
parser.add_argument(
    '--np_seed',
    dest='np_seed',
    type=int,
    help='Numpy seed',
    default=0)

args = parser.parse_args()

# Import experiment specific configs
if args.experiment in gc.available_experiments:
    sys.path.append(args.experiment)
    import aeconfig
else:
    print(
        "Error, unknonwn experiment {}. If you created a new experiment you have to add it to the available_experiments list in the gloabl config.".format(
            args.experiment))

vc = aeconfig.AEConfig()
torch.manual_seed(args.torch_seed)
np.random.seed(args.np_seed)

vc.model_params_1['method'] = 'ae'
vc.model_params_2['method'] = 'ae'

if not args.restore:
    # Create new model and train it from scratch
    # Fill them in the aeconfig
    vc.model_params_1['latent_dim'] = args.latent_dim
    vc.model_params_2['latent_dim'] = args.latent_dim

    vc.training_params['domain1'] = args.domain1
    vc.training_params['domain2'] = args.domain2
    vc.training_params['features1'] = args.features1
    vc.training_params['features2'] = args.features2

    vc.training_params['learning_rate'] = args.learning_rate
    vc.training_params['train_epochs'] = args.epochs
    vc.training_params['latent_dim'] = args.latent_dim
    vc.training_params['alpha'] = args.alpha
    vc.training_params['per_class'] = args.per_class

    if vc.training_params['per_class'] > 0:
        vc.training_params['use_target'] = True
    else:
        vc.training_params['use_target'] = False

    vc.training_params['sample'] = args.sample
    vc.training_params['threshold'] = args.threshold
    vc.training_params['batch_size1'] = args.batch_size1
    vc.training_params['batch_size2'] = args.batch_size2
    vc.training_params['hidden1'] = args.hidden1
    vc.training_params['hidden2'] = args.hidden2
    vc.training_params['torch_seed'] = args.torch_seed
    vc.training_params['np_seed'] = args.np_seed
    vc.training_params['supervision'] = args.supervision

    vc.dataset_params_train_1 = vc.get_dataset(
        args.domain1,
        args.features1,
        shuffle=True,
        batch_size=args.batch_size1,
        drop_last=True)
    vc.dataset_params_train_2 = vc.get_dataset(
        args.domain2,
        args.features2,
        shuffle=True,
        batch_size=args.batch_size2,
        drop_last=True)

    model = Model(vc)
else:
    # Load existing model and continue training
    model = Model(None)
    restore_path = gc.cfg['results_path'] + '/{}dim_b{}/model.pkl'.format(
        args.latent_dim, args.beta)
    model = model.load_model(restore_path)

    # Fill params into the aeconfig in case something changed. The other
    # params should not be changed really
    model.aeconfig.model_params['train_epochs'] = args.epochs

try:
    model.train()
except KeyboardInterrupt:
    print("Keyboard interrupt detected, trying to shutdown dataloaders...")
    try:
        train_dataset.stop()
        test_dataset.stop()
        print("Dataloaders stopped")
        exit()
    except Exception as e:
        print("Could not stop dataloaders after being done. Probably they already exited naturally")
        exit()
