# -*- coding: utf-8 -*-
"""modification of sdeflow_equivalent_sdes.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1Tx_Yt90NRgHve--ocIXi6SGR-0ebwH0N
    associated to https://github.com/CW-Huang/sdeflow-light
"""


import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.ticker as mticker 
import pandas as pd
from quantitative_comparison import compute_mmd

### 4.1. Define plotting tools
@torch.no_grad()
def get_2d_histogram_plot(data, val=3, num=64, vmin = 0, vmax=10, use_grid=False, origin='lower',logscale=True):

    # get data
    x = data[:, 0]
    if data.shape[1]<3:
        y = data[:, 1]
    else:
        y = data[:, 2]
        val = val/2

    xmin = -val
    xmax = val
    ymin = -val
    ymax = val

    # get histogram
    heatmap, xedges, yedges = np.histogram2d(x, y, range=[[xmin, xmax], [ymin, ymax]], bins=num)
    if logscale:
        heatmap_val = heatmap.copy()  # copy heatmap for vmin calculation
        if (heatmap > heatmap.min()).any():
            vmin = heatmap_val[heatmap > heatmap.min()].min()  # use the minimum value from the heatmap
            vmin /=2
        heatmap = np.log(heatmap + 1e-10)  # log scale for better visibility
        vmin = np.log(vmin)  # adjust vmin for log scale
        vmax = heatmap.max()  # adjust vmax for log scale
    extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]

    # plot heatmap
    fig, ax = plt.subplots(figsize=(5, 5))
    im = ax.imshow(heatmap.T, extent=extent, origin=origin, vmin=vmin, vmax=vmax)
    ax.grid(False)
    if use_grid:
        plt.xticks(np.arange(-val, val+1, step=1))
        plt.yticks(np.arange(-val, val+1, step=1))
    else:
        plt.xticks([])
        plt.yticks([])

    # tight
    plt.tight_layout()

    # draw to canvas
    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_argb(), dtype=np.uint8)

    tupl = fig.canvas.get_width_height()[::-1]
    if ( tupl[0]*tupl[1]*4 == image.shape[0] ) :
        image = image.reshape(tupl + (4,))
    else:
        image = image.reshape( (tupl[0]*2,tupl[1]*2,4) )
    # Keep only the first three channels (RGB), discarding the alpha channel
    image = image[:, :, 1:]

    plt.close()
    return image

@torch.no_grad()
def plot_selected_inds(xs, inds, use_xticks=True, use_yticks=True, lmbd = 0.,include_t0=False, backward=True, plt_show=True, val=3):
    imgs_ = []
    l_inds = len(inds)
    if backward:
        inds = reversed(inds)
    for ind in inds:
        imgs_ += [get_2d_histogram_plot(xs[ind].to('cpu').numpy(), val)]
    img_ = np.concatenate(imgs_, axis=1)

    height, width, _ = img_.shape
    height_per_img = width_per_img = height
    figwidth = 25
    fontsize = 15
    if use_xticks:
        xticks = [0.5*width_per_img + width_per_img*i for i in range(l_inds)]
        if not include_t0:
            xticklabels = [r'$i={:d}$'.format(ind+1) for ind in (inds)]
        else:
            xticklabels = [r'$i={:d}$'.format(ind) for ind in (inds)]
    else:
        xticks, xticklabels = [], []
    if use_yticks:
        yticks = [0.5*height_per_img]
        yticklabels = [r'$\lambda={:.2g}$'.format(lmbd)]
    else:
        yticks, yticklabels = [], []

    fig = plt.figure(figsize=(figwidth, figwidth*height/width))
    ax = fig.add_subplot(111)
    ax.imshow(img_)
    axis_color = 'white' #'white'
    ax.spines['bottom'].set_color(axis_color)
    ax.spines['top'].set_color(axis_color)
    ax.spines['left'].set_color(axis_color)
    ax.spines['right'].set_color(axis_color)
    ax.tick_params(axis='x', colors=axis_color)
    ax.tick_params(axis='y', colors=axis_color)
    plt.xticks(xticks, xticklabels, color='black', fontsize=fontsize)
    plt.yticks(yticks, yticklabels, color='black', fontsize=fontsize)
    if plt_show:
        plt.show(block=False)


@torch.no_grad()
def def_pd(xgen, std_norm, std_test_plot, datatype, dimplot=2, \
              crop_data_plot=False, plot_crop=3, columns_plot=None):
    
    xgen_plot = std_norm * xgen
    if crop_data_plot:
        boolean_mask = (xgen_plot.abs() < (plot_crop * std_norm * std_test_plot)).all(axis=1)
        print( str( (1 - boolean_mask.sum()/ len(boolean_mask)).item() * 100) + " % of samples outside plot limits")
        xgen_plot = xgen_plot[boolean_mask,:]

    pddatagen = pd.DataFrame(xgen_plot[:,0:dimplot].to('cpu'), columns=columns_plot)

    return pddatagen


@torch.no_grad()
def pairplots(xgen, xtest, std_norm, std_test_plot, datatype, name_simu, dimplot=2, \
              crop_data_plot=False, plot_crop=3, plot_xlim=3, plot_ref_pdf=False, \
              pdf_theor=None, log_scale_pdf=False, columns_plot=None, \
              plt_show=False, dpi=200, height_seaborn=2.5, ssize=10):

    pddatatest = def_pd(xtest, std_norm, std_test_plot, datatype, dimplot=dimplot, \
              crop_data_plot=crop_data_plot, plot_crop=plot_crop, columns_plot=columns_plot)
    pddatagen = def_pd(xgen, std_norm, std_test_plot, datatype, dimplot=dimplot, \
              crop_data_plot=crop_data_plot, plot_crop=plot_crop, columns_plot=columns_plot)

    pddata = pd.concat([pddatatest.assign(samples="test"),
                        pddatagen.assign(samples="gen.")])

    palette = {"test": sns.color_palette()[0], "gen.": sns.color_palette()[1]}
    plot_kws = {'alpha': 0.1, "s": ssize, "edgecolor": "none", "rasterized": True}

    # === Replace pairplot with PairGrid ===
    g = sns.PairGrid(pddata, hue="samples",
                    corner=True, height=height_seaborn, aspect=1,
                    palette=palette, diag_sharey=False)

    # lower triangle: scatter like before
    g.map_lower(sns.scatterplot, **plot_kws)

    def diag_plot(x, color=None, label=None, **kws):
        ax = plt.gca()

        if label == "test":
            # compute peak density from TEST values only (NumPy, not torch)
            x_np = np.asarray(x, dtype=np.float64)
            x_np = x_np[np.isfinite(x_np)]
            counts, _ = np.histogram(x_np, bins=80, density=True)
            ymax = float(counts.max()) if counts.size else 0.0

            # draw the test histogram
            sns.histplot(
                x=x, bins=80, stat="density",
                element="step", fill=True, alpha=0.25,
                color=palette["test"], **kws
            )

            # set Y limit for this diagonal axis only
            if log_scale_pdf and (counts > 0).any():
                ymin = counts[counts > 0].min()  # use the minimum value from the heatmap
                # ymin /=2
                # ymin *=8 # for swiss roll
            else:
                ymin = 0
                
            if ymax > 0:
                ax.set_ylim(ymin, 1.05 * ymax)

        elif label == "gen.":
            sns.kdeplot(x=x, color=palette["gen."], lw=1.5, **kws)
        
        if plot_ref_pdf:
            plot_xlim_col = plot_xlim * std_norm[0] * std_test_plot[0]
            # plot_xlim_col = plot_xlim * std_norm[i] * std_test_plot[i]
            x_min, x_max = -plot_xlim_col, plot_xlim_col
            xx = torch.linspace(x_min, x_max, 2000)
            pdf_theo = pdf_theor.log_prob(xx).exp()
            pdf_theo /= (pdf_theo.sum() * (xx[1]-xx[0]))  # normalize like a density
            plt.plot(xx,pdf_theo, color=palette["test"], linestyle=':', lw=1.5)

        if log_scale_pdf:
            ax.set_yscale('log')

    g.map_diag(diag_plot)

    # legend like before
    handles = [plt.Line2D([], [], marker='o', linestyle='',
                        color=palette[k], markersize=8, alpha=0.6) for k in ["test", "gen."]]
    labels = ["test", "gen."]
    g.figure.legend(handles=handles, labels=labels, loc='upper right', markerscale=ssize)

    # --- Pass 1: lower triangle only ---
    for i, row in enumerate(g.axes):
        plot_ylim_row = plot_xlim * std_norm[i] * std_test_plot[i]
        for j, ax in enumerate(row):
            if ax is None:
                continue
            plot_xlim_col = plot_xlim * std_norm[j] * std_test_plot[j]

            if j < i:  # lower triangle
                ax.set_xlim((-plot_xlim_col, plot_xlim_col))
                ax.set_ylim((-plot_ylim_row,  plot_ylim_row))

    # --- Pass 2: diagonals only ---
    for i in range(len(g.diag_vars)):
        ax = g.axes[i, i]
        if ax is None:
            continue
        var = g.diag_vars[i]
        plot_xlim_col = plot_xlim * std_norm[i] * std_test_plot[i]

        x_min, x_max = -plot_xlim_col, plot_xlim_col
        ax.set_xlim((x_min, x_max))

    for i, row in enumerate(g.axes):
        for j, ax in enumerate(row):
            if ax is None:
                continue

            nbins = 2
            # reduce the number of ticks
            ax.xaxis.set_major_locator(mticker.MaxNLocator(nbins=nbins))  # max 4 x-ticks
            ax.yaxis.set_major_locator(mticker.MaxNLocator(nbins=nbins))  # max 4 y-ticks

            # remove the "0.0" label but keep the tick itself (gridlines if any)
            def fmt_tick(val, pos):
                if abs(val) < 1e-8:   # close to zero
                    return ""         # empty label
                return f"{val:g}"     # compact formatting

            ax.xaxis.set_major_formatter(mticker.FuncFormatter(fmt_tick))
            ax.yaxis.set_major_formatter(mticker.FuncFormatter(fmt_tick))

    plt.tight_layout()
    if plt_show:
        plt.show(block=False); plt.pause(1)
    name_fig = name_simu + "_multDim.png"
    plt.savefig(name_fig, dpi=dpi)
    if plt_show:
        plt.pause(1)
    plt.close()


@torch.no_grad()
def pairplots_single( xtest, std_norm, std_test_plot, datatype, name_simu, dimplot=2, \
            crop_data_plot=False, plot_crop=3, plot_xlim=3, plot_ref_pdf=False, \
            pdf_theor=None, log_scale_pdf=False, columns_plot=None, \
            plt_show=False, dpi=200, height_seaborn=2.5, ssize=10):

    pddatatest = def_pd(xtest, std_norm, std_test_plot, datatype, dimplot=dimplot, \
            crop_data_plot=crop_data_plot, plot_crop=plot_crop, columns_plot=columns_plot)
    plot_kws={"s": ssize}
    scatter = sns.pairplot(pddatatest, aspect=1, height=height_seaborn, corner=True,plot_kws=plot_kws)
    for i, row in enumerate(scatter.axes):
        plot_ylim_row = plot_xlim * std_norm[i]* std_test_plot[i]
        for j, ax in enumerate(row):
            plot_xlim_col = plot_xlim * std_norm[j]* std_test_plot[j]
            if ax is not None:
                if i == j:  # Diagonal
                    ax.set_xlim((-plot_xlim_col,plot_xlim_col))
                if j < i:  # since corner=True, we only have lower triangle
                    ax.set_xlim((-plot_xlim_col,plot_xlim_col))
                    ax.set_ylim((-plot_ylim_row,plot_ylim_row))
    plt.tight_layout()
    if plt_show:
        plt.show(block=False)   
        plt.pause(0.1)
    plt.savefig("results/" + name_simu + ".png", dpi=dpi)
    plt.close()
    plt.pause(0.1)
    plt.close('all')


def preprocessing(xtest, xs_forward, num_steps_forward, name_simu_root, \
                  noising_plots, plt_show, folder_results, val_hist, std_test_plot, device):
    
    xgen_forward = xs_forward[-1,:,:].to(device)

    # metrics of convergence for the forward SDE
    cov_xtest = torch.cov(xtest.T)
    cov_xgen_forward = torch.cov(xgen_forward.T)
    xgen_forward_var = torch.var(xgen_forward.T,dim=1)
    xgen_forward_var_mean = xgen_forward_var.mean()
    xtest_var = torch.var(xtest.T,dim=1)
    xtest_var_mean = xtest_var.mean()

    # comparaison to cov ot X_inf
    cov_xgen_forward_converged = xtest_var_mean * torch.eye(xtest.shape[1]).to('cpu')
    # since tr(cov)=E||X||^2 is theoretically conserved
    d_cov_xtest = torch.norm(cov_xtest - cov_xgen_forward_converged)/torch.norm(cov_xgen_forward_converged)
    d_cov_xgen_forward = torch.norm(cov_xgen_forward - cov_xgen_forward_converged)/torch.norm(cov_xgen_forward_converged)
    print("dist cov_xtest to  cov_xgen_forward_converged (dist to  weak white noise)= " + str(d_cov_xtest.item()))
    print("dist cov_xgen_forward  to  cov_xgen_forward_converged = " + str(d_cov_xgen_forward.item()))

    # comparaison to cov of weak white noise (with same variance)
    cov_wwn = xgen_forward_var_mean * torch.eye(xtest.shape[1]).to('cpu')
    d_cov_xgen_forward = torch.norm(cov_xgen_forward - cov_wwn)/torch.norm(cov_wwn)
    print("dist cov_xgen_forward  to  weak white noise (w. same var.)= " + str(d_cov_xgen_forward.item()))

    # print energy
    energy_xtest = torch.sum((xtest**2),dim=1).mean()
    energy_xgen_forward = torch.sum((xgen_forward**2),dim=1).mean()
    print("energy_xtest = " + str(energy_xtest.item()))
    print("energy_xgen_forward = " + str(energy_xgen_forward.item()))
    print("energy_xgen_forward / energy_xtest = " + str(energy_xgen_forward.item()/energy_xtest.item()))

    # indices to visualize
    fig_step = int(num_steps_forward/8) #4
    if fig_step < 1:
        fig_step = 1
    inds_forward = range(0, num_steps_forward+1, fig_step)
    if (noising_plots):
        plot_selected_inds(xs_forward, inds_forward, \
            use_xticks= True, use_yticks=False, lmbd = 0., \
            include_t0=True, backward=False,
            plt_show=plt_show,
            val=val_hist* std_test_plot[0]) # plot
        time.sleep(0.5)
        if plt_show:
            plt.show(block=False)
        name_fig = folder_results + "/" + name_simu_root + "_Forward.png" 
        plt.savefig(name_fig)
        if plt_show:
            plt.pause(1)
        plt.close()
        plt.close('all')

def postprocessing(inds, i_dims, i_Res, i_num_stepss_backward, i_iterations, i_run, MSGM, sampler, \
                   xs, xtest, std_norm, std_test_plot, datatype, name_simu, dimplot, \
                   crop_data_plot, plot_crop, plot_xlim, plot_ref_pdf, \
                   pdf_theor, log_scale_pdf, columns_plot, \
                   scatter_plots, denoising_plots, include_t0_reverse, plt_show, dpi, height_seaborn, ssize, \
                   evalmmmd, justLoadmmmd, justLoad, save_results, lmbd, val_hist, device, \
                   mmd_ref, mmd_MSGM,mmd_SGM,max_num_samples_for_mmd):

    xgen = xs[-1,:,:].to(device)

    if save_results and not justLoad:
        np.save(name_simu + ".pt", xgen.clone().detach().cpu().numpy())

    # Identify rows with NaN values
    nan_mask = (torch.isnan(xgen) | (torch.abs(xgen) > 1e3 )).any(dim=1)
    # Count rows with NaN values
    nan_count = nan_mask.sum().item()
    if nan_count > 0:
        print(f"Number of rows with NaN or large value: {nan_count}")
    # Remove rows with NaN values
    xgen = xgen[~nan_mask,:]
    del nan_mask

    if (scatter_plots) and (i_run == 0):
        pairplots(xgen, xtest, std_norm, std_test_plot, datatype, name_simu, dimplot=dimplot, \
                    crop_data_plot=crop_data_plot, plot_crop=plot_crop, plot_xlim=plot_xlim, plot_ref_pdf=plot_ref_pdf, \
                    pdf_theor=pdf_theor, log_scale_pdf=log_scale_pdf, columns_plot=columns_plot, \
                    plt_show=plt_show, dpi=dpi, height_seaborn=height_seaborn, ssize=ssize)

    if (denoising_plots) and (i_run == 0):
        plot_selected_inds(xs, inds, True, False, lmbd, 
                            include_t0=include_t0_reverse, 
                            plt_show=plt_show, 
                            val=val_hist * std_test_plot[0]) # plot
        time.sleep(0.5)
        if plt_show:
            plt.show(block=False)
        name_fig = name_simu + ".png" 
        plt.savefig(name_fig)
        if plt_show:
            plt.pause(1)
        plt.close()
        plt.close('all')
    
    # MMD
    if evalmmmd and not justLoadmmmd:
        num_samples_for_mmd = min([xtest.shape[0],max_num_samples_for_mmd])
        xtest = xtest[0:num_samples_for_mmd-1,:]
        xgen = xgen[0:num_samples_for_mmd-1,:]
        with torch.no_grad():
            x_mmd1 = sampler.sample(xtest.shape[0]).to(device)
            dist_train_to_test = compute_mmd(std_norm * x_mmd1,std_norm * xtest)
            dist = compute_mmd(std_norm * xgen,std_norm * xtest)
        mmd_ref[i_dims, i_Res, i_num_stepss_backward,i_iterations,i_run] = dist_train_to_test
        print("MMD train to test = " + str(dist_train_to_test.sqrt().item()))
        print("MSGM = " + str(MSGM))
        print("MMD gen. to test = " + str(dist.sqrt().item()))
        if MSGM:
            mmd_MSGM[i_dims, i_Res, i_num_stepss_backward,i_iterations,i_run] = dist
        else:
            mmd_SGM[i_dims, i_Res, i_num_stepss_backward,i_iterations,i_run] = dist
