
import torch
import mcmc

from targets.gmm import build_target
from algorithms.e2mc_gmm import FrozenGMM

EXPERIMENT_CONFIGS = {
    # ===== GM2 =====
    ("GM2", 2): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 0.5, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.3, "h_L": 0.1, "steps_K": 15, "steps_L": 10},
            "RW":  {"sigma_K": 6, "sigma_L": 0.1, "steps_K": 20, "steps_L": 1},
        },
    },
    ("GM2", 4): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 0.5, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.3, "h_L": 0.1, "steps_K": 15, "steps_L": 10},
            "RW":  {"sigma_K": 6, "sigma_L": 0.1, "steps_K": 20, "steps_L": 1},
        },
    },
    ("GM2", 10): {"n_iter_e2mc": 30, "eps": 0.8, "lamda": 0.5, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.3, "h_L": 0.1, "steps_K": 15, "steps_L": 10},
            "RW":  {"sigma_K": 8.0, "sigma_L": 0.1, "steps_K": 20, "steps_L": 1},
        },
    },
    ("GM2", 20): {"n_iter_e2mc": 30, "eps": 0.8, "lamda": 0.5, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.3, "h_L": 0.1, "steps_K": 15, "steps_L": 10},
            "RW":  {"sigma_K": 7, "sigma_L": 0.1, "steps_K": 20, "steps_L": 1},
        },
    },

    # ===== GM4 =====
    ("GM4", 2): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 0.8, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.0, "h_L": 0.01, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 4.5, "sigma_L": 0.1, "steps_K": 15, "steps_L": 1},
        },
    },
    ("GM4", 4): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 0.8, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.0, "h_L": 0.01, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 4.5, "sigma_L": 0.1, "steps_K": 15, "steps_L": 1},
        },
    },
    ("GM4", 10): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 1.0, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.0, "h_L": 0.05, "steps_K": 10, "steps_L": 1}, ###############
            "RW":  {"sigma_K": 4.5, "sigma_L": 0.5, "steps_K": 15, "steps_L": 1},
        },
    },
    ("GM4", 20): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 0.5, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 2.0, "h_L": 0.05, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 5.0, "sigma_L": 0.5, "steps_K": 15, "steps_L": 1},
        },
    },

    # ===== GM25 =====
    ("GM25", 2): {"n_iter_e2mc": 15, "eps": 0.8, "lamda": 0.8, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 0.3, "h_L": 0.01, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 1.5, "sigma_L": 0.1, "steps_K": 10, "steps_L": 1},
        },
    },
    ("GM25", 4): {"n_iter_e2mc": 15, "eps": 0.8, "lamda": 0.8, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 0.3, "h_L": 0.01, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 1.5, "sigma_L": 0.1, "steps_K": 10, "steps_L": 1},
        },
    },
    ("GM25", 10): {"n_iter_e2mc": 20, "eps": 0.8, "lamda": 0.8, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 0.3, "h_L": 0.01, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 2.0, "sigma_L": 0.1, "steps_K": 15, "steps_L": 1},
        },
    },
    ("GM25", 20): {"n_iter_e2mc": 25, "eps": 0.8, "lamda": 0.8, "N": 2000,
        "kernel": {
            "ULA": {"h_K": 0.35, "h_L": 0.01, "steps_K": 10, "steps_L": 1},
            "RW":  {"sigma_K": 2.5, "sigma_L": 0.1, "steps_K": 20, "steps_L": 1},
        },
    },
}


def get_gmm_experiment(
    model_name: str,
    d: int,
    kernel_type: str,
    device="cpu",
):
    """
    Build everything needed to run E2MC on tensorized GMMs.

    Returns:
        target
        init_proposal
        K_config
        L_config
        e2mc_params : dict
        metadata    : dict (for plotting)
    """

    key = (model_name, d)
    if key not in EXPERIMENT_CONFIGS:
        raise ValueError(f"No config for model={model_name}, d={d}")

    cfg = EXPERIMENT_CONFIGS[key]
    kernel_cfg = cfg["kernel"][kernel_type]
    print("kernel config ", kernel_cfg)

    # --------------------------------------------------
    # Target + base mixture size
    # --------------------------------------------------
    target, K0 = build_target(model_name, d)

    # --------------------------------------------------
    # Initial proposal μ0 = N(30·1_d, I_d)
    # --------------------------------------------------
    init_proposal = FrozenGMM(
        weights=torch.ones(1, device=device),
        means=30.0 * torch.ones(1, d, device=device),
        log_stds=torch.zeros(1, d, device=device),
    )

    # --------------------------------------------------
    # Kernel K and L
    # --------------------------------------------------
    if kernel_type == "ULA":
        K_config = mcmc.MCMCConfig(
            n_iter=kernel_cfg["steps_K"],
            kernel_fn=mcmc.ula_kernel,
            kernel_params={"step_size": kernel_cfg["h_K"]},
            grad_logpdf_fn=target.grad_logpi,
        )
        L_config = mcmc.MCMCConfig(
            n_iter=kernel_cfg["steps_L"],
            kernel_fn=mcmc.ula_kernel,
            kernel_params={"step_size": kernel_cfg["h_L"]},
            grad_logpdf_fn=target.grad_logpi,
        )

    elif kernel_type == "RW":
        K_config = mcmc.MCMCConfig(
            n_iter=kernel_cfg["steps_K"],
            kernel_fn=mcmc.rw_kernel,
            kernel_params={
                "noise_dist": torch.distributions.Normal(
                    0.0, kernel_cfg["sigma_K"]
                )
            },
        )
        L_config = mcmc.MCMCConfig(
            n_iter=kernel_cfg["steps_L"],
            kernel_fn=mcmc.rw_kernel,
            kernel_params={
                "noise_dist": torch.distributions.Normal(
                    0.0, kernel_cfg["sigma_L"]
                )
            },
        )

    else:
        raise ValueError(f"Unknown kernel_type={kernel_type}")

    # --------------------------------------------------
    # E2MC parameters
    # --------------------------------------------------
    e2mc_params = {
        "N": cfg["N"],
        "T": cfg["n_iter_e2mc"],
        "eps": cfg["eps"],
        "K0": K0,
    }

    # --------------------------------------------------
    # Metadata for plotting
    # --------------------------------------------------
    metadata = {
        "model": model_name,
        "d": d,
        "kernel": kernel_type,
        "kernel_params": kernel_cfg,
        "n_iter": cfg["n_iter_e2mc"],
        "eps": cfg["eps"],
        "N": cfg["N"],
    }

    return target, init_proposal, K_config, L_config, e2mc_params, metadata



