import argparse
import datetime
import os
import sys

import numpy as np
import matplotlib.pyplot as plt

current = os.path.dirname(os.path.realpath(__file__))
parent = os.path.dirname(current)
sys.path.append(parent)

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as data_utils
import torchvision

from mnist_loader import MnistRotated
from ae_model import AE




def model_training(autoencoder, train_loader, epoch):
    loss_metric = nn.MSELoss()
    optimizer = torch.optim.Adam(autoencoder.parameters(), lr=args.lr,
                                 weight_decay=args.weight_decay)

    autoencoder.train()
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        images, _, _ = data
        images = Variable(images)

        if args.cuda: images = images.to(device)
        outputs = autoencoder(images)
        loss = loss_metric(outputs, images)
        loss.backward()
        optimizer.step()
        if (i + 1) % args.log_interval == 0:
            print('Epoch [{}/{}] - Iter[{}/{}], MSE loss:{:.4f}'.format(
                epoch + 1, args.epochs, i + 1,
                len(train_loader.dataset) // args.batch_size, loss.item()
            ))


def evaluation(autoencoder, test_loader, config):
    total_loss = 0
    loss_metric = nn.MSELoss()
    autoencoder.eval()
    for i, data in enumerate(test_loader):
        images, _, _ = data
        images = Variable(images)
        if args.cuda: images = images.to(device)
        outputs = autoencoder(images)
        loss = loss_metric(outputs, images)
        total_loss += loss * len(images)
    avg_loss = total_loss / len(test_loader.dataset)


    trans_imgs = autoencoder(images)
    grid_img = torchvision.utils.make_grid(torch.cat((images[:20],trans_imgs[:20])).view(-1, 1, 28, 28),
                                                       nrow=10)



    print('\nAverage MSE Loss on Test set: {:.4f}'.format(avg_loss))

    global BEST_VAL
    if avg_loss < BEST_VAL:
        BEST_VAL = avg_loss
        torch.save(autoencoder.state_dict(), config.save_dir + f'ae1-{config.train_domain}.pt')
        print('Save Best Model in HISTORY\n')

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Pretrain AE')

    # training
    parser.add_argument('--no-cuda', action='store_true', default=False,
                        help='disables CUDA training')
    parser.add_argument('--seed', type=int, default=0,
                        help='random seed (default: 0)')
    parser.add_argument('--mnist_subset', type=str, default='0')
    parser.add_argument('--batch-size', type=int, default=128,
                        help='input batch size for training (default: 128)')
    parser.add_argument('--epochs', type=int, default=100,
                        help='number of epochs to train (default: 10)')
    parser.add_argument('--lr', type=float, default=0.0001,
                        help='learning rate (default: 0.0001)')
    parser.add_argument('--weight-decay', type=float, default=0.00001)
    parser.add_argument('--log-interval', type=int, default=100)


    # data
    parser.add_argument('--all-data', action='store_true', default=False)
    parser.add_argument('--data-dir',default='',type=str)
    parser.add_argument('--num-supervised', default=60000, type=int,
                        help="number of supervised examples, /10 = samples per class")
    parser.add_argument('--list_train_domains', type=list,
                        default=['0','15','30','45','60','75'],
                        help='domains used during training')
    parser.add_argument('--train_domain', type=str, default='0',
                        help='domain used during testing')

    # model
    parser.add_argument('--activation', default='sigmoid')


    # log
    parser.add_argument('--run_name', default='indae')
    parser.add_argument('--save_dir', default='/')





    args = parser.parse_args()

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    seed = args.seed
    os.environ['PYTHONHASHSEED'] = str(seed)
    # Torch RNG
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    # Python RNG
    np.random.seed(seed)
    #random.seed(seed)

    print(args.list_train_domains)
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda" if args.cuda else "cpu")
    kwargs = {'num_workers': 1, 'pin_memory': False} if args.cuda else {}

    # all_training_domains = ['0', '15', '30', '45', '60', '75']
    # all_training_domains.remove(args.target_domain)
    # args.list_train_domains = all_training_domains
    args.run_name = args.run_name + '-' + args.train_domain






    train_set = MnistRotated([args.train_domain], [args.train_domain], args.data_dir,
    train = True, mnist_subset=args.mnist_subset, all_data=args.all_data)

    # the test set is the same as train since we don't need it to generalize
    test_set = MnistRotated([args.train_domain], [args.train_domain], args.data_dir,
    train = True, mnist_subset=args.mnist_subset, all_data=args.all_data)

    train_loader = data_utils.DataLoader(train_set,
                                         batch_size=args.batch_size,
                                         shuffle=True)
    test_loader = data_utils.DataLoader(test_set,
                                        batch_size=args.batch_size,
                                        shuffle=True)

    activations = {'tanh':nn.Tanh(),
                   'sigmoid':nn.Sigmoid()}
    args.activation = activations[args.activation]

    BEST_VAL = float('inf')
    autoencoder = AE1(args)
    if args.cuda: autoencoder.to(device)

    for epoch in range(args.epochs):
        starttime = datetime.datetime.now()
        model_training(autoencoder, train_loader, epoch)
        endtime = datetime.datetime.now()
        print(f'Train a epoch in {(endtime - starttime).seconds} seconds')
        # evaluate on test set and save best model
        with torch.no_grad():
            evaluation(autoencoder, test_loader, args)
    print('Trainig Complete with best validation loss {:.4f}'.format(BEST_VAL))