import argparse
import ruamel_yaml as yaml
from pathlib import Path
import os
import torch
torch.set_float32_matmul_precision('high')
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from transformers import RobertaConfig, RobertaTokenizer, RobertaModel, logging
logging.set_verbosity_error()

from pl_data import MMTSFMDatasetModule
from pl_model import LitMMTSFM

seed_everything(42)


def train(model_version=None,
          model_params=None,
          tokenizer=None,
          max_len=None,
          lr=None, 
          weight_decay=None,
          warmup_ratio=None,
          config=None,
          is_gpu=False,
          max_epochs=None,
          patience=None,
          checkpoint_path=None,
          ):
    
    lit_model = LitMMTSFM(model_params=model_params, tokenizer=tokenizer, max_len=max_len, lr=lr, weight_decay=weight_decay, warmup_ratio=warmup_ratio)
    data_module = MMTSFMDatasetModule(config=config)
    logger = TensorBoardLogger('./output/forecast/MyLogs', name='MM_model')
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=f'./output/forecast/model_ckpts/{model_version}',
        filename='model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=5,
        mode='min',
    )
    
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=patience,
        mode='min'
    )
    
    if is_gpu:
        accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
        devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    else:
        accelerator = 'cpu'
        devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    
    print(f"Finally use {accelerator}")

    trainer = Trainer(
        # log_every_n_steps=2,
        max_epochs=max_epochs, 
        logger=logger, 
        accelerator=accelerator,
        devices=devices,
        callbacks=[checkpoint_callback,early_stopping_callback],
        val_check_interval=1.0,
    )
    trainer.fit(lit_model, data_module, ckpt_path=checkpoint_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', default='synthetic')
    parser.add_argument('--config', default='./configs/forecast.yaml', help='Configuration of MMTSFM')
    parser.add_argument('--output_dir', default='./output/forecast')
    parser.add_argument('--text_encoder', default='./models/roberta_base') # Roberta
    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    max_epochs = 300

    if args.dataset_name=='ETTm1':
        config['train_file'] = './dataset/captioned_public/ETTm1_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTm1_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='ETTm2':
        config['train_file'] = './dataset/captioned_public/ETTm2_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTm2_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='ETTh1':
        config['train_file'] = './dataset/captioned_public/ETTh1_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTh1_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='ETTh2':
        config['train_file'] = './dataset/captioned_public/ETTh2_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTh2_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
        
    model_version = 'test'

    tokenizer = RobertaTokenizer.from_pretrained(args.text_encoder + '/')
    roberta_config = RobertaConfig.from_pretrained(args.text_encoder + '/')
    roberta_config.num_hidden_layers = config['text_encoder_layers']
    roberta_config.output_attentions = True
    roberta_config.output_hidden_states = True
    text_encoder = RobertaModel.from_pretrained(args.text_encoder + '/', config=roberta_config)
    text_encoder_type = 'Roberta'
    
    model_params = {
        'dim': config['dim'],
        'num_vars': config['num_vars'],
        'ts_token_size': config['token_size'],
        'ts_output_size': config['output_size'],
        'unimodal_depth': config['unimodal_depth'],
        'multimodal_depth': config['multimodal_depth'],
        'text_dim': config['text_dim'],
        'num_text_queries': config['num_text_queries'],
        'heads': config['heads'],
        'dim_head': config['dim_head'],
        'ff_mult': config['ff_mult'],
        'dropout_rate_fcst': config['dropout_rate_fcst'],
        'dropout_rate_cons': config['dropout_rate_cons'],
        'textAug': config['textAug_flag'],
        'addfuture': config['addfuture_flag'],
        'text_encoder': text_encoder,
        'text_encoder_type': text_encoder_type,
        'text_encoder_frozen_flag': config['text_encoder_frozen_flag'],
        'forecast_loss_weight': config['forecast_loss_weight'],
        'contrastive_loss_weight': config['contrastive_loss_weight'],
        'temperature': config['temperature'],
    }
    lr = 1e-4
    weight_decay = 1e-2
    warmup_ratio = 0.02
    patience = 7
    is_gpu = True

    max_len = config['max_len']
    train(model_version=model_version,
          model_params=model_params,
          tokenizer=tokenizer,
          max_len=max_len,
          lr=lr, 
          weight_decay=weight_decay,
          warmup_ratio=warmup_ratio,
          config=config,
          is_gpu=is_gpu,
          max_epochs=max_epochs,
          patience=patience,
          )