import torch
from torch.utils.data import DataLoader, TensorDataset
import mcmc
import utils
from tqdm import *
import zuko




class FlowProposal:
    """
    Wrapper to give a zuko flow a standard proposal interface.
    """

    def __init__(self, nf):
        self.nf = nf  # zuko flow (RealNVP, NSF, ...)

    def sample(self, N):
        return self.nf().sample((N,))

    def log_prob(self, x):
        return self.nf().log_prob(x)





class E2MCSampler:
    """
    E2MC sampler.

    - Kernel K: shared with RW / AIS (same kernel + step size),
                but with E2MC-specific number of steps.
    - Kernel L: E2MC-specific kernel with its own parameters.
    """

    def __init__(
        self,
        base_mcmc_config: mcmc.MCMCConfig,
        K_steps: int,
        L_steps: int,
        L_sigma: float,
        T: int,
        eps: float,
        lamda: float,
        flow_kwargs: dict,
        device="cpu",
        show_progress=True,
    ):
        """
        Args:
            base_mcmc_config : MCMCConfig shared with RW / AIS
                               (kernel + step size)
            K_steps          : number of kernel-K steps per E2MC iteration
            L_steps          : number of kernel-L steps
            L_sigma          : RW step size for kernel L
            T                : number of E2MC outer iterations
            eps              : importance tempering parameter
            lamda            : mixture coefficient
            flow_kwargs      : kwargs passed to flow training
            device           : cpu / cuda
        """
        self.base_mcmc_config = base_mcmc_config
        self.K_steps = K_steps
        self.L_steps = L_steps
        self.L_sigma = L_sigma
        self.T = T
        self.eps = eps
        self.lamda = lamda
        self.flow_kwargs = flow_kwargs
        self.device = device
        self.show_progress = show_progress

    def run(
        self,
        target,
        init_proposal,
        N: int,
    ):
        """
        Run E2MC and return final resampled particles.

        Returns:
            x_final        : Tensor [N, d] (raw samples)
            x_resampled    : Tensor [N, d] (importance-resampled)
            final_proposal : FlowProposal
            proposal_hist  : list of FlowProposal
        """

        # ----------------------------------------------------
        # Kernel K: shared with RW / AIS
        # ----------------------------------------------------
        K_config = mcmc.MCMCConfig(
            n_iter=self.K_steps,
            kernel_fn=self.base_mcmc_config.kernel_fn,
            kernel_params=self.base_mcmc_config.kernel_params,
            grad_logpdf_fn=self.base_mcmc_config.grad_logpdf_fn,
        )

        # ----------------------------------------------------
        # Kernel L: E2MC-specific RW kernel
        # ----------------------------------------------------
        L_config = mcmc.MCMCConfig(
            n_iter=self.L_steps,
            kernel_fn=mcmc.rw_kernel,
            kernel_params={
                "noise_dist": torch.distributions.Normal(
                    loc=0.0,
                    scale=self.L_sigma,
                )
            },
            grad_logpdf_fn=None,
        )

        return e2mc_flow(
            n_iter=self.T,
            target=target,
            init_proposal=init_proposal,
            K_config=K_config,
            L_config=L_config,
            eps=self.eps,
            lamda=self.lamda,
            N=N,
            flow_kwargs=self.flow_kwargs,
            show_progress=self.show_progress,
            device=self.device,
        )



def e2mc_flow(
    n_iter,
    target,
    init_proposal,
    K_config,
    L_config,
    eps,
    lamda,
    N,
    flow_kwargs,
    show_progress=True,
    device="cpu",
):
    """
    Core E2MC algorithm.
    """

    if flow_kwargs is None:
        flow_kwargs = {}

    # Initial proposal
    proposal = init_proposal
    x = proposal.sample(N).to(device)

    proposal_history = [proposal]

    iterator = range(n_iter)
    if show_progress:
        iterator = tqdm(iterator, desc="E2MC", leave=True)

    for _ in iterator:

        # -------------------------
        # Exploration (kernel K)
        # -------------------------
        y = mcmc.mcmc(x, target.logpi, K_config)[-1]

        # -------------------------
        # Importance weights
        # -------------------------
        w_x = utils.compute_weights(
            x, target.logpi, proposal.log_prob, eps
        )
        w_y = utils.compute_weights(
            y, target.logpi, proposal.log_prob, eps
        )

        # -------------------------
        # Mixture resampling
        # -------------------------
        Z = utils.sample_from_mixture(x, w_x, y, w_y, lamda)

        # -------------------------
        # Move step (kernel L)
        # -------------------------
        Z = mcmc.mcmc(Z, target.logpi, L_config)[-1].detach()

        # -------------------------
        # Flow projection
        # -------------------------
        nf, _ = fit_flow_with_dataloader(
            X=Z,
            d=Z.shape[1],
            device=device,
            show_progress=False,
            **flow_kwargs,
        )

        proposal = FlowProposal(nf)
        x = proposal.sample(N)

        proposal_history.append(proposal)

    # ----------------------------------------------------
    # Final importance resampling
    # ----------------------------------------------------
    w_final = compute_final_importance_weights(
        x, target, proposal
    )
    x_resampled = resample_from_weights(x, w_final, N)

    return x, x_resampled, proposal, proposal_history



# Utilities

def compute_final_importance_weights(x, target, proposal):
    with torch.no_grad():
        logw = target.logpi(x) - proposal.log_prob(x)
        logw = logw - logw.max()
        w = torch.exp(logw)
    return w / w.sum()


def resample_from_weights(x, w, N):
    idx = torch.multinomial(w, N, replacement=True)
    return x[idx]


def fit_flow_with_dataloader(
    X,
    d,
    flow_type="NSF",
    n_epochs=15,
    batch_size=1024,
    lr=1e-3,
    hidden_features=(128, 128),
    n_transforms=6,
    bins=16,
    shuffle=True,
    device="cpu",
    show_progress=True,
):
    """
    Fit a normalizing flow to samples X by maximum likelihood.
    """

    X = X.to(device)

    dataset = TensorDataset(X)
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=True,
    )

    if flow_type == "NSF":
        nf = zuko.flows.NSF(
            features=d,
            hidden_features=hidden_features,
            transforms=n_transforms,
            bins=bins,
        ).to(device)

    elif flow_type == "MAF":
        nf = zuko.flows.MAF(
            features=d,
            hidden_features=hidden_features,
            transforms=n_transforms,
        ).to(device)

    elif flow_type == "RealNVP":
        nf = zuko.flows.RealNVP(
            features=d,
            hidden_features=hidden_features,
            transforms=n_transforms,
        ).to(device)

    else:
        raise ValueError(f"Unknown flow type: {flow_type}")

    optimizer = torch.optim.Adam(nf.parameters(), lr=lr)

    epoch_iter = range(n_epochs)
    if show_progress:
        epoch_iter = tqdm(epoch_iter, desc=f"Flow training ({flow_type})")

    for _ in epoch_iter:
        for (batch,) in loader:
            loss = -nf().log_prob(batch).mean()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    return nf, None