"""Script for running training experiments."""
from functools import partial
import os

from click import command, option
import torch
import pyro
import wandb

from calnf.datasets import (
    FewShotMNIST,
    FewShotCIFAR,
)
from calnf.guides import (
    VanillaFlow,
    Glow,
    KLRegularizedFlow,
    W2RegularizedFlow,
    CalibratedFlow,
    BalancedCalibratedFlow,
    BalancedGlow,
    BaggedFlow,
    BaggedGlow,
)
from calnf.metrics import (
    ELBOMetric,
    ClassifierMetrics,
)
from calnf.trainer import train


@command()
@option("--epochs", default=1000, help="Number of epochs.")
@option("--device", default="cpu", help="Device to use for training.")
@option("--seed", default=0, help="Random seed.")
@option("--method", default="vanilla_nf", help="Method to use.")
@option("--lr", default=1e-4, help="Learning rate.")
@option("--grad-clip", default=1.0, help="Gradient clip for guide.")
@option("--subsamples", default=4, help="Number of subsamples to use.")
@option("--problem", default="mnist", help="Dataset to use.")
@option("--reg_penalty", default=1.0, help="Regularization penalty.")
def main(epochs, device, seed, method, lr, grad_clip, subsamples, problem, reg_penalty):
    # Define dataset and method options
    dataset_options = {
        "mnist": FewShotMNIST,
        "mnist_all": partial(FewShotMNIST, include_all_nominal=True),
        "cifar": FewShotCIFAR,
        "cifar_all": partial(FewShotCIFAR, include_all_nominal=True),
    }
    method_options = {
        # key: (class, name, group, kwargs)
        "vanilla_nf": (VanillaFlow, "vanilla_nf", "vanilla_nf", {}),
        "kl_nf": (KLRegularizedFlow, "kl", "kl", {"reg_penalty": reg_penalty}),
        "w2_nf": (W2RegularizedFlow, "w2", "w2", {"reg_penalty": reg_penalty}),
        "bagged_nf": (
            BaggedFlow,
            "bagged_nf",
            "bagged_nf",
            {"num_subsamples": subsamples},
        ),
        "bagged_glow": (
            BaggedGlow,
            "bagged_glow",
            "bagged_glow",
            {
                "num_subsamples": subsamples,
                "hidden_channels": 256,
                "L": 2,
                "K": 8,
                "input_shape": (3, 32, 32) if problem == "cifar" else (1, 32, 32),
            },
        ),
        "calnf": (CalibratedFlow, "calnf", "calnf", {"num_subsamples": subsamples}),
        "balcalnf": (
            BalancedCalibratedFlow,
            "calnf_balanced",
            "calnf_balanced",
            {"num_subsamples": subsamples, "reg_penalty": reg_penalty},
        ),
        "glow": (
            Glow,
            "glow",
            "glow",
            {
                "hidden_channels": 256,
                "L": 2,
                "K": 8,
                "input_shape": (3, 32, 32) if problem == "cifar" else (1, 32, 32),
            },
        ),
        "balcalnf_glow": (
            BalancedGlow,
            "balanced_glow",
            "balanced_glow",
            {
                "num_subsamples": 4,
                "reg_penalty": reg_penalty,
                "hidden_channels": 256,
                "L": 2,
                "K": 8,
                "input_shape": (3, 32, 32) if problem == "cifar" else (1, 32, 32),
            },
        ),
    }

    # Set the random seed
    torch.manual_seed(seed)
    pyro.set_rng_seed(seed)

    # Get the dataset and method
    if problem not in dataset_options:
        available_datasets = list(dataset_options.keys())
        raise ValueError(
            f"Problem {problem} not recognized (must be one of {available_datasets})."
        )
    dataset = dataset_options[problem]()

    if method not in method_options:
        available_methods = list(method_options.keys())
        raise ValueError(
            f"Method {method} not recognized (must be one of {available_methods})."
        )
    method, method_name, method_group, method_kwargs = method_options[method]
    method_kwargs["grad_clip"] = grad_clip

    # Initialize the guide
    guide = method(device=device, dataset=dataset, **method_kwargs)

    # Define metrics
    metrics = [
        (
            ELBOMetric(dataset=dataset, n_elbo_particles=10),
            1,
        ),
        (ClassifierMetrics(dataset=dataset, n_particles=10), 0),
    ]

    # Initialize wandb
    wandb.init(
        project=f"calnf-{problem}",
        group=method_group,
        name=method_name,
        config={
            "epochs": epochs,
            "seed": seed,
            "lr": lr,
        }
        | method_kwargs,
    )

    # Train the guide
    train(
        name=os.path.join(problem, method_name),
        device=device,
        dataset=dataset,
        guide=guide,
        epochs=epochs,
        metrics=metrics,
        visualize_every_n=5,
        lr=lr,
    )


if __name__ == "__main__":
    # Set a non-interactive backend for matplotlib
    import matplotlib

    matplotlib.use("Agg")

    # Run the main function
    main()
