import os
from pprint import pprint

os.environ["DDE_BACKEND"] = "tensorflow"

import deepxde as dde
# import fixed_dde_model, fixed_dde_data
import numpy as np
import tensorflow as tf
import argparse
from FokkerPlanck import Model
from scipy.stats import norm, multivariate_normal
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

import pickle

from util import kl_divergence, mse
from samplers import *
from SIREN import SIREN, clip_output

import logging, sys

from munch import DefaultMunch
from mpl_toolkits.axes_grid1 import make_axes_locatable

import optuna


def get_pde_loss(pde_weight=1.) -> Callable:
    """
    Return Fokker-Planck PDE Loss.
    :param pde_weight:
    :return:
    """

    def pde(x, y):
        dim = x.shape[1] - 1
        time_idx = dim
        space_ids = [i for i in range(dim)]

        t_ = x[..., time_idx]

        laplacian = tf.zeros_like(y)
        divergence = tf.zeros_like(y)

        density = tf.math.exp(y)
        drift = (Model.drift_amplitude * tf.sin(tf.stop_gradient(t_) * Model.sin_scale))[..., tf.newaxis]
        noise = dim * Model.noise_sigma ** 2 / 2

        dp_dt = dde.grad.jacobian(density, x, i=0, j=time_idx)

        flux = drift * density
        for space_idx in space_ids:
            dp_dx = dde.grad.jacobian(flux, x, i=0, j=space_idx)
            divergence = divergence + dp_dx

        for i in range(len(space_ids)):
            ddp_dxdx = dde.grad.hessian(noise * density, x, i=space_ids[i], j=space_ids[i])
            laplacian = laplacian + ddp_dxdx

        tf.debugging.assert_shapes([(density, ('N', 1)),
                                    (drift * density, ('N', 1)),
                                    (noise * density, ('N', 1)),
                                    (dp_dt, ('N', 1)),
                                    (laplacian, ('N', 1)),
                                    (laplacian, ('N', 1)),
                                    ])
        pde_term = pde_weight * (dp_dt + divergence - laplacian)
        # will be passed into mse(pde_term, 0), so no need for squaring here
        # importance_weight = tf.stop_gradient(tf.abs(pde_term))  + 1e-5
        return pde_term #/ importance_weight

    return pde


def solution(x) -> np.array:
    """
    Returns analytical solution.
    :param x:
    :return:
    """
    _x = x[..., :-1]
    _t = x[..., -1:]
    dim = _x.shape[-1]

    # mean = np.broadcast_to(, _x.shape[:1])
    mean = Model.mean_numpy(_t)
    var = Model.cov_numpy(_t)
    # if dim > 1:
    # #     cov = np.broadcast_to(np.eye(dim)[np.newaxis, ...], (_x.shape[0], dim, dim)) * var[..., np.newaxis]
    prob = 0
    if dim > 1:
        for i in range(dim):
            prob += norm.logpdf(np.squeeze(_x[..., i]), np.squeeze(mean), np.squeeze(var + 1e-6) ** .5)
    else:
        prob = norm.logpdf(np.squeeze(_x), np.squeeze(mean), np.squeeze(var + 1e-6) ** .5)
    # prob = norm.logpdf(np.squeeze(_x), np.squeeze(mean), np.squeeze(var + 1e-6) ** .5)

    return prob.astype(np.float32)


def get_ic_fun(dim) -> Callable:
    def initial_condition(x):
        return multivariate_normal.logpdf(x[..., :-1],
                                          mean=[0.] * dim, cov=np.eye(dim) * Model.initial_noise_scale ** 2)[
            ..., np.newaxis]

    return initial_condition


def sample_true_dist(data, num_samples, dt=None):
    x = data.geom.random_points(num_samples)
    # x = data.geom.uniform_points(num_samples, boundary=False)#[:num_samples]

    _x = x[..., :-1]
    _t = x[..., -1:]
    if dt is not None:
        _t = np.random.uniform(-1., -1+dt, _t.shape)
    dim = _x.shape[-1]

    mean = Model.mean_numpy(_t)
    var = Model.cov_numpy(_t)

    _x_new = np.zeros_like(_x)
    for i in range(dim):
        _x_new[..., i] = norm.rvs(np.squeeze(mean), np.squeeze(var) ** .5)
    # _x_new = norm.rvs(np.squeeze(mean), np.squeeze(var) ** .5).reshape(-1, 1)
    test_points = np.concatenate([_x_new, _t], axis=-1).astype(np.float32)
    return test_points


def sample_true_dist_init(data, num_samples):
    x = data.geom.random_points(num_samples)
    # x = data.geom.uniform_points(num_samples, boundary=False)#[:num_samples]

    _x = x[..., :-1]
    _t = x[..., -1:] * 0
    dim = _x.shape[-1]

    mean = Model.mean_numpy(_t)
    var = Model.cov_numpy(_t)

    _x_new = np.zeros_like(_x)
    for i in range(dim):
        _x_new[..., i] = norm.rvs(np.squeeze(mean), np.squeeze(var) ** .5)
    # _x_new = norm.rvs(np.squeeze(mean), np.squeeze(var) ** .5).reshape(-1, 1)
    test_points = np.concatenate([_x_new, _t], axis=-1).astype(np.float32)
    return test_points



class CTimePDE(dde.data.TimePDE):  # (fixed_dde_data.TimePDE):
    """
    Custom class, that samples the test points from the true particle distribution, instead of uniformly,
    i.e. from p(x,t), the solution to the Fokker-Planck Equation.
    """

    def test_points(self):
        return sample_true_dist(data=self, num_samples=self.num_test)


class CustomModel(dde.Model):
    def __init__(self, data:CTimePDE, net):
        super().__init__(data, net)
        self.loss_weights = None#tf.Variable(np.ones(self.data.n_samples, dtype="float32"), trainable=False)

    @tf.function
    def evaluate_pde_loss(self, inputs):
        y = self.net(inputs)
        return self.data.pde(inputs, y)

sampler_dict = {
    "uniform": UniformResampler,
    "it-pdPINN": InverseTransformResampler,
    "MH-pdPINN": RandomWalkMHResampler,
    "hmc-pdpinn": HamiltonianMCResampler,
    "RAR": RarResampler,
    "it-is": InverseTransformImportanceResampler,
    "pdPINN-RAR": PdpinnRarResampler,
    "mh-pdPINN-RAR": HMCPdpinnRarResampler,
    "NUTS-pdPINN": NUTSResampler,
    "true": TrueDensityResampler,
    "true-RAR": TrueRarResampler,
}
sampler_dict = {key.lower(): val for key, val in sampler_dict.items()}


def main(args, trial: optuna.Trial = None):
    dde.utils.config.set_random_seed(args.seed)

    dim = args.dim

    # geom = dde.geometry.Hypercube([-.5] * dim, [.5] * dim)
    # geom = dde.geometry.Hypercube([-.8] * dim, [.8] * dim)
    geom = dde.geometry.Hypercube([-1.5] * dim, [1.5] * dim)
    # timedomain = dde.geometry.TimeDomain(-1, 1)
    timedomain = dde.geometry.TimeDomain(-1, 1)
    geomtime = dde.geometry.GeometryXTime(geom, timedomain)

    ic = dde.icbc.IC(geomtime, get_ic_fun(dim), lambda _, on_initial: on_initial)

    n_total_col_points = args.n_samples
    n_background_col_points = int(np.round(args.fraction_background * n_total_col_points))
    n_density_col_points = n_total_col_points - n_background_col_points

    # if args.optuna:
    pde_loss = args.pde_loss
    # else:
    #     pde_loss = get_pde_loss(pde_weight_dict.get(args.sampling_method.lower()))

    data = CTimePDE(
        geomtime,
        pde_loss,
        [ic],
        num_domain=n_background_col_points if args.sampling_method != "uniform" else n_total_col_points,
        num_initial=args.num_initial,
        # num_initial=5_000,
        train_distribution="pseudo",
        num_test=11421,
        solution=solution
    )

    if dim == 1:
        # w0 = 15
        w0 = 10
        net = SIREN(output_units=1,
                    # units=128,
                    units=64,
                    # num_layers=7,
                    num_layers=5,
                    include_first_layer=True,
                    w0_initial=w0,
                    w0=w0)
    elif dim == 2:
        w0 = 15
        # w0 = 10
        net = SIREN(output_units=1,
                    units=128,
                    # units=64,
                    num_layers=9,
                    # num_layers=5,
                    include_first_layer=True,
                    w0_initial=w0,
                    w0=w0)
    else:
        # w0 = 15
        net = SIREN(output_units=1,
                    units=args.units,  # 64,
                    # units=64,
                    num_layers=args.num_layers,  # 10,
                    # num_layers=5,
                    include_first_layer=True,
                    w0_initial=args.w0,
                    w0=args.w0,  # 10
                    )

    net.apply_output_transform(
        # clipping and gaussian multiplication in log space.
        lambda X, y: clip_output(y) + tf.stop_gradient(
            -tf.reduce_sum(tf.square(X[..., :-1]), keepdims=True, axis=-1) / .3)
        # lambda X, y: clip_output(y) + tf.stop_gradient(-tf.square(X[..., :-1]) / .3)
    )

    model = CustomModel(data, net)  # fixed_dde_model.Model(data, net)
    model.sample_true_dist = sample_true_dist
    if dim == 1:
        resample_each = 100
    else:
        resample_each = 200

    sampler_args = dict(period=resample_each,
                        n_samples=n_density_col_points,
                        n_background=n_background_col_points,
                        n_proposals=n_density_col_points * 2,
                        model=model,
                        dim=dim,
                        n_iterations=args.n_iterations,
                        plot_samples=not args.no_plots,
                        optuna_trial=trial)
    pprint(sampler_args)

    cb_resample = sampler_dict.get(args.sampling_method.lower())(**sampler_args)

    model.compile("adam", lr=tf.keras.optimizers.schedules.InverseTimeDecay(1e-04, 40000, .9998),
                  # model.compile("adam", lr=tf.keras.optimizers.schedules.InverseTimeDecay(1e-04, 5000, .999),
                  metrics=[kl_divergence, mse]
                  )
    if args.runtime:
        filename = f"results/fp_metrics_{dim}d_{args.sampling_method}_{args.n_samples // 1_000}ks_{args.n_iterations // 1_000}kit_{args.seed}"
    else:
        filename = f"results/fp_metrics_{dim}d_{args.sampling_method}_{args.n_samples // 1_000}ks_{args.n_iterations // 1_000}kit_pde{args.pde_weight}_{args.seed}"
    try:
        losshistory, train_state = model.train(iterations=args.n_iterations, display_every=200, callbacks=[cb_resample]
                                               )

    except KeyboardInterrupt:
        losshistory, train_state = model.losshistory, model.train_state
        print("Interrupted training..")

    val_x = sample_true_dist(data=data, num_samples=512)
    val_y = solution(val_x)
    pred_y = model.predict(val_x)
    val_loss = mse(val_y, pred_y)
    print(f"Validation Loss: {val_loss:.3f}")
    metrics = np.array(losshistory.metrics_test)
    results = dict(steps=losshistory.steps,
                   custom_steps=np.arange(len(cb_resample.full_times)) * cb_resample.display_every,
                   KL=metrics[:, 0],
                   MSE=metrics[:, 1],
                   times=np.array(cb_resample.times),
                   full_times=np.array(cb_resample.full_times),
                   val_loss=val_loss
                   )

    plot_losses(results, losshistory, filename)


    if args.save_pickle:
        with open(f'{filename}.pickle', 'wb') as handle:
            pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

    if not args.no_plots:
        plot_marginals(args, dim, geom, geomtime, model, timedomain)

    return results["KL"][-1]

from matplotlib.colors import LogNorm, Normalize

def plot_marginals(args, dim, geom, geomtime, model, timedomain):
    tt = np.linspace(timedomain.t0, timedomain.t1, 500)
    xx = np.linspace(geom.xmin[0], geom.xmax[0], 500)
    n_reps = 20
    XT = np.stack(np.meshgrid(*[xx, tt], indexing='ij'), -1).reshape(-1, 2)
    XT_repeated = np.repeat(XT, n_reps, axis=0)
    fig, axs = plt.subplots(dim, 1, sharex='col')
    if dim > 1:
        axs = axs.flatten()
    else:
        axs = [axs]

    for i, ax in enumerate(axs):
        pred = None

        if dim > 1:
            n_mc = 200
            for j in tqdm(range(n_mc)):
                samples = geomtime.random_points(XT_repeated.shape[0])
                samples[:, [i, -1]] = XT_repeated
                if pred is None:
                    pred = np.exp(model.predict(samples))
                else:
                    pred += np.exp(model.predict(samples))
            pred = pred.reshape(-1, n_reps, 2).sum(1)
            pred = (pred / (n_mc * n_reps)).reshape(xx.shape[0], tt.shape[0])
        else:
            pred = model.predict(XT).reshape(xx.shape[0], tt.shape[0])

        im = ax.imshow(pred, origin='lower', cmap="Reds", extent=[timedomain.t0, timedomain.t1,
                                                             geom.xmin[i], geom.xmax[i]],
                       #norm=Normalize(np.log(1e-5), np.log(19))
                       vmin=1e-5,
                       vmax=3.
                       )
        ax.set_ylim(-.5, .5)
        # ax.set_title(f"$p(t, x_{i})$")
        ax.set_title(f"$log$$p(t, x_{i})$")
        ax.set_xlabel("Time")
        ax.set_ylabel(f"$x_{i}$")

        divider = make_axes_locatable(ax)
        cax = divider.append_axes('right', size='5%', pad=0.05)
        fig.colorbar(im, cax=cax, orientation='vertical')


    plt.tight_layout()
    # plt.colorbar()
    if args.runtime:
        plt.savefig(
            f"results/fp_{dim}d_{args.sampling_method}_{args.n_samples // 1_000}ks_{args.n_iterations // 1_000}kit_{args.seed}.png")
    else:
        plt.savefig(
            f"results/fp_{dim}d_{args.sampling_method}_{args.n_samples // 1_000}ks_{args.n_iterations // 1_000}kit_pde{args.pde_weight}_{args.seed}.png")

    with open(f"results/fp_{dim}d_{args.sampling_method}_{args.n_samples // 1_000}ks_{args.n_iterations // 1_000}kit_{args.seed}.npy", 'wb') as f:
        np.save(f, pred)
    # if not args.no_plots:
    #     plt.show()
    plt.close()
    fig, ax = plt.subplots(1, 1, sharex='col')
    pred = solution(XT).reshape(xx.shape[0], tt.shape[0])

    with open(f"results/groundtruth.npy", 'wb') as f:
        np.save(f, pred)
    im = ax.imshow(pred + 1e-10, origin='lower', cmap="Reds", extent=[timedomain.t0, timedomain.t1,
                                                                        geom.xmin[i], geom.xmax[i]],
                   vmin=1e-5,
                   vmax=3.
                   #norm=Normalize(np.log(1e-5), np.log(20))
                   )
    ax.set_ylim(-.5, .5)

    ax.set_title(f"$log$$p(t, x_{i})$")
    ax.set_xlabel("Time")
    ax.set_ylabel(f"$x_{i}$")

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)

    fig.colorbar(im, cax=cax, orientation='vertical')
    plt.savefig("groundtruth.png")

    plt.close()


def plot_losses(results, losshistory, filename):
    sns.set_style("white")
    fig, ax = plt.subplots()
    fig.subplots_adjust(right=0.75)

    twin_ax1 = ax.twinx()
    # twin_ax2 = ax.twinx()

    # p3, = twin_ax2.plot(losshistory.steps, results["times"], "k-", label="Wall-Time in s")
    # Offset the right spine of twin2.  The ticks and label have already been
    # placed on the right by twinx above.
    # twin_ax2.spines.right.set_position(("axes", 1.2))

    ax.plot(losshistory.steps, results["KL"], 'g-')
    # twin_ax1.plot(losshistory.steps, results["MSE"], 'b-')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('KL(p|q)', color='g')
    # twin_ax1.set_ylabel('MSE', color='b')
    # twin_ax2.set_ylabel('Wall-Time in s', color='b')
    # ax.set_ylim(0, 50)
    ax.set_yscale('log')
    twin_ax1.set_ylim(4, 7)
    plt.savefig(f"{filename}.png")
    print(f"Saved to {filename}.png")
    plt.close()




if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--silent", action="store_true")

    parser.add_argument("--no-plots", action="store_true")
    parser.add_argument("--save-pickle", action="store_false")
    parser.add_argument("--runtime", action="store_true")

    parser.add_argument("--time-it", action="store_true", dest="timeit")
    parser.add_argument("--timeit", action="store_true", dest="timeit")
    parser.add_argument("--sampling-method",
                        choices=("uniform", "it-pdPINN", "MH-pdPINN", "HMC-pdPINN", "it-is", "RAR", "pdPINN-RAR",
                                 "mh-pdPINN-RAR",
                                 "NUTS-pdPINN", "true", "true-RAR"),
                        default="uniform")
    parser.add_argument("--n-samples", default=5_000, dest="n_samples", type=int)
    parser.add_argument("--n-iterations", default=30_000, dest="n_iterations", type=int)
    parser.add_argument("--dim", default=1, dest="dim", type=int)
    parser.add_argument("--seed", default=1234, dest="seed", type=int)
    parser.add_argument("--fraction-background", default=0.2, dest="fraction_background", type=float)
    parser.add_argument("--pde-weight", default=-1., type=float)

    parser.add_argument("--optuna", action="store_true")

    args = parser.parse_args()

    if not args.optuna:
        pde_weight_dict = {
            "uniform": 1.,
            # Density Weighted loss needs a different loss weight.
            "it-pdPINN": .1,
            "MH-pdPINN": .05,
            # "HMC-pdPINN": .2,
            "HMC-pdPINN": .1,
            # "RAR": .5,
            "RAR": 2,
            "pdPINN-RAR": 1. if args.dim == 2 else .15,
            "it-is": .1,
            "mh-pdPINN-RAR": 1. if args.dim == 2 else .1,
            "NUTS-pdPINN": .2,
            "true": 1.0,
            "true-RAR": .2
        }
        pde_weight_dict = {key.lower(): val for key, val in pde_weight_dict.items()}

        cargs = dict(
            silent=args.silent,
            no_plots=args.no_plots,
            timeit=args.timeit,
            sampling_method=args.sampling_method,
            n_samples=args.n_samples,
            n_iterations=args.n_iterations,
            dim=args.dim,
            fraction_background=args.fraction_background,

            save_pickle=args.save_pickle,
            num_initial=10_000,

            units=64,
            num_layers=10,
            w0=10,
            pde_loss=get_pde_loss(pde_weight_dict.get(args.sampling_method.lower())) if args.pde_weight < 0. else get_pde_loss(args.pde_weight),
            pde_weight=pde_weight_dict.get(args.sampling_method.lower()) if args.pde_weight < 0. else args.pde_weight,
            seed=args.seed,
            runtime = args.runtime
        )

        main(DefaultMunch.fromDict(cargs))
