import argparse
import json
from pathlib import Path

import torch

from margflow.datasets.datasets import create_dataset
from margflow.marginal_flow import MarginalFlow
from margflow.other_models.flowmatching import FlowMatching
from margflow.other_models.free_form_flows import FreeFormFlow
from margflow.other_models.normalizing_flow import NormalizingFlow
from margflow.trainer import train_marginal_flow
from margflow.utils.plot_utils import (
    plot_samples_2D,
    plot_metrics_runtime,
    plot_likelihood_2D,
)
from margflow.utils.training_utils import set_random_seed, create_directories, model_signature

parser = argparse.ArgumentParser(description="Process some integers.")

# TRAINING PARAMETERS
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--device", type=str, default="cuda", help="device for training the model")
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs")
parser.add_argument("--n_val_steps", type=int, default=100, help="number of validation steps")
parser.add_argument(
    "--n_val_steps_no_increase",
    type=int,
    default=100,
    help="number of validation steps without improvement before stopping training",
)
parser.add_argument(
    "--training_mode",
    type=str,
    default="log_likelihood",
    choices=[
        "kl_divergence",  # variational inference: known (unnormalized) target and no training data
        "log_likelihood",  # density estimation: unknown target but training data
        "symmetric_kl",  # mixed setting: known (unnormalized) target and training data
        "score_matching",  # TODO: implement score matching both for variational inference and density estimation
    ],
)

# simulated annealing parameters
parser.add_argument("--T0", type=float, default=4.0, help="initial temperature")
parser.add_argument("--Tn", type=float, default=1.0, help="final temperature")
parser.add_argument(
    "--iter_per_cool_step",
    type=int,
    default=50,
    help="iterations per cooling step in simulated annealing",
)
parser.add_argument(
    "--temp_min_it_ratio", type=float, default=0.8, help="minimum temperature ratio"
)
parser.add_argument(
    "--base_distribution",
    type=str,
    default="mog",
    choices=["mog", "uniform", "betas", "quasi_random"],
)

parser.add_argument(
    "--n_base_means",
    type=int,
    default=5,
    help="number of (learnable) means in the marginal flow base distribution",
)

# DATASETS PARAMETERS
parser.add_argument("--x_dim", type=int, default=2, help="number of dimensions")
parser.add_argument(
    "--dataset",
    type=str,
    default="mog",
    choices=[
        "mog",
    ],
)
parser.add_argument(
    "--use_dataset_dim",
    type=bool,
    default=True,
    help="use dataset dimension (if applicable) instead of provided",
)
parser.add_argument(
    "--logp_estimator",
    type=str,
    default="parzen",
    choices=["parzen", "kde"],
    help="estimate log prob by fitting kde on samples or by using a parzen window on samples. Ignored if log prob is available analytically",
)

parser.add_argument(
    "--n_mog", type=int, default=12, help="number of mixture components of Gaussian mixture model"
)
parser.add_argument(
    "--mog_sigma", type=float, default=0.2, help="standard deviation of Gaussian mixture"
)
parser.add_argument(
    "--manifold", action="store_true", help="data lives in a lower dimensional space"
)
parser.add_argument(
    "--manifold_type", type=str, default="spiral", choices=["line", "sin", "circle", "spiral"]
)
parser.add_argument(
    "--epsilon", type=float, default=0.00, help="std of the isotropic noise in the data"
)
parser.add_argument(
    "--n_samples_dataset", type=int, default=10_000, help="number of data points in the dataset"
)
parser.add_argument("--bounds", type=float, default=2.0, help="domain bound")
parser.add_argument(
    "--n_samples_dataset_test",
    type=int,
    default=10_000,
    help="number of data points used for test",
)
parser.add_argument("--batch_size", type=int, default=10_000, help="batch size dimension")
parser.add_argument(
    "--accept_less_datapoints",
    type=bool,
    default=True,
    help="if n_samples_dataset is larger than the available number of datapoints just take all of them",
)
parser.add_argument(
    "--shuffle_dataset",
    type=bool,
    default=True,
    help="if the data should be shuffled before training",
)
parser.add_argument(
    "--normalize_dataset_type", type=int, default=1, help="0 -> [0,1], 1 -> N(0,1)"
)
parser.add_argument(
    "--keep_datapoints",
    action="store_true",
    help="if the samples should be kept constant for all iterations",
)
# von mises fisher parameters
parser.add_argument("--mu", type=float, default=None, help="mean of von mises distribution")
parser.add_argument(
    "--kappa", type=float, default=5.0, help="concentration parameter of von mises distribution"
)

parser.add_argument(
    "--n_turns_spiral",
    type=int,
    default=4,
    help="number of spiral turns for sphere spiral distribution",
)

args = parser.parse_args()
args.script_path = Path(__file__).resolve().parent
dtype = torch.float32
args.dtype = dtype


def main():
    set_random_seed(seed=args.seed)
    create_directories(args.script_path)
    dataset = create_dataset(args)
    if args.use_dataset_dim:
        args.n_dim = dataset.D
    assert dataset.D == args.n_dim, "dimensionality mismatch"
    signature = model_signature(args=args, dataset=dataset.dataset_suffix)

    metrics = ["kl_for", "log_lik"]
    save_best_val = True

    samples, _, _ = dataset.load_dataset()
    import matplotlib.pyplot as plt

    plt.scatter(samples[:, 0], samples[:, 1], alpha=0.1)
    plt.show()

    # =#=#=# train flow matching #=#=#=#
    n_hidden_dim = 256
    n_layers = 5
    lr_flowmatch = 1e-3
    overwrite_flowmatch = False
    flowmatch = FlowMatching(
        x_dim=args.n_dim,
        hid_dim=n_hidden_dim,
        n_layers=n_layers,
        script_path=args.script_path,
        signature=signature,
        device=args.device,
    )
    flowmatch.train_samples(
        n_epochs=args.n_epochs * 2,  # args.n_epochs,
        batch_size=args.batch_size,
        dataset=dataset,
        metrics=metrics,
        overwrite=overwrite_flowmatch,
        lr=lr_flowmatch,
        save_best_val=False,
    )
    flowmatch.solve_ode(plot=True)
    flowmatch.plot_log_likelihood(bound=dataset.domain_bound)

    # =#=#=# train marginal flow #=#=#=#
    n_mixtures = args.n_samples_dataset // 5 * 3  # 2048
    dropout = 0.00
    base_dim = 2  # 1 if args.manifold else args.n_dim
    log_sigma_init = 0
    marginal_flow = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        dropout=dropout,
        base_distribution=args.base_distribution,
        n_base_means=args.n_base_means,
        n_layers=5,
        hid_dim=256,
        device=args.device,
        dtype=dtype,
        signature=signature,
        script_path=args.script_path,
    )

    lr_network = 1e-4
    lr_sigma = 1e-2
    plotting_repeats = 5  # how many times to sample from marginal flow after training
    overwrite_marg_flow = False
    train_marginal_flow(
        model=marginal_flow,
        n_mixtures=n_mixtures,
        n_epochs=args.n_epochs,  # args.n_epochs,
        batch_size=args.batch_size,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        metrics=metrics,
        fixed_datapoints=args.keep_datapoints,
        overwrite=overwrite_marg_flow,
        save_best_val=save_best_val,
        n_val_steps=args.n_val_steps,
    )
    marginal_flow.eval()

    # =#=#=#  train a normalizing flow #=#=#=#
    lr_normflow = 1e-3
    n_layers_normflow = 5
    overwrite_normflow = False
    normalizing_flow = NormalizingFlow(
        x_dim=args.n_dim,
        direction="forward",
        n_layers=n_layers_normflow,
        signature=signature,
        script_path=args.script_path,
    )
    normalizing_flow.train_forward(
        dataset=dataset,
        batch_size=args.batch_size,
        n_epochs=args.n_epochs,
        lr=lr_normflow,
        metrics=metrics,
        fixed_datapoints=args.keep_datapoints,
        overwrite=overwrite_normflow,
        save_best_val=False,
    )
    normalizing_flow.flow.eval()

    # =#=#=#  train a free form flow  #=#=#=#
    n_layers = 5
    hidden_dims = 256
    base_dim = 1 if args.manifold else args.n_dim
    fff_model = FreeFormFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        n_layers=n_layers,
        signature=signature,
        script_path=args.script_path,
        hid_dim=hidden_dims,
    )

    overwrite_fff = False
    beta = 100
    fff_lr = 2e-4
    args.batch_size = 128
    fff_model.train_forward(
        dataset=dataset,
        beta=beta,
        n_epochs=args.n_epochs * 2,
        batch_size=args.batch_size,
        lr=fff_lr,
        fixed_datapoints=args.keep_datapoints,
        overwrite=overwrite_fff,
        metrics=metrics,
        save_best_val=True,
    )
    other_models = {
        "normalizing-flow": normalizing_flow,
        "free-form-flow": fff_model,
        "flow-matching": flowmatch,
    }
    if args.n_dim == 2:
        for i in range(plotting_repeats):
            plot_samples_2D(
                model=marginal_flow,
                target=dataset,
                n_outer_samples=n_mixtures,
                n_samples=10_000,
                other_models=other_models,
                bound=dataset.domain_bound,
                signature=signature,
                manifold=True if args.manifold else False,
                repeat=i,
            )
            plot_likelihood_2D(
                model=marginal_flow,
                target=dataset,
                grid_size=100,
                n_outer_samples=1_000,
                other_models=other_models,
                bound=dataset.domain_bound,
                signature=signature,
            )

    trained_models = {
        "marginal-flow": marginal_flow,
        "normalizing-flow": normalizing_flow,
        "free-form-flow": fff_model,
        "flow-matching": flowmatch,
    }
    if metrics is not None:
        n_samples = 10_000
        if hasattr(dataset, "logp_estimator"):
            val_samples = dataset.sample_estimator(n_samples=n_samples)
        else:
            val_samples = dataset.sample(n_samples=n_samples)
        for model_name, model in trained_models.items():
            metrics_dict = model.evaluate_metrics(
                metrics,
                val_samples=val_samples,
                n_samples=n_samples,
                dataset=dataset,
            )
            print(f"results for {model_name}: ", metrics_dict)

            if hasattr(model, "model_path"):
                with open(f"{model.model_path}_metrics_result.json", "w") as file:
                    json.dump(metrics_dict, file, indent=4)

    # plot_metrics_runtime(trained_models, metrics=metrics)


if __name__ == "__main__":
    main()
