from dataclasses import dataclass
import torch
from targets import RingsDistribution, NFDistributionWrapper
from algorithms.proposals import GaussianProposal
import normflows as nf
import mcmc


@dataclass
class ExperimentConfig:
    name: str
    target: object
    init_proposal: object

    # shared sampling
    N: int
    device: str

    # RW / AIS / E2MC-K
    rw_sigma: float
    rw_steps: int

    # AIS
    ais_intermediate: int

    # E2MC
    K_steps: int
    L_steps: int
    L_sigma: float
    T: int
    eps: float
    lamda: float
    flow_kwargs: dict

    # plotting
    xlim: tuple
    ylim: tuple




def rings_config(device="cpu"):
    return ExperimentConfig(
        name="rings",


        # Target distribution
        target=RingsDistribution(
            d=2,
            radii=(1.0, 4.0),
            sigma=0.1,
            device=device,
        ),

        # Initial proposal
        init_proposal=GaussianProposal(
            mean=torch.tensor([0.0, 0.0]),
            std=0.2,
            device=device,
        ),


        # Global
        N=10_000,
        device=device,


        # RW / AIS / E2MC-K
        # (shared kernel)
        rw_sigma=0.9,
        rw_steps=1000,

        # AIS
        ais_intermediate=40,

        # E2MC
        K_steps=10,          # kernel K steps
        L_steps=5,           # kernel L steps
        L_sigma=0.1,         # smaller local RW
        T=6,                 # E2MC outer iterations
        eps=0.8,
        lamda=0.8,

        # Flow 
        flow_kwargs={
            "flow_type": "NSF",
            "n_epochs": 8,
            "hidden_features": (64, 64),
            "n_transforms": 3,
            "bins": 8,
            "batch_size": 256,
            "lr": 1e-3,
        },


        # Plotting
        xlim=(-5.0, 5.0),
        ylim=(-5.0, 5.0),
    )



def moons_config(device="cpu"):
    return ExperimentConfig(
        name="moons",
        target=NFDistributionWrapper(
            nf.distributions.TwoMoons(), device
        ),
        init_proposal=GaussianProposal(
            mean=torch.tensor([1., 1.]),
            std=0.2,
            device=device,
        ),
        N=10_000,
        device=device,
        rw_sigma=1.0,
        rw_steps=1000,
        ais_intermediate=40,
        K_steps=10,
        L_steps=5,
        L_sigma=0.1,
        T=6,
        eps=0.8,
        lamda=0.8,
        flow_kwargs={
            "flow_type": "NSF",
            "n_epochs": 8,
            "hidden_features": (64, 64),
            "n_transforms": 3,
            "bins": 8,
            "batch_size": 64,
            "lr": 1e-3,
        },
        xlim=(-3, 3),
        ylim=(-3, 3),
    )
