import numpy as np
from scipy.sparse import diags, identity
from scipy.sparse.linalg import spsolve


def solve_Eikonal(N, epsilon):
    hg = np.array(1/(N+1))
    x_grid = (np.arange(1,N+1,1))*hg
    a1 = np.ones((N,N+1))
    a2 = np.ones((N+1,N))

    # diagonal element of A
    a_diag = np.reshape(a1[:,:N]+a1[:,1:]+a2[:N,:]+a2[1:,:], (1,-1))
    
    # off-diagonals
    a_super1 = np.reshape(np.append(a1[:,1:N], np.zeros((N,1)), axis = 1), (1,-1))
    a_super2 = np.reshape(a2[1:N,:], (1,-1))
    
    A = diags([[-a_super2[np.newaxis, :]],
               [-a_super1[np.newaxis, :]],
            [a_diag], [-a_super1[np.newaxis, :]],
            [-a_super2[np.newaxis, :]]], [-N,-1,0,1,N],
            shape=(N**2, N**2), format = 'csr')
    f = np.zeros((N,N))
    f[0,:] = f[0,:] + epsilon**2 / (hg**2)
    f[N-1,:] = f[N-1,:] + epsilon**2 / (hg**2)
    f[:, 0] = f[:, 0] + epsilon**2 / (hg**2)
    f[:, N-1] = f[:, N-1] + epsilon**2 / (hg**2)
    fv = f.flatten()
    fv = fv[:, np.newaxis]
    
    mtx = identity(N**2)+(epsilon**2)*A/(hg**2)
    sol_v = spsolve(mtx, fv)
    sol_u = -epsilon*np.log(sol_v)
    return sol_u


if __name__ == '__main__':
    import os
    import matplotlib.pyplot as plt
    N = 24
    x = np.linspace(0, 1, N+2)
    y = np.linspace(0, 1, N+2)
    z = np.zeros((N+2, N+2))
    z[1:-1, 1:-1] = solve_Eikonal(N, 0.1).reshape(N,N)
    plt.figure(figsize=(6,5))
    plt.tick_params(axis='both', which='major', labelsize=15)
    plt.contourf(x, y, z)#, 100)#, cmap='inferno')
    clb = plt.colorbar()
    clb.ax.tick_params(labelsize=15)
    plt.title("FD Solution", fontsize=20)
    # plt.xlabel("x_1", fontsize=18)
    # plt.ylabel("x_2", fontsize=18)
    # plt.show()
    os.makedirs('./results', exist_ok=True)
    plt.savefig(os.path.join('./results','eikonal_FD.png'), bbox_inches='tight')