import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

import matplotlib.pyplot as plt
from lorenz_utils import *

torch.manual_seed(0)
np.random.seed(0)

if __name__ == '__main__':
    # Grid params
    grid_sub = 17 # Fineness of grid to be sampled on
    upper = 500.
    lower = -500.

    plot_diss_shell = True
    inner_radius = 90
    outer_radius = inner_radius + 40

    out_dim = 3

    # Data
    sub = 128 * 10

    predloader = MatReader('L63T10000.mat')
    data = predloader.read_field('u')[::sub] #subsample for plotting
    data = data.numpy()

    # Model
    path = 'lorenz_dissipative_densenet_inner_rad_90_relu_0_05_time160000_ep1000_lr0_0005_schedstep100_relLpTrue_layers3_150_150_150_150_150_150_3'
    #path = 'lorenz_densenet_relu_time_0_05160000_ep1000_lr0_0005_schedstep100_relLpTrue_layers3_150_150_150_150_150_150_3'
    #path = 'lorenz_norm_dissipative_densenet_inner_rad_90_relu_0_05_time_test1160000_ep1000_lr0_0005_schedstep100_relLpTrue_layers3_150_150_150_150_150_150_3'
    model = torch.load('model/' + path)

    # Create grid
    x,y,z =  np.meshgrid(np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub), indexing='ij')

    # Generate predictions
    x_out,y_out,z_out =  np.meshgrid(np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub), indexing='ij')
    
    with torch.no_grad():
        for i, j, k in np.ndindex((grid_sub, grid_sub, grid_sub)):
            input_arr = torch.tensor(np.array([x[i][j][k], y[i][j][k], z[i][j][k]]), dtype=torch.float).cuda()
            pred = model(input_arr).reshape(out_dim,).cpu().numpy()
            # Resulting vector
            x_out[i][j][k] = pred[0] - x[i][j][k]
            y_out[i][j][k] = pred[1] - y[i][j][k]
            z_out[i][j][k] = pred[2] - z[i][j][k]

    # Plot 3D vector field
    # fig = plt.figure(figsize=(10,10))
    # ax = fig.gca(projection='3d')
    # ax.quiver(x, y, z, x_out, y_out, z_out, length=0.03)

    # Plot xy
    x,y = np.meshgrid(np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub))
    fig = plt.figure(figsize=(10,10))
    ax = fig.gca()
    ax.quiver(x, y, y_out[:,:,grid_sub//2], x_out[:,:,grid_sub//2], scale=1300, scale_units='inches') # swapped x and y because ij indexing transposes
    ax.scatter(data[:,0], data[:,1], s=0.1, c='r')

    if plot_diss_shell:
        ax.add_patch(plt.Circle((0,0), inner_radius, color='b', fill=False))
        ax.add_patch(plt.Circle((0,0), outer_radius, color='b', fill=False))

    plt.xlabel('x', fontsize=24)
    plt.ylabel('y', fontsize=24)
    plt.title('xy vector field at z=0')
    plt.savefig('plots/' + path + '_xy_vectorfield.png')
    print("Plot saved to", 'plots/' + path + '_xy_vectorfield.png')
    print()

    # Plot xz
    x,z = np.meshgrid(np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub))
    fig = plt.figure(figsize=(10,10))
    ax = fig.gca()
    ax.quiver(x, z, z_out[:,grid_sub//2,:], x_out[:,grid_sub//2,:], scale=1300, scale_units='inches') # swapped x and z because ij indexing transposes
    ax.scatter(data[:,0], data[:,2], s=0.1, c='r')

    if plot_diss_shell:
        ax.add_patch(plt.Circle((0,0), inner_radius, color='b', fill=False))
        ax.add_patch(plt.Circle((0,0), outer_radius, color='b', fill=False))

    plt.xlabel('x', fontsize=24)
    plt.ylabel('z', fontsize=24)
    plt.title('xz vector field at y=0')
    plt.savefig('plots/' + path + '_xz_vectorfield.png')
    print("Plot saved to", 'plots/' + path + '_xz_vectorfield.png')
    print()

    # Plot yz
    y,z = np.meshgrid(np.linspace(lower, upper, grid_sub), np.linspace(lower, upper, grid_sub))
    fig = plt.figure(figsize=(10,10))
    ax = fig.gca()
    ax.quiver(y, z, z_out[grid_sub//2,:,:], y_out[grid_sub//2,:,:], scale=1000, scale_units='inches') # swapped y and z because ij indexing transposes
    ax.scatter(data[:,1], data[:,2], s=0.1, c='r')

    if plot_diss_shell:
        ax.add_patch(plt.Circle((0,0), inner_radius, color='b', fill=False))
        ax.add_patch(plt.Circle((0,0), outer_radius, color='b', fill=False))

    plt.xlabel('y', fontsize=24)
    plt.ylabel('z', fontsize=24)
    plt.title('yz vector field at x=0')
    plt.savefig('plots/' + path + '_yz_vectorfield.png')
    print("Plot saved to", 'plots/' + path + '_yz_vectorfield.png')
