import os
import argparse
import torch
import pytorch_lightning as pl
from dataloader import DataModule, build_datamodule
from utils import seed_everything
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.loggers import CSVLogger
from trainer import LitGATCausalRegressor
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)


def main():
    parser = argparse.ArgumentParser(description='Lightning Causal Attention Regression')

    # ----------------------- core parameters ------------------------- #
    parser.add_argument("--mode", required=True,
                        choices=["SMILES", "PEPTIDE", "GEOMETRY", "FUSION"],
                        help="Which modality or fusion model to train.")
    parser.add_argument("--data_dir", required=True, help="CSV directory.")
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--num_workers", type=int, default=4)

    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--lr_smiles",   type=float, default=5e-4, help="learning rate for the SMILES backbone optimizer")
    parser.add_argument("--lr_peptide",  type=float, default=5e-4, help="learning rate for the Peptide backbone optimizer")
    parser.add_argument("--lr_geometry", type=float, default=5e-4, help="learning rate for the Sequence backbone optimizer")


    parser.add_argument("--explain_fusion", action="store_true", help="Load checkpoint and run fusion visualisation only.")
    parser.add_argument("--ckpt_path", default="best.ckpt")

    parser.add_argument("--dropout", type=float, default=0.3)
    parser.add_argument('--w_cons_loss', type=float, default=0)
    parser.add_argument("--seed", type=int, default=44)
    parser.add_argument("--gpus", type=int, default=1,
                        help="0 = CPU, >1 = multi-GPU DDP")

    # -------------------- loss & scheduler --------------------------- #
    parser.add_argument("--lambda_unif", type=float, default=1.0)
    parser.add_argument("--lambda_caus", type=float, default=1.0)
    parser.add_argument("--lambda_mono", type=float, default=0.5)
    
    parser.add_argument("--scheduler_type", default="plateau",
                        choices=["plateau", "cosine", "warm"])

    # -------------------- architecture ------------------------------ #
    parser.add_argument("--num_causal_blocks", type=int, default=3)

    parser.add_argument("--hidden_dim_projector", type=int, default=64)
    parser.add_argument("--emb_dim_smiles", type=int, default=32)
    parser.add_argument("--emb_dim_peptide", type=int, default=32)
    parser.add_argument("--emb_dim_geometry", type=int, default=64)

    parser.add_argument("--hidden_dim_smiles", type=int, default=64)
    parser.add_argument("--hidden_dim_peptide", type=int, default=64)
    parser.add_argument("--hidden_dim_geometry", type=int, default=64)

    parser.add_argument("--num_gc_layers_smiles", type=int, default=3)
    parser.add_argument("--num_gc_layers_peptide", type=int, default=3)
    parser.add_argument("--num_gc_layers_geometry", type=int, default=3)

    parser.add_argument("--heads_smiles", type=int, default=4)
    parser.add_argument("--heads_peptide", type=int, default=4)
    parser.add_argument("--heads_geometry", type=int, default=4)

    # ------------------- misc --------------------------------------- #
    parser.add_argument("--checkpoint_path", type=str, default="model.ckpt")


    args = parser.parse_args()
    seed_everything(args.seed)


    # data & model
    dm_main = build_datamodule(args.mode, args)
    model_kwargs = vars(args).copy()
    print('number of causal blocks:', model_kwargs["num_causal_blocks"])

    if args.mode in ["FUSION", "PEPTIDE"]:
        vocab_size = len(dm_main.vocab)
        model_kwargs["n_peptide_types"] = vocab_size
    model = LitGATCausalRegressor(**model_kwargs)

    # callbacks and logger
    checkpoint_cb = ModelCheckpoint(
        dirpath=args.data_dir,
        filename=os.path.splitext(os.path.basename(args.checkpoint_path))[0] + '-{epoch:02d}-{val_loss:.4f}',
        save_top_k=1,
        monitor='val_loss',
        mode='min'
    )
    lr_monitor = LearningRateMonitor(logging_interval='step')
    logger = CSVLogger(save_dir=args.data_dir, name='logs')

    # determine accelerator and devices
    if args.gpus == 0:
        accelerator = 'cpu'
        devices = None
    else:
        accelerator = 'gpu'
        devices = args.gpus

    # trainer
    trainer = Trainer(
        max_epochs=args.epochs,
        accelerator=accelerator,
        devices=devices,
        precision="bf16-mixed",
        strategy=DDPStrategy(find_unused_parameters=True),
        callbacks=[checkpoint_cb, lr_monitor],
        logger=logger,
        deterministic=True
    )
 

    # run
    trainer.fit(model, dm_main)
    best_val_loss = trainer.checkpoint_callback.best_model_score
    # print("Best val_loss:", best_val_loss)

    trainer.test(model, dm_main)

if __name__ == '__main__':
    main()
