# trainer.py
# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from typing import Tuple, Optional, List, Dict, Any

from model import BaseMLP, IrrepsMask, Scalarization


class Regressor(pl.LightningModule):
    """
    Two-stage regression:
      - mode='pretrain': train encoder + head (MAE/L1 in normalized space).
      - mode='finetune': load and freeze encoder, train a fresh head.
    Metrics (MAE/RMSE) are reported in original units.
    SwanLab: optional experiment logger; only logs on global rank 0.
    """
    def __init__(
        self,
        encoder: nn.Module,
        input_dim: int,
        y_mean: float,
        y_std: float,
        label_idx: int,
        lr: float = 1e-3,
        mode: str = "pretrain",
        mask_list: Optional[List] = None,
        encoder_ckpt: Optional[str] = None,
        head_ckpt: Optional[str] = None,
        hidden_dim: int = 64,
        # ---- SwanLab options ----
        swanlab_enabled: bool = True,
        swanlab_project: Optional[str] = "QM9",
        swanlab_run_name: Optional[str] = None,
        swanlab_config_extra: Optional[Dict[str, Any]] = None,
    ):
        super().__init__()
        assert mode in ("pretrain", "finetune")
        self.save_hyperparameters(ignore=["encoder"])

        self.encoder = encoder
        self.mode = mode
        self.lr = lr

        # mask / scalarization
        mask_list = [] if mask_list is None else mask_list
        self.mask = IrrepsMask(encoder.hidden_irreps, mask_list)
        self.scalarization = Scalarization(encoder.hidden_irreps)

        # head
        self.head = BaseMLP(input_dim=self.scalarization.output_dim, hidden_dim=hidden_dim, output_dim=1)
        if head_ckpt is not None:
            state = torch.load(head_ckpt, map_location="cpu")

            if isinstance(state, dict) and "state_dict" in state:
                head_sd = {
                    k[len("head."):] : v
                    for k, v in state["state_dict"].items()
                    if k.startswith("head.")
                }
                if not head_sd and "head" in state:
                    head_sd = state["head"]

            elif isinstance(state, dict) and "head" in state:
                head_sd = state["head"]

            else:
                head_sd = state

            missing, unexpected = self.head.load_state_dict(head_sd, strict=False)
            if unexpected:
                self.print(f"[warn] unexpected keys in head ckpt: {unexpected}")
            if missing:
                self.print(f"[warn] missing keys when loading head ckpt: {missing}")


        # normalization and target selection
        self.label_idx = int(label_idx)
        self.register_buffer("y_mean", torch.tensor([y_mean], dtype=torch.float32))
        self.register_buffer("y_std", torch.tensor([y_std], dtype=torch.float32))

        # runtime flag: are we freezing the encoder?
        self._freeze_encoder = (mode == "finetune")

        # SwanLab related
        self._swanlab_enabled = bool(swanlab_enabled)
        self._swanlab_project = swanlab_project
        self._swanlab_run_name = swanlab_run_name
        self._swanlab_config_extra = swanlab_config_extra or {}
        self._swanlab = None  # lazy import / init at on_fit_start

        # finetune: load and freeze encoder
        if self._freeze_encoder:
            if encoder_ckpt is None:
                raise ValueError("`encoder_ckpt` must be provided for finetune mode.")
            state = torch.load(encoder_ckpt, map_location="cpu")
            if isinstance(state, dict) and "state_dict" in state:
                enc_sd = {
                    k[len("encoder.") :]: v
                    for k, v in state["state_dict"].items()
                    if k.startswith("encoder.")
                }
                if not enc_sd and "encoder" in state:
                    enc_sd = state["encoder"]
            elif isinstance(state, dict) and "encoder" in state:
                enc_sd = state["encoder"]
            else:
                enc_sd = state  # assume it's a raw state_dict of encoder
            missing, unexpected = self.encoder.load_state_dict(enc_sd, strict=False)
            if unexpected:
                self.print(f"[warn] unexpected keys in encoder ckpt: {unexpected}")
            if missing:
                self.print(f"[warn] missing keys when loading encoder ckpt: {missing}")

            # hard freeze: no grads
            for p in self.encoder.parameters():
                p.requires_grad = False

            # also stop BN/Dropout updates/stochasticity
            self.encoder.eval()

    # ---------------- SwanLab helpers ----------------
    def _is_rank0(self) -> bool:
        # safe check before trainer is attached
        try:
            return self.trainer is None or self.trainer.is_global_zero
        except Exception:
            # fallback to assume rank 0 if no trainer yet
            return True

    def _swanlab_try_init(self):
        """Initialize SwanLab on rank 0 only; keep no-op on others or if not installed."""
        if not self._swanlab_enabled or self._swanlab is not None or not self._is_rank0():
            return
        try:
            import swanlab
            self._swanlab = swanlab
            # build config from hyperparameters
            cfg = dict(self.hparams)
            cfg.update(self._swanlab_config_extra)
            self._swanlab.init(
                project=self._swanlab_project,
                experiment_name=self._swanlab_run_name,
                config=cfg,
            )
            # tag mode for clarity
            # self._swanlab.tag(self.mode)
            self.print("[SwanLab] initialized (rank 0).")
        except Exception as e:
            self.print(f"[SwanLab] init failed or not installed: {e}")
            self._swanlab = None

    def _swanlab_log(self, metrics: Dict[str, float], step: Optional[int] = None):
        if self._swanlab is None or not self._is_rank0():
            return
        try:
            # SwanLab accepts flat dict logs
            self._swanlab.log(metrics, step=step)
        except Exception as e:
            self.print(f"[SwanLab] log failed: {e}")

    def _swanlab_close(self):
        if self._swanlab is None or not self._is_rank0():
            return
        try:
            self._swanlab.finish()
            self.print("[SwanLab] finished.")
        except Exception as e:
            self.print(f"[SwanLab] finish failed: {e}")

    # ---------------- PL lifecycle ----------------
    def on_fit_start(self):
        # Safety: ensure eval mode for encoder in finetune; train mode in pretrain.
        if self._freeze_encoder:
            self.encoder.eval()
        else:
            self.encoder.train()

        # Log trainable parameter count to verify freezing
        trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
        total = sum(p.numel() for p in self.parameters())
        self.print(f"[Regressor] trainable params: {trainable} / {total}")

        # Init SwanLab (rank 0 only)
        self._swanlab_try_init()

    def on_fit_end(self):
        self._swanlab_close()

    def on_save_checkpoint(self, checkpoint):
        checkpoint["encoder"] = self.encoder.state_dict()

    # ---------------- Forward ----------------
    def forward(self, batch):
        import time
        t0 = time.time()
        # Encoder forward: block grad & graph in finetune to save memory
        if self._freeze_encoder:
            with torch.no_grad():
                graph_feat = self.encoder(batch)
        else:
            graph_feat = self.encoder(batch)

        if self.mode == "finetune":
            graph_feat = self.mask(graph_feat)

        graph_feat = self.scalarization(graph_feat)

        pred_norm = self.head(graph_feat).squeeze(-1)
        return pred_norm

    @torch.no_grad()
    def _denorm(self, y_norm: torch.Tensor) -> torch.Tensor:
        return y_norm * (self.y_std + 1e-8) + self.y_mean

    # ---------------- Train / Val / Test ----------------
    def training_step(self, batch, batch_idx):
        # target in original units
        y = batch.y[:, self.label_idx].view(-1)
        # normalize target
        y_norm = (y - self.y_mean) / (self.y_std + 1e-8)
        # predict normalized
        pred_norm = self(batch)
        loss = F.l1_loss(pred_norm, y_norm, reduction="mean")  # MAE in normalized space

        # metrics in original units (on-the-fly)
        pred = self._denorm(pred_norm)
        mae = (pred - y).abs().mean()
        rmse = torch.sqrt(((pred - y) ** 2).mean())
        bs = y.size(0)

        # PL logs (sync to avoid multi-process drift; but only one progress bar shows)
        self.log("train/loss", loss, on_epoch=True, prog_bar=True, batch_size=bs, sync_dist=True)
        self.log("train/mae", mae, on_epoch=True, prog_bar=True, batch_size=bs, sync_dist=True)
        self.log("train/rmse", rmse, on_epoch=True, batch_size=bs, sync_dist=True)

        # SwanLab: log on rank 0 only (per step for smooth curves)
        self._swanlab_log(
            {"train/loss": float(loss.detach()),
             "train/mae": float(mae.detach()),
             "train/rmse": float(rmse.detach())},
            step=int(self.global_step),
        )
        return loss

    def validation_step(self, batch, batch_idx):
        # target in original units
        y = batch.y[:, self.label_idx].view(-1)
        # normalize target
        y_norm = (y - self.y_mean) / (self.y_std + 1e-8)
        # predict normalized
        pred_norm = self(batch)
        loss = F.l1_loss(pred_norm, y_norm, reduction="mean")  # MAE in normalized space

        # metrics in original units
        pred = self._denorm(pred_norm)
        mae = (pred - y).abs().mean()
        rmse = torch.sqrt(((pred - y) ** 2).mean())
        bs = y.size(0)

        # logs for checkpointing & monitoring
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, batch_size=bs, sync_dist=True)
        self.log("val/mae", mae, on_epoch=True, prog_bar=True, batch_size=bs, sync_dist=True)
        self.log("val/rmse", rmse, on_epoch=True, batch_size=bs, sync_dist=True)

        # SwanLab logging (rank 0 only)
        self._swanlab_log(
            {"val/loss": float(loss.detach()),
             "val/mae": float(mae.detach()),
             "val/rmse": float(rmse.detach())},
            step=int(self.global_step),
        )

        return loss

    def test_step(self, batch, batch_idx):
        # target in original units
        y = batch.y[:, self.label_idx].view(-1)
        # normalize target
        y_norm = (y - self.y_mean) / (self.y_std + 1e-8)
        # predict normalized
        pred_norm = self(batch)
        loss = F.l1_loss(pred_norm, y_norm, reduction="mean")

        # metrics in original units
        pred = self._denorm(pred_norm)
        mae = (pred - y).abs().mean()
        rmse = torch.sqrt(((pred - y) ** 2).mean())
        bs = y.size(0)

        self.log("test/loss", loss, on_epoch=True, prog_bar=True, batch_size=bs, sync_dist=True)
        self.log("test/mae", mae, on_epoch=True, prog_bar=True, batch_size=bs, sync_dist=True)
        self.log("test/rmse", rmse, on_epoch=True, batch_size=bs, sync_dist=True)

        # SwanLab logging
        self._swanlab_log(
            {"test/loss": float(loss.detach()),
             "test/mae": float(mae.detach()),
             "test/rmse": float(rmse.detach())},
            step=int(self.global_step),
        )

        return loss

    def on_train_epoch_end(self):
        # Optionally log epoch-level metrics (rank 0 only).
        metrics = {}
        for k, v in self.trainer.callback_metrics.items():
            # take simple float-able ones
            try:
                metrics[f"epoch/{k}"] = float(v)
            except Exception:
                pass
        if metrics:
            self._swanlab_log(metrics, step=int(self.current_epoch))

    # ---------------- Optimizer ----------------
    def configure_optimizers(self):
        # In finetune, only train head (and any other trainable readout params).
        if self._freeze_encoder:
            params = list(self.head.parameters())
            # If IrrepsMask / Scalarization have trainable params, include them:
            params += [p for p in getattr(self.mask, "parameters", lambda: [])() if p.requires_grad]
            params += [p for p in getattr(self.scalarization, "parameters", lambda: [])() if p.requires_grad]
        else:
            params = [p for p in self.parameters() if p.requires_grad]
        return torch.optim.AdamW(params, lr=self.lr)

@torch.no_grad()
def compute_label_stats(train_loader, label_idx: int) -> Tuple[float, float]:
    ys = []
    for batch in train_loader:
        ys.append(batch.y[:, label_idx].cpu().float())
    y = torch.cat(ys, dim=0)
    mean = float(y.mean())
    std = float(y.std(unbiased=False) + 1e-8)
    return mean, std
