# pretrain.py
# -*- coding: utf-8 -*-
import os
import argparse
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar

from dataset import get_qm9_dataloaders
from model import TFNModel, HEGNNModel
from trainer import Regressor, compute_label_stats


def main():
    parser = argparse.ArgumentParser(
        description="Pretrain encoder + head on QM9 (train split only)"
    )
    parser.add_argument("--encoder", type=str, default="HEGNN")
    parser.add_argument("--root", type=str, default="./data/QM9")
    parser.add_argument("--outdir", type=str, default="./ckpts_pre")
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=300)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--label_idx", type=int, default=1, help="QM9 target index")
    parser.add_argument("--hidden_dim", type=int, default=64, help="Hidden dimension of MLP head")
    parser.add_argument("--no_swanlab", action="store_true", help="Disable SwanLab logging")
    parser.add_argument("--swanlab_project", type=str, default="QM9", help="SwanLab project name")
    parser.add_argument("--swanlab_run_name", type=str, default=None, help="SwanLab run name")
    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)

    pl.seed_everything(args.seed, workers=True)

    train_loader, _, _ = get_qm9_dataloaders(
        root=args.root,
        batch_size=args.batch_size,
        seed=args.seed,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    y_mean, y_std = compute_label_stats(train_loader, label_idx=args.label_idx)

    encoder = {
        'HEGNN': HEGNNModel,
        'TFN': TFNModel,
    }[args.encoder]()
    model = Regressor(
        encoder=encoder,
        input_dim=encoder.output_dim,
        y_mean=y_mean,
        y_std=y_std,
        label_idx=args.label_idx,
        lr=args.lr,
        mode="pretrain",
        hidden_dim=args.hidden_dim,
        swanlab_enabled=(not args.no_swanlab),
        swanlab_project=args.swanlab_project,
        swanlab_run_name=args.swanlab_run_name,
        swanlab_config_extra={
            "phase": "pretrain",
            "dataset": "QM9",
        },
    )

    trainer = pl.Trainer(
        max_epochs=args.epochs,
        accelerator="auto",
        devices="auto",
        strategy="ddp_find_unused_parameters_true",
        callbacks=[RichProgressBar()],
        log_every_n_steps=20,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        enable_checkpointing=False,
    )

    trainer.fit(model, train_dataloaders=train_loader)

    if trainer.is_global_zero:
        enc_path = os.path.join(args.outdir, "encoder_pretrained.pt")
        torch.save(model.encoder.state_dict(), enc_path)
        print(f"[Pretrain] encoder saved to: {enc_path}")

        # Optionally save full checkpoint for resuming
        ckpt_path = os.path.join(args.outdir, "pretrain_full.ckpt")
        trainer.save_checkpoint(ckpt_path)
        print(f"[Pretrain] full checkpoint saved to: {ckpt_path}")


if __name__ == "__main__":
    main()
