import argparse
import torch
import os
import numpy as np
import time
from utils import utils, cca
from models import MVAE
import data
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from tqdm import tqdm
from matplotlib.animation import FuncAnimation
import matplotlib.animation as animation


parser = argparse.ArgumentParser(description='Parameters for training the PTC network.')
parser.add_argument('--epochs', help='The number of epochs to train the network. (default: 300)', type=int, default=5)
parser.add_argument('--dir_data', type=str, default='data', help='Directory to the datasets.')
parser.add_argument('--dataset', type=str, nargs='+', default=['sarscov2_closed', 'sarscov2_partialopen', 'protease'],
                    help='The dataset to train the network on.')  # da_10906555
parser.add_argument('--dataset_old', type=str, default='ace2', help='The dataset to train the network on.')  # da_10906555
# parser.add_argument('--ref_surf', type=int, default=16000, help='The reference mesh to be used during training. '
#                                                                 '(default: [16000, 4])') #16000, 4

parser.add_argument('--n_latent', type=int, default=[2, 16, 32], nargs='+',
                    help='Latent space dimensions. (default: {[z_x], [z_c, z_n]})')
parser.add_argument('--bond_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the conformal factor encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--dihedrals_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--ca_enc_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the normal encoder. '
                         '(default: {[12, 24, 24, 48, _], [24, 24, 48, 48, _]})')
parser.add_argument('--xyz_dec_krn', type=int, nargs='+', default=None,
                    help='The number of kernels for in each layer for the xyz decoder. '
                         '(default: [64, 32, 32, 16, _])')
parser.add_argument('--n_heads', type=int, default=4, help='Number of heads of the attention mechanism in a GAT layer. '
                                                           '(default: 4)')
parser.add_argument('--conv_type', type=str, default='gat', help='The non-Euclidean convolution to use.')
parser.add_argument('--kernel_rad', type=float, default=2.5, help='The radius of the convolutional kernel on the '
                                                                    'surface. (default: 2.0A)')  #5E-2 / 2E-2
parser.add_argument('--seed', help='The random seed for execution.', type=int, nargs='+', default=[0, 1])
parser.add_argument('--model_name', help='The name associated with the model with saving results.', type=str,
                    default=None)
# parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='trial_rad')
parser.add_argument('--model_vers', help='Model version to append to name', type=str, default='fin3_reg')
parser.add_argument('--model_dir', help='The directory for saving the results.', type=str, default='./models')
parser.add_argument('--use_cuda', type=bool, default=True, help='Train on GPU if available.')


args = parser.parse_args()
train_loss = np.zeros([len(args.dataset), len(args.seed)])
train_loss_mse = np.zeros([len(args.dataset), len(args.seed)])
train_loss_cf = np.zeros([len(args.dataset), len(args.seed)])
val_loss = np.zeros([len(args.dataset), len(args.seed)])
val_loss_mse = np.zeros([len(args.dataset), len(args.seed)])
val_loss_cf = np.zeros([len(args.dataset), len(args.seed)])
test_loss = np.zeros([len(args.dataset), len(args.seed)])
test_loss_mse = np.zeros([len(args.dataset), len(args.seed)])
test_loss_cf = np.zeros([len(args.dataset), len(args.seed)])
test_loss_val = np.zeros([len(args.dataset), len(args.seed)])
for i_data, dataset in enumerate(args.dataset):
    for i_seed, seed in enumerate(args.seed):
        model_dir = args.model_dir + '/' + dataset
        # model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, seed)
        model_name = 'from_%s_%02d_%s_s%02d' % (args.dataset_old, sum(args.n_latent), args.model_vers, seed)
        results = np.load('%s/results_transfer_%s_%03d.npz' % (model_dir, model_name, args.epochs), allow_pickle=True)

        # model_name = 'baseline_%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, seed)

        train_loss[i_data, i_seed] = results['train_loss'][-1]
        train_loss_mse[i_data, i_seed] = results['train_loss_mse'][-1]
        train_loss_cf[i_data, i_seed] = results['train_loss_cf'][-1]

        val_loss[i_data, i_seed] = results['val_loss'][-1]
        val_loss_mse[i_data, i_seed] = results['val_loss_mse'][-1]
        val_loss_cf[i_data, i_seed] = results['val_loss_cf'][-1]

        test_loss[i_data, i_seed] = results['test_loss']
        test_loss_mse[i_data, i_seed] = results['test_loss_mse']
        test_loss_cf[i_data, i_seed] = results['test_loss_cf']
        test_loss_val[i_data, i_seed] = results['test_loss_val']


def get_stats(signal):
    mu = np.mean(signal, axis=-1)
    std = np.std(signal, axis=-1, ddof=1)
    return mu, std


train_loss_mu, train_loss_std = get_stats(train_loss)
train_loss_mse_mu, train_loss_mse_std = get_stats(train_loss_mse)
train_loss_cf_mu, train_loss_cf_std = get_stats(train_loss_cf)

val_loss_mu, val_loss_std = get_stats(val_loss)
val_loss_mse_mu, val_loss_mse_std = get_stats(val_loss_mse)
val_loss_cf_mu, val_loss_cf_std = get_stats(val_loss_cf)

test_loss_mu, test_loss_std = get_stats(test_loss)
test_loss_mse_mu, test_loss_mse_std = get_stats(test_loss_mse)
test_loss_cf_mu, test_loss_cf_std = get_stats(test_loss_cf)
test_loss_val_mu, test_loss_val_std = get_stats(test_loss_val)

print('Transfer Results')
for i, dataset in enumerate(args.dataset):
    print('Dataset:', dataset)
    # print('Train loss mu/std: %.2E %.2E' % (train_loss_mu[i], train_loss_std[i]))
    # print('Train loss mse mu/std: %.2E %.2E' % (train_loss_mse_mu[i], train_loss_mse_std[i]))
    # print('Train loss cf mu/std: %.2E %.2E' % (train_loss_cf_mu[i], train_loss_cf_std[i]))
    #
    # print('Val loss mu/std: %.2E %.2E' % (val_loss_mu[i], val_loss_std[i]))
    # print('Val loss mse mu/std: %.2E %.2E' % (val_loss_mse_mu[i], val_loss_mse_std[i]))
    # print('Val loss cf mu/std: %.2E %.2E' % (val_loss_cf_mu[i], val_loss_cf_std[i]))
    #
    # print('Test loss mu/std: %.2E %.2E' % (test_loss_mu[i], test_loss_std[i]))
    # print('Test loss mse mu/std: %.2E %.2E' % (test_loss_mse_mu[i], test_loss_mse_std[i]))
    # print('Test loss cf mu/std: %.2E %.2E' % (test_loss_cf_mu[i], test_loss_cf_std[i]))

    print('Test loss val mu/std: %.2E %.2E' % (test_loss_val_mu[i], test_loss_val_std[i]))

# Baseline
for i_data, dataset in enumerate(args.dataset):
    for i_seed, seed in enumerate(args.seed):
        model_dir = args.model_dir + '/' + dataset
        # model_name = '%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, seed)
        model_name = 'baseline_%02d_%s_s%02d' % (sum(args.n_latent), args.model_vers, seed)
        results = np.load('%s/results_transfer_%s_%03d.npz' % (model_dir, model_name, args.epochs),
                          allow_pickle=True)

        train_loss[i_data, i_seed] = results['train_loss'][-1]
        train_loss_mse[i_data, i_seed] = results['train_loss_mse'][-1]
        train_loss_cf[i_data, i_seed] = results['train_loss_cf'][-1]

        val_loss[i_data, i_seed] = results['val_loss'][-1]
        val_loss_mse[i_data, i_seed] = results['val_loss_mse'][-1]
        val_loss_cf[i_data, i_seed] = results['val_loss_cf'][-1]

        test_loss[i_data, i_seed] = results['test_loss']
        test_loss_mse[i_data, i_seed] = results['test_loss_mse']
        test_loss_cf[i_data, i_seed] = results['test_loss_cf']

        test_loss_val[i_data, i_seed] = results['test_loss_val']

train_loss_mu, train_loss_std = get_stats(train_loss)
train_loss_mse_mu, train_loss_mse_std = get_stats(train_loss_mse)
train_loss_cf_mu, train_loss_cf_std = get_stats(train_loss_cf)

val_loss_mu, val_loss_std = get_stats(val_loss)
val_loss_mse_mu, val_loss_mse_std = get_stats(val_loss_mse)
val_loss_cf_mu, val_loss_cf_std = get_stats(val_loss_cf)

test_loss_mu, test_loss_std = get_stats(test_loss)
test_loss_mse_mu, test_loss_mse_std = get_stats(test_loss_mse)
test_loss_cf_mu, test_loss_cf_std = get_stats(test_loss_cf)

test_loss_val_mu, test_loss_val_std = get_stats(test_loss_val)

print('Baseline Results')
for i, dataset in enumerate(args.dataset):
    print('Dataset:', dataset)
    # print('Train loss mu/std: %.2E %.2E' % (train_loss_mu[i], train_loss_std[i]))
    # print('Train loss mse mu/std: %.2E %.2E' % (train_loss_mse_mu[i], train_loss_mse_std[i]))
    # print('Train loss cf mu/std: %.2E %.2E' % (train_loss_cf_mu[i], train_loss_cf_std[i]))
    #
    # print('Val loss mu/std: %.2E %.2E' % (val_loss_mu[i], val_loss_std[i]))
    # print('Val loss mse mu/std: %.2E %.2E' % (val_loss_mse_mu[i], val_loss_mse_std[i]))
    # print('Val loss cf mu/std: %.2E %.2E' % (val_loss_cf_mu[i], val_loss_cf_std[i]))
    #
    # print('Test loss mu/std: %.2E %.2E' % (test_loss_mu[i], test_loss_std[i]))
    # print('Test loss mse mu/std: %.2E %.2E' % (test_loss_mse_mu[i], test_loss_mse_std[i]))
    # print('Test loss cf mu/std: %.2E %.2E' % (test_loss_cf_mu[i], test_loss_cf_std[i]))

    print('Test loss val mu/std: %.2E %.2E' % (test_loss_val_mu[i], test_loss_val_std[i]))

print('Done')
