# -*- coding: utf-8 -*-
import argparse
import json
import time
import gc
from functools import partial
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt
import torch
from scipy import linalg

from margflow.datasets.datasets import create_dataset
from margflow.marginal_flow import MarginalFlow
from margflow.other_models.flowmatching import FlowMatching
from margflow.trainer import train_marginal_flow
from margflow.utils.training_utils import (
    set_random_seed,
    create_directories,
    model_signature,
    check_tuple,
    batched_evaluation,
    batched_sampling,
)

parser = argparse.ArgumentParser(description="Process some integers.")

# TRAINING PARAMETERS
parser.add_argument("--seed", type=int, default=1234, help="random seed")
parser.add_argument("--device", type=str, default="cuda", help="device for training the model")
parser.add_argument("--n_epochs", type=int, default=2000, help="number of epochs")
parser.add_argument("--n_val_steps", type=int, default=100, help="number of validation steps")
parser.add_argument(
    "--n_val_steps_no_increase",
    type=int,
    default=100,
    help="number of validation steps without improvement before stopping training",
)
parser.add_argument(
    "--training_mode",
    type=str,
    default="log_likelihood",
    choices=[
        "kl_divergence",  # variational inference: known (unnormalized) target and no training data
        "log_likelihood",  # density estimation: unknown target but training data
        "symmetric_kl",  # mixed setting: known (unnormalized) target and training data
        "score_matching",  # TODO: implement score matching both for variational inference and density estimation
    ],
)

# model parameters
# parser.add_argument("--overwrite", action="store_true", help="re-train and overwrite flow model")
# parser.add_argument("--n_layers", type=int, default=10, help='number of layers in the flow model')
# parser.add_argument("--n_hidden_features", type=int, default=256, help='number of hidden features in the embedding space of the flow model')

# simulated annealing parameters
parser.add_argument("--T0", type=float, default=4.0, help="initial temperature")
parser.add_argument("--Tn", type=float, default=1.0, help="final temperature")
parser.add_argument(
    "--iter_per_cool_step",
    type=int,
    default=50,
    help="iterations per cooling step in simulated annealing",
)
parser.add_argument(
    "--temp_min_it_ratio", type=float, default=0.8, help="minimum temperature ratio"
)
parser.add_argument("--base_dim", type=int, default=5, help="dimensionality of base distribution")
parser.add_argument(
    "--n_base_means",
    type=int,
    default=5,
    help="number of (learnable) means in the marginal flow base distribution",
)
parser.add_argument(
    "--fourier_sigma",
    type=float,
    default=0.01,
    help="std dev of Gaussian vectors sampled when using Fourier features in conditional networks",
)
parser.add_argument("--dropout", type=float, default=0.0, help="dropout of marginal flow network")


# DATASETS PARAMETERS
parser.add_argument("--x_dim", type=int, default=2, help="number of dimensions")
parser.add_argument(
    "--dataset",
    type=str,
    default="gas",
    choices=[
        "gas",
        "power",
        "hepmass",
        "miniboone",
    ],
)
parser.add_argument(
    "--use_dataset_dim",
    type=bool,
    default=True,
    help="use dataset dimension (if applicable) instead of provided",
)
parser.add_argument(
    "--diagonal_sigma",
    action="store_true",
    help="whether we assume diagonal sigma in the conditional of marginal flow",
)
# for all datasets
parser.add_argument(
    "--manifold", action="store_true", help="data lives in a lower dimensional space"
)
parser.add_argument(
    "--manifold_type", type=str, default="spiral", choices=["line", "sin", "circle", "spiral"]
)
parser.add_argument(
    "--epsilon", type=float, default=0.00, help="std of the isotropic noise in the data"
)
parser.add_argument(
    "--n_samples_dataset", type=int, default=1_000, help="number of data points in the dataset"
)
parser.add_argument(
    "--n_samples_dataset_test", type=int, default=1_000, help="number of data points used for test"
)
parser.add_argument("--batch_size", type=int, default=10_000, help="batch size dimension")
parser.add_argument(
    "--accept_less_datapoints",
    type=bool,
    default=True,
    help="if n_samples_dataset is larger than the available number of datapoints just take all of them",
)
parser.add_argument(
    "--shuffle_dataset",
    type=bool,
    default=True,
    help="if the data should be shuffled before training",
)
parser.add_argument(
    "--normalize_dataset_type", type=int, default=1, help="0 -> [0,1], 1 -> N(0,1)"
)
parser.add_argument(
    "--keep_datapoints",
    action="store_true",
    help="if the samples should be kept constant for all iterations",
)
# sbi datasets
parser.add_argument(
    "--n_simulations",
    type=int,
    default=1000,
    help="number of simulations in simulation based inference",
)
parser.add_argument(
    "--normalize_sbi_data",
    action="store_true",
    help="normalize paired data (x, theta) in sbi datasets (relevant primarily for SLCP/SLCP_distractors tasks)",
)
# mixture of Gaussians (MoG)
parser.add_argument(
    "--bounds",
    type=float,
    default=2.0,
    help="data/target distribution lives in [-bound,bound]^n_dim",
)
parser.add_argument(
    "--n_mog", type=int, default=12, help="number of mixture components of Gaussian mixture model"
)
parser.add_argument(
    "--mog_sigma", type=float, default=0.2, help="standard deviation of Gaussian mixture"
)
# von mises fisher parameters
parser.add_argument("--mu", type=float, default=None, help="mean of von mises distribution")
parser.add_argument(
    "--kappa", type=float, default=5.0, help="concentration parameter of von mises distribution"
)
# von mises fisher mixture parameters
parser.add_argument(
    "--n_mix",
    type=int,
    default=50,
    help="number of mixture components for mixture of von mises fisher distribution",
)
parser.add_argument(
    "--kappa_mix",
    type=float,
    default=30.0,
    help="concentration parameter of mixture of von mises distribution",
)
parser.add_argument(
    "--alpha_mix",
    type=float,
    default=0.3,
    help="alpha parameter of mixture of von mises distribution",
)
parser.add_argument(
    "--n_turns_spiral",
    type=int,
    default=4,
    help="number of spiral turns for sphere spiral distribution",
)
# uniform checkerboard parameters
parser.add_argument(
    "--n_theta", type=int, default=6, help="number of rows in the checkerboard (n_theta>0)"
)
parser.add_argument(
    "--n_phi",
    type=int,
    default=6,
    help="number of columns in the checkerboard (n_phi>0 and must be even)",
)
parser.add_argument(
    "--mixture_distribution",
    type=str,
    default="gaussian",
    choices=["gaussian", "beta"],
)

parser.add_argument(
    "--base_distribution",
    type=str,
    default="mog",
    choices=["mog", "uniform", "betas", "quasi_random"],
)

# lp uniform parameters
parser.add_argument("--beta", type=float, default=1.0, help="p of the lp norm")
parser.add_argument("--radius", type=float, default=1.0, help="radius of manifold")

args = parser.parse_args()
# args.directory_path = os.path.abspath(args.directory_path) + "/" # overwrite local path with global path
args.script_path = Path(__file__).resolve().parent  # overwrite local path with global path
dtype = torch.float32
args.dtype = dtype

import numpy as np


def get_statistics_numpy(numpy_data):
    mu = np.mean(numpy_data, axis=0)
    cov = np.cov(numpy_data, rowvar=False)
    return mu, cov


def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
    """Numpy implementation of the Frechet Distance.
    The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
    and X_2 ~ N(mu_2, C_2) is
            d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).

    Stable version by Dougal J. Sutherland.

    Params:
    -- mu1   : Numpy array containing the activations of a layer of the
               inception net (like returned by the function 'get_predictions')
               for generated samples.
    -- mu2   : The sample mean over activations, precalculated on an
               representative data set.
    -- sigma1: The covariance matrix over activations for generated samples.
    -- sigma2: The covariance matrix over activations, precalculated on an
               representative data set.

    Returns:
    --   : The Frechet Distance.
    """

    mu1 = np.atleast_1d(mu1)
    mu2 = np.atleast_1d(mu2)

    sigma1 = np.atleast_2d(sigma1)
    sigma2 = np.atleast_2d(sigma2)

    assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
    assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"

    diff = mu1 - mu2

    # Product might be almost singular
    covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
    if not np.isfinite(covmean).all():
        msg = (
            "fid calculation produces singular product; " "adding %s to diagonal of cov estimates"
        ) % eps
        print(msg)
        offset = np.eye(sigma1.shape[0]) * eps
        covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))

    # Numerical error might give slight imaginary component
    if np.iscomplexobj(covmean):
        if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
            m = np.max(np.abs(covmean.imag))
            raise ValueError("Imaginary component {}".format(m))
        covmean = covmean.real

    tr_covmean = np.trace(covmean)

    return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean


def main():
    set_random_seed(seed=args.seed)
    create_directories(script_dir=args.script_path)
    dataset = create_dataset(args)
    args.keep_datapoints = True
    if args.use_dataset_dim:
        args.n_dim = dataset.D
    assert dataset.D == args.n_dim, "dimensionality mismatch"
    samples_to_plot = dataset.sample(args.n_simulations, data_type="train")
    samples_to_plot, context = check_tuple(samples_to_plot)
    args.n_dim = samples_to_plot.shape[-1]
    signature = model_signature(args=args, dataset=dataset.dataset_suffix)

    # load all datasets for evaluation
    train_samples, val_samples, test_samples = dataset.load_dataset(overwrite=False)
    train_samples = torch.from_numpy(train_samples).float().to(args.device)
    val_samples = torch.from_numpy(val_samples).float().to(args.device)
    test_samples = torch.from_numpy(test_samples).float().to(args.device)

    data_mean, data_cov = get_statistics_numpy(test_samples.cpu().numpy())

    print(f"Train dataset size: {train_samples.shape}")
    print(f"Validation dataset size: {val_samples.shape}")
    print(f"Test dataset size: {test_samples.shape}")

    metrics = ["log_lik"]

    # =#=#=# train marginal flow #=#=#=#
    use_trainable_means = False
    n_mixtures = 32768 * 2  # 32768 * 2  # 32768 * 2  # 32768
    base_dim = args.base_dim if args.manifold else args.n_dim
    print("base_dim: ", base_dim)
    log_sigma_init = 0
    marginal_flow = MarginalFlow(
        x_dim=args.n_dim,
        z_dim=base_dim,
        log_sigma_init=log_sigma_init,
        dropout=args.dropout,
        use_trainable_means=use_trainable_means,
        mixture_distribution=args.mixture_distribution,
        base_distribution=args.base_distribution,
        n_base_means=args.n_base_means,
        fourier_sigma=args.fourier_sigma,
        device=args.device,
        isotropic_sigma=not args.diagonal_sigma,
        n_layers=5,
        hid_dim=512,
        dtype=dtype,
        signature=signature,
        script_path=args.script_path,
    )

    lr_network = 5e-4
    lr_sigma = 1e-2
    overwrite_marg_flow = False
    train_marginal_flow(
        model=marginal_flow,
        n_mixtures=n_mixtures,
        n_epochs=args.n_epochs,
        batch_size=args.batch_size,
        training_mode=args.training_mode,
        dataset=dataset,
        lr_network=lr_network,
        lr_sigma=lr_sigma,
        overwrite=overwrite_marg_flow,
        fixed_datapoints=args.keep_datapoints,
        normalize_data=args.normalize_sbi_data,
        save_best_val=True,
        metrics=metrics,
        n_val_steps=args.n_val_steps,
        n_val_steps_no_increase=args.n_val_steps_no_increase,
    )
    marginal_flow.eval()

    start_time = time.monotonic()
    with torch.no_grad():
        log_lik_func = partial(marginal_flow.log_prob, n_mixtures=n_mixtures)
        # train_log_lik_mf = batched_evaluation(
        #     data=train_samples, batch_size=args.batch_size, function=log_lik_func
        # )
        val_log_lik_mf = batched_evaluation(
            data=val_samples, batch_size=args.batch_size, function=log_lik_func
        )
        test_log_lik_mf = batched_evaluation(
            data=test_samples, batch_size=args.batch_size, function=log_lik_func
        )
    val_time = time.monotonic() - start_time
    print(val_log_lik_mf.mean(), test_log_lik_mf.mean())

    with torch.no_grad():
        sampling_function = partial(marginal_flow.sample, n_mixtures=n_mixtures)
        start_time = time.monotonic()
        mf_samples_np = batched_sampling(
            n_points=test_samples.shape[0],
            batch_size=args.batch_size,
            sampling_function=sampling_function,
        )
        sampling_time = time.monotonic() - start_time
        mf_mean, mf_cov = get_statistics_numpy(mf_samples_np)
        mf_fid = calculate_frechet_distance(
            mu1=mf_mean, sigma1=mf_cov, mu2=data_mean, sigma2=data_cov
        )
        print(f"FID: {mf_fid}")

    results_mf = dict(
        val_logl=float(val_log_lik_mf.mean()),
        test_logl=float(test_log_lik_mf.mean()),
        fid=float(mf_fid),
        val_time=val_time,
        sampling_time=sampling_time,
    )

    with open(
        f"./plots/margflow_{args.dataset}_{args.base_distribution}_diagsigma{args.diagonal_sigma}_norm{args.normalize_dataset_type}_seed{args.seed}.json",
        "w",
    ) as file:
        json.dump(results_mf, file, indent=4)

    test_df = pd.DataFrame(test_samples.detach().cpu().numpy())
    test_df.hist(figsize=(15, 15))
    plt.savefig(f"./plots/{args.dataset}_ground_truth.png", dpi=300)
    mf_df = pd.DataFrame(mf_samples_np)
    mf_df.hist(figsize=(15, 15))
    plt.savefig(f"./plots/{args.dataset}_marginal_flow.png", dpi=300)
    plt.show()

    breakpoint()

    # =#=#=# train flow matching #=#=#=#
    # n_hidden_dim = 512
    # n_layers = 5
    # lr_flowmatch = 1e-3
    # overwrite_flowmatch = True
    # flowmatch = FlowMatching(
    #    x_dim=args.n_dim,
    #    hid_dim=n_hidden_dim,
    #    n_layers=n_layers,
    #    script_path=args.script_path,
    #    signature=signature,
    #    device=args.device,
    # )
    # flowmatch.train_samples(
    #     n_epochs=args.n_epochs,
    #     batch_size=args.batch_size,
    #     dataset=dataset,
    #     overwrite=overwrite_flowmatch,
    #     lr=lr_flowmatch,
    #     save_best_val=True,
    # )
    # flowmatch.eval()
    # start_time = time.monotonic()
    # with torch.no_grad():
    #     val_log_lik_fm = batched_evaluation(
    #         data=val_samples, batch_size=args.batch_size, function=flowmatch.log_likelihood
    #     )
    #     test_log_lik_fm = batched_evaluation(
    #         data=test_samples, batch_size=args.batch_size, function=flowmatch.log_likelihood
    #     )
    # val_time = time.monotonic() - start_time
    # print(val_log_lik_fm.mean(), test_log_lik_fm.mean())

    # with torch.no_grad():
    #     start_time = time.monotonic()
    #     fm_samples_np = batched_sampling(
    #         n_points=test_samples.shape[0],
    #         batch_size=args.batch_size,
    #         sampling_function=flowmatch.sample,
    #     )
    #     sampling_time = time.monotonic() - start_time
    #     fm_mean, fm_cov = get_statistics_numpy(fm_samples_np)
    #     fm_fid = calculate_frechet_distance(
    #          mu1=fm_mean, sigma1=fm_cov, mu2=data_mean, sigma2=data_cov
    #     )
    #     print(f"FID: {fm_fid}")
    # results_fm = dict(
    #     val_logl=float(val_log_lik_fm.mean()),
    #     test_logl=float(test_log_lik_fm.mean()),
    #     fid=float(fm_fid),
    #     val_time=val_time,
    #     sampling_time=sampling_time,
    # )

    # with open(
    #     f"./plots/flowmatch_{args.dataset}_norm{args.normalize_dataset_type}_seed{args.seed}.json",
    #     "w",
    # ) as file:
    #     json.dump(results_fm, file, indent=4)

    # mf_df = pd.DataFrame(fm_samples_np)
    # mf_df.hist(figsize=(15, 15))
    # plt.savefig(f"./plots/{args.dataset}_flow_matching.png", dpi=300)
    # plt.show()

    # kde = KernelDensity(kernel="gaussian", bandwidth=0.025)
    # kde.fit(train_samples.cpu().numpy())

    # Step 4: Evaluate on test data using log-likelihood
    # log_probs = kde.score_samples(test_samples[::10].cpu().numpy())  # log density estimates
    # avg_log_likelihood = np.mean(log_probs)
    # print(f"Average log likelihood: {avg_log_likelihood}")

    # import seaborn as sns
    #
    # test_df["model"] = "gt"
    # mf_df["model"] = "mf"
    # df_all = pd.concat([test_df, mf_df])
    # df_melted = df_all.melt(id_vars="model", var_name="Feature", value_name="Value")
    # n_cols = int(np.sqrt(mf_samples_np.shape[-1])) + 1
    # sns.displot(
    #     df_melted,
    #     x="Value",
    #     hue="model",
    #     col="Feature",
    #     kind="hist",
    #     bins=50,
    #     alpha=0.5,
    #     col_wrap=n_cols,
    #     facet_kws={"sharex": False, "sharey": False},
    # )
    # plt.savefig(f"./plots/{args.dataset}_marginal_flow.png", dpi=300)
    # plt.show()

    # =#=#=#  train a normalizing flow #=#=#=#
    # lr_normflow = 1e-3
    # n_layers_normflow = 6
    # overwrite_normflow = False
    # normalizing_flow = NormalizingFlow(
    #     x_dim=args.n_dim,
    #     direction="forward",
    #     n_layers=n_layers_normflow,
    #     signature=signature,
    #     script_path=args.script_path,
    # )
    # normalizing_flow.train_forward(
    #     dataset=dataset,
    #     batch_size=args.batch_size,
    #     n_epochs=args.n_epochs,
    #     lr=lr_normflow,
    #     fixed_datapoints=args.keep_datapoints,
    #     overwrite=overwrite_normflow,
    #     save_best_val=False,
    # )
    # normalizing_flow.flow.eval()
    #
    # train_log_lik_nf = normalizing_flow.log_prob(x=train_samples)
    # val_log_lik_nf = normalizing_flow.log_prob(x=val_samples)
    # test_log_lik_nf = normalizing_flow.log_prob(x=test_samples)
    # print(train_log_lik_nf.mean(), val_log_lik_nf.mean(), test_log_lik_nf.mean())


if __name__ == "__main__":
    main()
