import os
import torch
import argparse
import yaml
import numpy as np
import time

from torch import nn
from torch_geometric.loader import DataLoader

from load_dataset import load_train_val_fold
from dataset import GraphDataset
#from model import GeomCFD, GeoCA3D

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='data/mlcfd_data/training_data')
parser.add_argument('--save_dir', default='data/mlcfd_data/preprocessed_data')
parser.add_argument('--fold_id', default=0, type=int)
parser.add_argument('--gpu', default=0, type=int)

parser.add_argument('--config_dir', default='params.yaml')
parser.add_argument('--ulip_model', default='none')
parser.add_argument('--ulip_ckpt', default='')
parser.add_argument('--frozen', action='store_true')

parser.add_argument('--cfd_config_dir', default='cfd/cfd_params.yaml')
parser.add_argument('--cfd_model')
parser.add_argument('--cfd_mesh', action='store_true')
parser.add_argument('--weight', default=0.5, type=float)

args = parser.parse_args()
print(args)

with open(args.cfd_config_dir, 'r') as f:
    hparams = yaml.safe_load(f)[args.cfd_model]
print(hparams)

n_gpu = torch.cuda.device_count()
use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available()
device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu')

if args.ulip_model == 'ULIP_PN_NEXT':
    import sys
    pointnext_dir = './geom/models/pointnext/PointNeXt'

    def add_path_recursive(directory):
        sys.path.append(directory)
        for root, dirs, files in os.walk(directory):
            for d in dirs:
                add_path_recursive(os.path.join(root, d))

    add_path_recursive(pointnext_dir)
    
train_data, val_data, coef_norm = load_train_val_fold(args, preprocessed=True)

use_height = args.ulip_model == 'ULIP_PN_NEXT'
r = hparams['r'] if 'r' in hparams.keys() else None
val_ds = GraphDataset(val_data, use_height=use_height, use_cfd_mesh=args.cfd_mesh, r=r)

path = f'metrics/{args.cfd_model}/{args.fold_id}/{args.ulip_model}_{bool(args.ulip_ckpt)}_{args.frozen}_{args.weight}'
model = torch.load(os.path.join(path, 'model_200.pth')).to(device)

test_loader = DataLoader(val_ds, batch_size=1)

with torch.no_grad():
    model.eval()

    criterion_func = nn.MSELoss(reduction='none')
    l2errs_press = []
    l2errs_velo = []
    mses_press = []
    mses_velo_var = []
    times = []
    for cfd_data, geom in test_loader:        
        cfd_data = cfd_data.to(device)

        prob = torch.ones(geom.shape[1]) * 0.8
        indices = torch.bernoulli(prob).bool()
        geom = geom[:, indices]
        #print(geom.shape)

        geom = geom.to(device)
        tic = time.time()
        out = model((cfd_data, geom))       
        toc = time.time()
        targets = cfd_data.y

        if coef_norm is not None:
            mean = torch.tensor(coef_norm[2]).to(device)
            std = torch.tensor(coef_norm[3]).to(device)
            pred_press = out[cfd_data.surf, -1] * std[-1] + mean[-1]
            gt_press = targets[cfd_data.surf, -1] * std[-1] + mean[-1]
            pred_velo = out[~cfd_data.surf, :-1] * std[:-1] + mean[:-1]
            gt_velo = targets[~cfd_data.surf, :-1] * std[:-1] + mean[:-1]
        
        l2err_press = torch.norm(pred_press - gt_press) / torch.norm(gt_press)
        l2err_velo = torch.norm(pred_velo - gt_velo) / torch.norm(gt_velo)

        mse_press = criterion_func(out[cfd_data.surf, -1], targets[cfd_data.surf, -1]).mean(dim=0)
        mse_velo_var = criterion_func(out[~cfd_data.surf, :-1], targets[~cfd_data.surf, :-1]).mean(dim=0)

        l2errs_press.append(l2err_press.cpu().numpy())
        l2errs_velo.append(l2err_velo.cpu().numpy())
        mses_press.append(mse_press.cpu().numpy())
        mses_velo_var.append(mse_velo_var.cpu().numpy())
        times.append(toc - tic)

    l2err_press = np.mean(l2errs_press)
    l2err_velo = np.mean(l2errs_velo)
    rmse_press = np.sqrt(np.mean(mses_press))
    rmse_velo_var = np.sqrt(np.mean(mses_velo_var, axis=0))
    if coef_norm is not None:
        rmse_press *= coef_norm[3][-1]
        rmse_velo_var *= coef_norm[3][:-1]
    print('relative l2 error press:', l2err_press)
    print('relative l2 error velo:', l2err_velo)
    print('press:', rmse_press)
    print('velo:', rmse_velo_var, np.sqrt(np.mean(np.square(rmse_velo_var))))
    print('time:', np.mean(times))