'''
=====
Transformer-PhysX
- Associated publication:
url: 
doi: 
github: 
=====
'''
import sys
import logging
import torch
from transformers import HfArgumentParser
from config.args import ModelArguments, TrainingArguments, DataArguments, ArgUtils
from config.configuration_auto import AutoPhysConfig
from phys_transformer.phys_transformer_heads import PhysformerTrainer
from embedding.embedding_auto import AutoEmbeddingModel
from viz.viz_auto import AutoViz
from data_utils.dataset_auto import AutoDataset, AutoPredictionDataset
from utils.trainer import PhysTrainer

logger = logging.getLogger(__name__)

if __name__ == "__main__":

    # Arguments for controlling which numerical example
    # Requires the pre-trained embedding model
    # sys.argv = sys.argv + ["--init_name", "lorenz"]
    # sys.argv = sys.argv + ["--embedding_file_or_path", "./embedding/pretrained/lorenz/embedding_lorenz300.pth"]
    # sys.argv = sys.argv + ["--training_h5_file","./data/lorenz.hdf5"]
    # sys.argv = sys.argv + ["--eval_h5_file","./data/lorenz_valid.hdf5"]
    # sys.argv = sys.argv + ["--train_batch_size", "16"]
    # sys.argv = sys.argv + ["--stride", "16"]
    # sys.argv = sys.argv + ["--max_grad_norm", "0.1"]

    # Flow around a cylinder
    # sys.argv = sys.argv + ["--init_name", "cylinder"]
    # sys.argv = sys.argv + ["--embedding_file_or_path",
    #                        "./embedding/pretrained/cylinder/embedding_cylinder300.pth"]
    # 
    # sys.argv = sys.argv + ["--embedding_file_or_path",
    #                        "./embedding/pretrained/cylinder_auto/embedding_cylinder300.pth"]
    # sys.argv = sys.argv + ["--notes", "auto"]
    #
    # sys.argv = sys.argv + ["--embedding_file_or_path",
    #                        "./embedding/pretrained/cylinder_pca/embedding_cylinder_pca0.pth"]
    # sys.argv = sys.argv + ["--embedding_name", "cylinder_pca"]
    # sys.argv = sys.argv + ["--notes", "pca"]
    #
    # sys.argv = sys.argv + ["--training_h5_file", "./data/cylinder.hdf5"]
    # sys.argv = sys.argv + ["--eval_h5_file", "./data/cylinder_valid.hdf5"]
    # sys.argv = sys.argv + ["--train_batch_size", "16"]
    # sys.argv = sys.argv + ["--n_train", "27"]
    # sys.argv = sys.argv + ["--n_eval", "13"]
    # sys.argv = sys.argv + ["--stride", "4"]
    # sys.argv = sys.argv + ["--max_grad_norm", "0.001"]
    # sys.argv = sys.argv + ["--epochs", "200"]

    # Gray-scott system
    sys.argv = sys.argv + ["--init_name", "grayscott"]
    sys.argv = sys.argv + ["--embedding_file_or_path",
                           "./embedding/pretrained/grayscott/embedding_grayscott200.pth"]
    sys.argv = sys.argv + ["--training_h5_file", "./data/grayscott.hdf5"]
    sys.argv = sys.argv + ["--eval_h5_file", "./data/grayscott_valid.hdf5"]
    sys.argv = sys.argv + ["--train_batch_size", "8"]
    sys.argv = sys.argv + ["--eval_batch_size", "8"]
    sys.argv = sys.argv + ["--n_train", "512"]
    sys.argv = sys.argv + ["--n_eval", "8"]
    sys.argv = sys.argv + ["--stride", "4"]
    sys.argv = sys.argv + ["--max_grad_norm", "0.1"]
    sys.argv = sys.argv + ["--epochs", "200"]

    # Parse arguments using the hugging face argument parser
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN)
    # Configure arguments after intialization 
    model_args, data_args, training_args = ArgUtils.config(model_args, data_args, training_args)

    # Load model configuration
    config = AutoPhysConfig.load_config(model_args.config_name)
    # Load embedding model
    embedding_model = AutoEmbeddingModel.load_model(
        model_args.embedding_name, 
        config, 
        model_args.embedding_file_or_path).to(training_args.src_device)

    # Load visualization utility class
    viz = AutoViz.init_viz(model_args.viz_name)(training_args.plot_dir)
    
    # Init transformer model
    model  = PhysformerTrainer(config, embedding_model, model_args.model_name)
    if(training_args.epoch_start > 0):
        model.load_model(training_args.ckpt_dir, epoch=training_args.epoch_start)
    if(model_args.transformer_file_or_path):
        model.load_model(model_args.transformer_file_or_path)
    
    # Initialize 
    training_data = AutoDataset.create_dataset(
        model_args.model_name,
        embedding_model, 
        data_args.training_h5_file, 
        block_size=config.n_ctx, 
        stride=data_args.stride,
        ndata=data_args.n_train, 
        overwrite_cache=data_args.overwrite_cache)

    eval_data = AutoPredictionDataset.create_dataset(
        model_args.model_name,
        embedding_model, 
        data_args.eval_h5_file, 
        block_size=150,
        neval=data_args.n_eval, 
        overwrite_cache=data_args.overwrite_cache)

    # Optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=training_args.lr, weight_decay=1e-8)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 14, 2, eta_min=1e-8)
    trainer = PhysTrainer(model, training_args, (optimizer, scheduler), train_dataset=training_data, eval_dataset=eval_data, viz=viz)
    
    trainer.train()
