import argparse
import json
from pathlib import Path

import torch

from margflow.datasets.datasets import create_dataset
from margflow.marginal_flow import MarginalFlow
from margflow.other_models.normalizing_flow import NormalizingFlow
from margflow.trainer import train_marginal_flow
from margflow.utils.plot_utils import (
    plot_samples_2D,
    plot_metrics_runtime,
    plot_likelihood_2D,
)
from margflow.utils.training_utils import set_random_seed, create_directories, model_signature

parser = argparse.ArgumentParser(description="Process some integers.")

# TRAINING PARAMETERS
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--device", type=str, default="cuda", help="device for training the model")
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs")
# parser.add_argument('--lr', type=float, default=1e-3, help='learning rate')
parser.add_argument(
    "--training_mode",
    type=str,
    default="kl_divergence",
    choices=[
        "kl_divergence",  # variational inference: known (unnormalized) target and no training data
        "log_likelihood",  # density estimation: unknown target but training data
        "symmetric_kl",  # mixed setting: known (unnormalized) target and training data
        "score_matching",  # TODO: implement score matching both for variational inference and density estimation
    ],
)

# simulated annealing parameters
parser.add_argument("--T0", type=float, default=4.0, help="initial temperature")
parser.add_argument("--Tn", type=float, default=1.0, help="final temperature")
parser.add_argument(
    "--iter_per_cool_step",
    type=int,
    default=50,
    help="iterations per cooling step in simulated annealing",
)
parser.add_argument(
    "--temp_min_it_ratio", type=float, default=0.8, help="minimum temperature ratio"
)
parser.add_argument(
    "--base_distribution", type=str, default="mog", choices=["mog", "uniform", "betas"]
)

parser.add_argument(
    "--n_base_means",
    type=int,
    default=5,
    help="number of (learnable) means in the marginal flow base distribution",
)

# DATASETS PARAMETERS
parser.add_argument("--x_dim", type=int, default=2, help="number of dimensions")
parser.add_argument(
    "--dataset",
    type=str,
    default="mog",
    choices=[
        "uniform_square",
        "mog",
        "mog_manifold",
        "power",
        "gas",
        "hepmass",
        "miniboone",
        "two_moons",
        "sbi_two_moons",
        "two_circles",
        "swiss_roll",
        "checkerboard",
        "pinwheel",
    ],
)
parser.add_argument(
    "--use_dataset_dim",
    type=bool,
    default=True,
    help="use dataset dimension (if applicable) instead of provided",
)
parser.add_argument(
    "--logp_estimator",
    type=str,
    default="parzen",
    choices=["parzen", "kde"],
    help="estimate log prob by fitting kde on samples or by using a parzen window on samples. Ignored if log prob is available analytically",
)

parser.add_argument(
    "--n_mog", type=int, default=12, help="number of mixture components of Gaussian mixture model"
)
parser.add_argument(
    "--mog_sigma", type=float, default=0.2, help="standard deviation of Gaussian mixture"
)
parser.add_argument(
    "--manifold", action="store_true", help="data lives in a lower dimensional space"
)
parser.add_argument(
    "--manifold_type", type=str, default="spiral", choices=["line", "sin", "circle", "spiral"]
)
parser.add_argument(
    "--epsilon", type=float, default=0.00, help="std of the isotropic noise in the data"
)
parser.add_argument(
    "--n_samples_dataset", type=int, default=10_000, help="number of data points in the dataset"
)
parser.add_argument("--bounds", type=float, default=2.0, help="domain bound")
parser.add_argument(
    "--n_samples_dataset_test",
    type=int,
    default=10_000,
    help="number of data points used for test",
)
parser.add_argument("--batch_size", type=int, default=10_000, help="batch size dimension")
parser.add_argument(
    "--accept_less_datapoints",
    type=bool,
    default=True,
    help="if n_samples_dataset is larger than the available number of datapoints just take all of them",
)
parser.add_argument(
    "--shuffle_dataset",
    type=bool,
    default=True,
    help="if the data should be shuffled before training",
)
parser.add_argument(
    "--normalize_dataset_type", type=int, default=1, help="0 -> [0,1], 1 -> N(0,1)"
)
parser.add_argument(
    "--keep_datapoints",
    action="store_true",
    help="if the samples should be kept constant for all iterations",
)
# von mises fisher parameters
parser.add_argument("--mu", type=float, default=None, help="mean of von mises distribution")
parser.add_argument(
    "--kappa", type=float, default=5.0, help="concentration parameter of von mises distribution"
)

parser.add_argument(
    "--n_turns_spiral",
    type=int,
    default=4,
    help="number of spiral turns for sphere spiral distribution",
)

args = parser.parse_args()
args.script_path = Path(__file__).resolve().parent
dtype = torch.float32
args.dtype = dtype


def main():
    set_random_seed(seed=args.seed)
    create_directories(args.script_path)
    dataset = create_dataset(args)
    if args.use_dataset_dim:
        args.n_dim = dataset.D
    assert dataset.D == args.n_dim, "dimensionality mismatch"
    # samples_to_plot, _, _ = dataset.load_dataset()
    # if not isinstance(samples_to_plot, tuple):
    #     plot_samples(samples_to_plot, bounds=dataset.domain_bound, alpha=1)
    signature = model_signature(args=args, dataset=dataset.dataset_suffix)

    metrics = ["mse_logp", "kl_for", "kl_rev"]  # , "sinkhorn", "energy", "gaussian", "laplacian"]

    # n_hidden_dim = 256
    # n_layers = 5
    # lr_flowmatch = 1e-4
    # overwrite_flowmatch = True
    # flowmatch = FlowMatching(
    #     x_dim=args.n_dim,
    #     hid_dim=n_hidden_dim,
    #     n_layers=n_layers,
    #     script_path=args.script_path,
    #     signature=signature,
    #     device=args.device,
    # )
    # flowmatch.train_score(
    #     n_epochs=args.n_epochs,
    #     batch_size=args.batch_size,
    #     dataset=dataset,
    #     # metrics=metrics,
    #     overwrite=overwrite_flowmatch,
    #     lr=lr_flowmatch,
    # )
    # flowmatch.solve_ode(plot=True)
    # flowmatch.plot_log_likelihood(bound=dataset.domain_bound)
    #
    # import matplotlib.pyplot as plt
    #
    # fm_samples = flowmatch.sample(n_samples=10000).detach().cpu().numpy()
    # gt_samples = dataset.sample(n_samples=10000).detach().cpu().numpy()
    # plt.scatter(fm_samples[:, 0], fm_samples[:, 1], label="flow")
    # plt.scatter(gt_samples[:, 0], gt_samples[:, 1], label="gt")
    # plt.legend()
    # plt.show()

    # =#=#=# train marginal flow #=#=#=#
    n_mixtures = 10_000  # args.n_samples_dataset  # 2048
    dropout = 0.0
    base_dim = 1 if args.manifold else args.n_dim
    log_sigma_init = 0
    marginal_flow = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        dropout=dropout,
        n_base_means=args.n_base_means,
        n_layers=5,
        hid_dim=256,
        device=args.device,
        dtype=dtype,
        signature=signature,
        script_path=args.script_path,
    )

    lr_network = 1e-4
    lr_sigma = 1e-2
    plotting_repeats = 2  # how many times to sample from marginal flow after training
    overwrite_marg_flow = True
    train_marginal_flow(
        model=marginal_flow,
        n_mixtures=n_mixtures,
        n_epochs=args.n_epochs,
        batch_size=args.batch_size,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        metrics=metrics,
        fixed_datapoints=args.keep_datapoints,
        overwrite=overwrite_marg_flow,
        save_best_val=False,
    )
    marginal_flow.eval()

    # =#=#=#  train a normalizing flow #=#=#=#
    lr_normflow = 1e-3
    n_layers_normflow = 6
    overwrite_normflow = True
    normalizing_flow = NormalizingFlow(
        x_dim=args.n_dim,
        direction="reverse",
        n_layers=n_layers_normflow,
        signature=signature,
        script_path=args.script_path,
    )
    normalizing_flow.train_reverse(
        dataset=dataset,
        batch_size=args.batch_size,
        n_epochs=args.n_epochs,
        lr=lr_normflow,
        metrics=metrics,
        overwrite=overwrite_normflow,
    )
    normalizing_flow.flow.eval()

    # evaluation phase
    trained_models = {
        "marginal-flow": marginal_flow,
        "normalizing-flow": normalizing_flow,
    }

    if metrics is not None:
        n_samples = 10_000
        if hasattr(dataset, "logp_estimator"):
            val_samples = dataset.sample_estimator(n_samples=n_samples)
        else:
            val_samples = dataset.sample(n_samples=n_samples)
        for model_name, model in trained_models.items():
            metrics_dict = model.evaluate_metrics(
                metrics,
                val_samples=val_samples,
                n_samples=n_samples,
                dataset=dataset,
            )
            print(f"results for {model_name}: ", metrics_dict)

            if hasattr(model, "model_path"):
                with open(f"{model.model_path}_metrics_result.json", "w") as file:
                    json.dump(metrics_dict, file, indent=4)

    if args.n_dim == 2:
        for i in range(plotting_repeats):
            other_models = {
                "normalizing-flow": normalizing_flow,
            }
            # other_models = None
            plot_samples_2D(
                model=marginal_flow,
                target=dataset,
                n_outer_samples=1_000,
                n_samples=10_000,
                other_models=other_models,
                bound=dataset.domain_bound,
                signature=signature,
                manifold=True if args.manifold else False,
            )
            plot_likelihood_2D(
                model=marginal_flow,
                target=dataset,
                grid_size=100,
                n_outer_samples=1_000,
                other_models=other_models,
                bound=dataset.domain_bound,
                signature=signature,
            )

    plot_metrics_runtime(
        trained_models,
        metrics=metrics,
    )


if __name__ == "__main__":
    main()
