import train
import os
import torch
import argparse
import yaml

from geom.pc_encoder import load_geom_encoder
from load_dataset import load_train_val_fold
from dataset import GraphDataset
#from model import GeomCFD, GeoCA3D
from model import GeoCA3D
from cfd.models.utils import MLP

parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='data/mlcfd_data/training_data')
parser.add_argument('--save_dir', default='data/mlcfd_data/preprocessed_data')
parser.add_argument('--fold_id', default=0, type=int)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--val_iter', default=10, type=int)

parser.add_argument('--config_dir', default='params.yaml')
parser.add_argument('--ulip_model', default='none')
parser.add_argument('--ulip_ckpt', default='')
parser.add_argument('--frozen', action='store_true')

parser.add_argument('--cfd_config_dir', default='cfd/cfd_params.yaml')
parser.add_argument('--cfd_model')
parser.add_argument('--cfd_mesh', action='store_true')
parser.add_argument('--weight', default=0.5, type=float)

args = parser.parse_args()
print(args)

with open(args.cfd_config_dir, 'r') as f:
    cfd_hparams = yaml.safe_load(f)[args.cfd_model]
print(cfd_hparams)

with open(args.config_dir, 'r') as f:
    #hparams = yaml.safe_load(f)['GeomCFD']
    hparams = yaml.safe_load(f)['GeoCA3D']

n_gpu = torch.cuda.device_count()
use_cuda = 0 <= args.gpu < n_gpu and torch.cuda.is_available()
device = torch.device(f'cuda:{args.gpu}' if use_cuda else 'cpu')

if args.ulip_model == 'none':
    g_encoder, g_proj = None, None
else:
    g_encoder, g_proj = load_geom_encoder(args, pretrained=bool(args.ulip_ckpt), frozen=args.frozen)
    print(hparams)

train_data, val_data, coef_norm = load_train_val_fold(args, preprocessed=True)

use_height = args.ulip_model == 'ULIP_PN_NEXT'
r = cfd_hparams['r'] if 'r' in cfd_hparams.keys() else None
train_ds = GraphDataset(train_data, use_height=use_height, use_cfd_mesh=args.cfd_mesh, r=r)
val_ds = GraphDataset(val_data, use_height=use_height, use_cfd_mesh=args.cfd_mesh, r=r)

encoder = MLP(cfd_hparams['encoder'], batch_norm=False)
decoder = MLP(cfd_hparams['decoder'], batch_norm=False)

if args.cfd_model == 'GraphSAGE':
    from cfd.models.GraphSAGE import GraphSAGE
    model = GraphSAGE(cfd_hparams, encoder, decoder)
elif args.cfd_model == 'MLP':
    from cfd.models.NN import NN
    model = NN(cfd_hparams, encoder, decoder)
elif args.cfd_model == 'GAT':
    from cfd.models.GAT import GAT
    model = GAT(cfd_hparams, encoder, decoder)
elif args.cfd_model == 'GNO':
    from cfd.models.GNO import GNO
    model = GNO(cfd_hparams, encoder, decoder)

#model = GeomCFD(model, geom_encoder=g_encoder, geom_proj=g_proj, **hparams)
model = GeoCA3D(model, geom_encoder=g_encoder, geom_proj=g_proj, **hparams)
path = f'metrics/{args.cfd_model}/{args.fold_id}/{args.ulip_model}_{bool(args.ulip_ckpt)}_{args.frozen}_{args.weight}'
if not os.path.exists(path):
    os.makedirs(path)

model = train.main(device, train_ds, val_ds, model, cfd_hparams, path, val_iter=args.val_iter, reg=args.weight, coef_norm=coef_norm, frozen=args.frozen)
