import argparse
from experiment import *
from main_utils import *
from utils import *

parser = argparse.ArgumentParser(description='Inverse problem of PDE')

parser.add_argument('--config',
                    type=str,
                    default=None,
                    required=True)

parser.add_argument('--device',
                    type=str,
                    default=None,
                    required=True)  

parser.add_argument('--seed', 
                    type=int, 
                    default=0)

args = parser.parse_args()


if __name__ == "__main__":
    set_seed(args.seed)
    os.environ['CUDA_VISIBLE_DEVICES'] = args.device
    torch.distributed.init_process_group("nccl")
    local_rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    device = local_rank
    torch.cuda.set_device(device)

    config_file = "configs/" + args.config + ".jsonc"
    config = Configuration(config_file)
    
    train_dataset, train_sampler, train_dataloader, \
    val_dataset, val_sampler, val_dataloader, \
    test_dataset, test_sampler, test_dataloader, \
    transformer, masker, poser, \
    model, loss, optimizer, scheduler \
    = get_data_model(config, device)
    
    log_dir = "../runs/" + args.config + "/log/"
    checkpoint_dir = "../runs/" + args.config + "/checkpoint/"

    if config.role == "propagator":
        train_propagator(
            train_dataloader,
            val_dataloader,
            test_dataloader,
            transformer,
            masker,
            poser,
            model,
            config.model.name,
            loss,
            optimizer,
            scheduler,
            local_rank,
            world_size,
            config.train.grad_clip,
            config.train.epoch,
            config.train.log_print_interval_epoch,
            config.train.model_save_interval_epoch,
            log_dir,
            checkpoint_dir
            )
    elif config.role == "completer":
        train_completer(
            train_dataloader,
            val_dataloader,
            test_dataloader,
            transformer,
            masker,
            poser,
            model,
            config.model.name,
            loss,
            optimizer,
            scheduler,
            local_rank,
            world_size,
            config.train.grad_clip,
            config.train.epoch,
            config.train.log_print_interval_epoch,
            config.train.model_save_interval_epoch,
            log_dir,
            checkpoint_dir
            )
    torch.distributed.destroy_process_group()
