#!/usr/bin/env python
# -*- coding:utf-8 _*-
import os
import sys

sys.path.append('..')

import numpy as np
import torch
import time
import pickle
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--start_epochs',default=8000,type=int)
parser.add_argument("--epochs", default=500, type=int
                    )
parser.add_argument('--ft_steps', default=128, type=int
                    )
parser.add_argument('--lr', default=0.01, type=float)
parser.add_argument('--log_path',
                    default=None,
                    type=str)

parser.add_argument('--ift_method',default='broyden',type=str)
# neumann sgd parameters
parser.add_argument('--neumann_iter',default=16,type=int)
parser.add_argument('--alpha',default=0.001,type=float)
# broyden parameters
parser.add_argument('--threshold',default=40,type=int)
parser.add_argument('--eps',default=1e-4,type=float)
parser.add_argument('--ls',default=True,type=bool)
parser.add_argument('--beta',default=0.01,type=float)
parser.add_argument('--max_iter', default=-1, type=int)

#

# search params
parser.add_argument("--re0",default=100,type=float)
parser.add_argument("--scale",default=1.0,type=float)
parser.add_argument("--ref_ratio",default=3.0, type=float)
parser.add_argument("--rad_steps",default=5,type=int)
parser.add_argument('--gpu',
                    default=0,
                    type=int)

args = parser.parse_args()
print(args)
if args.gpu is not None:
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    print('Running on GPU ' + str(args.gpu))
else:
    os.environ['CUDA_VISIBLE_DEVICES'] = ''

# if you want to use given CUDA, must set default device before import dde, if cpu, set after it, or set visible device to ''
import deepxde as dde
from deepxde.icbcs import DirichletBC, PeriodicBC, IC
from bipinn.pdeco import PDEOptimizerModel
from xpinn.dde_utils import ParametricDirichletBC, ParametricPointSetBC
from xpinn.dde_utils import ResampleCallback
import torch.nn.functional as F

pi = np.pi

width = 2.0
height = 1.0
scale_ratio = args.scale
Re0 = args.re0
Re = scale_ratio*Re0
ref_ratio = args.ref_ratio


'''
    Use a FEM solver to estimate the result formally
'''

class NSControler(PDEOptimizerModel):
    def __init__(self, net_u, net_c, loss=None, log_path=None, n_domain=256, n_boundary=64, n_initial=64):

        def ns_pde(x, u):
            u_vel, v_vel, p = u[:, 0:1], u[:, 1:2], u[:, 2:]
            u_vel_x = dde.grad.jacobian(u, x, i=0, j=0)
            u_vel_y = dde.grad.jacobian(u, x, i=0, j=1)
            u_vel_xx = dde.grad.hessian(u, x, component=0, i=0, j=0)
            u_vel_yy = dde.grad.hessian(u, x, component=0, i=1, j=1)

            v_vel_x = dde.grad.jacobian(u, x, i=1, j=0)
            v_vel_y = dde.grad.jacobian(u, x, i=1, j=1)
            v_vel_xx = dde.grad.hessian(u, x, component=1, i=0, j=0)
            v_vel_yy = dde.grad.hessian(u, x, component=1, i=1, j=1)

            p_x = dde.grad.jacobian(u, x, i=2, j=0)
            p_y = dde.grad.jacobian(u, x, i=2, j=1)

            momentum_x = (
                     (u_vel * u_vel_x + v_vel * u_vel_y) + p_x - 1/Re * (u_vel_xx + u_vel_yy)
            )
            momentum_y = (
                     (u_vel * v_vel_x + v_vel * v_vel_y) + p_y - 1/Re * (v_vel_xx + v_vel_yy)
            )
            continuity = u_vel_x + v_vel_y

            return [momentum_x, momentum_y, continuity]


        self.pde = ns_pde

        vertices = np.array(
            [[0., 0], [width/scale_ratio, 0], [width/scale_ratio, height/scale_ratio], [width / (2*scale_ratio), height/scale_ratio], [width / (2*scale_ratio), height / (2*scale_ratio)], [0, height / (2*scale_ratio)]])
        geom = dde.geometry.Polygon(
            vertices=vertices
        )

        # train domain

        self.geom = geom

        def boundary_left(x, on_boundary):
            return on_boundary and np.isclose(x[0], 0)

        def boundary_right(x, on_boundary):
            return on_boundary and np.isclose(x[0], width/scale_ratio)



        def boundary_wall(x, on_boundary):
            return on_boundary and (not (np.isclose(x[0], 0) or np.isclose(x[0], width/scale_ratio)))





        def u_target(x):
            return scale_ratio*x[:,1:2]*(height - scale_ratio*x[:, 1: 2]) *ref_ratio

            # BCs

        def get_inlet_points(n):
            return np.concatenate([np.zeros([n])[:, None], height/(2*scale_ratio) *np.linspace(0, 1, n)[:, None]], axis=1)

        def aux_constraint(x, theta):
            return theta * scale_ratio*x[:, 1:2] * (height/2 - scale_ratio*x[:, 1:2])*64 / (height**3)/2

        inlet_points = get_inlet_points(len(net_c))
        bc_l_u = ParametricPointSetBC(inlet_points, net_c,component=0,aux_fun= aux_constraint)
        bc_l_v = DirichletBC(geom, lambda _: 0, boundary_left, component=1)
        bc_r = DirichletBC(geom, lambda _: 0, boundary_right, component=2)

        bc_wall_u = DirichletBC(geom, lambda _: 0, boundary_wall, component=0)
        bc_wall_v = DirichletBC(geom, lambda _: 0, boundary_wall, component=1)


        self.bcs = [bc_l_u, bc_l_v, bc_r, bc_wall_u, bc_wall_v]

        data = dde.data.PDE(geom, self.pde, self.bcs, n_domain, n_boundary, solution=None)

        self.theta_loss_ref = None
        self.u_target = u_target
        super(NSControler, self).__init__(data, net_u, loss=None, theta=net_c, log_path=log_path)



    # calcalate loss for reference control parameters
    def theta_loss(self, theta):
        pass


    # sample uniform points space X time
    def evaluate_J(self, n):
        x = torch.cat([width/scale_ratio * torch.ones([n, 1]), torch.linspace(0, height/scale_ratio, n).unsqueeze(-1) ], dim=1)
        x.requires_grad = True
        u = self.net(x)
        J = height * torch.mean((u[:,0:1] - self.u_target(x)) ** 2)*scale_ratio

        ## penalize the source function
        z = torch.linspace(0,height/2,len(self.theta))[:,None]
        u_in = self.theta * z *scale_ratio  *( width/2-z*scale_ratio)*64/(height**3)
        # u_in = self.theta
        reg = 100*(F.relu(-u_in)**2 + F.relu(u_in - 20)**2).mean() + 100*F.relu(self.theta.mean()-4)**2

        return (J + reg)


    def save_source(self):
        if self.log_path is not None:
            source_path = self.log_path + '_source.txt'

        else:
            source_path = 'ns2d_backstep_source.txt'

        with torch.no_grad():
            x = torch.linspace(0, 1, len(self.theta))[:,None]
            source = self.theta.clone()
            x, source = x.cpu().numpy(), source.cpu().numpy()
            f_source = np.concatenate([x,source],axis=1)
            np.savetxt(source_path, f_source)
        if self.epoch == 0:
            self.sources = []
        if ((self.finetune_epochs * self.epoch // 100) % 10 == 0):
            self.sources.append(f_source)
            pickle.dump(np.stack(self.sources, axis=0), open('./data/exp/ns_backstep_source_bpn.pkl', 'wb'))

        return



def main():
    log_dir = './data/models/' + str(args.log_path) + time.strftime(
        '_%m%d/%H_%M_%S') if args.log_path else None  # log folder by day
    if log_dir:
        sys.stdout = open('./data/logs/' + str(args.log_path) + time.strftime('_%m%d_%H_%M_%S') + '.txt', 'w')
    print(args)

    load_path = None

    layer_size = [2] + [64] * 4 + [3]

    activation = "tanh"
    initializer = "Glorot normal"

    net_u = dde.maps.FNN(layer_size, activation, initializer)
    net_c = torch.ones([96]).unsqueeze(-1).requires_grad_(True)


    model = NSControler(net_u, net_c, log_path=log_dir, n_domain=96 *96 , n_boundary = 8 * 96)

    if load_path is None:
        model.compile("adam", lr=0.001, metrics=None)
        losshistory, train_state = model.train(epochs=args.start_epochs, display_every=2000)

    else:
        model.compile("adam", lr=0.001)
        net_u.load_state_dict(torch.load(load_path, map_location='cuda' if args.gpu else "cpu")['net_u'])
        net_c.load_state_dict(torch.load(load_path, map_location='cuda' if args.gpu else "cpu")['net_c'])
        model.train(epochs=1000, display_every=1000)

    model.train_pdeco(epochs=args.epochs,
                      finetune_epochs=args.ft_steps,
                      num_val=64,
                      lr=args.lr,
                      ifd_method=args.ift_method,
                      threshold=args.threshold,
                      max_iter=args.max_iter,
                      ls=args.ls,
                      eps=args.eps,
                      beta=args.beta,
                      neumann_iter=args.neumann_iter,
                      alpha=args.alpha,
                      rad_steps=args.rad_steps,
                      grad_ref=None
                      )


if __name__ == "__main__":
    main()