from mdgen.parsing_ti import parse_train_args
import numpy as np

args = parse_train_args()
from mdgen.logger import get_logger

logger = get_logger(__name__)

import torch, os, wandb

if args.mdcath:
    from mdgen.dataset_msm_mdcath import MDGenDataset
else:
    from mdgen.dataset_msm import MDGenDataset
from mdgen.wrapper_st import NewMDGenWrapper
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_only


torch.set_float32_matmul_precision("medium")


if args.wandb:
    if rank_zero_only.rank == 0:
        wandb.init(
            settings=wandb.Settings(start_method="fork"),
            project="mdgen_msm",
            name=args.run_name,
            config=args,
        )

from mdgen.utils import set_seed

set_seed(args.seed)

trainset = MDGenDataset(args, split=args.train_split)

if args.overfit:
    valset = trainset
else:
    valset = MDGenDataset(args, split=args.val_split, repeat=args.val_repeat)

train_loader = torch.utils.data.DataLoader(
    trainset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    shuffle=True,
    # pin_memory=True,
)

val_loader = torch.utils.data.DataLoader(
    valset,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    # pin_memory=True,
)
model = NewMDGenWrapper(args)
trainer = pl.Trainer(
    accelerator="gpu" if torch.cuda.is_available() else "auto",
    max_epochs=args.epochs,
    limit_train_batches=args.train_batches or 1.0,
    limit_val_batches=0.0 if args.no_validate else (args.val_batches or 1.0),
    num_sanity_val_steps=0,
    precision=args.precision,
    enable_progress_bar=not args.wandb,
    gradient_clip_val=args.grad_clip,
    default_root_dir=os.environ["MODEL_DIR"],
    callbacks=[
        ModelCheckpoint(
            dirpath=os.environ["MODEL_DIR"],
            save_top_k=-1,
            every_n_epochs=args.ckpt_freq,
        ),
        ModelSummary(max_depth=2),
    ],
    accumulate_grad_batches=args.accumulate_grad,
    val_check_interval=args.val_freq,
    check_val_every_n_epoch=args.val_epoch_freq,
    logger=False,
)
# dist.init_process_group(backend="nccl")
# model = torch.nn.parallel.DistributedDataParallel(model)

# -------------------- Check for existing checkpoints --------------------
ckpt_path = args.ckpt
if ckpt_path is None:
    model_dir = os.environ["MODEL_DIR"]
    ckpt_files = [
        os.path.join(model_dir, f) for f in os.listdir(model_dir) if f.endswith(".ckpt")
    ]
    if ckpt_files:
        latest_ckpt = max(ckpt_files, key=os.path.getctime)
        logger.info(f"Found an existing checkpoint: {latest_ckpt}. Resuming from it.")
        ckpt_path = latest_ckpt
    else:
        logger.info("No existing checkpoint found. Training from scratch.")
else:
    logger.info(f"Using user-supplied checkpoint: {ckpt_path}")

# -------------------- Train or Validate --------------------

if args.validate:
    trainer.validate(model, val_loader, ckpt_path=ckpt_path)
else:
    trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path)
