import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import seaborn as sns
import pysindy as ps
import torch
import torch.nn as nn

import sys
sys.path.append("../")
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"


from HyperSINDy import Net
from baseline import Trainer
from library_utils import Library
from Datasets import SyntheticDataset
from other import init_weights, make_folder, set_random_seed
from exp_utils import get_equations, log_equations
from kramersmoyal import km

"""
Generate results for fig5 (lotka volterra kramers-moyal).
Kramers-Moyal code from https://github.com/LRydin/KramersMoyal
"""

def load_model(device, z_dim, poly_order, include_constant,
               noise_dim, hidden_dim, stat_size, batch_size,
               cp_path, num_hidden):

    torch.cuda.set_device(device=device)
    device = torch.cuda.current_device()

    library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)

    net = Net(library, noise_dim=noise_dim, hidden_dim=hidden_dim,
              statistic_batch_size=stat_size, num_hidden=num_hidden).to(device)
    net.get_masked_coefficients(batch_size=batch_size, device=device)

    cp = torch.load(cp_path, map_location="cuda:" + str(device)) 
    net.load_state_dict(cp['model'])
    net.to(device)
    net = net.eval()
    
    return net, library, device

def sample_trajectory(net, library, device, x0, seed, batch_size=10, dt=1e-2, ts=10000):
    set_random_seed(seed)
    zc = torch.from_numpy(x0).type(torch.FloatTensor).to(device).unsqueeze(0)
    zc = zc.expand(batch_size, -1)
    zs = [zc]
    for i in range(ts - 1):
        coefs = net.get_masked_coefficients(batch_size=batch_size, device=device)        
        lib = net.library.transform(zc).unsqueeze(1)
        zc = zc + torch.bmm(lib, coefs).squeeze(1) * dt
        zs.append(zc)
    zs = torch.stack(zs, dim=0)
    zs = torch.transpose(zs, 0, 1)
    return zs.detach().cpu().numpy()

def plot(c1, c2, c3, vmin, vmax, fname):
    fig, axes = plt.subplots(1, 3, figsize=(20, 20), dpi=300)
    axes[0].imshow(c1 , cmap='turbo', vmin=vmin, vmax=vmax)
    axes[1].imshow(c2, cmap='turbo', vmin=vmin, vmax=vmax)
    im = axes[2].imshow(c3, cmap='turbo', vmin=vmin, vmax=vmax)
    fig.subplots_adjust(bottom=0.1, top=0.9, left=0.1, right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.4, 0.0125, 0.2])
    cbar = fig.colorbar(im, cax=cbar_ax, ticks=[vmin, (vmin + vmax) / 2, vmax], aspect=0.2)
    cbar.ax.tick_params(size=0, labelsize=30)
    axes[0].axis('off')
    axes[1].axis('off')
    axes[2].axis('off')
    plt.savefig(fname)
    plt.show()
    plt.close()

def main():
    # Settings
    sns.set()
    SEED = 5281998
    set_random_seed(SEED)
    device = 2
    noise_dim = 4
    batch_size = 250
    data_folder = "../data/"
    model = "HyperSINDy"
    hidden_dim = 64
    stat_size = 250
    num_hidden = 5
    z_dim = 2
    adam_reg = 1e-2
    gamma_factor = 0.999
    poly_order = 3
    include_constant = True
    runs = "runs/"
    library = Library(n=z_dim, poly_order=poly_order, include_constant=include_constant)
    t = None

    plt.rcParams['mathtext.rm'] = 'Arial'

    # Load model
    cp_path = "../runs/lotkavolterra/cp_1.pt"
    net, library, device = load_model(device, z_dim, poly_order, include_constant,
                                    noise_dim, hidden_dim, stat_size, batch_size,
                                    cp_path, num_hidden)

    # Generate HyperSINDy trajectory
    x0_test = np.array([2.1, 1.0])
    dt = 0.01
    samples = sample_trajectory(net, library, device, x0_test, SEED, 10, dt=dt, ts=10000)
    hsample = samples[6]

    # KMC settings
    powers = np.array([[0,0], [1,0], [0,1], [1,1], [2,0], [0,2], [2,2]])
    bw = 0.05

    # Get KMC for test trajectory
    xt = np.load("../data/lotkavolterra/x_test.npy")
    bin1 = np.linspace(np.min(xt[:,0]), np.max(xt[:,0]), 300)
    bin2 = np.linspace(np.min(xt[:,1]), np.max(xt[:,1]), 300)
    bins = np.array((bin1, bin2))

    kmc, edges = km(xt, bw = bw, bins = bins, powers = powers)

    # Get KMC for HyperSINDy trajectory
    kmc2, edge2 = km(hsample, bw = bw, bins = bins, powers = powers)

    # Get KMC for PyDaddy trajectory
    xp = np.load("../data/gen_ss/x_train.npy")
    kmc3, edges3 = km(xp, bw = bw, bins = bins, powers = powers)

    # Mask out zeros for better plot
    m1 = np.ma.masked_where(kmc==0,kmc)
    m2 = np.ma.masked_where(kmc2==0,kmc2)
    m3 = np.ma.masked_where(kmc3==0,kmc3)
    
    # Plot KM coefficients
    plot(m1[1], m2[1], m3[1], np.percentile(kmc[1], 10), np.percentile(kmc[1], 90), "../results/drift_x.png")
    plot(m1[2], m2[2], m3[2], np.percentile(kmc[2], 10), np.percentile(kmc[2], 90), "../results/drift_y.png")
    plot(m1[4], m2[4], m3[4], np.percentile(kmc[4], 5), np.percentile(kmc[4], 90), "../results/diff_x.png")
    plot(m1[5], m2[5], m3[5], np.percentile(kmc[5], 5), np.percentile(kmc[5], 90), "../results/diff_y.png")

    # Print equations
    print(get_equations(net, library, "HyperSINDy", device, round_eq=True, seed=5281998))


if __name__ == "__main__":
    main()