from mdgen.parsing import parse_train_args

args = parse_train_args()
from mdgen.logger import get_logger

logger = get_logger(__name__)

import torch, os, wandb
from mdgen.dataset import MDGenDataset
from mdgen.wrapper import NewMDGenWrapper
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
import pytorch_lightning as pl
import numpy as np


from mdgen.utils import set_seed

set_seed(args.seed)

torch.set_float32_matmul_precision("medium")

if args.wandb:
    wandb.init(
        settings=wandb.Settings(start_method="fork"),
        project="mdgen",
        name=args.run_name,
        config=args,
    )


trainset = MDGenDataset(args, split=args.train_split)

if args.overfit or args.overfit_peptide:
    valset = trainset
else:
    valset = MDGenDataset(args, split=args.val_split, repeat=args.val_repeat)

train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=True,
)

val_loader = torch.utils.data.DataLoader(
    valset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
)
model = NewMDGenWrapper(args)

trainer = pl.Trainer(
    accelerator="gpu" if torch.cuda.is_available() else "auto",
    max_epochs=args.epochs,
    limit_train_batches=args.train_batches or 1.0,
    limit_val_batches=0.0 if args.no_validate else (args.val_batches or 1.0),
    num_sanity_val_steps=0,
    precision=args.precision,
    enable_progress_bar=not args.wandb,
    gradient_clip_val=args.grad_clip,
    default_root_dir=os.environ["MODEL_DIR"],
    callbacks=[
        ModelCheckpoint(
            dirpath=os.environ["MODEL_DIR"],
            save_top_k=-1,
            every_n_epochs=args.ckpt_freq,
        ),
        ModelSummary(max_depth=2),
    ],
    accumulate_grad_batches=args.accumulate_grad,
    val_check_interval=args.val_freq,
    check_val_every_n_epoch=args.val_epoch_freq,
    logger=False,
)


# dist.init_process_group(backend="nccl")
# model = torch.nn.parallel.DistributedDataParallel(model)
# -------------------- Check for existing checkpoints --------------------
ckpt_path = args.ckpt
if ckpt_path is None:
    model_dir = os.environ["MODEL_DIR"]
    ckpt_files = [
        os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith(".ckpt")
    ]
    if ckpt_files:
        latest_ckpt = max(ckpt_files, key=os.path.getctime)
        logger.info(f"Found an existing checkpoint: {latest_ckpt}. Resuming from it.")
        ckpt_path = latest_ckpt
    else:
        logger.info("No existing checkpoint found. Training from scratch.")
else:
    logger.info(f"Using user-supplied checkpoint: {ckpt_path}")

if args.validate:
    trainer.validate(model, val_loader, ckpt_path=ckpt_path)
else:
    trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
