import json
import time
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader

from margflow.datasets.dataset_abstracts import DatasetIdentifier, DensityDataset
from margflow.marginal_flow import MarginalFlow
from margflow.utils.training_utils import check_tuple, ConditionalDataset


def norm_data(data, dim=0):
    data -= data.mean(dim)
    data /= data.std(dim)

    return data


def jacobian_norm(f, x):
    x.requires_grad_(True)
    y = f(x).sum()  # sum makes backward easy
    grad = torch.autograd.grad(y, x, create_graph=True)[0]
    return (grad**2).sum(dim=1).mean()  # Frobenius norm squared


def contractive_loss(f, x):
    x = x.clone().detach().requires_grad_(True)
    y = f(x)  # (N, d2)
    loss = 0.0
    for i in range(y.shape[1]):  # over d2
        grad = torch.autograd.grad(y[:, i].sum(), x, create_graph=True)[0]  # (N, d1)
        loss += (grad**2).sum(dim=1)  # Frobenius norm squared
    return loss.mean()


def gen_cooling_schedule(T0, Tn, n_epochs, share_active_epochs=0.66, scheme="exp_mult"):
    def cooling_schedule(t):
        n_eff_epochs = n_epochs * share_active_epochs
        if t < n_eff_epochs:
            k = t / n_eff_epochs
            if scheme == "exp_mult":
                alpha = Tn / T0
                return T0 * (alpha**k)
            # elif scheme == 'log_mult':
            #    return T0 / (1 + alpha * math.log(1 + k))
            elif scheme == "lin_mult":
                alpha = T0 / Tn - 1
                return T0 / (1 + alpha * k)
            elif scheme == "quad_mult":
                alpha = T0 / Tn - 1
                return T0 / (1 + alpha * (k**2))
        else:
            return Tn

    return cooling_schedule


def format_with_precision(a, precision: int = 3, list_print_max_items: int = 20):
    if isinstance(a, float):
        return f"{a:.{precision}f}"
    elif isinstance(a, list):
        return "[" + ", ".join(f"{x:.{precision}f}" for x in a[:list_print_max_items]) + "]"
    else:
        return str(a)


def train_log_likelihood(
    model: MarginalFlow,
    optim: torch.optim.Optimizer,
    batch_size: int,
    n_mixtures: int,
    n_epochs: int,
    dataset: DatasetIdentifier,
    train_dataloader: DataLoader,
    val_dataloader: DataLoader,
    metrics: list = None,
    save_best_val: bool = True,
    fixed_datapoints: bool = True,
    n_val_steps: int = 200,
    n_val_steps_no_increase: int = 100
):
    print("==> Objective function: log-likelihood")
    model.training_log = dict(loss_train=[], loss_val=[], logging_dict={})
    log_lik_test_prev = 100000
    if save_best_val:
        no_improv = 0
    start_time = time.monotonic()
    training_time = 0.0
    # temperature = gen_cooling_schedule(T0=100, Tn=10, n_epochs=n_epochs, share_active_epochs=0.5, scheme="exp_mult")
    if metrics is not None:
        _, _, test_samples = dataset.load_dataset(overwrite=False)
        test_samples = torch.from_numpy(test_samples).float().to(model.device)
    for i in range(n_epochs):
        loss_dict = {}
        metrics_dict = {}
        if fixed_datapoints:
            for data in train_dataloader:
                optim.zero_grad()
                data, context = check_tuple(data, move_to_torch=False, device=model.device)
                log_prob_flow = model.log_prob(n_mixtures=n_mixtures, x=data, context=context)
                log_lik = -log_prob_flow.mean()
                # samples, mixtures, base_samples = model.sample_all(n_mixtures=n_mixtures, n_samples=1, context=context)
                # penalty = jacobian_norm(model.network, base_samples)
                # penalty = contractive_loss(model.network, base_samples)
                # T = temperature(t=i)
                total_loss = log_lik  # + T * penalty
                # loss_dict["T"] = T
                loss_dict["total_loss"] = total_loss.item()
                loss_dict["log_lik"] = log_lik.item()
                # loss_dict["penalty"] = penalty.item()
                loss_dict["sigma"] = model.log_sigma.exp().detach().cpu().numpy().tolist()
                # model.training_log["loss_train"].append(log_lik.item())
                # log_lik.backward()
                model.training_log["loss_train"].append(total_loss.item())
                total_loss.backward()
                optim.step()
        else:
            optim.zero_grad()
            data = dataset.sample(batch_size, "train")  # infinite sample regime
            data, context = check_tuple(data, move_to_torch=True, device=model.device)
            log_prob_flow = model.log_prob(n_mixtures=n_mixtures, x=data, context=context)
            log_lik = -log_prob_flow.mean()

            # mflow_samples = model.sample(n_samples=1000)
            # logp_gt = dataset.log_prob(x=mflow_samples)
            # logp_mflow = model.log_prob(x=mflow_samples)
            # kl_rev = torch.mean(logp_mflow - logp_gt)

            loss_dict["log_lik"] = log_lik.item()
            loss_dict["sigma"] = model.log_sigma.exp().item()
            model.training_log["loss_train"].append(log_lik.item())
            log_lik.backward()
            optim.step()

        n_val_steps = min(n_val_steps, n_epochs)
        if i % (n_epochs // n_val_steps) == 0:
            # softplus = torch.nn.Softplus()
            # right_range = softplus(model.network.ranges[:, 1])
            # left_range = -softplus(model.network.ranges[:, 0])
            # print(left_range, right_range)
            # softplus = torch.nn.Softplus()
            # print(softplus(model.network.ranges))
            # model.network.eval()
            training_time += time.monotonic() - start_time
            runtime_df = {"runtime": training_time}
            with torch.no_grad():
                if metrics is not None and context is None:
                    metrics_dict = model.evaluate_metrics(
                        metrics,
                        val_samples=test_samples,
                        n_samples=batch_size,
                        dataset=dataset,
                        n_mixtures=n_mixtures,
                    )
                logs_dict = runtime_df | loss_dict | metrics_dict
                if not model.training_log["logging_dict"]:
                    model.training_log["logging_dict"] = {key: [] for key in logs_dict}
                for key, value in logs_dict.items():
                    model.training_log["logging_dict"][key].append(value)

                if fixed_datapoints:
                    log_prob = []
                    for data in val_dataloader:
                        data, context = check_tuple(data)
                        _, _, log_prob_flow_test, _ = model.sample_and_log_prob(
                            n_mixtures=n_mixtures,
                            n_samples=data.shape[0],
                            x=data,
                            context=context,
                        )
                        log_prob.append(log_prob_flow_test)
                    log_prob = torch.cat(log_prob, 0)
                    log_lik_test = -log_prob.mean()
                    loss_dict["log_lik_test"] = log_lik_test.item()
                    model.training_log["loss_val"].append(log_lik_test.item())
                    if save_best_val:
                        if log_lik_test < log_lik_test_prev:
                            print("saved model")
                            log_lik_test_prev = log_lik_test
                            model.save_trained_model()
                            no_improv = 0
                        else:
                            no_improv += 1
                        if no_improv > n_val_steps_no_increase:
                            break
            start_time = time.monotonic()

            print(
                f"Epoch {i} losses: "
                + ", ".join(
                    f"{key}:{format_with_precision(value)}" for key, value in loss_dict.items()
                )
            )
            if metrics is not None:
                print(
                    f"Epoch {i} other metrics: "
                    + ", ".join(f"{key}:{value:.3f}" for key, value in metrics_dict.items())
                )
            model.network.train()

        if not save_best_val:
            model.save_trained_model()


def train_kl_divergence(
    model: MarginalFlow,
    optim: torch.optim.Optimizer,
    batch_size: int,
    n_mixtures: int,
    n_epochs: int,
    dataset: DensityDataset,
    metrics: list = None,
    save_best_val: bool = True,
    fixed_datapoints: bool = True,
):
    print("==> Objective function: kl-divergence")
    model.training_log = dict(loss_train=[], loss_val=[], logging_dict={})
    kl_test_prev = 100000
    if save_best_val:
        no_improv = 0

    context = None
    start_time = time.monotonic()
    training_time = 0.0
    temperature = gen_cooling_schedule(
        T0=5, Tn=1, n_epochs=n_epochs, share_active_epochs=0.5, scheme="exp_mult"
    )

    if metrics is not None and context is None:
        if hasattr(dataset, "logp_estimator"):
            val_samples = dataset.sample_estimator(n_samples=batch_size)
        else:
            val_samples = dataset.sample(n_samples=batch_size)

    accumulation_steps = 1
    optim.zero_grad()
    for i in range(n_epochs * accumulation_steps):
        loss_dict = {}
        metrics_dict = {}
        # optim.zero_grad()

        mixtures, samples, log_prob, base_samples = model.sample_and_log_prob(
            n_mixtures=n_mixtures, n_samples=batch_size, context=context
        )

        # samples = model.sample(n_mixtures=n_mixtures, n_samples=batch_size, context=context)
        # log_prob = model.log_prob(x=samples, n_mixtures=n_mixtures)
        # with torch.no_grad():
        log_prob_target = dataset.log_prob(x=samples)
        T = temperature(i) if temperature is not None else 1.0
        kl_div = torch.mean(log_prob - (log_prob_target / T))
        kl_div_orig = torch.mean(log_prob - log_prob_target).detach()
        loss_dict["kl_div_T"] = kl_div.item()
        loss_dict["kl_div"] = kl_div_orig.item()
        loss_dict["temp"] = T
        loss_dict["sigma"] = model.log_sigma.exp().item()
        kl_div = kl_div / accumulation_steps
        kl_div.backward()

        model.training_log["loss_train"].append(kl_div.item())

        if (i + 1) % accumulation_steps == 0:
            optim.step()
            optim.zero_grad()

        n_val_steps = min(50, n_epochs)
        if i % (n_epochs // n_val_steps) == 0:
            model.network.eval()
            training_time += time.monotonic() - start_time
            runtime_df = {"runtime": training_time}
            with torch.no_grad():
                if metrics is not None and context is None:
                    metrics_dict = model.evaluate_metrics(
                        metrics,
                        val_samples=val_samples,
                        n_samples=batch_size,
                        dataset=dataset,
                        n_mixtures=n_mixtures,
                    )
                logs_dict = runtime_df | loss_dict | metrics_dict
                if not model.training_log["logging_dict"]:
                    model.training_log["logging_dict"] = {key: [] for key in logs_dict}
                for key, value in logs_dict.items():
                    model.training_log["logging_dict"][key].append(value)

                _, samples_test, log_prob_flow_test, _ = model.sample_and_log_prob(
                    n_mixtures=n_mixtures,
                    n_samples=batch_size,
                    context=context,
                )

                log_prob_target_test = dataset.log_prob(x=samples_test)

                kl_div_test = torch.mean(log_prob_flow_test - log_prob_target_test)
                loss_dict["kl_div_test"] = kl_div_test.item()
                model.training_log["loss_val"].append(kl_div_test.item())

                if save_best_val:
                    if kl_div_test < kl_test_prev:
                        print("saved model")
                        kl_test_prev = kl_div_test
                        model.save_trained_model()
                        no_improv = 0
                    else:
                        no_improv += 1
                    if no_improv > 50:
                        break
            start_time = time.monotonic()

            print(
                f"Epoch {i} losses: "
                + ", ".join(f"{key}:{value:.3f}" for key, value in loss_dict.items())
            )
            if metrics is not None:
                print(
                    f"Epoch {i} other metrics: "
                    + ", ".join(f"{key}:{value:.3f}" for key, value in metrics_dict.items())
                )
            model.network.train()

        if not save_best_val:
            model.save_trained_model()


def train_marginal_flow(
    model: MarginalFlow,
    n_mixtures: int,
    n_epochs: int,
    batch_size: int,
    dataset: DatasetIdentifier | DensityDataset,
    training_mode: str = "log_likelihood",
    lr_network: float = 5e-4,
    lr_sigma: float = 1e-2,
    metrics: list = None,
    fixed_datapoints: bool = True,
    save_best_val: bool = True,
    overwrite: bool = False,
    normalize_data: bool = False,
    n_val_steps = 200,
    n_val_steps_no_increase = 100
):
    if model.existing_trained_model(overwrite):
        model.load_trained_model()
    else:
        print("+++++ Training marginal flow ++++++")
        n_parameters = model.count_parameters()
        print(f"Model has {n_parameters} trainable parameters")
        parameters = [
            {"params": model.network.parameters(), "lr": lr_network},  # Model parameters
            {"params": model.log_sigma, "lr": lr_sigma},
        ]
        if model.use_trainable_means:
            parameters += [{"params": model.trainable_means, "lr": 5e-1}]
            parameters = parameters[1:]  # remove network parameters
        elif model.base_distribution == "mog" and model.n_base_means > 1:
            parameters += [{"params": model.base_means, "lr": 5e-2}]

        optim = torch.optim.Adam(parameters)

        start_time = time.monotonic()
        model.network.train()

        if fixed_datapoints:
            train_samples, val_samples, _ = dataset.load_dataset(overwrite=False)
            train_samples, train_context = check_tuple(
                train_samples, move_to_torch=True, device=model.device
            )
            val_samples, val_context = check_tuple(
                val_samples, move_to_torch=True, device=model.device
            )

            if normalize_data:
                train_samples = norm_data(train_samples)
                val_samples = norm_data(val_samples)
                if train_context is not None:
                    train_context = norm_data(train_context)
                    val_context = norm_data(val_context)

            if train_context is not None:
                train_dataset = ConditionalDataset(train_samples, train_context)
                val_dataset = ConditionalDataset(val_samples, val_context)
                train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
                val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
            else:
                train_dataloader = DataLoader(train_samples, batch_size=batch_size, shuffle=True)
                val_dataloader = DataLoader(val_samples, batch_size=batch_size, shuffle=False)

        if not fixed_datapoints:
            save_best_val = False
            train_dataloader = None
            val_dataloader = None
        try:
            if training_mode == "log_likelihood":
                train_log_likelihood(
                    model=model,
                    optim=optim,
                    n_mixtures=n_mixtures,
                    n_epochs=n_epochs,
                    batch_size=batch_size,
                    train_dataloader=train_dataloader,
                    val_dataloader=val_dataloader,
                    dataset=dataset,
                    metrics=metrics,
                    save_best_val=save_best_val,
                    fixed_datapoints=fixed_datapoints,
                    n_val_steps=n_val_steps,
                    n_val_steps_no_increase=n_val_steps_no_increase
                )
            elif training_mode == "kl_divergence":
                train_kl_divergence(
                    model=model,
                    optim=optim,
                    n_mixtures=n_mixtures,
                    n_epochs=n_epochs,
                    batch_size=batch_size,
                    dataset=dataset,
                    metrics=metrics,
                    save_best_val=save_best_val,
                    fixed_datapoints=fixed_datapoints,
                )
            elif training_mode == "symmetric_kl":
                # TODO: implement training mode as separate method
                raise NotImplementedError
            elif training_mode == "score_matching":
                # TODO: implement training mode as separate method
                raise NotImplementedError
        except KeyboardInterrupt:
            print("finishing training early due to keyboard interrupt")

        end_time = time.monotonic()
        time_diff = timedelta(seconds=end_time - start_time)
        print(f"Training marginal flow took {time_diff} seconds")

        loss_train = np.array(model.training_log["loss_train"])
        loss_val = np.array(model.training_log["loss_val"])
        for losses, name in (
            [(loss_train, "train"), (loss_val, "val")]
            if len(loss_val) > 0
            else [(loss_train, "train")]
        ):
            plt.plot(range(len(losses)), losses)
            plt.title(name)
            plt.show()

        if hasattr(model, "model_path"):
            with open(f"{model.model_path}.json", "w") as file:
                json.dump(model.training_log["logging_dict"], file, indent=4)

        if save_best_val:
            model.load_trained_model()
        else:
            model.save_trained_model()
