# -*- coding: utf-8 -*-
import argparse
import torch
from pathlib import Path

from margflow.datasets.datasets import create_dataset
from margflow.datasets.dataset_abstracts import DiscreteSampleDataset, DensityDataset
from margflow.marginal_flow import MarginalFlow
from margflow.trainer import train_marginal_flow
from margflow.utils.training_utils import (
    set_random_seed,
    create_directories,
    model_signature,
    check_tuple,
)
from margflow.utils.plot_utils import sample_animation, plot_samples

parser = argparse.ArgumentParser(description="Process some integers.")

# TRAINING PARAMETERS
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--device", type=str, default="cuda", help="device for training the model")
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs")
parser.add_argument(
    "--training_mode",
    type=str,
    default="log_likelihood",
    choices=[
        "kl_divergence",  # variational inference: known (unnormalized) target and no training data
        "log_likelihood",  # density estimation: unknown target but training data
        "symmetric_kl",  # mixed setting: known (unnormalized) target and training data
        "score_matching",  # TODO: implement score matching both for variational inference and density estimation
    ],
)

# model parameters
# parser.add_argument("--overwrite", action="store_true", help="re-train and overwrite flow model")
# parser.add_argument("--n_layers", type=int, default=10, help='number of layers in the flow model')
# parser.add_argument("--n_hidden_features", type=int, default=256, help='number of hidden features in the embedding space of the flow model')

# simulated annealing parameters
parser.add_argument("--T0", type=float, default=4.0, help="initial temperature")
parser.add_argument("--Tn", type=float, default=1.0, help="final temperature")
parser.add_argument(
    "--iter_per_cool_step",
    type=int,
    default=50,
    help="iterations per cooling step in simulated annealing",
)
parser.add_argument(
    "--temp_min_it_ratio", type=float, default=0.8, help="minimum temperature ratio"
)
parser.add_argument(
    "--n_base_means",
    type=int,
    default=10,
    help="number of (learnable) means in the marginal flow base distribution",
)
parser.add_argument(
    "--fourier_sigma",
    type=float,
    default=0.01,
    help="std dev of Gaussian vectors sampled when using Fourier features in conditional networks",
)
parser.add_argument("--dropout", type=float, default=0.0, help="dropout of marginal flow network")


# DATASETS PARAMETERS
parser.add_argument("--x_dim", type=int, default=2, help="number of dimensions")
parser.add_argument(
    "--dataset",
    type=str,
    default="sbi_two_moons",
    choices=[
        "sbi_two_moons",
        "sbi_gaussian_linear",
        "sbi_gaussian_linear_uniform",
        "sbi_gaussian_mixture",
        "mog_time",
        "sbi_slcp",
        "sbi_slcp_distractors",
        "sbi_sir",
        "sbi_bernoulli_glm",
        "sbi_bernoulli_glm_raw",
        "sbi_lotka_volterra",
    ],
)
parser.add_argument(
    "--use_dataset_dim",
    type=bool,
    default=True,
    help="use dataset dimension (if applicable) instead of provided",
)
# for all datasets
parser.add_argument(
    "--manifold", action="store_true", help="data lives in a lower dimensional space"
)
parser.add_argument(
    "--manifold_type", type=str, default="spiral", choices=["line", "sin", "circle", "spiral"]
)
parser.add_argument(
    "--epsilon", type=float, default=0.00, help="std of the isotropic noise in the data"
)
parser.add_argument(
    "--n_samples_dataset", type=int, default=1_000, help="number of data points in the dataset"
)
parser.add_argument(
    "--n_samples_dataset_test", type=int, default=1_000, help="number of data points used for test"
)
parser.add_argument("--batch_size", type=int, default=100, help="batch size dimension")
parser.add_argument(
    "--accept_less_datapoints",
    type=bool,
    default=True,
    help="if n_samples_dataset is larger than the available number of datapoints just take all of them",
)
parser.add_argument(
    "--shuffle_dataset",
    type=bool,
    default=True,
    help="if the data should be shuffled before training",
)
parser.add_argument(
    "--normalize_dataset_type", type=int, default=1, help="0 -> [0,1], 1 -> N(0,1)"
)
parser.add_argument(
    "--keep_datapoints",
    action="store_true",
    help="if the samples should be kept constant for all iterations",
)
# sbi datasets
parser.add_argument(
    "--n_simulations",
    type=int,
    default=1000,
    help="number of simulations in simulation based inference",
)
parser.add_argument(
    "--normalize_sbi_data",
    action="store_true",
    help="normalize paired data (x, theta) in sbi datasets (relevant primarily for SLCP/SLCP_distractors tasks)",
)
# mixture of Gaussians (MoG)
parser.add_argument(
    "--bounds",
    type=float,
    default=2.0,
    help="data/target distribution lives in [-bound,bound]^n_dim",
)
parser.add_argument(
    "--n_mog", type=int, default=12, help="number of mixture components of Gaussian mixture model"
)
parser.add_argument(
    "--mog_sigma", type=float, default=0.2, help="standard deviation of Gaussian mixture"
)
# von mises fisher parameters
parser.add_argument("--mu", type=float, default=None, help="mean of von mises distribution")
parser.add_argument(
    "--kappa", type=float, default=5.0, help="concentration parameter of von mises distribution"
)
# von mises fisher mixture parameters
parser.add_argument(
    "--n_mix",
    type=int,
    default=50,
    help="number of mixture components for mixture of von mises fisher distribution",
)
parser.add_argument(
    "--kappa_mix",
    type=float,
    default=30.0,
    help="concentration parameter of mixture of von mises distribution",
)
parser.add_argument(
    "--alpha_mix",
    type=float,
    default=0.3,
    help="alpha parameter of mixture of von mises distribution",
)
parser.add_argument(
    "--n_turns_spiral",
    type=int,
    default=4,
    help="number of spiral turns for sphere spiral distribution",
)
# uniform checkerboard parameters
parser.add_argument(
    "--n_theta", type=int, default=6, help="number of rows in the checkerboard (n_theta>0)"
)
parser.add_argument(
    "--n_phi",
    type=int,
    default=6,
    help="number of columns in the checkerboard (n_phi>0 and must be even)",
)
parser.add_argument("--base_distribution", type=str, default="mog", choices=["mog", "uniform"])

# lp uniform parameters
parser.add_argument("--beta", type=float, default=1.0, help="p of the lp norm")
parser.add_argument("--radius", type=float, default=1.0, help="radius of manifold")

args = parser.parse_args()
# args.directory_path = os.path.abspath(args.directory_path) + "/" # overwrite local path with global path
args.script_path = Path(__file__).resolve().parent  # overwrite local path with global path
dtype = torch.float32
args.dtype = dtype


def main():
    set_random_seed(seed=args.seed)
    create_directories(script_dir=args.script_path)
    dataset = create_dataset(args)
    args.keep_datapoints = True
    if args.use_dataset_dim:
        args.n_dim = dataset.D
    assert dataset.D == args.n_dim, "dimensionality mismatch"
    samples_to_plot = dataset.sample(args.n_simulations, data_type="train")
    samples_to_plot, context = check_tuple(samples_to_plot)
    args.cond_dim = context.shape[-1]
    args.n_dim = samples_to_plot.shape[-1]
    # plot_samples(samples_to_plot.detach().cpu().numpy(), bounds=args.bounds, alpha=1)
    if context is not None:
        plot_samples(samples_to_plot.reshape(-1, args.n_dim).detach().cpu().numpy())
        plot_samples(context.detach().cpu().numpy(), bounds=3.0, alpha=1)
    signature = model_signature(args=args, dataset=dataset.dataset_suffix)
    has_samples = isinstance(dataset, DiscreteSampleDataset)
    has_density = isinstance(dataset, DensityDataset)
    sample_training = args.training_mode in ["log_likelihood", "symmetric_kl"]
    density_training = args.training_mode in ["kl_divergence", "symmetric_kl"]
    assert has_samples or has_density
    assert (
        not sample_training or has_samples
    ), "wants to perform training on samples but dataset cannot provide samples"
    assert (
        not density_training or has_density
    ), "wants to perform training on density but dataset does not provide density"

    # =#=#=# train marginal flow #=#=#=#
    n_mixtures = 2048  # 32768
    base_dim = 1 if args.manifold else args.n_dim
    print("base_dim: ", base_dim)
    log_sigma_init = 0
    marginal_flow = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        dropout=args.dropout,
        base_distribution=args.base_distribution,
        n_base_means=args.n_base_means,
        fourier_sigma=args.fourier_sigma,
        device=args.device,
        n_layers=4,
        hid_dim=256,
        conditional_network="cond_fourier",
        cond_dim=context.shape[-1],
        dtype=dtype,
        signature=signature,
        script_path=args.script_path,
    )

    lr_network = 1e-4
    lr_sigma = 1e-2
    use_manifold_length_costs = False
    use_manifold_trivial_solution_costs = False
    plotting_repeats = 2  # how many times to sample from marginal flow after training
    overwrite_marg_flow = True
    train_marginal_flow(
        model=marginal_flow,
        n_mixtures=n_mixtures,
        n_epochs=args.n_epochs,
        batch_size=args.batch_size,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        overwrite=overwrite_marg_flow,
        fixed_datapoints=args.keep_datapoints,
        normalize_data=args.normalize_sbi_data,
        save_best_val=True,
    )
    marginal_flow.eval()
    # plot_samples_time(dataset, model=marginal_flow, n_samples=1000)
    if args.dataset.split("_")[0] == "sbi":
        from functools import partial

        task_name = "_".join(args.dataset.split("_")[1:])
        c2st, mmd, med_dist = dataset.evaluate_metric(
            partial(marginal_flow.sample_all, n_mixtures=n_mixtures),
            task=task_name,
            n_sim=args.n_simulations,
            fourier_sigma=args.fourier_sigma,
            seed=args.seed,
            normalized_data=args.normalize_sbi_data,
        )
        print(
            f"c2st: {c2st.mean()}(+/-{c2st.std()}) mmd: {mmd.mean()}(+/-{c2st.std()}), med_dist: {med_dist.mean()}(+/-{med_dist.std()})"
        )
    else:
        n_timesteps = 25
        n_samples_per_timestep = 1000
        sample_animation(
            dataset=dataset,
            model=marginal_flow,
            model_name="marginal_flow",
            n_samples=n_timesteps * n_samples_per_timestep,
            n_timesteps=n_timesteps,
            n_outer_samples=n_samples_per_timestep,
        )
    # =#=#=#  train a normalizing flow #=#=#=#
    # density_training = False
    # lr_normflow = 5e-4
    # n_layers_normflow = 3
    # n_hidden_features = 256  # ignored with iResBlock layers
    # overwrite_normflow = False
    #
    # normalizing_flow = NormalizingFlow(n_dim=args.n_dim, direction="forward", n_layers=n_layers_normflow, cond_dim=args.cond_dim,
    #                                    signature=signature, directory_path=args.directory_path, n_hidden_features=n_hidden_features)
    # normalizing_flow.train_forward(args, dataset=dataset, n_epochs=args.n_epochs, lr=lr_normflow,
    #                                early_stopping=False, overwrite=overwrite_normflow)
    # normalizing_flow.flow.eval()

    # sample_animation(dataset=dataset, model=normalizing_flow, model_name="normalizing_flow", n_samples=5000, n_timesteps=50)


if __name__ == "__main__":
    main()
