import torch
import hydra
from pathlib import Path
from omegaconf import DictConfig
import pickle
import numpy as np

from tt_sbi.ocsvm import OCSVMConfig, fit_ocsvm_detector


@hydra.main(version_base=None, config_path="../configs", config_name="gaussian_config")
def main(cfg: DictConfig):
    model_type = cfg.get("model_type", "gaussian")
    d, n = cfg.dim, cfg.n_obs
    models_dir = Path(cfg.models_dir)
    models_dir.mkdir(parents=True, exist_ok=True)
    
    data_path = Path(cfg.data_dir) / f"train_d{d}_n{n}.pt"
    
    if not data_path.exists():
        raise FileNotFoundError(
            f"Training data not found at {data_path}. "
            f"Run 'python scripts/generate_data.py --config-name={model_type}_config' first."
        )
    
    train_data = torch.load(data_path)
    X_tensor = train_data['X']
    try:
        X_train = X_tensor.numpy()
    except RuntimeError:
        X_train = X_tensor.detach().cpu().numpy()
    
    M = X_train.shape[0]
    n_obs_actual = X_train.shape[1]
    element_shape = tuple(X_train.shape[2:])
    element_dim = int(np.prod(element_shape))
    total_elements = M * n_obs_actual
    
    print(f"=== OC-SVM Training for {model_type.upper()} ===")
    print(f"Config: dim={d}, n_obs={n}")
    print(f"Data loaded from: {data_path}")
    print(f"Training data shape: {X_train.shape}")
    print(f"  - M (datasets): {M}")
    print(f"  - n_obs_actual (observations per dataset): {n_obs_actual}")
    print(f"  - element_shape: {element_shape}")
    print(f"  - element_dim (flattened): {element_dim}")
    print(f"  - Total pooled elements: {total_elements:,}")
    
    if element_dim > 500:
        print(f"\n⚠️  WARNING: High-dimensional elements ({element_dim}D).")
        print(f"   OC-SVM with RBF kernel may struggle in high dimensions.")
        print(f"   Consider using summary statistics or reducing max_train_elements.")
    
    ocsvm_cfg = cfg.get("ocsvm", {})
    
    default_max_train = 20000 if element_dim <= 100 else 10000 if element_dim <= 500 else 5000
    
    config = OCSVMConfig(
        nu=ocsvm_cfg.get("nu", 0.05),
        calibrate_fpr=ocsvm_cfg.get("calibrate_fpr", 0.05),
        kernel=ocsvm_cfg.get("kernel", "rbf"),
        gamma=ocsvm_cfg.get("gamma", "scale"),
        standardize=ocsvm_cfg.get("standardize", True),
        max_train_elements=ocsvm_cfg.get("max_train_elements", default_max_train),
        random_state=cfg.get("seed", 0),
    )
    
    print(f"\nOC-SVM Configuration:")
    print(f"  nu: {config.nu}")
    print(f"  kernel: {config.kernel}")
    print(f"  gamma: {config.gamma}")
    print(f"  standardize: {config.standardize}")
    print(f"  calibrate_fpr: {config.calibrate_fpr}")
    print(f"  max_train_elements: {config.max_train_elements}")
    
    print(f"\nFitting OC-SVM...")
    detector = fit_ocsvm_detector(X_train, config=config)
    
    output_path = models_dir / f"ocsvm_d{d}_n{n}.pkl"
    with open(output_path, 'wb') as f:
        pickle.dump(detector, f)
    
    print(f"\n=== Training Summary ===")
    print(f"Model type: {model_type}")
    print(f"Saved OC-SVM detector: {output_path}")
    print(f"Detector element shape: {detector.element_shape}")
    if detector.threshold is not None:
        print(f"Calibrated threshold (FPR={config.calibrate_fpr}): {detector.threshold:.6f}")
    else:
        print("Using sklearn default threshold")


if __name__ == "__main__":
    main()
