# finetune.py
# -*- coding: utf-8 -*-
import os
import argparse
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import RichProgressBar, ModelCheckpoint

from dataset import get_qm9_dataloaders
from model import TFNModel, HEGNNModel
from trainer import Regressor, compute_label_stats


def finetune_head(args, train_loader, val_loader, y_mean, y_std, mask_type, mask_list=None):
    encoder = {
        "HEGNN": HEGNNModel,
        "TFN": TFNModel,
    }[args.encoder]()

    model = Regressor(
        encoder=encoder,
        input_dim=encoder.output_dim,
        y_mean=y_mean,
        y_std=y_std,
        label_idx=args.label_idx,
        lr=args.lr,
        mode="finetune",
        mask_list=mask_list,
        encoder_ckpt=args.encoder_ckpt,
        hidden_dim=args.hidden_dim,
        swanlab_enabled=(not args.no_swanlab),
        swanlab_project=args.swanlab_project,
        swanlab_run_name=args.swanlab_run_name,
        swanlab_config_extra={
            "phase": "finetune",
            "dataset": "QM9",
            "mask_type": mask_type,
            "mask_list": mask_list,
            "encoder": args.encoder,
        },
    )

    if mask_type is None:
        save_dir = args.outdir
    else:
        save_dir = os.path.join(args.outdir, f"mask_{mask_type}")
    os.makedirs(save_dir, exist_ok=True)

    ckpt_callback = ModelCheckpoint(
        dirpath=save_dir,
        filename="ft-best-{epoch:02d}-{val_loss:.4f}",
        monitor="val_loss",
        mode="min",
        save_top_k=1,
        save_last=True,
    )

    trainer = pl.Trainer(
        max_epochs=args.epochs,
        accelerator="auto",
        devices="auto",
        strategy="ddp_find_unused_parameters_true",  # 和 pretrain 对齐，防止 DDP unused params 报错
        callbacks=[RichProgressBar(), ckpt_callback],
        log_every_n_steps=20,
        gradient_clip_val=1.0,
        gradient_clip_algorithm="norm",
        enable_checkpointing=True,
    )

    trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

    print(f"[Finetune][mask={mask_type}] Training complete.")
    print(f"[Finetune][mask={mask_type}] Best checkpoint: {ckpt_callback.best_model_path}")
    print(f"[Finetune][mask={mask_type}] Best val_loss: {ckpt_callback.best_model_score}")

    best_ckpt_path = ckpt_callback.best_model_path
    if best_ckpt_path and os.path.isfile(best_ckpt_path):
        ckpt = torch.load(best_ckpt_path, map_location="cpu")
        state_dict = ckpt["state_dict"]

        head_state_dict = {
            k.replace("head.", ""): v
            for k, v in state_dict.items()
            if k.startswith("head.")
        }

        head_save_path = os.path.join(
            save_dir, f"head-best-mask_0_{mask_type}.pt"
        )

        torch.save(head_state_dict, head_save_path)
        print(f"[Finetune][mask={mask_type}] Saved head weights to: {head_save_path}")
    else:
        print(f"[Finetune][mask={mask_type}] WARNING: best checkpoint not found; head not saved.")


def main():
    parser = argparse.ArgumentParser(
        description="Finetune regression head with a frozen encoder (train + validation split)"
    )
    parser.add_argument("--encoder", type=str, default="HEGNN", choices=["HEGNN", "TFN"])
    parser.add_argument("--root", type=str, default="./data/QM9")
    parser.add_argument("--outdir", type=str, default="./ckpts_ft")
    parser.add_argument("--encoder_ckpt", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=30)
    parser.add_argument("--lr", type=float, default=1e-3)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--label_idx", type=int, default=1)
    parser.add_argument("--hidden_dim", type=int, default=64)

    parser.add_argument("--no_swanlab", action="store_true", help="Disable SwanLab logging")
    parser.add_argument("--swanlab_project", type=str, default="QM9", help="SwanLab project name")
    parser.add_argument("--swanlab_run_name", type=str, default=None, help="SwanLab run name")
    args = parser.parse_args()

    os.makedirs(args.outdir, exist_ok=True)
    pl.seed_everything(args.seed, workers=True)

    train_loader, val_loader, _ = get_qm9_dataloaders(
        root=args.root,
        batch_size=args.batch_size,
        seed=args.seed,
        num_workers=args.num_workers,
        pin_memory=True,
    )

    y_mean, y_std = compute_label_stats(train_loader, label_idx=args.label_idx)

    ell_all = [f"{l}{'o' if l % 2 else 'e'}" for l in range(12)]
    # l = L
    # mask_all = {l_max: [ell_all[l]] for l_max in range(12)}
    # l = 0, ..., L
    mask_all = {l_max: [ell_all[l] for l in range(l_max + 1)] for l_max in range(12)}

    for k, v in mask_all.items():
        finetune_head(
            args=args,
            train_loader=train_loader,
            val_loader=val_loader,
            y_mean=y_mean,
            y_std=y_std,
            mask_type=k,
            mask_list=v,
        )


if __name__ == "__main__":
    main()
