"""base.py — reusable scaffolding for tabular generators / imputers
================================================================
This module defines :class:`Base`, a thin wrapper around
:class:`lightning.pytorch.LightningModule` that removes boilerplate while
ensuring fully reproducible initialisation.

Highlights
----------
* **Automatic seed locking** – every subclass constructor runs inside
  :class:`SeedContext`; global RNG is restored afterwards.
* **Zero‑boilerplate super()** – subclasses call **only** ``super().__init__()``;
  all local variables (including ``cfg``) are captured and stored (excluding
  the potentially large ``cfg`` itself for checkpoint size).
* **Centralised data / trainer plumbing** – common PyTorch‑Lightning chores
  handled once here.

Example
-------
```python
class MyAutoEncoder(Base):
    def __init__(self, latent=32):
        super().__init__()
        self.encoder = nn.Linear(self.column_dim, latent)
        self.decoder = nn.Linear(latent, self.column_dim)
```
"""
from __future__ import annotations

import functools
import inspect
import os
from abc import ABC
from typing import Any, Callable, Dict, Optional

import lightning.pytorch as pl
import pandas as pd
import torch
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.model_helpers import is_overridden
from lightning.pytorch.utilities.model_summary import ModelSummary
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader, TensorDataset

import hashlib
import json
import stat
from pathlib import Path

try:
    from safetensors.torch import save_file as _st_save, load_file as _st_load
    from safetensors.torch import load_model, save_model

    _HAVE_SAFETENSORS = True
except Exception:
    _HAVE_SAFETENSORS = False

from ..transform.transform import TabularTransformSet
from ..utils import SeedContext, train_test_split
from ..utils import rank_zero_print
from ..utils.io import (
    maybe_load,
    setup_logger,
    create_latest_symlink,
    save_model_no_lightning,
)

# -----------------------------------------------------------------------------
# Helpers & defaults
# -----------------------------------------------------------------------------
_IDENTITY: Callable[[pd.DataFrame], pd.DataFrame] = lambda df: df  # noqa: E731
_DEFAULT_FLAGS: Dict[str, bool] = {
    "onehot": True,
    "scaler": "standard",
    "in_sample_only": False,
    "allow_missing_on_dataset": False,
    "require_validation_split": False,
    "drop_target": False,
}
_DEFAULTS = dict(
    cudnn_benchmark=torch.backends.cudnn.benchmark,
    cudnn_deterministic=torch.backends.cudnn.deterministic,
    algo_state=torch.are_deterministic_algorithms_enabled(),
)
_HP_FMT_VERSION = 1


class RestoreTorchDefaults(Callback):
    def on_fit_end(self, trainer, pl_module):
        torch.backends.cudnn.benchmark = _DEFAULTS["cudnn_benchmark"]
        torch.backends.cudnn.deterministic = _DEFAULTS["cudnn_deterministic"]
        torch.use_deterministic_algorithms(_DEFAULTS["algo_state"])
        os.environ.pop("CUBLAS_WORKSPACE_CONFIG", None)


class IO:

    # --------------------- model-hparam snapshot helpers ---------------------
    @staticmethod
    def _snapshot_cfg_model(cfg: Optional[DictConfig]) -> Dict[str, Any]:
        if cfg is None:
            return {}
        try:
            node = cfg.get("model") if hasattr(cfg, "get") else cfg.model  # type: ignore[attr-defined]
        except Exception:
            node = None
        if node is None:
            return {}
        cont = OmegaConf.to_container(node, resolve=False)
        if not isinstance(cont, dict):
            return {}
        out: Dict[str, Any] = {}
        for k, v in cont.items():
            if isinstance(v, (int, float, bool, str)) or v is None:
                out[k] = v
            elif isinstance(v, (list, tuple)) and len(v) <= 256 and all(
                    isinstance(x, (int, float, bool, str)) or x is None for x in v):
                out[k] = list(v)
        return out

    def _pick_from_hparams(self, keys: list[str]) -> Dict[str, Any]:
        """self.hparams에서 keys에 해당하는 값만 JSON-safe로 추출."""
        basic = (int, float, bool, str, type(None))
        out: Dict[str, Any] = {}
        hp = getattr(self, "hparams", {}) or {}
        if not isinstance(hp, dict):
            return out
        for k in keys:
            if k not in hp:
                continue
            v = hp[k]
            if isinstance(v, basic):
                out[k] = v
            elif isinstance(v, (list, tuple)) and all(isinstance(x, basic) for x in v) and len(v) <= 256:
                out[k] = list(v)
            # 그 외 타입은 보안상/재현성상 스킵
        return out

    @staticmethod
    def _is_dir_like(p: Path) -> bool:
        return (not p.suffix) or p.is_dir()

    @staticmethod
    def _is_regular_file(path: str | os.PathLike) -> None:
        p = Path(path)
        if not p.exists():
            raise FileNotFoundError(f"Checkpoint not found: {p}")
        # Avoid following symlinks to reduce TOCTOU/symlink attacks.
        if p.is_symlink():
            raise ValueError(f"Refusing to load symlinked file: {p}")
        if not p.is_file():
            raise ValueError(f"Not a regular file: {p}")

    def _expected_state_spec(self) -> Dict[str, tuple[torch.Size, torch.dtype]]:
        spec = {}
        for k, v in self.state_dict().items():
            spec[k] = (v.shape, v.dtype)
        return spec

    def _validate_state_dict(
            self,
            loaded: Dict[str, torch.Tensor],
            *,
            max_size_factor: float = 1.05,  # allow tiny overhead
            strict: bool = True,
    ) -> None:
        exp = self._expected_state_spec()
        # Key set check
        if strict:
            missing = sorted(set(exp.keys()) - set(loaded.keys()))
            unexpected = sorted(set(loaded.keys()) - set(exp.keys()))
            if missing or unexpected:
                raise ValueError(
                    f"State dict key mismatch. Missing: {missing[:5]}{'...' if len(missing) > 5 else ''} "
                    f"Unexpected: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}"
                )

        # Shape/dtype check + total bytes bound
        def _nbytes(sd: Dict[str, torch.Tensor]) -> int:
            return sum(int(t.numel()) * int(t.element_size()) for t in sd.values())

        exp_bytes = _nbytes(self.state_dict())
        got_bytes = _nbytes(loaded)
        if got_bytes > int(exp_bytes * max_size_factor):
            raise ValueError(
                f"Loaded tensors ({got_bytes / 1e6:.1f} MB) exceed expected size "
                f"({exp_bytes / 1e6:.1f} MB) by factor>{max_size_factor}."
            )

        # Per-param validation (only for intersecting keys to give nice errors)
        for k, t in loaded.items():
            if k not in exp:
                continue
            exp_shape, exp_dtype = exp[k]
            if t.shape != exp_shape:
                raise ValueError(f"Tensor shape mismatch for '{k}': got {tuple(t.shape)}, expected {tuple(exp_shape)}")
            if t.dtype != exp_dtype:
                raise ValueError(f"Dtype mismatch for '{k}': got {t.dtype}, expected {exp_dtype}")

        # Optional: NaN/Inf sanity (defensive)
        for k, t in loaded.items():
            if not torch.is_floating_point(t):
                continue
            if torch.isnan(t).any() or torch.isinf(t).any():
                raise ValueError(f"Non-finite values detected in '{k}'")

    @staticmethod
    def _sha256(path: str | os.PathLike, chunk: int = 1024 * 1024) -> str:
        h = hashlib.sha256()
        with open(path, "rb") as f:
            while True:
                b = f.read(chunk)
                if not b:
                    break
                h.update(b)
        return h.hexdigest()

    # --------------------- SAVE: weights + transform + hparams ---------------
    def save_weights_secure(
            self,
            path: str | os.PathLike,
            *,
            prefer_safetensors: bool | None = None,
            overwrite: bool = True,
            metadata: Optional[Dict[str, Any]] = None,
            hparams_allowlist: Optional[list[str]] = None,
    ) -> str:
        """
        디렉터리(권장) 또는 파일 경로를 받아
          - weights.(safetensors|pt) (+ .json sidecar sha256)
          - transform/  (TabularTransformSet JSON)
          - model_hparams.json  (self._cfg.model 스냅샷)
        를 저장합니다.
        """
        prefer_safetensors = (_HAVE_SAFETENSORS if prefer_safetensors is None else prefer_safetensors)
        target = Path(path)
        if self._cfg is None:
            rank_zero_print("[save_weights_secure] warning: self._cfg is None; model hparams snapshot may be empty.")

        # 목적지 디렉터리/파일 결정
        if self._is_dir_like(target):
            d = target
            d.mkdir(parents=True, exist_ok=True)
            weights_path = d / ("weights.safetensors" if prefer_safetensors else "weights.pt")
        else:
            d = target.parent
            d.mkdir(parents=True, exist_ok=True)
            weights_path = target

        if weights_path.exists() and not overwrite:
            raise FileExistsError(f"Target exists: {weights_path}")

        # 1) weights 저장 (CPU, data-only)
        cpu_sd = {k: v.detach().cpu() for k, v in self.state_dict().items()}
        if prefer_safetensors and weights_path.suffix == ".safetensors":
            save_model(self, str(weights_path))
        else:
            torch.save(cpu_sd, str(weights_path))
        try:
            os.chmod(weights_path, stat.S_IRUSR | stat.S_IWUSR)
        except Exception:
            pass

        # sidecar meta (sha256 + spec)
        sha = self._sha256(weights_path)
        sidecar = weights_path.with_suffix(weights_path.suffix + ".json")
        meta = {
            "format": "safetensors" if weights_path.suffix == ".safetensors" else "torch.pt",
            "sha256": sha,
            "model_class": self.__class__.__name__,
            "torch_version": torch.__version__,
            "pl_version": getattr(pl, "__version__", None),
            "state_spec": {k: {"shape": tuple(t.shape), "dtype": str(t.dtype)} for k, t in cpu_sd.items()},
        }
        if metadata:
            meta["user_metadata"] = metadata
        with open(sidecar, "w", encoding="utf-8") as f:
            json.dump(meta, f, indent=2)
        try:
            os.chmod(sidecar, stat.S_IRUSR | stat.S_IWUSR)
        except Exception:
            pass

        # 2) transform 저장
        tdir = d / "transform"
        tdir.mkdir(exist_ok=True)
        self._transform.save_secure(tdir)

        # 3) 하이퍼파라미터 번들 저장
        hp_payload = {
            "version": _HP_FMT_VERSION,
            "cfg_model": Base._snapshot_cfg_model(self._cfg),
            "model_flags": dict(getattr(self, "model_flags", {})),
        }
        if hparams_allowlist:
            hp_payload["extra"] = self._pick_from_hparams(hparams_allowlist)

        hp_file = d / "model_hparams.json"
        with open(hp_file, "w", encoding="utf-8") as f:
            json.dump(hp_payload, f, indent=2)
        try:
            os.chmod(hp_file, stat.S_IRUSR | stat.S_IWUSR)
        except Exception:
            pass

        return str(weights_path)

    # --------------------- LOAD (classmethod): new instance -------------------
    @classmethod
    def load_weights_secure(
            cls,
            path: str | os.PathLike,
            *,
            map_location: str | torch.device = "cpu",
            strict: bool = True,
            verify_hash: bool = True,
            allow_partial: bool = False,
            max_size_factor: float = 1.05,
    ) -> "Base":
        """
        저장된 번들에서
          - model_hparams.json 읽어 하이퍼파라미터 획득
          - transform/ 로드
          - 새 인스턴스 생성: Model(tabular_transform_set=..., **model_hparams)
          - weights.(safetensors|pt) CPU 로드/검증 후 주입
        을 수행합니다.
        """
        p = Path(path)
        if cls._is_dir_like(p):
            d = p
            # weights 탐색 우선순위
            candidates = [
                d / "weights.safetensors",
                d / "weights.pt",
                *d.glob("*.safetensors"),
                *d.glob("*.pt"),
                *d.glob("*.pth"),
            ]
            weights_path = next((c for c in candidates if c.exists()), None)
            if weights_path is None:
                raise FileNotFoundError(f"No weights file found under: {d}")
        else:
            weights_path = p
            d = p.parent

        # 1) hparams 로드
        hp_file = d / "model_hparams.json"
        if not hp_file.exists():
            raise FileNotFoundError(f"Missing model_hparams.json in: {d}")
        with open(hp_file, "r", encoding="utf-8") as f:
            hp_payload = json.load(f)
        if not isinstance(hp_payload, dict):
            raise ValueError("model_hparams.json must be a JSON object.")

        cfg_model = dict(hp_payload.get("cfg_model", {}))
        model_flags = dict(hp_payload.get("model_flags", {}))
        extra = dict(hp_payload.get("extra", {}))

        # 2) transform 로드
        tmeta = d / "transform" / "meta.json"
        if not tmeta.exists():
            raise FileNotFoundError(f"Missing transform/ in bundle: {d}")
        tset = TabularTransformSet.load_secure(d / "transform")
        tset.scaler_key = model_flags['scaler']

        # 3) 새 인스턴스 생성 (요구사항대로 cfg 없이)
        init_kwargs = {**cfg_model, **extra}
        model = cls(tabular_transform_set=tset, model_flags=model_flags, **init_kwargs)

        # 4) weights 로드(안전)
        wp = Path(weights_path)
        cls._is_regular_file(wp)

        # sidecar 검증
        sidecar = wp.with_suffix(wp.suffix + ".json")
        if verify_hash and sidecar.exists():
            with open(sidecar, "r", encoding="utf-8") as f:
                meta = json.load(f)
            expected = meta.get("sha256")
            if expected:
                actual = cls._sha256(wp)
                if actual != expected:
                    raise ValueError(f"SHA256 mismatch for {wp.name}: {actual} != {expected}")

        # CPU로 state_dict 읽기
        if wp.suffix == ".safetensors":
            if not _HAVE_SAFETENSORS:
                raise RuntimeError("Install `safetensors` to load .safetensors")
            loaded_sd = _st_load(str(wp))
        else:
            loaded = torch.load(str(wp), map_location="cpu", weights_only=True)
            if not isinstance(loaded, dict) or not all(isinstance(v, torch.Tensor) for v in loaded.values()):
                raise ValueError("Expected dict[str, Tensor] in weights file.")
            loaded_sd = loaded

        # 검증 + 주입
        if allow_partial:
            current = model.state_dict()
            filtered = {k: t for k, t in loaded_sd.items()
                        if k in current and t.shape == current[k].shape and t.dtype == current[k].dtype}

            def _nbytes(sd):
                return sum(int(x.numel()) * int(x.element_size()) for x in sd.values())

            if _nbytes(filtered) > int(_nbytes(current) * max_size_factor):
                raise ValueError("Partial load exceeds allowed size factor.")
            load_model(model, wp)
            # missing, unexpected = model.load_state_dict(filtered, strict=False)
        else:
            load_model(model, wp)
            # model._validate_state_dict(loaded_sd, max_size_factor=max_size_factor, strict=strict)
            # missing, unexpected = model.load_state_dict(loaded_sd, strict=strict)

        # 디바이스 이동
        if map_location != "cpu":
            model.to(map_location)

        # 로드 결과 로깅(선택)
        # if missing or unexpected:
        #     rank_zero_print(f"[secure-load] missing={missing}, unexpected={unexpected}")

        return model


# -----------------------------------------------------------------------------
# Base class
# -----------------------------------------------------------------------------
class Base(ABC, IO, pl.LightningModule):
    """Abstract LightningModule with automatic seed and kwargs capture."""

    # ---------------------- metaclass wrapper ----------------------
    def __init_subclass__(cls, **kwargs):  # noqa: D401
        super().__init_subclass__(**kwargs)
        orig_init = cls.__init__
        if getattr(orig_init, "_seed_wrapped", False):
            return

        @functools.wraps(orig_init)
        def wrapped_init(self, *args, **kw):
            cfg_candidate = kw.get("cfg") or (args[0] if args and isinstance(args[0], DictConfig) else None)
            seed = getattr(cfg_candidate, "seed", None) if cfg_candidate is not None else None
            with SeedContext(seed):
                orig_init(self, *args, **kw)

        wrapped_init._seed_wrapped = True  # type: ignore[attr-defined]
        cls.__init__ = wrapped_init  # type: ignore[method-assign]

    # ---------------------------- constructor ----------------------------
    def __init__(
            self,
            cfg: Optional[DictConfig] = None,
            model_flags: Optional[Dict[str, bool]] = None,
            tabular_transform_set: Optional[TabularTransformSet] = None,
            **kwargs: Any,
    ) -> None:
        super().__init__()

        # Capture caller locals
        caller = inspect.currentframe().f_back  # type: ignore[arg-type]
        caller_locals = caller.f_locals if caller else {}
        captured: Dict[str, Any] = {k: v for k, v in caller_locals.items() if k not in {"self", "__class__"}}
        captured.update(kwargs)

        # Resolve cfg precedence
        if cfg is None:
            cfg = kwargs.get("cfg") or caller_locals.get("cfg")
            if cfg is None and isinstance(caller_locals.get("kwargs"), dict):
                cfg = caller_locals["kwargs"].get("cfg")
        if cfg is not None and not isinstance(cfg, DictConfig):
            raise TypeError("cfg must be an OmegaConf DictConfig or None")

        # Resolve model_flags
        if model_flags is None:
            model_flags = kwargs.get("model_flags") or caller_locals.get("model_flags")
            if model_flags is None and isinstance(caller_locals.get("kwargs"), dict):
                model_flags = caller_locals["kwargs"].get("model_flags")

        # Resolve tabular_transform_set
        if tabular_transform_set is None:
            tabular_transform_set = kwargs.get("tabular_transform_set") or caller_locals.get("tabular_transform_set")
            if tabular_transform_set is None and isinstance(caller_locals.get("kwargs"), dict):
                tabular_transform_set = caller_locals["kwargs"].get("tabular_transform_set")

        self._cfg: Optional[DictConfig] = cfg
        # Model flags
        self.model_flags: Dict[str, bool] = {**_DEFAULT_FLAGS, **(model_flags or {})}

        self._transform: TabularTransformSet = (
            tabular_transform_set
            if tabular_transform_set is not None
            else TabularTransformSet(cfg, self.model_flags["onehot"], self.model_flags["scaler"])
        )

        # Save minimal hparams (exclude cfg)
        captured.pop("cfg", None)
        if isinstance(captured.get("kwargs"), dict):
            captured["kwargs"].pop("cfg", None)
        if cfg is not None:
            self.save_hyperparameters(
                {**captured, "tabular_transform_set": self._transform, "cfg": None, '_target_': cfg.model._target_,
                 'model_flags': self.model_flags},
                logger=False)

        # Schema attributes
        self._refresh_schema()
        self._was_training = self.training
        self.log_dir: Optional[str] = None
        self._after_init(**captured)

    def _after_init(self, *args, **kwargs):
        pass

    # ------------------------ schema helper ------------------------
    def _refresh_schema(self) -> None:
        """Populate convenience attributes based on TabularTransformSet."""
        self.tabular_transform = (
            self._transform.no_target if self.model_flags["drop_target"] else self._transform.target
        )
        self.n_numerical_columns = self.tabular_transform.numerical_dim
        self.n_categorical_columns = len(self.tabular_transform.n_categories_per_columns)
        self.n_columns = self.n_numerical_columns + self.n_categorical_columns
        self.categorical_dim = self.tabular_transform.categorical_dim
        self.n_categories_per_columns = self.tabular_transform.n_categories_per_columns
        self.n_categorical_dim_per_columns = self.tabular_transform.n_categorical_dim_per_columns
        self.numerical_dim = self.tabular_transform.numerical_dim
        self.column_dim = self.categorical_dim + self.numerical_dim

    # --------------------- utility wrappers ---------------------
    @property
    def name(self) -> str:  # noqa: D401
        return self.__class__.__name__

    @property
    def tgt(self) -> str:
        return self._transform.target_column

    @property
    def tgt_index(self):
        if self.tgt in self.tabular_transform.numerical_columns:
            return self.tabular_transform.numerical_columns.index(self.tgt)
        elif self.tgt in self.tabular_transform.categorical_columns:
            return self.numerical_dim + self.tabular_transform.categorical_columns.index(self.tgt)
        else:
            raise IndexError(self.tgt)

    def save_model_mode(self) -> None:
        self._was_training = self.training

    def load_model_mode(self) -> None:
        if self._was_training:
            self.train()
        else:
            self.eval()

    def dataframe_to_dataset(self, df: pd.DataFrame) -> DataSet:
        return DataSet(self.tabular_transform.transform(df, return_as_tensor=True))

    def dataset_to_dataloader(self, dataset: DataSet, **kw) -> DataLoader:
        batch_size = self._cfg.model.batch_size if self._cfg else 32
        return DataLoader(dataset, batch_size=batch_size, **kw)

    # --------------------- dataloader builder ---------------------
    def prepare_dataloader(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = _IDENTITY) -> dict[str, DataLoader]:
        if self._cfg is None:
            return {}
        cfg = self._cfg
        loaders: dict[str, DataLoader] = {}

        train_df = scenario(pd.read_csv(cfg.dataset.train_path))
        if not self.model_flags["allow_missing_on_dataset"]:
            train_df = train_df.dropna(subset=self.tabular_transform.columns, how="any")

        # validation / external test
        if self.model_flags["require_validation_split"]:
            train_df, val_df = train_test_split(train_df, 0.9, cfg.seed)
            loaders["val_dataloaders"] = self.dataset_to_dataloader(self.dataframe_to_dataset(val_df), shuffle=False)
        else:
            test_df = maybe_load(cfg.dataset.get("test_path"))
            if test_df is not None:
                loaders["val_dataloaders"] = self.dataset_to_dataloader(
                    self.dataframe_to_dataset(scenario(test_df)), shuffle=False
                )

        loaders["train_dataloaders"] = self.dataset_to_dataloader(
            self.dataframe_to_dataset(train_df), shuffle=True, drop_last=len(train_df) > cfg.model.batch_size
        )
        return loaders

    # --------------------- trainer builder ---------------------
    def prepare_trainer(self, logger=None) -> pl.Trainer:
        cfg = self._cfg
        logger = setup_logger(self, cfg) if logger is None else logger
        return pl.Trainer(
            max_epochs=cfg.model.get("max_epochs"),
            max_steps=cfg.model.get("max_steps", -1),
            precision=cfg.model.get("precision", 32),
            deterministic=cfg.model.get("deterministic", False),
            gradient_clip_val=cfg.model.get("gradient_clip_val"),
            logger=logger,
            accelerator=cfg.device,
            enable_checkpointing=cfg.model.get("enable_checkpointing", True),
            callbacks=[RestoreTorchDefaults()],
        )

    # ------------------------------------------------------------------
    # Fit / Evaluate
    # ------------------------------------------------------------------
    def fit(self, scenario: Callable[[pd.DataFrame], pd.DataFrame] = _IDENTITY):  # type: ignore[override]
        if self._cfg is None:
            raise ValueError("Configuration missing. Cannot proceed with training.")
        cfg = self._cfg
        logger = setup_logger(self, cfg)
        self.log_dir = logger.log_dir
        os.makedirs(self.log_dir, exist_ok=True)
        create_latest_symlink(self.log_dir)
        OmegaConf.save(cfg, os.path.join(self.log_dir, "config.yaml"))
        rank_zero_print(OmegaConf.to_yaml(cfg))
        rank_zero_print(ModelSummary(self))
        self._transform.fit(scenario(pd.read_csv(cfg.dataset.train_path)))

        if not is_overridden("training_step", self, parent=pl.LightningModule):
            save_model_no_lightning(self, os.path.join(self.log_dir, f"{self.name}.ckpt"))
            return self

        with SeedContext(cfg.seed):
            trainer = self.prepare_trainer(logger)
            trainer.fit(self, **self.prepare_dataloader(scenario))
        trainer.save_checkpoint(os.path.join(self.log_dir, f"{self.name}.ckpt"), weights_only=True)
        return self

    def evaluation(self, cfg):
        p = os.path.join(self.log_dir, "report.csv")
        return pd.read_csv(p) if os.path.isfile(p) else pd.DataFrame()

    def change_column_names(self, mapping: dict):
        self._transform.change_column_names(mapping)

class DataSet(TensorDataset):

    def __getitem__(self, index):
        return self.tensors[0][index]
