import torch
import numpy as np
import math
from pino.fourier2d import FNN2d
from pino.adam import Adam
import time


model = FNN2d(modes1=[20, 20, 20, 20],
              modes2=[20, 20, 20, 20],
              fc_dim=128,
              layers=[64, 64, 64, 64, 64],
              activation='gelu',
              out_dim=3).cuda()

optimizer = Adam(model.parameters(), betas=(0.9, 0.999),
                 lr=0.005)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                milestones=[500,1000,1500,2000,2500,3000,3500],
                                                gamma=0.5)

model.train()

res_size_x = 441
res_size_y = 83

dx = 2.2 / (res_size_x-1)
dy = 0.41 / (res_size_y-1)

x = torch.Tensor([[i*dx for j in range(res_size_y)] for i in range(res_size_x)] ).cuda()
y = torch.Tensor([[j*dy for j in range(res_size_y)] for i in range(res_size_x)] ).cuda()

x_extended = torch.Tensor([[(i-1)*dx for j in range(res_size_y+2)] for i in range(res_size_x+2)] ).cuda()
y_extended = torch.Tensor([[(j-1)*dy for j in range(res_size_y+2)] for i in range(res_size_x+2)] ).cuda()

phi_S = torch.sqrt( (x_extended-0.2)**2 + (y_extended-0.2)**2 ) - 0.05

interior = torch.ones((res_size_x,res_size_y)).cuda()
for i in range(res_size_x):
    for j in range(res_size_y):
        if phi_S[i+1,j+1] <= 0:
            interior[i,j] = 0


with open('boundary-conditions-for-pinos-code/fenics_reference_solution/DFG_2D_1/pred_441_83_3.npy', 'rb') as f:
    reference_solution = torch.tensor(np.load(f)).cuda()

n_col_circle = int(0.1 * np.pi / dx)
weights_circle = torch.zeros((n_col_circle,res_size_x,res_size_y)).cuda()
for n in range(n_col_circle):
    x_col = 0.2 + 0.05*np.cos(2*np.pi*n/n_col_circle)
    y_col = 0.2 + 0.05*np.sin(2*np.pi*n/n_col_circle)
    i = int(x_col/dx)
    j = int(y_col/dy)
    alpha = x_col/dx - i
    beta = y_col/dy - j
    weights_circle[n,i,j] = (1-alpha) * (1-beta)
    weights_circle[n,i+1,j] = alpha * (1-beta)
    weights_circle[n,i,j+1] = (1-alpha) * beta
    weights_circle[n,i+1,j+1] = alpha * beta

U = 0.3
g_1 = torch.zeros( (res_size_x+2,res_size_y+2,2) ).cuda()
g_1[..., 0] = 4*U*y_extended*(0.41-y_extended)/0.41**2


input = torch.ones( (1, res_size_x+2, res_size_y+2, 3) ).cuda()
input[0, :, :, 1] = x_extended
input[0, :, :, 2] = y_extended

losses = []
losses_pde = []
losses_bc = []
errors_u = []
errors_v = []
errors_p = []

epochs = 4000
start_time = time.time()
for i in range(epochs):
    optimizer.zero_grad()
    output = model(input)

    nu = 0.001
    sqrt_nu = np.sqrt(0.001)

    Psi_u = output[..., 0]
    Psi_v = output[..., 1]
    Psi_p = output[..., 2]

    pred_u = Psi_u
    pred_v = Psi_v
    pred_p = Psi_p

    pred_u_x_3 = (pred_u[0,-1,2:-2] - pred_u[0,-3,2:-2]) / dx / 2
    pred_v_x_3 = (pred_v[0,-1,2:-2] - pred_v[0,-3,2:-2]) / dx / 2

    pred_u = pred_u[:,1:-1,1:-1]
    pred_v = pred_v[:,1:-1,1:-1]
    pred_p = pred_p[:,1:-1,1:-1]

    pred_u_x = (pred_u[:, 2::, 1:-1] - pred_u[:, 0:-2, 1:-1]) / dx / 2
    pred_u_y = (pred_u[:, 1:-1, 2::] - pred_u[:, 1:-1, 0:-2]) / dy / 2
    pred_v_x = (pred_v[:, 2::, 1:-1] - pred_v[:, 0:-2, 1:-1]) / dx / 2
    pred_v_y = (pred_v[:, 1:-1, 2::] - pred_v[:, 1:-1, 0:-2]) / dy / 2
    pred_u_xx = (pred_u[:, 2::, 1:-1] - 2*pred_u[:, 1:-1, 1:-1] + pred_u[:, 0:-2, 1:-1]) / dx**2
    pred_u_yy = (pred_u[:, 1:-1, 2::] - 2*pred_u[:, 1:-1, 1:-1] + pred_u[:, 1:-1, 0:-2]) / dy**2
    pred_v_xx = (pred_v[:, 2::, 1:-1] - 2*pred_v[:, 1:-1, 1:-1] + pred_v[:, 0:-2, 1:-1]) / dx**2
    pred_v_yy = (pred_v[:, 1:-1, 2::] - 2*pred_v[:, 1:-1, 1:-1] + pred_v[:, 1:-1, 0:-2]) / dy**2
    pred_p_x = (pred_p[:, 2::, 1:-1] - pred_p[:, 0:-2, 1:-1]) / dx / 2
    pred_p_y = (pred_p[:, 1:-1, 2::] - pred_p[:, 1:-1, 0:-2]) / dy / 2

    res_1 = -nu * ( pred_u_xx + pred_u_yy ) + sqrt_nu*pred_p_x + pred_u[:,1:-1,1:-1] * pred_u_x + pred_v[:,1:-1,1:-1] * pred_u_y 
    res_2 = -nu * ( pred_v_xx + pred_v_yy ) + sqrt_nu*pred_p_y + pred_u[:,1:-1,1:-1] * pred_v_x + pred_v[:,1:-1,1:-1] * pred_v_y 
    res_3 = pred_u_x + pred_v_y

    res_1 = res_1 * interior[1:-1,1:-1]
    res_2 = res_2 * interior[1:-1,1:-1]
    res_3 = res_3 * interior[1:-1,1:-1]

    loss_pde = (torch.nn.functional.mse_loss(res_1, torch.zeros(res_1.shape).cuda()) + torch.nn.functional.mse_loss(res_2, torch.zeros(res_2.shape).cuda()) + torch.nn.functional.mse_loss(res_3, torch.zeros(res_3.shape).cuda()))

    res_circle_u = torch.sum(weights_circle * pred_u, dim=[1,2])
    res_circle_v = torch.sum(weights_circle * pred_v, dim=[1,2])
    res_u_bc = torch.cat(( res_circle_u, pred_u[0,0,:]-g_1[0,1:-1, 0], pred_u[0,1::,0], pred_u[0,1::,-1] ))
    res_v_bc = torch.cat(( res_circle_v, pred_v[0,0,:], pred_v[0,1::,0], pred_v[0,1::,-1] ))
    res_robin = torch.cat(( sqrt_nu*pred_u_x_3-pred_p[0,-1,1:-1], pred_v_x_3 ))
    loss_bc = torch.nn.functional.mse_loss( res_u_bc, torch.zeros(res_u_bc.shape).cuda() ) + torch.nn.functional.mse_loss( res_v_bc, torch.zeros(res_v_bc.shape).cuda() ) + torch.nn.functional.mse_loss( res_robin, torch.zeros(res_robin.shape).cuda() )
    
    loss = loss_pde + loss_bc

    error_u = torch.norm( (pred_u[0, ...] - reference_solution[...,0])*interior, 2 ) / torch.norm(reference_solution[..., 0]*interior, 2)
    error_v = torch.norm( (pred_v[0, ...] - reference_solution[...,1])*interior, 2 ) / torch.norm(reference_solution[..., 1]*interior, 2)
    error_p = torch.norm( (pred_p[0, ...]*sqrt_nu - reference_solution[...,2])*interior, 2 ) / torch.norm(reference_solution[..., 2]*interior, 2)

    losses.append(loss.item())
    losses_pde.append(loss_pde.item())
    losses_bc.append(loss_bc.item())
    errors_u.append(error_u.item())
    errors_v.append(error_v.item())
    errors_p.append(error_p.item())

    loss.backward()
    optimizer.step()
    scheduler.step()

    print('epoch: ', i, ' loss_pde: ', loss_pde.item(), ' loss_bc: ', loss_bc.item(), ' lr: ', scheduler.get_last_lr()[0], ' error_u:', error_u.item(), ' error_v:', error_v.item(), ' error_p:', error_p.item())


end_time = time.time()
print("Training time: ", end_time-start_time)

with open('boundary-conditions-for-pinos-code/results_Navier_stokes/weak/losses_errors.npy', 'wb') as f:
    np.save(f, np.array(losses))
    np.save(f, np.array(errors_u))
    np.save(f, np.array(errors_v))
    np.save(f, np.array(errors_p))

predicted_solution = np.zeros((res_size_x,res_size_y,3))
predicted_solution[...,0] = pred_u.cpu().detach().numpy()
predicted_solution[...,1] = pred_v.cpu().detach().numpy()
predicted_solution[...,2] = pred_p.cpu().detach().numpy()

with open('boundary-conditions-for-pinos-code/results_Navier_stokes/weak/prediction_441_83_3.npy', 'wb') as f:
    np.save(f, predicted_solution)

with open('boundary-conditions-for-pinos-code/results_Navier_stokes/weak/runtime.npy', 'wb') as f:
    np.save(f, end_time-start_time)

torch.save(model, 'boundary-conditions-for-pinos-code/results_Navier_stokes/weak/net.pt')