import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.func import vmap, jacrev, hessian
import os
from models.ns import NSRestrictedNet
 
from matplotlib.colors import Normalize

font = {'size'   : 16}
import matplotlib
matplotlib.rc('font', **font)

plot_downsample = 2

class MidpointNormalize(Normalize):
    def __init__(self, vmin, vmax, midpoint=0, clip=False):
        self.midpoint = midpoint
        Normalize.__init__(self, vmin, vmax, clip)

    def __call__(self, value, clip=None):
        normalized_min = max(0, 1 / 2 * (1 - abs((self.midpoint - self.vmin) / (self.midpoint - self.vmax))))
        normalized_max = min(1, 1 / 2 * (1 + abs((self.vmax - self.midpoint) / (self.midpoint - self.vmin))))
        normalized_mid = 0.5
        x, y = [self.vmin, self.midpoint, self.vmax], [normalized_min, normalized_mid, normalized_max]
        return np.ma.masked_array(np.interp(value, x, y))
   
   
def ns_errorplot(to_plot:np.array, model_names:list[str], X,Y, path:str, t:float, name:str, curr_pde:np.array, apx:str='', x_min=0., xmax=1.7, y_min=0., y_max=0.41):
    fig, ax = plt.subplots(nrows=to_plot.shape[0], ncols=1, figsize=(11, 4*to_plot.shape[0]), layout='compressed', sharex=True, sharey=True)
    
    vmin = np.min(to_plot)
    vmax = np.max(to_plot)
    levels = np.linspace(vmin, vmax, 100)
    
    for i in range(to_plot.shape[0]):
        #ax[i].streamplot(X, Y, curr_pde[:,3].reshape((50,50)), curr_pde[:,4].reshape((50,50)), color='black')
        im = ax[i].contourf(X, Y, to_plot[i], levels=levels, cmap='jet', vmin=vmin, vmax=vmax)
        for c in im.collections:
            c.set_rasterized(True)
        ax[i].set_xlabel('x')
        ax[i].set_ylabel('y')
        ax[i].set_xlim(x_min, xmax)
        ax[i].set_ylim(y_min, y_max)
        ax[i].set_title(f'{model_names[i]}')
    fig.colorbar(im, ax=ax, orientation='vertical')
    plt.savefig(f'{path}/{name}{t}{apx}.pdf')
    plt.close()
    
def ns_compareplot(to_plot:np.array, model_names:list[str], X,Y, path:str, t:float, name:str, curr_pde:np.array, apx:str='', x_min=0., xmax=1.7, y_min=0., y_max=0.41, compare_to='DERL'):
    fig, ax = plt.subplots(nrows=to_plot.shape[0], ncols=1, figsize=(11, 4*to_plot.shape[0]), layout='compressed', sharex=True, sharey=True)
    
    vmin = np.min(to_plot)
    vmax = np.max(to_plot)
    levels = np.linspace(vmin, vmax, 100)
    norm = MidpointNormalize(vmin=vmin, vmax=vmax, midpoint=0)
    
    for i in range(to_plot.shape[0]):
        #ax[i].streamplot(X, Y, curr_pde[:,3].reshape((50,50)), curr_pde[:,4].reshape((50,50)), color='black')
        im = ax[i].contourf(X, Y, to_plot[i], levels=levels, cmap='seismic_r', vmin=vmin, vmax=vmax, norm=norm)
        for c in im.collections:
            c.set_rasterized(True)
        ax[i].set_xlabel('x')
        ax[i].set_ylabel('y')
        ax[i].set_xlim(x_min, xmax)
        ax[i].set_ylim(y_min, y_max)
        ax[i].set_title(f'{model_names[i]} - {compare_to}')
    fig.colorbar(im, ax=ax, orientation='vertical')
    plt.savefig(f'{path}/{name}{t}{apx}.pdf')
    plt.close()
    
    
def plot_errors(model_list:list[NSRestrictedNet], model_names:list[str], path:str, t:float, curr_pde:np.array, apx:str=''):
    
    # Convert the model list to a numpy array
    num_models = len(model_list)
    print(f'Number of models: {num_models}')
    model_list = np.array(model_list)
    model_names = np.array(model_names)
    
    # Get the unique points
    points_x = np.unique(curr_pde[:,1])
    points_y = np.unique(curr_pde[:,2])
    len_x = len(points_x)
    len_y = len(points_y)

    # Arrays for the points to evaluate
    X = curr_pde[:,1].reshape((len_x,len_y))
    Y = curr_pde[:,2].reshape((len_x,len_y))
    T = t*np.ones_like(X.reshape((-1)))
    pts = np.vstack([T,X.reshape(-1),Y.reshape(-1)]).T
    print(pts.shape)
    
    # Get the indexes that are close to the obstacle
    avoid = np.argwhere(np.sqrt((pts[:,1] - 0.2)**2 + (pts[:,2] - 0.2)**2) < 1/20.+0.01)
    
    xvel_true = curr_pde[:,3].reshape(X.shape)
    yvel_true = curr_pde[:,4].reshape(X.shape)
    p_true = curr_pde[:,5].reshape(X.shape)
    
    ns_errorplot(np.array([xvel_true, yvel_true, p_true]), ['True X velocity', 'True Y velocity', 'True Pressure'], X, Y, f'{path}/errors', t, 'true_fields', curr_pde, apx)
    
    der_true = curr_pde[:,6:].reshape((-1,3,3))
    ders_true = curr_pde[:,6:].reshape(-1,3,3)
    dudy_true = curr_pde[:,8].reshape(X.shape)
    dvdx_true = curr_pde[:,10].reshape(X.shape)
    
    outs = []
    xvels = []
    yvels = []
    ps = []
    ders = []
    der_errors = []
    vorts = []
    nablas = []
    vorts_true  = []
    
    for i, model in enumerate(model_list):
        out = model.forward(torch.from_numpy(pts).to(model.device).float()).detach().cpu().numpy()
        der = vmap(jacrev(model.forward_single))(torch.from_numpy(pts).to(model.device).float())
        hess = vmap(hessian(model.forward_single))(torch.from_numpy(pts).to(model.device).float())
        nabla = torch.diagonal(hess[:,:2,1:,1:], dim1=2, dim2=3).sum(dim=2)
        nablas.append(np.linalg.norm(nabla.detach().cpu().numpy(), axis=1).reshape(X.shape))
        ders.append(der.detach().cpu().numpy())
        outs.append(out)
        vort = der[:,1,1] - der[:,0,2]
        vorts.append(vort.detach().cpu().numpy())
        xvels.append(out[:,0].reshape(X.shape))
        yvels.append(out[:,1].reshape(X.shape))
        ps.append(out[:,2].reshape(X.shape))
        
        der_true = curr_pde[:,6:].reshape((-1,3,3))
        vort_true = der_true[:,1,1] - der_true[:,0,2]
        vorts_true.append(vort_true)
    
    outs = np.array(outs)
    ders = np.array(ders)
    
    xvels = np.array(xvels)
    xvel_errors = np.sqrt((xvels - xvel_true)**2)
    
    yvels = np.array(yvels)
    yvel_errors = np.sqrt((yvels - yvel_true)**2)
    
    ps = np.array(ps)
    ps_errors = np.sqrt((ps - p_true)**2)
    
    for i in range(num_models):
        ns_errorplot(np.array([xvels[i], yvels[i], ps[i]]), ['X velocity', 'Y velocity', 'Pressure'], X, Y, f'{path}/errors', t, f'{model_names[i]}', curr_pde, apx)
    
    vorts = np.array(vorts)
    vorts_true = np.array(vorts_true)
    #der_true = curr_pde[:,6:].reshape((-1,3,3))
    #vorts_true = der_true[:,1,1] - der_true[:,0,2]
    vorts_true = (dvdx_true - dudy_true).reshape((-1))
    vort_errors = np.sqrt((vorts - vorts_true)**2).reshape((num_models, X.shape[0], X.shape[1]))
    vorts = vorts.reshape((num_models, X.shape[0], X.shape[1]))
    
    nablas = np.array(nablas)
    
    global_errors = np.linalg.norm(outs - curr_pde[:,3:6], axis=2).reshape((num_models, X.shape[0], X.shape[1]))
    der_errors = np.linalg.norm(ders - curr_pde[:,6:].reshape((-1,3,3)), axis=(2,3)).reshape((num_models, X.shape[0], X.shape[1]))
    
    if t == 0.:
        mode = 'w'
    else:
        mode = 'a'
    
    with open(f'{path}/losses.txt', mode) as f:
        print(f'\nGlobal error averaged over the domain, t = {t}', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: L2 loss {np.sqrt(np.mean(global_errors[i]**2))}', file=f)
            print(f'{model_names[i]}: mean {np.mean(global_errors[i])}, std {np.std(global_errors[i])}', file=f)
        print('\nX velocity error averaged over the domain', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: L2 loss {np.sqrt(np.mean(xvel_errors[i]**2))}', file=f)
            print(f'{model_names[i]}: mean {np.mean(xvel_errors[i])}, std {np.std(xvel_errors[i])}', file=f)
        print('\nY velocity error averaged over the domain', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: L2 loss {np.sqrt(np.mean(yvel_errors[i]**2))}', file=f)
            print(f'{model_names[i]}: mean {np.mean(yvel_errors[i])}, std {np.std(yvel_errors[i])}', file=f)
        print('\nPressure error averaged over the domain', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: L2 loss {np.sqrt(np.mean(ps_errors[i]**2))}', file=f)
            print(f'{model_names[i]}: mean {np.mean(ps_errors[i])}, std {np.std(ps_errors[i])}', file=f)
        print('\nDerivative error averaged over the domain', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: L2 loss {np.sqrt(np.mean(der_errors[i]**2))}', file=f)
            print(f'{model_names[i]}: mean {np.mean(der_errors[i])}, std {np.std(der_errors[i])}', file=f)
        print('\nVorticity error averaged over the domain', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: L2 loss {np.sqrt(np.mean(vort_errors[i]**2))}', file=f)
            print(f'{model_names[i]}: mean {np.mean(vort_errors[i])}, std {np.std(vort_errors[i])}', file=f)

    if os.path.exists(f'{path}/errors') == False:
        os.makedirs(f'{path}/errors')
    if os.path.exists(f'{path}/errors_compare') == False:
        os.makedirs(f'{path}/errors_compare')
    
    ns_errorplot(global_errors, model_names, X, Y, f'{path}/errors', t, 'global_error', curr_pde, apx)    
    global_errors_comp = global_errors[1:] - global_errors[0]
    ns_compareplot(global_errors_comp, model_names[1:], X, Y, f'{path}/errors_compare', t, 'global_error_comp', curr_pde, apx)
    
    ns_errorplot(der_errors, model_names, X, Y, f'{path}/errors', t, 'derivative_error', curr_pde, apx)
    der_errors_comp = der_errors[1:] - der_errors[0]
    ns_compareplot(der_errors_comp, model_names[1:], X, Y, f'{path}/errors_compare', t, 'derivative_error_comp', curr_pde, apx)
    
    ns_errorplot(vort_errors, model_names, X, Y, f'{path}/errors', t, 'vorticity_error', curr_pde, apx)
    vort_errors_comp = vort_errors[1:] - vort_errors[0]
    ns_compareplot(vort_errors_comp, model_names[1:], X, Y, f'{path}/errors_compare', t, 'vorticity_error_comp', curr_pde, apx)
    
    ns_errorplot(vorts, model_names, X, Y, f'{path}/errors', t, 'vorticity', curr_pde, apx)
    
    ns_errorplot(nablas, model_names, X, Y, f'{path}/errors', t, 'nablas', curr_pde, apx)
    '''
    ns_errorplot(xvel_errors, model_names, X, Y, f'{path}/errors', t, 'xvel_error', curr_pde, apx)
    xvel_errors_comp = xvel_errors[1:] - xvel_errors[0]
    ns_compareplot(xvel_errors_comp, model_names[1:], X, Y, f'{path}/errors_compare', t, 'xvel_error_comp', curr_pde, apx)
    
    ns_errorplot(yvel_errors, model_names, X, Y, f'{path}/errors', t, 'yvel_error', curr_pde, apx)
    yvel_errors_comp = yvel_errors[1:] - yvel_errors[0]
    ns_compareplot(yvel_errors_comp, model_names[1:], X, Y, f'{path}/errors_compare', t, 'yvel_error_comp', curr_pde, apx)
    
    ns_errorplot(ps_errors, model_names, X, Y, f'{path}/errors', t, 'p_error', curr_pde, apx)
    ps_errors_comp = ps_errors[1:] - ps_errors[0]
    ns_compareplot(ps_errors_comp, model_names[1:], X, Y, f'{path}/errors_compare', t, 'p_error_comp', curr_pde, apx)
    '''
    moms = []
    divs = []
    
    for i, model in enumerate(model_list):
        mom, div = model.get_consistencies(torch.from_numpy(pts).to(model.device).float())
        moms.append(mom.detach().cpu().numpy().reshape(X.shape))
        divs.append(div.detach().cpu().numpy().reshape(X.shape))
        
    moms = np.array(moms)
    divs = np.array(divs)
    
    with open(f'{path}/consistencies.txt', mode) as f:
        print(f'Momentum consistency averaged over the domain, t = {t}', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: mean {np.mean(moms[i])}, std {np.std(moms[i])}', file=f)
        print('Divergence consistency averaged over the domain', file=f)
        for i, model in enumerate(model_list):
            print(f'{model_names[i]}: mean {np.mean(divs[i])}, std {np.std(divs[i])}', file=f)
            
            
    if os.path.exists(f'{path}/consistencies') == False:
        os.makedirs(f'{path}/consistencies')
    if os.path.exists(f'{path}/consistencies_compare') == False:
        os.makedirs(f'{path}/consistencies_compare')
        
    
    ns_errorplot(moms, model_names, X, Y, f'{path}/consistencies', t, 'momentum_consistency', curr_pde, apx)
    moms_comp = moms[1:] - moms[0]
    ns_compareplot(moms_comp, model_names[1:], X, Y, f'{path}/consistencies_compare', t, 'momentum_consistency', curr_pde, apx)
    
    ns_errorplot(divs, model_names, X, Y, f'{path}/consistencies', t, 'divergence_consistency', curr_pde, apx)
    divs_comp = divs[1:] - divs[0]
    ns_compareplot(divs_comp, model_names[1:], X, Y, f'{path}/consistencies_compare', t, 'divergence_consistency', curr_pde, apx)
    
    
    
    
            
    