import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from argparse import ArgumentParser
from rdkit import Chem

import utils
from model import SmilesTransformerDistillationModel, SmilesTransformerFinetuneModel
from data import SmilesDataModule, DistillationDataModule
from args import add_model_args, add_data_args, to_int_list
from pytorch_lightning.loggers import WandbLogger
import wandb
import os
from time import time

def main():
    parser = ArgumentParser()
    parser = pl.Trainer.add_argparse_args(parser)
    parser = add_model_args(parser)
    parser = add_data_args(parser)
    args = parser.parse_args()
    print(args)

    local_rank = os.environ.get("LOCAL_RANK", 0)
    if local_rank == 0:
        wandb.init(
            project='Smiles Transformer Training on {}'.format(args.dataset_name),
            mode='offline' if args.wandb_offline else 'online',
        )
    pl.seed_everything(args.seed)
    # model
    if args.online_distillation:
        datamodule = DistillationDataModule.from_argparse_args(args)
        dataset_dict = datamodule.smiles_dataset_dict
        model = SmilesTransformerDistillationModel(
            dataset_name=args.dataset_name,
            vocab_size=dataset_dict['vocab_size'],
            d_model=args.d_model,
            nhead=args.nhead,
            dim_feedforward=args.dim_feedforward,
            dropout=args.dropout,
            num_layers=args.num_layers,
            max_len=dataset_dict['max_len'],
            pe_type=args.pe_type,
            pe_scale_factor=args.pe_scale_factor,
            teacher_save_path=args.teacher_checkpoint_save_path,
            feat_dist_layers_s=to_int_list(args.student_feature_distillation_layers),
            feat_dist_layers_t=to_int_list(args.teacher_feature_distillation_layers),
            feat_dist_loss_weight=args.feature_distillation_loss_weight,
            attnw_dist_layers_s=to_int_list(args.student_attention_weight_distillation_layers),
            attnw_dist_layers_t=to_int_list(args.teacher_attention_weight_distillation_layers),
            attnw_dist_loss_weight=args.attention_weight_distillation_loss_weight,
            warmup_epochs=args.warmup_epochs,
            warmup_task_loss_weight=args.warmup_task_loss_weight,
            warmup_feat_dist_loss_weight=args.warmup_feature_distillation_loss_weight,
            warmup_attnw_dist_loss_weight=args.warmup_attention_weight_distillation_loss_weight,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay
        )
    elif args.finetune:
        datamodule = SmilesDataModule.from_argparse_args(args)
        dataset_dict = datamodule.dataset_dict
        model = SmilesTransformerFinetuneModel(
            checkpoint_path=args.finetune_checkpoint_save_path,
            dataset_name=args.dataset_name,
            change_dropout=args.change_dropout,
            freeze_layers=to_int_list(args.freeze_layers),
            d_model=args.d_model,
            learning_rate=args.learning_rate,
            weight_decay=args.weight_decay,
        )
    print(model)

    wandb_logger = WandbLogger(
        project='Smiles Transformer on {}'.format(args.dataset_name),
        offline=args.wandb_offline
    )
    # trainer
    trainer = pl.Trainer.from_argparse_args(args)
    trainer.logger = wandb_logger
    version = trainer.logger.version
    dirpath = args.default_root_dir + '/lightning_logs/version_{}/checkpoints'.format(version)
    metric = 'valid_{}'.format(dataset_dict['metric'])
    mode = dataset_dict['metric_mode']
    checkpoint_callback = ModelCheckpoint(
        #dirpath=dirpath,
        monitor=metric,
        filename='{epoch}_{' + metric + ':.4f}',
        save_top_k=args.save_top_k,
        mode=mode,
        save_last=True
    )
    start_time = time()
    trainer.callbacks.append(checkpoint_callback)
    if args.validate:
        trainer.validate(model=model, datamodule=datamodule)
    elif args.test:
        trainer.test(model=model, datamodule=datamodule)
    else:
        trainer.fit(model=model, datamodule=datamodule) 
        trainer.test(ckpt_path="best")
        if args.dataset_name == 'qm9':
            print(utils.multimae_results, torch.sum(utils.multimae_results))
    print('time: {}s'.format(time() - start_time))

if __name__ == '__main__':
    main()