import time
from datetime import timedelta
import matplotlib.pyplot as plt
from matplotlib import cm
import numpy as np
import json

import torch
from torch.cuda import temperature
from torch.distributions import Independent, Normal, Uniform
from torch.utils.data import DataLoader

from flow_matching.path import AffineProbPath
from flow_matching.path.scheduler import CondOTScheduler
from flow_matching.utils import ModelWrapper
from flow_matching.solver import ODESolver

from margflow.nn.mlp import MLPSwish
from margflow.trainer import gen_cooling_schedule
from margflow.utils.io_utils import add_n_param_in_signature
from margflow.utils.training_utils import check_tuple
from margflow.abstract_model import AbstractModel


class WrappedModel(ModelWrapper):
    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras):
        return self.model(x, t)


class GaussBaseDistr:
    def __init__(self, dim, device, sigma=1.0):
        self.dim = dim
        self.device = device
        self.sigma = sigma
        self.mean = torch.zeros(dim, device=device)
        self.cov = torch.ones(dim, device=device) * sigma

    def sample(self, n_samples):
        return torch.randn((n_samples, self.dim)) * self.sigma

    def log_p(self, x):
        return Independent(Normal(self.mean, self.cov), 1).log_prob(x)


class UniformBaseDistr:
    def __init__(self, dim, device, bound=1):
        self.dim = dim
        self.device = device
        self.bound = bound
        self.bound_tensor = torch.ones(dim, device=device) * bound

    def sample(self, n_samples):
        return torch.rand((n_samples, self.dim)) * 2 * self.bound - self.bound

    def log_p(self, x):
        # raise NotImplementedError
        return Independent(Uniform(low=-self.bound_tensor, high=self.bound_tensor), 1).log_prob(x)


def target_score(x, dataset, temp=1):
    x = x.clone().requires_grad_(True)
    logp = dataset.log_prob(x) / temp
    logp_grad = torch.autograd.grad(logp.sum(), x, create_graph=True)[0]

    # x_np = x.detach().cpu().numpy()
    # plt.scatter(x_np[:, 0], x_np[:, 1], label="samples_flow")
    # samples_gt = dataset.sample(x.shape[0]).detach().cpu().numpy()
    # plt.scatter(samples_gt[:, 0], samples_gt[:, 1], label="samples_gt")
    # plt.legend()
    # plt.show()

    return logp_grad


def velocity_field(x, t, dataset, temp=1):
    score = target_score(x, dataset, temp)
    score = score.clamp(min=-3, max=-3)
    return (1 - t.reshape(-1, 1)) * score


class FlowMatching(AbstractModel):
    def __init__(
        self,
        x_dim,
        hid_dim,
        n_layers,
        script_path=None,
        signature=None,
        device="cuda",
        dtype=torch.float32,
    ):
        super(FlowMatching, self).__init__(
            model_name="flow_matching",
            x_dim=x_dim,
            script_path=script_path,
            signature=signature,
            device=device,
            dtype=dtype,
        )

        self.hid_dim = hid_dim
        self.network = MLPSwish(
            input_dim=x_dim, n_layers=n_layers, time_dim=1, hidden_dim=hid_dim
        ).to(self.device)
        self.wrapped_network = WrappedModel(self.network)
        self.path = AffineProbPath(scheduler=CondOTScheduler())
        self.base_distr = GaussBaseDistr(dim=x_dim, device=device, sigma=0.5)
        # self.base_distr = UniformBaseDistr(dim=x_dim, device=device, bound=2)

        self.trainable_params = {"network": self.network}
        self.set_model_signature()

    def sample_x0_gaussian(self, n_samples):
        return self.base_distr.sample(n_samples).to(self.device)

    def sample_t_uniform(self, n_samples):
        return torch.rand(n_samples).to(self.device)

    def log_likelihood(self, x1, step_size=0.05, num_acc=10, exact=True):
        self.solver = ODESolver(velocity_model=self.wrapped_network)
        if exact:
            _, log_lik = self.solver.compute_likelihood(
                x_1=x1,
                method="midpoint",
                step_size=step_size,
                exact_divergence=True,
                log_p0=self.base_distr.log_p,
            )
        else:
            log_lik = 0
            for i in range(num_acc):
                _, log_p = self.solver.compute_likelihood(
                    x_1=x1,
                    method="midpoint",
                    step_size=step_size,
                    exact_divergence=False,
                    log_p0=self.base_distr.log_p,
                )
                log_lik += log_p

            log_lik /= num_acc

        return log_lik

    def solve_ode(self, n_samples=10000, step_size=0.05, n_time_steps=9, plot=False):
        T = torch.linspace(0, 1, n_time_steps)  # sample times
        T = T.to(device=self.device)

        x_init = self.sample_x0_gaussian(n_samples)
        self.solver = ODESolver(velocity_model=self.wrapped_network)  # create an ODESolver class
        sol = self.solver.sample(
            time_grid=T,
            x_init=x_init,
            method="midpoint",
            step_size=step_size,
            return_intermediates=True,
        )  # sample from the model

        if plot:
            self.plot_ode_solution(sol=sol.cpu().numpy(), n_time_steps=n_time_steps)

        return sol

    def sample(self, n_samples, context=None, **kwargs):
        sol = self.solve_ode(n_samples=n_samples, step_size=0.05, n_time_steps=2, plot=False)
        return sol[-1]

    def log_prob(self, x, context=None, exact=False, **kwargs):
        log_prob = self.log_likelihood(x, step_size=0.05, num_acc=10, exact=exact)
        return log_prob

    def sample_and_log_prob(self, n_samples, context=None, **kwargs):
        samples = self.sample(n_samples=n_samples, context=context, **kwargs)
        log_prob = self.log_prob(x=samples, context=context, **kwargs)

        return samples, log_prob

    def train_samples(
        self,
        n_epochs,
        batch_size,
        dataset,
        lr,
        fixed_datapoints=True,
        metrics=None,
        overwrite=False,
        save_best_val=False,
    ):
        log_lik_test_prev = 100000
        if save_best_val:
            no_improv = 0
        if self.existing_trained_model(overwrite=overwrite):
            # self.network.load_state_dict(torch.load(self.model_path, weights_only=True))
            self.load_trained_model()
        else:
            if fixed_datapoints:
                train_samples, val_samples, _ = dataset.load_dataset(overwrite=False)
                train_samples, train_context = check_tuple(
                    train_samples, move_to_torch=True, device=self.device
                )
                val_samples, val_context = check_tuple(
                    val_samples, move_to_torch=True, device=self.device
                )

                train_dataset = DataLoader(train_samples, batch_size=batch_size, shuffle=True)
                val_dataset = DataLoader(val_samples, batch_size=batch_size, shuffle=False)

            print("+++++ Training flow matching ++++++")
            print(f"Model has {self.count_parameters()} trainable parameters")
            optim = torch.optim.Adam(self.network.parameters(), lr=lr)
            self.training_log = dict(loss_train=[], loss_val=[], logging_dict={})
            start_time = time.monotonic()
            start_time_fixed = time.monotonic()
            training_time = 0.0
            loss_val = []
            try:
                for i in range(n_epochs):
                    loss_dict = {}
                    metrics_dict = {}
                    optim.zero_grad()

                    if fixed_datapoints:
                        for data in train_dataset:
                            data, context = check_tuple(
                                data, move_to_torch=True, device=self.device
                            )
                            optim.zero_grad()
                            x0 = self.sample_x0_gaussian(n_samples=data.shape[0])
                            # sample time (user's responsibility)
                            t = self.sample_t_uniform(n_samples=x0.shape[0])
                            # sample probability path
                            path_sample = self.path.sample(t=t, x_0=x0, x_1=data)

                            # flow matching l2 loss
                            loss = torch.pow(
                                self.network(path_sample.x_t, path_sample.t) - path_sample.dx_t, 2
                            ).mean()
                            self.training_log["loss_train"].append(loss.item())
                            loss_dict["loss"] = loss.item()
                            loss.backward()
                            optim.step()
                    else:
                        data = dataset.sample(batch_size, "train")  # infinite sample regime
                        data, context = check_tuple(data, move_to_torch=True, device=self.device)
                        x0 = self.sample_x0_gaussian(n_samples=batch_size)
                        # sample time (user's responsibility)
                        t = self.sample_t_uniform(n_samples=batch_size)
                        # sample probability path
                        path_sample = self.path.sample(t=t, x_0=x0, x_1=data)

                        # flow matching l2 loss
                        loss = torch.pow(
                            self.network(path_sample.x_t, path_sample.t) - path_sample.dx_t, 2
                        ).mean()
                        self.training_log["loss_train"].append(loss.item())
                        loss_dict["loss"] = loss.item()
                        loss.backward()
                        optim.step()

                    # log loss
                    n_val_steps = min(20, n_epochs)
                    # n_val_steps = n_epochs
                    if i % (n_epochs // n_val_steps) == 0:
                        training_time += time.monotonic() - start_time
                        runtime_df = {"runtime": training_time}
                        if metrics is not None and context is None:
                            _, _, test_samples = dataset.load_dataset(overwrite=False)
                            test_samples = torch.from_numpy(test_samples).float().to(self.device)
                            metrics_dict = self.evaluate_metrics(
                                metrics,
                                n_samples=batch_size,
                                val_samples=test_samples,
                                dataset=dataset,
                            )

                        # Initialize keys in the first iteration
                        logs_dict = runtime_df | loss_dict | metrics_dict
                        if not self.training_log["logging_dict"]:
                            self.training_log["logging_dict"] = {key: [] for key in logs_dict}
                        for key, value in logs_dict.items():
                            self.training_log["logging_dict"][key].append(value)
                        print(
                            f"Epoch {i} losses: "
                            + ", ".join(f"{key}:{value:.3f}" for key, value in loss_dict.items())
                        )
                        if metrics is not None:
                            print(
                                f"Epoch {i} other metrics: "
                                + ", ".join(
                                    f"{key}:{value:.3f}" for key, value in metrics_dict.items()
                                )
                            )

                        if fixed_datapoints:
                            log_prob = []
                            for data in val_dataset:
                                data, context = check_tuple(data)
                                log_prob_flow_test = self.log_prob(data, context=context)
                                log_prob.append(log_prob_flow_test)
                            log_prob = torch.cat(log_prob, -1)
                            log_lik_test = -log_prob.mean()
                            loss_dict["log_lik_test"] = log_lik_test.item()
                            loss_val.append(log_lik_test.item())
                            if save_best_val:
                                if log_lik_test < log_lik_test_prev:
                                    print("saved model")
                                    log_lik_test_prev = log_lik_test
                                    self.save_trained_model()
                                    no_improv = 0
                                else:
                                    no_improv += 1
                                if no_improv > 50:
                                    break

                        start_time = time.monotonic()

            except KeyboardInterrupt:
                print("finishing training early due to keyboard interrupt")
                self.save_trained_model()
                # torch.save(self.network.state_dict(), self.model_path)

            # torch.save(self.network.state_dict(), self.model_path)
            self.save_trained_model()

            end_time = time.monotonic()
            time_diff = timedelta(seconds=end_time - start_time_fixed)
            print(f"Training flow matching took {time_diff} seconds")

            loss_train = np.array(self.training_log["loss_train"])
            loss_val = np.array(self.training_log["loss_val"])
            for losses, name in (
                [(loss_train, "train"), (loss_val, "val")]
                if len(loss_val) > 0
                else [(loss_train, "train")]
            ):
                plt.plot(range(len(losses)), losses)
                plt.title(name)
                plt.show()
            if hasattr(self, "model_path"):
                with open(f"{self.model_path}.json", "w") as file:
                    json.dump(self.training_log["logging_dict"], file, indent=4)

            if save_best_val:
                self.load_trained_model()
            else:
                self.save_trained_model()

    def train_score(
        self,
        n_epochs,
        batch_size,
        dataset,
        lr,
        metrics=None,
        overwrite=False,
    ):
        if self.existing_trained_model(overwrite=overwrite):
            # self.network.load_state_dict(torch.load(self.model_path, weights_only=True))
            self.load_trained_model()
        else:
            print("+++++ Training flow matching ++++++")
            print(f"Model has {self.count_parameters()} trainable parameters")
            optim = torch.optim.Adam(self.network.parameters(), lr=lr)
            self.training_log = dict(loss_train=[], loss_val=[], logging_dict={})
            start_time = time.monotonic()
            start_time_fixed = time.monotonic()
            training_time = 0.0
            temperature = gen_cooling_schedule(
                T0=1, Tn=1, n_epochs=n_epochs, share_active_epochs=0.5, scheme="exp_mult"
            )
            try:
                for i in range(n_epochs):
                    loss_dict = {}
                    metrics_dict = {}
                    optim.zero_grad()

                    x0 = self.sample_x0_gaussian(n_samples=batch_size)
                    t = self.sample_t_uniform(n_samples=batch_size)

                    # x0_logp = self.base_distr.log_p(x0)
                    # x0_target_logp = dataset.log_prob(x0)
                    # noise = torch.randn_like(x0)
                    # sigma = 2
                    # x_noisy = x0 + sigma * noise
                    #
                    # weight = torch.exp(x0_target_logp - x0_logp)
                    # pred_score = self.network(x_noisy, t)
                    # pred_norm = (pred_score + noise / sigma**2) ** 2
                    # loss = (weight * pred_norm.sum(-1)).mean()

                    velocity_gt = velocity_field(
                        x=x0, t=t, dataset=dataset, temp=temperature(i)
                    ).detach()
                    velocity_pred = self.network(x0, t)

                    loss = ((velocity_gt - velocity_pred) ** 2).mean()

                    self.training_log["loss_train"].append(loss.item())
                    loss_dict["loss"] = loss.item()
                    loss.backward()
                    optim.step()

                    # log loss
                    n_val_steps = min(50, n_epochs // 20)
                    if i % (n_epochs // n_val_steps) == 0:

                        score = target_score(x0, dataset, temp=temperature(i))
                        print(f"Score norm: {score.norm(dim=1).mean().item():.4f}")
                        print(f"Velocity_gt norm: {velocity_gt.norm(dim=1).mean().item():.4f}")
                        print(f"Velocity_pred norm: {velocity_pred.norm(dim=1).mean().item():.4f}")
                        print(f"Loss: {loss.item():.4f}")

                        training_time += time.monotonic() - start_time
                        runtime_df = {"runtime": training_time}
                        if metrics is not None:
                            _, _, test_samples = dataset.load_dataset(overwrite=False)
                            test_samples = torch.from_numpy(test_samples).float().to(self.device)
                            metrics_dict = self.evaluate_metrics(
                                metrics,
                                n_samples=batch_size,
                                val_samples=test_samples,
                                dataset=dataset,
                            )

                        # Initialize keys in the first iteration
                        logs_dict = runtime_df | loss_dict | metrics_dict
                        if not self.training_log["logging_dict"]:
                            self.training_log["logging_dict"] = {key: [] for key in logs_dict}
                        for key, value in logs_dict.items():
                            self.training_log["logging_dict"][key].append(value)
                        print(
                            f"Epoch {i} losses: "
                            + ", ".join(f"{key}:{value:.3f}" for key, value in loss_dict.items())
                        )
                        if metrics is not None:
                            print(
                                f"Epoch {i} other metrics: "
                                + ", ".join(
                                    f"{key}:{value:.3f}" for key, value in metrics_dict.items()
                                )
                            )
                        start_time = time.monotonic()

            except KeyboardInterrupt:
                print("finishing training early due to keyboard interrupt")
                self.save_trained_model()
                # torch.save(self.network.state_dict(), self.model_path)

            # torch.save(self.network.state_dict(), self.model_path)
            self.save_trained_model()

            end_time = time.monotonic()
            time_diff = timedelta(seconds=end_time - start_time_fixed)
            print(f"Training flow matching took {time_diff} seconds")

            loss_train = np.array(self.training_log["loss_train"])
            loss_val = np.array(self.training_log["loss_val"])
            for losses, name in (
                [(loss_train, "train"), (loss_val, "val")]
                if len(loss_val) > 0
                else [(loss_train, "train")]
            ):
                plt.plot(range(len(losses)), losses)
                plt.title(name)
                plt.show()
            if hasattr(self, "model_path"):
                with open(f"{self.model_path}.json", "w") as file:
                    json.dump(self.training_log["logging_dict"], file, indent=4)

    def plot_ode_solution(self, sol, n_time_steps=9):

        n_sqrt = int(np.sqrt(n_time_steps))
        n_sqrt = n_sqrt if n_sqrt**2 == n_time_steps else n_sqrt + 1

        fig, axs = plt.subplots(n_sqrt, n_sqrt, figsize=(20, 20))

        for i in range(n_sqrt):
            for j in range(n_sqrt):
                try:
                    index = i * n_sqrt + j
                    H = axs[i, j].hist2d(
                        sol[index, :, 0], sol[index, :, 1], 300, range=((-5, 5), (-5, 5))
                    )
                    cmin = 0.0
                    cmax = torch.quantile(torch.from_numpy(H[0]), 0.99).item()
                    norm = cm.colors.Normalize(vmax=cmax, vmin=cmin)
                    _ = axs[i, j].hist2d(
                        sol[index, :, 0],
                        sol[index, :, 1],
                        300,
                        range=((-5, 5), (-5, 5)),
                        norm=norm,
                    )
                    # _ = axs[i,j].hist2d(sol[index, :, 0], sol[index, :, 1], 300, range=((-5, 5), (-5, 5)))

                    axs[i, j].set_aspect("equal")
                    axs[i, j].axis("off")
                except:
                    break
        plt.tight_layout()
        plt.show()

    def plot_log_likelihood(self, grid_size=200, bound=2):
        x_1 = torch.meshgrid(
            torch.linspace(-bound, bound, grid_size), torch.linspace(-5, 5, grid_size)
        )
        x_1 = torch.stack([x_1[0].flatten(), x_1[1].flatten()], dim=1).to(self.device)
        log_p = self.log_prob(x_1, exact=True)

        likelihood = torch.exp(log_p).cpu().reshape(grid_size, grid_size).detach().numpy()
        fig, axs = plt.subplots(1, 1, figsize=(10, 10))
        axs.imshow(
            likelihood, extent=(-bound, bound, -bound, bound), origin="lower", cmap="viridis"
        )  # , norm=norm)
        fig.colorbar(
            cm.ScalarMappable(cmap="viridis"), ax=axs, orientation="horizontal", label="density"
        )
        plt.show()
