import argparse
from pathlib import Path

import torch

from margflow.datasets.dataset_abstracts import DiscreteSampleDataset, DensityDataset
from margflow.datasets.datasets import create_dataset
from margflow.marginal_flow import MarginalFlow  # , ParzenFlowWithSpikedCov
from margflow.other_models.normalizing_flow import NormalizingFlow
from margflow.trainer import train_marginal_flow
from margflow.utils.plot_utils import (
    plot_samples,
    plot_samples_2D,
    plot_metrics_runtime,
    plot_likelihood_2D,
)
from margflow.utils.training_utils import (
    set_random_seed,
    create_directories,
    model_signature,
    TemperatureSchedule,
)

parser = argparse.ArgumentParser(description="Process some integers.")

# TRAINING PARAMETERS
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--device", type=str, default="cuda", help="device for training the model")
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs")
# parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument(
    "--training_mode",
    type=str,
    default="log_likelihood",
    choices=[
        "kl_divergence",  # variational inference: known (unnormalized) target and no training data
        "log_likelihood",  # density estimation: unknown target but training data
        "symmetric_kl",  # mixed setting: known (unnormalized) target and training data
        "score_matching",  # TODO: implement score matching both for variational inference and density estimation
    ],
)

# simulated annealing parameters
parser.add_argument("--T0", type=float, default=4.0, help="initial temperature")
parser.add_argument("--Tn", type=float, default=1.0, help="final temperature")
parser.add_argument(
    "--iter_per_cool_step",
    type=int,
    default=50,
    help="iterations per cooling step in simulated annealing",
)
parser.add_argument(
    "--temp_min_it_ratio", type=float, default=0.8, help="minimum temperature ratio"
)

# DATASETS PARAMETERS
parser.add_argument("--x_dim", type=int, default=2, help="number of dimensions")
parser.add_argument(
    "--dataset",
    type=str,
    default="miniboone",
    choices=[
        "uniform_square",
        "mog",
        "mog_manifold",
        "power",
        "gas",
        "hepmass",
        "miniboone",
        "two_moons",
        "sbi_two_moons",
    ],
)
parser.add_argument(
    "--use_dataset_dim",
    type=bool,
    default=True,
    help="use dataset dimension (if applicable) instead of provided",
)
parser.add_argument(
    "--logp_estimator",
    type=str,
    default="parzen",
    choices=["parzen", "kde"],
    help="estimate log prob by fitting kde on samples or by using a parzen window on samples. Ignored if log prob is available analytically",
)

# sbi datasets
parser.add_argument(
    "--n_simulations",
    type=int,
    default=1000,
    help="number of simulations in simulation based inference",
)
# mixture of Gaussians (MoG)
parser.add_argument(
    "--bounds",
    type=float,
    default=2.0,
    help="data/target distribution lives in [-bound,bound]^n_dim",
)
parser.add_argument(
    "--n_mog", type=int, default=12, help="number of mixture components of Gaussian mixture model"
)
parser.add_argument(
    "--mog_sigma", type=float, default=0.2, help="standard deviation of Gaussian mixture"
)
parser.add_argument(
    "--manifold", action="store_true", help="data lives in a lower dimensional space"
)
parser.add_argument(
    "--manifold_type", type=str, default="spiral", choices=["line", "sin", "circle", "spiral"]
)
parser.add_argument(
    "--epsilon", type=float, default=0.00, help="std of the isotropic noise in the data"
)
parser.add_argument(
    "--n_samples_dataset", type=int, default=10_000, help="number of data points in the dataset"
)
parser.add_argument(
    "--n_samples_dataset_test",
    type=int,
    default=10_000,
    help="number of data points used for test",
)
parser.add_argument("--batch_size", type=int, default=10_000, help="batch size dimension")
parser.add_argument(
    "--accept_less_datapoints",
    type=bool,
    default=True,
    help="if n_samples_dataset is larger than the available number of datapoints just take all of them",
)
parser.add_argument(
    "--shuffle_dataset",
    type=bool,
    default=True,
    help="if the data should be shuffled before training",
)
parser.add_argument(
    "--normalize_dataset_type", type=int, default=1, help="0 -> [0,1], 1 -> N(0,1)"
)
parser.add_argument(
    "--keep_datapoints",
    action="store_true",
    help="if the samples should be kept constant for all iterations",
)
# von mises fisher parameters
parser.add_argument("--mu", type=float, default=None, help="mean of von mises distribution")
parser.add_argument(
    "--kappa", type=float, default=5.0, help="concentration parameter of von mises distribution"
)
# von mises fisher mixture parameters
parser.add_argument(
    "--n_mix",
    type=int,
    default=50,
    help="number of mixture components for mixture of von mises fisher distribution",
)
parser.add_argument(
    "--kappa_mix",
    type=float,
    default=30.0,
    help="concentration parameter of mixture of von mises distribution",
)
parser.add_argument(
    "--alpha_mix",
    type=float,
    default=0.3,
    help="alpha parameter of mixture of von mises distribution",
)
parser.add_argument(
    "--n_turns_spiral",
    type=int,
    default=4,
    help="number of spiral turns for sphere spiral distribution",
)
# uniform checkerboard parameters
parser.add_argument(
    "--n_theta", type=int, default=6, help="number of rows in the checkerboard (n_theta>0)"
)
parser.add_argument(
    "--n_phi",
    type=int,
    default=6,
    help="number of columns in the checkerboard (n_phi>0 and must be even)",
)
# lp uniform parameters
parser.add_argument("--beta", type=float, default=1.0, help="p of the lp norm")
parser.add_argument("--radius", type=float, default=1.0, help="radius of manifold")

args = parser.parse_args()
args.script_path = Path(__file__).resolve().parent
dtype = torch.float32
args.dtype = dtype


def main():
    set_random_seed(seed=args.seed)
    create_directories(args.script_path)
    dataset = create_dataset(args)
    if args.use_dataset_dim:
        args.n_dim = dataset.D
    assert dataset.D == args.n_dim, "dimensionality mismatch"
    samples_to_plot = dataset.sample(args.n_samples_dataset, data_type="train")
    if not isinstance(samples_to_plot, tuple):
        plot_samples(samples_to_plot.detach().cpu().numpy(), bounds=dataset.domain_bound, alpha=1)
    signature = model_signature(args=args, dataset=dataset.dataset_suffix)
    has_samples = isinstance(dataset, DiscreteSampleDataset)
    has_density = isinstance(dataset, DensityDataset)
    sample_training = args.training_mode in ["log_likelihood", "symmetric_kl"]
    density_training = args.training_mode in ["kl_divergence", "symmetric_kl"]
    assert has_samples or has_density
    assert (
        not sample_training or has_samples
    ), "wants to perform training on samples but dataset cannot provide samples"
    assert (
        not density_training or has_density
    ), "wants to perform training on density but dataset does not provide density"

    # temperature for simulated annealing; ignored when training_mode == "log_likelihood"
    temperature = TemperatureSchedule(
        temp_max=args.T0,
        temp_min=args.Tn,
        num_iter=args.n_epochs,
        temp_min_it_ratio=args.temp_min_it_ratio,
    )

    metrics = ["mse_logp", "kl_for", "kl_rev", "sinkhorn", "energy", "gaussian", "laplacian"]

    # =#=#=# train flow matching #=#=#=#
    # n_hidden_dim = 512
    # n_layers = 5
    # lr_flowmatch = 1e-3
    # overwrite_flowmatch = False
    # flowmatch = FlowMatching(
    #     x_dim=args.n_dim,
    #     hid_dim=n_hidden_dim,
    #     n_layers=n_layers,
    #     script_path=args.script_path,
    #     signature=signature,
    #     device=args.device,
    # )
    # flowmatch.train_samples(
    #     n_epochs=args.n_epochs,
    #     dataset=dataset,
    #     args=args,
    #     metrics=metrics,
    #     overwrite=overwrite_flowmatch,
    #     lr=lr_flowmatch,
    # )
    # flowmatch.solve_ode(plot=True)
    # flowmatch.plot_log_likelihood(bound=dataset.domain_bound)

    # =#=#=# train marginal flow #=#=#=#
    n_mapped_z = 512  # 32768 # TODO discuss names -> n_mixture_components / n_mixture_samples
    n_samples_from_mz = 2048  # 2000
    n_mixture_base = 20
    dropout = 0
    base_dim = 1 if args.manifold else args.n_dim
    log_sigma_init = 0
    marginal_flow = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        dropout=dropout,
        n_base_means=n_mixture_base,
        device=args.device,
        dtype=dtype,
        signature=signature,
        script_path=args.script_path,
    )

    lr_network = 5e-4
    lr_sigma = 5e-2
    plotting_repeats = 2  # how many times to sample from marginal flow after training
    overwrite_marg_flow = False
    train_marginal_flow(
        model=marginal_flow,
        n_mixtures=n_mapped_z,
        batch_size=n_samples_from_mz,
        n_epochs=args.n_epochs,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        metrics=metrics,
        overwrite=overwrite_marg_flow,
    )
    marginal_flow.eval()

    # =#=#=#  train a normalizing flow #=#=#=#
    # density_training = False
    lr_normflow = 1e-3
    n_layers_normflow = 6
    n_hidden_features = 128  # ignored with iResBlock layers
    overwrite_normflow = False

    if density_training:
        normalizing_flow = NormalizingFlow(
            x_dim=args.n_dim,
            direction="reverse",
            n_layers=n_layers_normflow,
            signature=signature,
            script_path=args.script_path,
        )
        normalizing_flow.train_reverse(
            log_target=dataset.log_prob,
            lr=lr_normflow,
            batch_size=n_samples_from_mz,
            n_epochs=args.n_epochs,
            temperature=temperature,
            overwrite=overwrite_normflow,
        )
    elif sample_training and not density_training:  # to make it clear -> pure sample training
        normalizing_flow = NormalizingFlow(
            x_dim=args.n_dim,
            direction="forward",
            n_layers=n_layers_normflow,
            signature=signature,
            script_path=args.script_path,
        )
        normalizing_flow.train_forward(
            dataset=dataset,
            batch_size=args.batch_size,
            n_epochs=args.n_epochs,
            lr=lr_normflow,
            metrics=metrics,
            overwrite=overwrite_normflow,
        )
    normalizing_flow.flow.eval()
    if args.n_dim == 2:
        for i in range(plotting_repeats):
            # plot_mog(model=marginal_flow, flow=normalizing_flow, target=dataset, n_outer_samples=n_samples_from_mz,
            #          n_gridpoints=150, plot_surface=True, bound=args.bounds, idxs=str(i))
            other_models = {
                "normalizing-flow": normalizing_flow,
                # "flow-matching": flowmatch
            }
            plot_samples_2D(
                model=marginal_flow,
                target=dataset,
                n_outer_samples=1_000,
                n_samples=10_000,
                other_models=other_models,
                bound=dataset.domain_bound,
            )
            plot_likelihood_2D(
                model=marginal_flow,
                target=dataset,
                grid_size=100,
                n_outer_samples=1_000,
                other_models=other_models,
                bound=dataset.domain_bound,
            )

    plot_metrics_runtime(
        {
            "marginal-flow": marginal_flow,
            "normalizing-flow": normalizing_flow,
            # "free-form-flow": fff_model,
            # "flow-matching": flowmatch,
        },
        metrics=metrics,
    )
    # plot_metrics_runtime({"marginal-flow":marginal_flow, "flow-matching":flowmatch, "normalizing-flow":normalizing_flow}, metrics=metrics)

    # =#=#=#  train a free form flow (using their original code) #=#=#=#

    # n_layers = 5
    # hidden_dims = 128
    # fff_model = FreeFormFlow(
    #     x_dim=args.n_dim,
    #     z_dim=base_dim,
    #     n_layers=n_layers,
    #     signature=signature,
    #     script_path=args.script_path,
    #     hid_dim=hidden_dims,
    # )
    #
    # overwrite_fff = False
    # beta = 100
    # fff_lr = 1e-3
    # args.keep_datapoints = True
    # fff_model.train_forward(
    #     dataset=dataset,
    #     beta=beta,
    #     n_epochs=args.n_epochs * 2,
    #     lr=fff_lr,
    #     batch_size=args.batch_size,
    #     fixed_datapoints=args.keep_datapoints,
    #     overwrite=overwrite_fff,
    #     metrics=metrics,
    # )
    if args.n_dim == 2:
        for i in range(plotting_repeats):
            # plot_mog(model=marginal_flow, flow=normalizing_flow, target=dataset, n_outer_samples=n_samples_from_mz,
            #          n_gridpoints=150, plot_surface=True, bound=args.bounds, idxs=str(i))
            other_models = {
                "normalizing-flow": normalizing_flow,
                "free-form-flow": fff_model,
                "flow-matching": flowmatch,
            }
            plot_samples_2D(
                model=marginal_flow,
                target=dataset,
                n_outer_samples=1_000,
                n_samples=10_000,
                other_models=other_models,
                bound=dataset.domain_bound,
            )
            plot_likelihood_2D(
                model=marginal_flow,
                target=dataset,
                grid_size=100,
                n_outer_samples=1_000,
                other_models=other_models,
                bound=dataset.domain_bound,
            )

    # =#=#=#  train a free form flow (using their original code) #=#=#=#
    # model_type = "fff" # alternatively "fif" if you wish the latent space to be lower dimensional than the ambient space
    # fff_directory = (args.script_path / "../../FFF/").resolve() # to be changed depending where the FFF repo has been cloned
    # model_version = "version_0"
    # trained_model_name = fff_directory / "lightning_logs" / args.dataset / model_version
    # overwrite_fff = False
    # if not os.path.isdir(trained_model_name) or overwrite_fff:
    #     define_fff_config(args, model=model_type, fff_directory=fff_directory) # write config file to run free form flow
    #     os.chdir(fff_directory)
    #     config_file_path = f"configs/fff/margflow-{args.dataset}.yaml"
    #     command = "python -m lightning_trainable.launcher.fit " + config_file_path + " --name {data_set[name]}"
    #     os.system(command) # run free form flow code as shell command (I know it's ugly. Alternatives?)
    #     os.chdir(args.script_path)
    # fff_model = load_fff_model_and_nll(args, fff_directory=str(fff_directory), version=model_version)
    # _, _, test_data = dataset.load_dataset(overwrite=False)
    # test_data = torch.from_numpy(test_data).float().to(args.device)
    # fff_samples, fff_nll = fff_sample_and_logprob(fff_model, samples=test_data)
    #
    # for i in range(plotting_repeats):
    #     # plot_mog(model=marginal_flow, flow=fff_model, target=dataset, n_outer_samples=n_samples_from_mz,
    #     #          n_gridpoints=150, plot_surface=True, bound=args.bounds, idxs=str(i))
    #     other_models = {"normalizing-flow": normalizing_flow, "flow-matching": flowmatch,  "free-form-flows": fff_model}
    #     plot_samples_2D(model=marginal_flow, target=dataset, n_outer_samples=n_samples_from_mz, n_samples=10_000,
    #                     other_models=other_models, bound=dataset.domain_bound)
    #     plot_likelihood_2D(model=marginal_flow, target=dataset, grid_size=200, n_outer_samples=1_000,
    #                        other_models=other_models, bound=dataset.domain_bound)


if __name__ == "__main__":
    main()
