import argparse
from pathlib import Path

import numpy as np
import torch

from margflow.datasets.datasets import create_dataset
from margflow.marginal_flow import MarginalFlow
from margflow.trainer import train_marginal_flow
from margflow.utils.plot_utils import (
    plot_samples_2D,
    plot_likelihood_2D,
)
from margflow.utils.training_utils import set_random_seed, create_directories, model_signature

parser = argparse.ArgumentParser(description="Process some integers.")

# TRAINING PARAMETERS
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--device", type=str, default="cuda", help="device for training the model")
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs")
# parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument(
    "--training_mode",
    type=str,
    default="log_likelihood",
    choices=[
        "kl_divergence",  # variational inference: known (unnormalized) target and no training data
        "log_likelihood",  # density estimation: unknown target but training data
        "symmetric_kl",  # mixed setting: known (unnormalized) target and training data
        "score_matching",  # TODO: implement score matching both for variational inference and density estimation
    ],
)

# simulated annealing parameters
parser.add_argument("--T0", type=float, default=4.0, help="initial temperature")
parser.add_argument("--Tn", type=float, default=1.0, help="final temperature")
parser.add_argument(
    "--iter_per_cool_step",
    type=int,
    default=50,
    help="iterations per cooling step in simulated annealing",
)
parser.add_argument(
    "--temp_min_it_ratio", type=float, default=0.8, help="minimum temperature ratio"
)

# DATASETS PARAMETERS
parser.add_argument("--x_dim", type=int, default=2, help="number of dimensions")
parser.add_argument(
    "--dataset",
    type=str,
    default="mog",
    choices=[
        "uniform_square",
        "mog",
        "mog_manifold",
        "power",
        "gas",
        "hepmass",
        "miniboone",
        "two_moons",
        "sbi_two_moons",
        "two_circles",
        "swiss_roll",
        "checkerboard",
        "pinwheel",
    ],
)
parser.add_argument(
    "--use_dataset_dim",
    type=bool,
    default=True,
    help="use dataset dimension (if applicable) instead of provided",
)
parser.add_argument(
    "--logp_estimator",
    type=str,
    default="parzen",
    choices=["parzen", "kde"],
    help="estimate log prob by fitting kde on samples or by using a parzen window on samples. Ignored if log prob is available analytically",
)

parser.add_argument(
    "--n_mog", type=int, default=12, help="number of mixture components of Gaussian mixture model"
)
parser.add_argument(
    "--mog_sigma", type=float, default=0.2, help="standard deviation of Gaussian mixture"
)
parser.add_argument(
    "--manifold", action="store_true", help="data lives in a lower dimensional space"
)
parser.add_argument(
    "--manifold_type", type=str, default="spiral", choices=["line", "sin", "circle", "spiral"]
)
parser.add_argument(
    "--epsilon", type=float, default=0.00, help="std of the isotropic noise in the data"
)
parser.add_argument(
    "--n_samples_dataset", type=int, default=10_000, help="number of data points in the dataset"
)
parser.add_argument("--bounds", type=float, default=1.0, help="domain bound")
parser.add_argument(
    "--n_samples_dataset_test",
    type=int,
    default=10_000,
    help="number of data points used for test",
)
parser.add_argument("--batch_size", type=int, default=10_000, help="batch size dimension")
parser.add_argument(
    "--accept_less_datapoints",
    type=bool,
    default=True,
    help="if n_samples_dataset is larger than the available number of datapoints just take all of them",
)
parser.add_argument(
    "--shuffle_dataset",
    type=bool,
    default=True,
    help="if the data should be shuffled before training",
)
parser.add_argument(
    "--normalize_dataset_type", type=int, default=1, help="0 -> [0,1], 1 -> N(0,1)"
)
parser.add_argument(
    "--keep_datapoints",
    action="store_true",
    help="if the samples should be kept constant for all iterations",
)
# von mises fisher parameters
parser.add_argument("--mu", type=float, default=None, help="mean of von mises distribution")
parser.add_argument(
    "--kappa", type=float, default=5.0, help="concentration parameter of von mises distribution"
)

parser.add_argument(
    "--n_turns_spiral",
    type=int,
    default=4,
    help="number of spiral turns for sphere spiral distribution",
)

parser.add_argument("--base_distribution", type=str, default="mog", choices=["mog", "uniform"])


args = parser.parse_args()
args.script_path = Path(__file__).resolve().parent
dtype = torch.float32
args.dtype = dtype


def main():
    set_random_seed(seed=args.seed)
    create_directories(args.script_path)
    dataset = create_dataset(args)
    if args.use_dataset_dim:
        args.n_dim = dataset.D
    assert dataset.D == args.n_dim, "dimensionality mismatch"
    signature = model_signature(args=args, dataset=dataset.dataset_suffix)

    metrics = ["mse_logp", "kl_for", "kl_rev"]  # , "sinkhorn", "energy", "gaussian", "laplacian"]

    # =#=#=# train marginal flow with neural network transformed means #=#=#=#
    use_trainable_means = False
    n_mixtures = 10  # args.n_samples_dataset  # 2048
    n_mixture_base = 10
    dropout = 0.0
    base_dim = 1 if args.manifold else args.n_dim
    log_sigma_init = np.log(0.1)
    marginal_flow = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        use_trainable_means=use_trainable_means,
        base_distribution=args.base_distribution,
        dropout=dropout,
        n_base_means=n_mixture_base,
        device=args.device,
        dtype=dtype,
        signature=signature,
        script_path=args.script_path,
    )

    overwrite_margflow = False
    lr_network = 5e-5
    lr_sigma = 5e-20
    train_marginal_flow(
        model=marginal_flow,
        n_mixtures=n_mixtures,
        n_epochs=args.n_epochs,
        batch_size=args.batch_size,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        metrics=metrics,
        fixed_datapoints=args.keep_datapoints,
        save_best_val=False,
        overwrite=overwrite_margflow,
    )
    marginal_flow.eval()

    # =#=#=# train marginal flow with directly learnable means (no neural network) #=#=#=#
    use_trainable_means = True
    gmm = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        use_trainable_means=use_trainable_means,
        dropout=dropout,
        n_base_means=n_mixtures,
        device=args.device,
        dtype=dtype,
        signature=signature,
    )
    train_marginal_flow(
        model=gmm,
        n_mixtures=n_mixtures,
        n_epochs=args.n_epochs,
        batch_size=args.batch_size,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        metrics=metrics,
        fixed_datapoints=args.keep_datapoints,
        save_best_val=False,
    )
    gmm.eval()

    plotting_repeats = 2
    if args.n_dim == 2:
        for i in range(plotting_repeats):
            other_models = {
                "gmm": gmm,
            }
            plot_samples_2D(
                model=marginal_flow,
                target=dataset,
                n_outer_samples=n_mixtures * 1000,
                n_samples=10_000,
                other_models=other_models,
                bound=args.bounds + 0.5,
            )
            plot_likelihood_2D(
                model=marginal_flow,
                target=dataset,
                grid_size=100,
                n_outer_samples=n_mixtures * 1000,
                other_models=other_models,
                bound=args.bounds + 0.5,
            )
    # plot_metrics_runtime(
    #     {
    #         "marginal-flow": marginal_flow,
    #         "gmm": gmm,
    #         # "normalizing-flow": normalizing_flow,
    #         # "free-form-flow": fff_model,
    #         # "flow-matching": flowmatch,
    #     },
    #     metrics=metrics,
    # )


if __name__ == "__main__":
    main()
