"""Online training using TabICL/MLP-SCM sampler (no offline files).

This mirrors train.py logging and training behavior while replacing the
data loading with an online generator that yields pre-batched DataAttr.

Both train and (optional) val splits are generated on-the-fly using the
same sampler with configurable Nc/D combos and normalization of x and y.
"""

import os

# Keep environment tweaks consistent with train.py defaults
os.environ["TORCHINDUCTOR_MAX_AUTOTUNE_GEMM_BACKENDS"] = "TRITON"
os.environ["TORCHINDUCTOR_CPP_WRAPPER"] = "0"
os.environ["PYTORCH_MIOPEN_SUGGEST_NHWC"] = "1"
os.environ["MIOPEN_FIND_MODE"] = "1"
os.environ["PYTORCH_HIP_ALLOC_CONF"] = "max_split_size_mb:512"
os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1"

import torch._inductor.config as config

config.max_autotune_gemm = True
config.rocm.n_max_profiling_configs = 10
config.compile_threads = 8
config.triton.unique_kernel_names = True
config.triton.cudagraphs = False
config.coordinate_descent_tuning = True
config.triton.persistent_reductions = True

import torch
from pathlib import Path
import torch.multiprocessing as mp
from torch.utils.data import DataLoader

import hydra
from omegaconf import DictConfig

from train import CompiledTrainer  # reuse logging + training loop
from src.data.online_tabular_dataset import OnlineTabularDataset


class OnlineCompiledTrainer(CompiledTrainer):
    """Trainer that swaps dataloaders for online SCM tabular generation."""

    def _build_dataloader(self, split: str) -> DataLoader:
        dcfg = self.cfg.data
        # Determine per-split batch count
        if split == "train":
            num_batches = int(dcfg.num_batches_per_epoch)
        else:
            val_batches = int(dcfg.get("val_num_batches", 0) or 0)
            if val_batches <= 0:
                return None  # type: ignore[return-value]
            num_batches = val_batches

        # Build online dataset
        dtype = torch.bfloat16 if dcfg.get("dtype", "float32") == "bfloat16" else getattr(torch, dcfg.get("dtype", "float32"))
        dataset = OnlineTabularDataset(
            batch_size=int(dcfg.batch_size),
            num_batches=num_batches,
            d_list=list(map(int, dcfg.d_list)),
            nc_list=list(map(int, dcfg.nc_list)),
            num_buffer=int(dcfg.num_buffer),
            num_target=int(dcfg.num_target),
            normalize_x=bool(dcfg.normalize_x),
            x_norm_method=str(dcfg.x_norm_method),
            x_outlier_threshold=float(dcfg.get("x_outlier_threshold", 4.0)),
            normalize_y=bool(dcfg.normalize_y),
            dtype=dtype,
            device=str(self.device),
            seed=int(dcfg.get("seed", 123)),
        )

        # DataLoader with batch_size=None since dataset yields pre-batched DataAttr
        dl = DataLoader(
            dataset,
            batch_size=None,
            shuffle=False,  # generation itself is randomized per batch
            num_workers=int(dcfg.num_workers),
            pin_memory=(self.device.type == "cuda"),
        )

        # Intentionally no extra prints to keep logging identical to train.py
        return dl


@hydra.main(version_base=None, config_path="configs", config_name="train_tabular_online")
def main(cfg: DictConfig):
    # Ensure checkpoint dir exists
    Path(cfg.checkpoint.save_dir).mkdir(parents=True, exist_ok=True)

    trainer = OnlineCompiledTrainer(cfg)
    trainer.train()


if __name__ == "__main__":
    try:
        mp.set_start_method("spawn", force=True)
    except RuntimeError:
        pass
    main()
