from targets.base import BaseTarget, GrayCodedTarget
from targets.gmm import DiscretisedGMM
from targets.manywell import DiscretisedManyWell
from targets.ising2d import Ising2D
from targets.potts2d import Potts2D


from typing import TYPE_CHECKING

if TYPE_CHECKING:
    import torch
    from omegaconf import DictConfig


def create_target(cfg: "DictConfig", device: "torch.device") -> "BaseTarget":
    """Create a target distribution based on the configuration.

    Args:
        cfg: Hydra configuration.
        device: Device to place tensors on.

    Returns:
        A target distribution instance.
    """
    if cfg.target.name == "gmm":
        target = DiscretisedGMM(
            device=device,
            spatial_dim=cfg.target.spatial_dim,
            n_bits=cfg.target.n_bits,
            translate=cfg.target.translate,
            scale=cfg.target.scale,
            n_centres=cfg.target.n_centres,
            variance=cfg.target.variance,
            seed=cfg.seed,
        )
    elif cfg.target.name == "manywell":
        target = DiscretisedManyWell(
            device=device,
            spatial_dim=cfg.target.spatial_dim,
            rotated=cfg.target.rotated,
            beta=cfg.target.beta,
            n_bits=cfg.target.n_bits,
            translate=cfg.target.translate,
            scale=cfg.target.scale,
            seed=cfg.seed,
        )
    elif cfg.target.name == "ising":
        if cfg.target.ising_J < 0:
            # Metropolis-Hastings sampling
            mcmc_configs = {"B": 128, "burn_in": 2**20, "collect_every": 2**16}
        else:
            # Swendsen-Wang sampling
            mcmc_configs = {"B": 128, "burn_in": 2**16, "collect_every": 2**10}

        target = Ising2D(
            device=device,
            L=cfg.target.ising_L,
            beta=cfg.target.ising_beta,
            J=cfg.target.ising_J,
            h=cfg.target.ising_h,
            mcmc_configs=mcmc_configs,
            seed=cfg.seed,
        )
    elif cfg.target.name == "potts":
        if cfg.target.potts_J < 0:
            # Metropolis-Hastings sampling
            mcmc_configs = {"B": 128, "burn_in": 2**20, "collect_every": 2**16}
        else:
            # Swendsen-Wang sampling
            mcmc_configs = {"B": 128, "burn_in": 2**16, "collect_every": 2**10}

        target = Potts2D(
            device=device,
            L=cfg.target.potts_L,
            q=cfg.target.potts_q,
            beta=cfg.target.potts_beta,
            J=cfg.target.potts_J,
            mcmc_configs=mcmc_configs,
            seed=cfg.seed,
        )
    else:
        raise ValueError(f"Unknown target: {cfg.target.name}")

    return target
