import argparse
import ruamel_yaml as yaml
from pathlib import Path
import os
import torch
torch.set_float32_matmul_precision('high')
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from transformers import RobertaConfig, RobertaTokenizer, RobertaModel, BigBirdConfig, BigBirdModel, logging
logging.set_verbosity_error()

from pl_data import MMTSFMDatasetModule
from pl_model import LitMMTSFM

seed_everything(42)


def train(model_version=None,
          model_params=None,
          tokenizer=None,
          max_len=None,
          lr=None, 
          weight_decay=None,
          warmup_ratio=None,
          config=None,
          is_gpu=False,
          max_epochs=None,
          patience=None,
          checkpoint_path=None,
          ):
    
    lit_model = LitMMTSFM(model_params=model_params, tokenizer=tokenizer, max_len=max_len, lr=lr, weight_decay=weight_decay, warmup_ratio=warmup_ratio)
    data_module = MMTSFMDatasetModule(config=config)
    logger = TensorBoardLogger('./output/forecast/MyLogs', name='MM_model')
    
    checkpoint_callback = ModelCheckpoint(
        monitor='val_loss',
        dirpath=f'./output/forecast/model_ckpts/{model_version}',
        filename='model-{epoch:02d}-{val_loss:.2f}',
        save_top_k=5,
        mode='min',
    )
    
    early_stopping_callback = EarlyStopping(
        monitor='val_loss',
        patience=patience,
        mode='min'
    )
    
    if is_gpu:
        accelerator = 'gpu' if torch.cuda.is_available() else 'cpu'
        devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    else:
        accelerator = 'cpu'
        devices = torch.cuda.device_count() if torch.cuda.is_available() else 1
    
    print(f"Finally use {accelerator}")

    trainer = Trainer(
        # log_every_n_steps=2,
        max_epochs=max_epochs, 
        logger=logger, 
        accelerator=accelerator,
        devices=devices,
        callbacks=[checkpoint_callback,early_stopping_callback],
        val_check_interval=1.0,
    )
    trainer.fit(lit_model, data_module, ckpt_path=checkpoint_path)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset_name', default='synthetic')
    parser.add_argument('--config', default='./configs/forecast.yaml', help='Configuration of MMTSFM')
    parser.add_argument('--output_dir', default='./output/forecast')
    parser.add_argument('--text_encoder', default='./models/roberta_base') # Roberta
    args = parser.parse_args()

    config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)

    max_epochs = 300
    lr = 1e-4
    weight_decay = 1e-2
    warmup_ratio = 0.02
    patience = 7
    is_gpu = True

    if args.dataset_name=='ETTm1':
        config['train_file'] = './dataset/captioned_public/ETTm1_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTm1_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='ETTm2':
        config['train_file'] = './dataset/captioned_public/ETTm2_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTm2_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='ETTh1':
        config['train_file'] = './dataset/captioned_public/ETTh1_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTh1_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='ETTh2':
        config['train_file'] = './dataset/captioned_public/ETTh2_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/ETTh2_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='exchange_rate':
        config['train_file'] = './dataset/captioned_public/exchange_rate_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/exchange_rate_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 96
    elif args.dataset_name=='stock':
        config['train_file'] = './dataset/captioned_public/stock_train_dataset.pth'
        config['val_file'] = './dataset/captioned_public/stock_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 42
        config['output_size'] = 21
    elif args.dataset_name=='Weather_captioned':
        config['train_file'] = './dataset/existing_mm/Weather_captioned_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Weather_captioned_val_dataset.pth'
        max_epochs = 100
        config['max_len'] = 512
        config['max_patchnum'] = 36
        config['output_size'] = 36
        config['dim'] = 64
        config['heads'] = 4
    elif args.dataset_name=='Time_MMD_Climate':
        config['train_file'] = './dataset/existing_mm/Time_MMD_Climate_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_Climate_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1024
        config['max_patchnum'] = 1
        config['output_size'] = 8
        config['dim'] = 64
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 2
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32
    elif args.dataset_name=='Time_MMD_Economy':
        config['train_file'] = './dataset/existing_mm/Time_MMD_Economy_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_Economy_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1024
        config['max_patchnum'] = 1
        config['output_size'] = 8
        config['dim'] = 64
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 2
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32
        lr = 0.005
        patience = 20
    elif args.dataset_name=='Time_MMD_SocialGood':
        config['train_file'] = './dataset/existing_mm/Time_MMD_SocialGood_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_SocialGood_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1800
        config['max_patchnum'] = 1
        config['output_size'] = 8
        config['dim'] = 64
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 2
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32
    elif args.dataset_name=='Time_MMD_Traffic':
        config['train_file'] = './dataset/existing_mm/Time_MMD_Traffic_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_Traffic_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1280
        config['max_patchnum'] = 1
        config['output_size'] = 8
        config['dim'] = 128
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 4
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32
    elif args.dataset_name=='Time_MMD_Energy':
        config['train_file'] = './dataset/existing_mm/Time_MMD_Energy_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_Energy_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1280
        config['max_patchnum'] = 5
        config['output_size'] = 12
        config['dim'] = 64
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 2
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32
    elif args.dataset_name=='Time_MMD_Health_US':
        config['train_file'] = './dataset/existing_mm/Time_MMD_Health_US_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_Health_US_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1800
        config['max_patchnum'] = 5
        config['output_size'] = 12
        config['dim'] = 128
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 2
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32
    elif args.dataset_name=='Time_MMD_Health_AFR':
        config['train_file'] = './dataset/existing_mm/Time_MMD_Health_AFR_train_dataset.pth'
        config['val_file'] = './dataset/existing_mm/Time_MMD_Health_AFR_val_dataset.pth'
        max_epochs = 300
        config['max_len'] = 1440
        config['max_patchnum'] = 5
        config['output_size'] = 12
        config['dim'] = 128
        config['unimodal_depth'] = 2
        config['multimodal_depth'] = 1
        config['heads'] = 2
        config['batch_size_train'] = 32
        config['batch_size_val'] = 32

    Path(args.output_dir).mkdir(parents=True, exist_ok=True)

    yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
        
    model_version = 'test'
    if args.dataset_name.startswith('Time_MMD_'):
        args.text_encoder = './models/bigbird_roberta_base'
        tokenizer = RobertaTokenizer.from_pretrained(args.text_encoder + '/')
        roberta_config = BigBirdConfig.from_pretrained(args.text_encoder + '/')
        roberta_config.num_hidden_layers = config['text_encoder_layers']
        roberta_config.output_attentions = True
        roberta_config.output_hidden_states = True
        text_encoder = BigBirdModel.from_pretrained(args.text_encoder + '/', config=roberta_config)
        text_encoder_type = 'Roberta_BigBird'
    else:
        tokenizer = RobertaTokenizer.from_pretrained(args.text_encoder + '/')
        roberta_config = RobertaConfig.from_pretrained(args.text_encoder + '/')
        roberta_config.num_hidden_layers = config['text_encoder_layers']
        roberta_config.output_attentions = True
        roberta_config.output_hidden_states = True
        text_encoder = RobertaModel.from_pretrained(args.text_encoder + '/', config=roberta_config)
        text_encoder_type = 'Roberta'
    
    model_params = {
        'dim': config['dim'],
        'num_vars': config['num_vars'],
        'ts_token_size': config['token_size'],
        'ts_output_size': config['output_size'],
        'unimodal_depth': config['unimodal_depth'],
        'multimodal_depth': config['multimodal_depth'],
        'text_dim': config['text_dim'],
        'num_text_queries': config['num_text_queries'],
        'heads': config['heads'],
        'dim_head': config['dim_head'],
        'ff_mult': config['ff_mult'],
        'dropout_rate_fcst': config['dropout_rate_fcst'],
        'dropout_rate_cons': config['dropout_rate_cons'],
        'textAug': config['textAug_flag'],
        'addfuture': config['addfuture_flag'],
        'text_encoder': text_encoder,
        'text_encoder_type': text_encoder_type,
        'text_encoder_frozen_flag': config['text_encoder_frozen_flag'],
        'forecast_loss_weight': config['forecast_loss_weight'],
        'contrastive_loss_weight': config['contrastive_loss_weight'],
        'temperature': config['temperature'],
    }

    max_len = config['max_len']
    train(model_version=model_version,
          model_params=model_params,
          tokenizer=tokenizer,
          max_len=max_len,
          lr=lr, 
          weight_decay=weight_decay,
          warmup_ratio=warmup_ratio,
          config=config,
          is_gpu=is_gpu,
          max_epochs=max_epochs,
          patience=patience,
          )