import hydra
import torch
import logging
from pathlib import Path
from omegaconf import DictConfig, OmegaConf

from tt_sbi.inference import (
    build_npe_model,
    train_NPE_estimator,
    train_NPE_estimator_noisy,
    train_NPE_estimator_mmd,
    NPE_TrainConfig,
    NoisyNPE_TrainConfig,
    NPE_MMD_TrainConfig,
)
from tt_sbi.utils.misc import get_test_data_path, get_misspec_suffix

log = logging.getLogger(__name__)

@hydra.main(version_base=None, config_path="../configs", config_name="gaussian_config")
def main(cfg: DictConfig):
    log.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
    
    data_dir = Path(cfg.data_dir)
    output_dir = Path(cfg.get("models_dir", cfg.output_dir))
    output_dir.mkdir(parents=True, exist_ok=True)
    
    device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    
    d, n = cfg.dim, cfg.n_obs
    data_filename = cfg.get("data_filename", f"train_d{d}_n{n}.pt")
    data_path = data_dir / data_filename
    
    log.info(f"Loading training data from {data_path}")
    data = torch.load(data_path)
    
    thetas = data["thetas"]
    X = data["X"]
    
    n_train = cfg.get("n_train", None)
    if n_train is not None and n_train < len(thetas):
        thetas = thetas[:n_train]
        X = X[:n_train]
        log.info(f"Using first {n_train} training samples")

    use_summary_stats = cfg.get("use_summary_stats", False)
    if use_summary_stats:
        if "S" not in data:
            raise ValueError("use_summary_stats=True but 'S' not found in data file.")
        xs = data["S"]
        if n_train is not None:
            xs = xs[:n_train]
        embedding_type = "none"
        embedding_dim = None
        log.info("Using pre-computed summary statistics")
    else:
        xs = X.view(X.shape[0], -1)
        embedding_type = cfg.get("embedding_type", "fc")
        embedding_dim = cfg.get("embedding_dim")
        log.info(f"Using {embedding_type} embedding on raw data")

    method = cfg.get("method", "npe")
    
    model_kwargs = {
        "theta_sample": thetas,
        "x_sample": xs,
        "n_obs": n,
        "dim": d,
        "embedding_type": embedding_type,
        "embedding_dim": embedding_dim,
        "num_transforms": cfg.get("num_transforms"),
        "hidden_features": cfg.get("hidden_features"),
        "embedding_hidden": cfg.get("embedding_hidden"),
        "embedding_layers": cfg.get("embedding_layers"),
    }

    def model_builder():
        return build_npe_model(**model_kwargs)

    if method == "npe":
        log.info("Training Standard NPE")
        config = NPE_TrainConfig(
            lr=cfg.get("lr", 5e-4),
            batch_size=cfg.get("batch_size", 256),
            val_frac=cfg.get("val_frac", 0.1),
            stop_after_epochs=cfg.get("stop_after_epochs", 20),
            max_epochs=cfg.get("max_epochs", 10000),
            log_every=cfg.get("log_every", 50),
        )
        model = model_builder()
        model, history = train_NPE_estimator(
            model, thetas, xs, config=config, device=device, seed=cfg.get("seed")
        )
        
    elif method == "npe_noisy":
        log.info("Training Noisy NPE")
        config = NoisyNPE_TrainConfig(
            lr=cfg.get("lr", 5e-4),
            batch_size=cfg.get("batch_size", 256),
            val_frac=cfg.get("val_frac", 0.1),
            stop_after_epochs=cfg.get("stop_after_epochs", 20),
            max_epochs=cfg.get("max_epochs", 10000),
            slab_scale=cfg.get("slab_scale", 0.1),
            spike_scale=cfg.get("spike_scale", 0.01),
            spike_prob=cfg.get("spike_prob", 0.5),
            noise_on_val=cfg.get("noise_on_val", False)
        )
        model = model_builder()
        model, history = train_NPE_estimator_noisy(
            model, thetas, xs, config=config, device=device, seed=cfg.get("seed")
        )
        
    elif method == "npe_rs":
        log.info("Training NPE-RS (MMD Regularized)")
        
        test_data_dir = Path(cfg.get("test_data_dir", cfg.data_dir))
        
        misspec_type = cfg.get("misspec", {}).get("type", "none")

        test_data_path = get_test_data_path(
            test_data_dir, d, n, misspec_type,
            prior_location_shift=cfg.misspec.get("prior_location_shift", 0.0),
            prior_scale_factor=cfg.misspec.get("prior_scale_factor", 1.0),
            likelihood_scale_factor=cfg.misspec.get("likelihood_scale_factor", 1.0),
            contamination_eps=cfg.misspec.get("contamination_eps", 0.0),
            contamination_shift=cfg.misspec.get("contamination_shift", 0.0),
        )
        
        log.info(f"Loading target data from {test_data_path}")
        test_data = torch.load(test_data_path)
        
        obs_target = test_data["X_obs"]
        obs_target_flat = obs_target.view(obs_target.shape[0], -1)
        
        train_config = NPE_MMD_TrainConfig(
            lr=cfg.get("lr", 5e-4),
            lambda_reg=cfg.get("lambda_reg", 1.0),
            batch_size=cfg.get("batch_size", 256),
            val_frac=cfg.get("val_frac", 0.1),
            stop_after_epochs=cfg.get("stop_after_epochs", 20),
            max_epochs=cfg.get("max_epochs", 10000),
        )
        
        model = model_builder()
        model, history = train_NPE_estimator_mmd(
            model, thetas, xs, obs_target_flat, config=train_config, device=device, seed=cfg.get("seed")
        )

    else:
        raise ValueError(f"Unknown method: {method}")

    if method == "npe":
        suffix = ""
    elif method == "npe_noisy":
        suffix = "_noisy"
        
    elif method == "npe_rs":
        lambda_val = history.get('lambda_reg', cfg.get('lambda_reg', 1.0))
        
        misspec_type = cfg.get("misspec", {}).get("type", "none")
        
        if misspec_type == "none":
            suffix = f"_rs_lambda_{lambda_val}"
        else:
            ms_suffix = get_misspec_suffix(
                misspec_type,
                prior_location_shift=cfg.misspec.get("prior_location_shift", 0.0),
                prior_scale_factor=cfg.misspec.get("prior_scale_factor", 1.0),
                likelihood_scale_factor=cfg.misspec.get("likelihood_scale_factor", 1.0),
                contamination_eps=cfg.misspec.get("contamination_eps", 0.0),
                contamination_shift=cfg.misspec.get("contamination_shift", 0.0),
            )
            suffix = f"_rs_lambda_{lambda_val}_{ms_suffix}"

    model_name = cfg.get("model_name", f"npe_d{d}_n{n}{suffix}.pt")
    save_path = output_dir / model_name
    
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": OmegaConf.to_container(cfg, resolve=True),
        "history": history
    }, save_path)
    log.info(f"Saved model to {save_path}")

if __name__ == "__main__":
    main()