import pickle

import numpy as np
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from scripts.error_landscape_runner import *
from scripts.error_landscape_utils import *
from scipy.interpolate import RegularGridInterpolator


def run(runner, save_path):
    """
    Train the specified model and 
    gather the model parameters during training.
    """
    runner.load_config()
    config = runner.get_config()

    # Modify the config if needed
    config["Log Interval"] = 100

    runner.set_config(config)
    runner.setup()
    runner.run()

    loss_history = runner.logger.loss_history
    l2re_history = runner.logger.l2re_history
    weights = runner.weights
    # Calculate the PCA directions
    directions = setup_PCA_directions(weights)
    # Project the model parameters to the PCA directions
    coords = project_trajectory(weights, directions)

    # Makeup the surface
    weight = weights[-1]
    x_len = np.max(np.abs(coords[0])) * 1.2
    y_len = np.max(np.abs(coords[1])) * 1.2
    X = np.linspace(-x_len, x_len, num=101)
    Y = np.linspace(-y_len, y_len, num=101)
    X, Y = np.meshgrid(X, Y, indexing='ij')
    Z = np.zeros_like(X)
    for i in range(len(X)):
        for j in range(len(Y)):
            set_weights(runner.model, weight, 
                directions, [X[i][j], Y[i][j]])
            l2re_val = runner.eval_error()
            Z[i][j] = l2re_val

    # Save X, Y, Z, coords, loss_history, l2re_history
    with open(save_path, "wb") as f:	
        pickle.dump([X, Y, Z, coords, 
            loss_history, l2re_history], f)


def simple_run(runner, save_path):
    """
    Only train the specified model.
    """
    runner.load_config()
    config = runner.get_config()

    # Modify the config if needed
    config["Log Interval"] = 100

    runner.set_config(config)
    runner.setup()
    runner.run()

    loss_history = runner.logger.loss_history
    l2re_history = runner.logger.l2re_history
    
    # Save loss_history, l2re_history
    with open(save_path, "wb") as f:
        pickle.dump([loss_history, l2re_history], f)


def plot_errorlandscape(load_path, plot_path, colorbar=True, 
    c_range=None, z_range=None):
    # Load X, Y, Z, coords, loss_history, l2re_history
    with open(load_path, "rb") as f:
        X, Y, Z, coords, loss_history, l2re_history = pickle.load(f)
    
    # Plot the surface
    coords = np.array(coords)
    x_len = np.max(np.abs(coords[0])) * 1.2
    y_len = np.max(np.abs(coords[1])) * 1.2
    X = np.linspace(-x_len, x_len, num=101)
    Y = np.linspace(-y_len, y_len, num=101)
    Z = np.log10(Z)
    interp = RegularGridInterpolator((X, Y), Z, method='nearest')
    values = interp(coords.T)
    if c_range is not None:
        surface = go.Surface(z=Z.T, x=X, 
                y=Y, colorscale='rdbu_r', 
                cmin=c_range[0], cmax=c_range[1],
                showscale=colorbar)
    else:
        surface = go.Surface(z=Z.T, x=X, 
                y=Y, colorscale='rdbu_r', 
                showscale=colorbar)
    surface.colorbar.tickfont = dict(size=60)

    fig = go.Figure(data=[surface, 
        go.Scatter3d(x=coords[0], y=coords[1], 
            z=values, marker=dict(size=7), line=dict(width=10))])
    fig.update_layout(autosize=True,
        width=800, height=800)
    fig.update_layout(
        font=dict(
            family="Times New Roman",
            color="black"
        )
    )
    fig.update_layout(
        scene=dict(
            xaxis_title="PCA 1",
            yaxis_title="PCA 2",
            zaxis_title="Log10 L2RE",
            xaxis=dict(title_font=dict(size=60), showticklabels=False),
            yaxis=dict(title_font=dict(size=60), showticklabels=False),
            zaxis=dict(title_font=dict(size=60), showticklabels=False)
        )
    )
    fig.update_layout(scene=dict(
        xaxis=dict(gridcolor='grey', backgroundcolor='white', gridwidth=4),
        yaxis=dict(gridcolor='grey', backgroundcolor='white', gridwidth=4),
        zaxis=dict(gridcolor='grey', backgroundcolor='white', gridwidth=4),
        bgcolor='white'  # background color of the whole scene
    ))

    if z_range is not None:
        fig.update_layout(scene=dict(zaxis=dict(range=[
            z_range[0], z_range[1]])))

    fig.update_traces(opacity=0.9)
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0)
    )
    fig.write_html(plot_path)


def plot_l2re_history(l2re_history):
    epochs = np.arange(0, 20001, 100)
    colors = {
        'MultiAdam': '#8E6A2A',
        'PINN': '#DEA01E',
        'LAAF': '#2474B5',
        'LRA': '#766BBB',
        'NTK': '#008575',
        'PCPINN': '#B85029',
    }
    plt.grid()
    order_dict = [
        'MultiAdam', 'PINN', 'LAAF', 'LRA', 'NTK', 'PCPINN', 
    ]
    z_ind = np.arange(1, 7)
    for i in range(len(order_dict)):
        k = order_dict[i]
        v = l2re_history[k]
        plt.plot(epochs, v[0], 
            color=colors[k], label=k, 
            linewidth=4.0, zorder=z_ind[i])
        plt.fill_between(epochs,
            v[0] - v[1], v[0] + v[1],
            color=colors[k], alpha=0.2)
    
    plt.rc('font',family='Times New Roman')
    plt.rcParams['font.sans-serif'] = 'times new roman'
    plt.xticks([0, 10000, 20000], fontproperties = 'times new roman', size=30)
    plt.yticks(fontproperties = 'times new roman', size=30)
    plt.xlabel('Iterations', fontsize=38)
    plt.ylabel(r'L2RE',fontsize=38)
    plt.legend(loc = 'center left',fontsize=28, bbox_to_anchor=(1.02, 0.5))
    # Log scale
    plt.yscale('log')

    plt.savefig('figs/history_error_landscape.pdf', bbox_inches='tight')


if __name__ == "__main__":
    # Run PC-PINN (ours)
    runner = PCPINNRunner("cuda" 
        if torch.cuda.is_available() else "cpu")
    run(runner, "temp/pcpinn.pkl")

    # Run PINN
    runner = PINNRunner("cuda"
        if torch.cuda.is_available() else "cpu")
    run(runner, "temp/pinn.pkl")

    with open("temp/pcpinn.pkl", "rb") as f:
        Z = np.log10(pickle.load(f)[2])
        vmin, vmax = np.min(Z), np.max(Z)
    with open("temp/pinn.pkl", "rb") as f:
        Z = np.log10(pickle.load(f)[2])
        vmin, vmax = min(vmin, np.min(Z)), max(vmax, np.max(Z))
    plot_errorlandscape("temp/pcpinn.pkl", 
        "figs/pcpinn_error_landscape.html", 
        c_range=[vmin, vmax])
    plot_errorlandscape("temp/pinn.pkl", 
        "figs/pinn_error_landscape.html", 
        colorbar=False, c_range=[vmin, vmax])
    
    # Run more trials
    for i in range(2):
        # Run PC-PINN (ours)
        runner = PCPINNRunner("cuda" 
            if torch.cuda.is_available() else "cpu")
        simple_run(runner, "temp/pcpinn_%d.pkl"%i)

        # Run PINN
        runner = PINNRunner("cuda"
            if torch.cuda.is_available() else "cpu")
        simple_run(runner, "temp/pinn_%d.pkl"%i)

    # Plot the loss history
    l2re_history_pcpinn = []
    with open("temp/pcpinn.pkl", "rb") as f:
        l2re_history_pcpinn.append(pickle.load(f)[-1])
    l2re_history_pinn = []
    with open("temp/pinn.pkl", "rb") as f:
        l2re_history_pinn.append(pickle.load(f)[-1])
    for i in range(2):
        with open("temp/pcpinn_%d.pkl"%i, "rb") as f:
            l2re_history_pcpinn.append(pickle.load(f)[-1])
        with open("temp/pinn_%d.pkl"%i, "rb") as f:
            l2re_history_pinn.append(pickle.load(f)[-1])
    
    l2re_history = np.load('scripts/l2re_baselines.npz')
    l2re_history = dict(l2re_history)
    l2re_history['PCPINN'] = [
        np.mean(l2re_history_pcpinn, axis=0),
        np.std(l2re_history_pcpinn, axis=0)]

    l2re_history['PINN'] = [
        np.mean(l2re_history_pinn, axis=0),
        np.std(l2re_history_pinn, axis=0)]
    l2re_history['PINN'] = [
        np.concatenate(
            (l2re_history['PINN'][0][0:1], 
            l2re_history['PINN'][0][99::100])
        ),
        np.concatenate(
            (l2re_history['PINN'][1][0:1], 
            l2re_history['PINN'][1][99::100])
        )
    ]
    
    plot_l2re_history(l2re_history)
