import torch
import torch.optim as optim
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import Cartesian, Distance
from models import GenGeomAutoencoder
from datatools import WindTerrainDataset, compute_dataset_stats, norm_data
from box import Box
import yaml
import os
from tqdm import tqdm
import string
import random

torch.multiprocessing.set_sharing_strategy('file_system')

def train(config_path: str):
    # load the config file
    config = Box.from_yaml(filename=parser.parse_args().config, Loader=yaml.FullLoader)
    
    if config.data_settings.transform == 'Cartesian':
        transform = Cartesian(norm=False)
    elif config.data_settings.transform == 'Distance':
        transform = Distance(norm=False)
    else:
        transform = None


    # initialize the datasets and dataloaders
    train_dataset = WindTerrainDataset(filename=config.io_settings.train_dataset_path, transform=transform, channels=config.data_settings.channels,
                                       max_cells_above_terrain=config.data_settings.max_cells_above_terrain, mode='train')
    train_loader = DataLoader(train_dataset, batch_size=config.hyperparameters.batch_size, shuffle=True,
                                exclude_keys=['terrain_mask', 'fluid_indices'],
                                num_workers=config.run_settings.num_t_workers, pin_memory=False,
                                persistent_workers=False if config.run_settings.num_t_workers == 0 else True)
    if config.run_settings.validate:
        validate_dataset = WindTerrainDataset(filename=config.io_settings.valid_dataset_path, transform=transform, channels=config.data_settings.channels,
                                               max_cells_above_terrain=config.data_settings.max_cells_above_terrain,
                                               mode='eval')
        validate_loader = DataLoader(validate_dataset, batch_size=config.hyperparameters.batch_size, shuffle=False,
                                    num_workers=config.run_settings.num_v_workers, pin_memory=False, exclude_keys=['terrain_mask', 'fluid_indices'],
                                    persistent_workers=False if config.run_settings.num_v_workers == 0 else True)
        
    # get the dimenstions of the data and add to the config dict
    config.data_dims = train_dataset.get_data_dims_dict()

    # start the comet tracking if required
    uid = ''.join(random.choices(string.ascii_letters + string.digits, k=4))
    run_name = '{}_dim_{}_uid_{}'.format(config.model_settings.model_type, config.model_settings.latent_dim, uid)

    # create model saving dir and save the config file to run dir to keep a record of the run
    current_run_dir = os.path.join(config.io_settings.run_dir, run_name)
    os.makedirs(os.path.join(current_run_dir, 'trained_models'))
    config.to_yaml(filename=os.path.join(current_run_dir, 'config.yml'))

    # use gpu if available
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # initialize the model
    model = GenGeomAutoencoder(**config.data_dims, **config.hyperparameters, **config.model_settings)

    # if using a pretrained model, load it here
    if config.io_settings.pretrained_model:
        checkpoint = torch.load(config.io_settings.pretrained_model, map_location=torch.device('cpu'))
        model.load_state_dict(checkpoint['model_state_dict'])
        
    model.trainset_stats  = compute_dataset_stats(train_loader, device)

    # send the models to the gpu if available
    model.to(device)

    # define optimizer
    optimizer = optim.AdamW(model.parameters(), lr=float(config.hyperparameters.start_lr), weight_decay=float(config.hyperparameters.weight_decay))

    # define the learning rate scheduler
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=float(config.hyperparameters.lr_decay))

    # training loop
    print('Starting run {} on {}'.format(run_name, next(model.parameters()).device))
    pbar = tqdm(total=config.hyperparameters.epochs)
    pbar.set_description('Training')
    for epoch in range(config.hyperparameters.epochs):
        train_loss = 0
        train_recon_loss = 0
        train_mmd_loss = 0
        model.train()

        # mini-batch loop
        for i_batch, data in enumerate(train_loader):
            # norm the data
            data = norm_data(data, model.trainset_stats)
            
            # get batch data and send to the right device
            data = data.to(device)

            # reset the gradients back to zero
            optimizer.zero_grad()

            # run forward pass and compute the batch training loss metrics
            data = model(data)
            batch_loss, batch_recon_loss, batch_mmd_loss = model.compute_loss(data)
            train_loss += batch_loss.item()
            train_recon_loss += batch_recon_loss.item()
            train_mmd_loss += batch_mmd_loss.item()

            # perform SGD parameter update
            batch_loss.backward()
            optimizer.step()

        # compute the epoch training loss
        train_loss = train_loss / len(train_loader)
        train_recon_loss = train_recon_loss / len(train_loader)
        train_mmd_loss = train_mmd_loss / len(train_loader)

        # step the scheduler
        scheduler.step()

        # save the trained model every n epochs
        if (epoch + 1) % config.io_settings.save_epochs == 0:
            torch.save({'model_state_dict': model.state_dict(),'trainset_stats': model.trainset_stats},
                       os.path.join(current_run_dir, 'trained_models', 'e{}.pt'.format(epoch + 1)))


        if config.run_settings.validate:
            # compute validation loss
            validation_loss = 0
            validation_recon_loss = 0
            validation_mmd_loss = 0


            model.eval()
            with torch.no_grad():
                for i_batch, data in enumerate(validate_loader):
                    # norm the data
                    data = norm_data(data, model.trainset_stats)
                    
                    # get batch data and send to the right device, reshape globals
                    data = data.to(device)

                    # forward pass
                    data = model(data)
                    
                    # compute the batch validation loss metrics
                    batch_loss, batch_recon_loss, batch_mmd_loss = model.compute_loss(data)
                    validation_loss += batch_loss
                    validation_recon_loss += batch_recon_loss
                    validation_mmd_loss += batch_mmd_loss

            # get the full dataset validation loss for this epoch
            validation_loss =  validation_loss / len(validate_loader)
            validation_recon_loss = validation_recon_loss / len(validate_loader)
            validation_mmd_loss = validation_mmd_loss / len(validate_loader)

            # display losses and progress bar
            pbar.set_postfix({'Train Loss': f'{train_loss:.8f}','Validation Loss': f'{validation_loss:.8f}'})
            pbar.update(1)

            # save the model with the best validation loss
            if epoch == 0:
                best_validation_loss = validation_loss
            else:
                if validation_loss < best_validation_loss:
                    best_validation_loss = validation_loss
                    torch.save({'model_state_dict': model.state_dict(),'trainset_stats': model.trainset_stats},
                        os.path.join(current_run_dir, 'trained_models', 'best.pt'))

        else:
            # display losses and progress bar
            pbar.set_postfix({'Train Loss': f'{train_loss:.8f}'})
            pbar.update(1)
            
if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', '-c', help="path to the yaml config file", type=str, required=True)
    train(config_path=parser.parse_args().config)