import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader
from models import FRGNN
from dataset import CFDGraphsDataset
from box import Box
import yaml
import os, argparse
from shutil import copyfile
from tqdm import tqdm
from names_generator import generate_name

parser = argparse.ArgumentParser()
parser.add_argument('--yaml_config', '-c', help="path to the yaml config file", type=str, required=True)
parser.add_argument('--run_folder', '-r', help="path to the run folder where the training run should be saved", type=str, required=True)

# load the config file
config = Box.from_yaml(filename=parser.parse_args().yaml_config, Loader=yaml.FullLoader)

# initialize the datasets and dataloaders
train_dataset = CFDGraphsDataset(zip_path=config.io_settings.train_dataset_path,
                           random_masking=config.hyperparameters.random_masking,
                           farfield_mag_aoa=config.hyperparameters.farfield_mag_aoa,
                           one_hot_node_type=config.hyperparameters.one_hot_node_type)
train_loader = DataLoader(train_dataset, batch_size=config.hyperparameters.batch_size,shuffle=True,
                              exclude_keys=['triangles', 'triangle_points'],
                              num_workers=config.run_settings.num_t_workers, pin_memory=False,
                              persistent_workers=False if config.run_settings.num_t_workers == 0 else True)
if config.run_settings.validate:
    validate_dataset = CFDGraphsDataset(zip_path=config.io_settings.valid_dataset_path,
                                        random_masking=False,farfield_mag_aoa=config.hyperparameters.farfield_mag_aoa,
                                        one_hot_node_type=config.hyperparameters.one_hot_node_type)
    validate_loader = DataLoader(validate_dataset, batch_size=1, shuffle=False, exclude_keys=['triangles', 'triangle_points'],
                                 num_workers=config.run_settings.num_v_workers, pin_memory=False,
                                 persistent_workers=False if config.run_settings.num_v_workers == 0 else True)

if __name__ == "__main__":
    # create model saving dir and copy config file to run dir
    run_name = generate_name()
    current_run_dir = os.path.join(parser.parse_args().run_folder, run_name)
    os.makedirs(os.path.join(current_run_dir, 'trained_models'))
    copyfile(parser.parse_args().yaml_config, os.path.join(current_run_dir, 'config.yml'))

    # use gpu if available
    device = torch.device('cuda' if torch.cuda.is_available() else  'cpu')

    # initialize the model
    model = FRGNN(node_feature_dim=train_dataset.num_node_features, edge_feature_dim=train_dataset.num_edge_features,
                glob_feature_dim=train_dataset.num_glob_features, node_out_dim=train_dataset.num_node_output_features,
                glob_out_dim=train_dataset.num_glob_output_features, glob_loss_factor=config.hyperparameters.glob_loss_factor,
                div_loss_factor=config.hyperparameters.div_loss_factor, **config.model_settings )

    # if using a pretrained model, load it here
    if config.io_settings.pretrained_model:
        model.load_state_dict(torch.load(config.io_settings.pretrained_model, map_location=torch.device('cpu')))

    # send the models to the gpu if available
    model.to(device)

    # define optimizer
    optimizer = optim.Adam(model.parameters(), lr=float(config.hyperparameters.start_lr))

    # define the learning rate scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=config.hyperparameters.lr_decay_gamma)

    # training loop
    print('Starting run {} on {}'.format(run_name, next(model.parameters()).device))
    pbar = tqdm(total=config.hyperparameters.epochs)
    pbar.set_description('Training')
    for epoch in range(config.hyperparameters.epochs):
        train_loss = 0
        model.train()

        # mini-batch loop
        for i_batch, data in enumerate(train_loader):
            # get batch data and send to the right device, reshape globals
            data = data.to(device)

            # reset the gradients back to zero
            optimizer.zero_grad()

            # run forward pass and compute the batch training loss
            data = model(data)
            batch_loss = model.compute_loss(data)

            train_loss += batch_loss.item()

            # perform batched SGD parameter update
            batch_loss.backward()
            optimizer.step()

        # compute the epoch training loss
        train_loss = train_loss / len(train_loader)

        # step the scheduler
        scheduler.step()

        # save the trained model every n epochs
        if (epoch + 1) % config.io_settings.save_epochs == 0:
            torch.save(model.state_dict(), os.path.join(current_run_dir, 'trained_models', 'e{}.pt'.format(epoch + 1)))

        if config.run_settings.validate:
            # compute validation loss
            validation_loss = 0
            model.eval()
            with torch.no_grad():
                for i_batch, data in enumerate(validate_loader):
                    # get batch data and send to the right device, reshape globals
                    data = data.to(device)

                    # compute the batch validation loss
                    data = model(data)
                    validation_loss += model.compute_loss(data)

            # get the full dataset validation loss for this epoch
            validation_loss =  validation_loss / len(validate_loader)

            # display losses and progress bar
            pbar.set_postfix({'Train Loss': f'{train_loss:.8f}','Validation Loss': f'{validation_loss:.8f}'})
            pbar.update(1)

            # save the model with the best validation loss
            if epoch == 0:
                best_validation_loss = validation_loss
            else:
                if validation_loss < best_validation_loss:
                    best_validation_loss = validation_loss
                    torch.save(model.state_dict(), os.path.join(current_run_dir, 'trained_models', 'best.pt'))

        else:
            # display losses and progress bar
            pbar.set_postfix({'Train Loss': f'{train_loss:.8f}'})
            pbar.update(1)