
import torch
from algorithms import (
    RandomWalkSampler,
    AISSampler,
    NUTSSampler,
    E2MCSampler,
)
import mcmc


def run_all_methods(config):

    results = {}

    # ------------------
    # Shared RW kernel
    # ------------------
    base_mcmc = mcmc.MCMCConfig(
        n_iter=config.rw_steps,
        kernel_fn=mcmc.rw_kernel,
        kernel_params={
            "noise_dist": torch.distributions.Normal(
                0.0, config.rw_sigma
            )
        },
        grad_logpdf_fn=None,
    )

    # ------------------
    # RW
    # ------------------
    rw = RandomWalkSampler(
        sigma=config.rw_sigma,
        n_steps=config.rw_steps,
        device=config.device,
    )
    results["RW"] = rw.run(
        config.target, config.init_proposal, config.N
    )

    # ------------------
    # AIS
    # ------------------
    ais = AISSampler(
        mcmc_config=base_mcmc,
        n_intermediate=config.ais_intermediate,
        device=config.device,
    )
    results["AIS"] = ais.run(
        config.target, config.init_proposal, config.N
    )

    # ------------------
    # NUTS
    # ------------------
    nuts = NUTSSampler(
        num_samples=config.N,
        device=config.device,
    )
    results["NUTS"] = nuts.run(
        config.target, config.init_proposal
    )

    # ------------------
    # E2MC
    # ------------------
    e2mc = E2MCSampler(
        base_mcmc_config=base_mcmc,
        K_steps=config.K_steps,
        L_steps=config.L_steps,
        L_sigma=config.L_sigma,
        T=config.T,
        eps=config.eps,
        lamda=config.lamda,
        flow_kwargs=config.flow_kwargs,
        device=config.device,
    )

    _, results["E2MC"], _, _ = e2mc.run(
        config.target, config.init_proposal, config.N
    )

    return results
