import torch
import hydra
from pathlib import Path
from omegaconf import DictConfig

from sbi.neural_nets import likelihood_nn
from sbi.inference import NLE

from ttt_sbi.npe import (
    load_npe_model,
    get_embedding_net,
    compute_summary_statistics,
)


@hydra.main(version_base=None, config_path="../configs", config_name="gaussian_config")
def main(cfg: DictConfig):
    data_dir = Path(cfg.data_dir)
    models_dir = Path(cfg.get("models_dir", "models/gaussian_linear"))
    output_dir = Path(cfg.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    d, n = cfg.dim, cfg.n_obs
    device = cfg.get("device", "cuda" if torch.cuda.is_available() else "cpu")
    
    composite_likelihood = cfg.get("composite_likelihood", False)
    mode_str = "composite" if composite_likelihood else "full"
    
    print(f"Training NCPP with {mode_str} likelihood mode")
    
    data_path = data_dir / f"train_d{d}_n{n}.pt"
    data = torch.load(data_path)
    thetas = data["thetas"]
    X = data["X"]
    
    N_TRAIN, N_OBS = len(thetas), n
    
    n_train = cfg.get("n_train", None)
    if n_train is not None and n_train < N_TRAIN:
        thetas = thetas[:n_train]
        X = X[:n_train]
        N_TRAIN = n_train
        print(f"Using first {n_train} training samples")
    
    X_flat = X.view(N_TRAIN, -1)
    print(f"Loaded training data: thetas {thetas.shape}, X {X.shape}")
    
    npe_model_path = models_dir / f"npe_d{d}_n{n}.pt"
    print(f"Loading NPE model from {npe_model_path}...")
    
    npe_model, npe_config = load_npe_model(npe_model_path, thetas, X_flat, device)
    embedding_net = get_embedding_net(npe_model)
    print(f"Extracted embedding network (embedding_dim={npe_config['embedding_dim']})")
    
    print("Generating summary statistics...")
    S = compute_summary_statistics(embedding_net, X_flat, device)
    print(f"Summary statistics shape: {S.shape}")
    
    ncpp_hidden = cfg.get("ncpp_hidden_features", 64)
    ncpp_transforms = cfg.get("ncpp_num_transforms", 5)
    batch_size = cfg.get("ncpp_batch_size", 256)
    
    builder = likelihood_nn(
        model="nsf",
        hidden_features=ncpp_hidden,
        num_transforms=ncpp_transforms,
        num_bins=10,
    )
    
    nle = NLE(density_estimator=builder, device=device)
    
    if composite_likelihood:
        xs = X.reshape(N_TRAIN * N_OBS, d)
        ss = S.repeat_interleave(N_OBS, dim=0)
        x_dim = d
        print(f"Composite mode: {N_TRAIN * N_OBS} effective samples")
        nle.append_simulations(theta=ss, x=xs)
    else:
        x_dim = N_OBS * d
        print(f"Full mode: {N_TRAIN} effective samples")
        nle.append_simulations(theta=S, x=X_flat)
    
    print(f"Training NCPP on {device}...")
    nle.train(training_batch_size=batch_size)
    ncpp_model = nle._neural_net
    
    # Save
    model_path = models_dir / f"ncpp_{mode_str}_d{d}_n{n}.pt"
    torch.save({
        "model_state_dict": ncpp_model.state_dict(),
        "config": {
            "x_dim": x_dim,
            "n_obs": N_OBS,
            "theta_dim": d,
            "embedding_dim": npe_config["embedding_dim"],
            "ncpp_hidden_features": ncpp_hidden,
            "ncpp_num_transforms": ncpp_transforms,
            "composite_likelihood": composite_likelihood,
        },
        "npe_config": npe_config,
    }, model_path)
    print(f"Saved model: {model_path}")
    print(f"\n=== Training Complete ({mode_str} likelihood) ===")


if __name__ == "__main__":
    main()