import torch
from torch import nn
from torch_geometric.loader import DataLoader
from models import FRGNN
from dataset import CFDGraphsDataset
from plotter import plot_mesh
from box import Box
import yaml
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--trained_model_dir', '-d', help="path to the trained model directory", type=str, required=True)
parser.add_argument('--model_version', '-v', help="trained model version to load, e.g. best or e100", type=str, required=True)
parser.add_argument('--test_dataset_path', '-t', help="path to the testing dataset", type=str, required=True)
parser.add_argument('--plot_sample', '-p', help="if the last sample should be plotted", action='store_true')

if __name__ == "__main__":
    args = parser.parse_args()

    # load the config file
    config = Box.from_yaml(filename=os.path.join(args.trained_model_dir, 'config.yml'), Loader=yaml.FullLoader)

    # init dataset and dataloader
    dataset = CFDGraphsDataset(zip_path=args.test_dataset_path, random_masking=False, farfield_mag_aoa=config.hyperparameters.farfield_mag_aoa,
                               one_hot_node_type=config.hyperparameters.one_hot_node_type)

    test_loader = DataLoader(dataset, batch_size=config.hyperparameters.batch_size, shuffle=False)

    model = FRGNN(node_feature_dim=dataset.num_node_features, edge_feature_dim=dataset.num_edge_features,
                glob_feature_dim=dataset.num_glob_features, node_out_dim=dataset.num_node_output_features,
                glob_out_dim=dataset.num_glob_output_features, glob_loss_factor=config.hyperparameters.glob_loss_factor,
                div_loss_factor=config.hyperparameters.div_loss_factor, **config.model_settings )

    model.load_state_dict(torch.load(os.path.join(args.trained_model_dir,
                            'trained_models/{}.pt'.format(args.model_version)), map_location=torch.device('cpu')))


    # use gpu if available
    device = torch.device('cuda' if torch.cuda.is_available() else  'cpu')
    model.to(device)
    model.eval()

    # init lists to store results
    pressure_rmse = []
    velocityx_rmse = []
    velocityy_rmse = []

    pressure_maxae = []
    velocityx_maxae = []
    velocityy_maxae = []

    glob_pred = []
    glob_y = []

    with torch.no_grad():
        print('Evaluating model {} ({}) on test set'.format(args.trained_model_dir, args.model_version))
        for i_batch, data in enumerate(test_loader):
            # get batch data and send to the right device, reshape globals
            data = data.to(device)
            orig_pos = data.pos.clone()
            
            # compute the batch validation loss
            data = model(data)

            # multiply and add network output and target by normalization values to get physically meaningful results
            data.x = data.x * data.normalization_values[:, 1] + data.normalization_values[:, 0]
            data.y = data.y * data.normalization_values[:, 1] + data.normalization_values[:, 0]

            # gather globals
            glob_pred += [data.globals.squeeze()]
            glob_y += [data.globals_y.squeeze()]

            # compute the rmse
            pressure_rmse += [torch.sqrt(nn.functional.mse_loss(data.x[:, 0], data.y[:, 0]))]
            velocityx_rmse += [torch.sqrt(nn.functional.mse_loss(data.x[:, 1], data.y[:, 1]))]
            velocityy_rmse += [torch.sqrt(nn.functional.mse_loss(data.x[:, 2], data.y[:, 2]))]

            # compute the max absolute error
            pressure_maxae += [nn.functional.l1_loss(data.x[:, 0], data.y[:, 0], reduction='none').max()]
            velocityx_maxae += [nn.functional.l1_loss(data.x[:, 1], data.y[:, 1], reduction='none').max()]
            velocityy_maxae += [nn.functional.l1_loss(data.x[:, 2], data.y[:, 2], reduction='none').max()]

        glob_pred = torch.stack(glob_pred)
        glob_y = torch.stack(glob_y)
        u_inf_pred = glob_pred[:, 0]
        u_inf = glob_y[:, 0]
        aoa_pred = glob_pred[:, 1]
        aoa = glob_y[:, 1]
        ti_pred = glob_pred[:, 2]
        ti = glob_y[:, 2]

        p_rmse_avg = torch.tensor(pressure_rmse).mean()
        ux_rmse_avg = torch.tensor(velocityx_rmse).mean()
        uy_rmse_avg = torch.tensor(velocityy_rmse).mean()

        p_mae_avg = torch.tensor(pressure_maxae).mean()
        ux_mae_avg = torch.tensor(velocityx_maxae).mean()
        uy_mae_avg = torch.tensor(velocityy_maxae).mean()

        print('---------Root Mean Squared Error---------')
        print('Average p RMSE on test set: {}'.format(p_rmse_avg))
        print('Average u_x RMSE on test set: {}'.format(ux_rmse_avg))
        print('Average u_y RMSE on test set: {}'.format(uy_rmse_avg))
        print('Average uinf RMSE on test set: {}'.format(torch.sqrt(nn.functional.mse_loss(u_inf_pred, u_inf))))
        print('Average aoa RMSE on test set: {}'.format(torch.sqrt(nn.functional.mse_loss(aoa_pred, aoa))))
        print('Average ti RMSE on test set: {}'.format(torch.sqrt(nn.functional.mse_loss(ti_pred, ti))))

        print('---------Max Absolute Error---------')
        print('Average p MaxAE on test set: {}'.format(p_mae_avg))
        print('Average u_x MaxAE on test set: {}'.format(ux_mae_avg))
        print('Average u_y MaxAE on test set: {}'.format(uy_mae_avg))
        print('Average uinf MaxE on test set: {}'.format(nn.functional.l1_loss(u_inf_pred, u_inf, reduction='none').max()))
        print('Average aoa MaxAE on test set: {}'.format(nn.functional.l1_loss(aoa_pred, aoa, reduction='none').max()))
        print('Average ti MaxAE on test set: {}'.format(nn.functional.l1_loss(ti_pred, ti, reduction='none').max()))


        if args.plot_sample:
            print('---------Plotting sample---------')
            print('Globals True: ', data.globals_y)
            print('Globals Pred: ', data.globals)

            xlims = (-0.1, 1.1)
            ylims = (-0.08, 0.12)
            figsize = (12, 4)

            plot_mesh(data, 'pressure', plot_predicted=True, xlimits=xlims, ylimits=ylims, fig_size=figsize, tile_vertical=True, show=False)
            plot_mesh(data, 'velocity_mag', plot_predicted=True, xlimits=xlims, ylimits=ylims, fig_size=figsize, tile_vertical=True, show=False)
            plot_mesh(data, 'velocity_x', plot_predicted=True, xlimits=xlims, ylimits=ylims, fig_size=figsize, tile_vertical=True, show=False)
            plot_mesh(data, 'velocity_y', plot_predicted=True, xlimits=xlims, ylimits=ylims, fig_size=figsize, tile_vertical=True, show=True)