from kan import *
import torch
import matplotlib.pyplot as plt
from torch import autograd
from tqdm import tqdm
from experiments.baselines.MLP import MLP, MLP_RFF
import time

# implement Allen Cahn 1D,  2D input (x, t)
N_deltat = 10
steps = 1000 #10000

start_time = time.time()

for i_deltat in range(N_deltat):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(device)

    dim = 2
    np_i = 51 # number of interior points (along each dimension)
    np_t = 51


    def batch_jacobian(func, x, create_graph=False):
        # x in shape (Batch, Length)
        def _func_sum(x):
            return func(x).sum(dim=0)
        return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)


    print('mark')
    # interior
    sampling_mode = 'mesh' # 'radnom' or 'mesh'

    x_mesh = torch.linspace(-1,1,steps=np_i).to(device)
    t_mesh = torch.linspace(0.,0.1,steps=np_t).to(device)
    X, T = torch.meshgrid(x_mesh, t_mesh, indexing='ij')

    if sampling_mode == 'mesh':
        #mesh
        x_i = torch.stack([X.reshape(-1,), T.reshape(-1,)]).permute(1,0)
    else:
        #random
        x_i = torch.rand((np_i*np_t,2))

    x_i = x_i.to(device)

    # boundary, 4 sides
    helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
    x_left = helper(X[0,:], T[0,:])
    x_right = helper(X[-1,:], T[0,:])

    x_init = helper(X[:,0], T[:,0])


    pde_losses = []
    bc_losses = []
    init_losses = []

    lamb_r = 1
    lamb_b = 1
    lamb_i = 100

    log = 1


    if i_deltat == 0:
        def u0_true(x):
            x1 = x[:,[0]]
            return x1**2 * torch.cos(torch.pi*x1)
            #return torch.cos(torch.pi*x1)

        model = MLP_RFF(width=[2,128,128,128,1], act='silu', seed=1, device=device, s=30)
    else:
        model_init = MLP_RFF(width=[2,128,128,128,1], act='silu', seed=1, device=device, s=30)
        model_init.load_state_dict(model.state_dict())
        
        def u0_true(x):
            x1 = x[:,[0]]
            t = x1 * 0. + 0.1
            x = torch.cat([x1, t], dim=1)
            return model_init(x)
            
        model = MLP_RFF(width=[2,128,128,128,1], act='silu', seed=1, device=device, s=30)
        #u0_true = MLP(width=[2,128,128,128,1], act='silu', seed=1, device=device).load_state_dict(model.state_dict())


    def train():

        for i in range(3):
            #optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
            if i == 0:
                optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
            else:
                for g in optimizer.param_groups:
                    g['lr'] *= 0.1
            #optimizer = torch.optim.LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe")

            pbar = tqdm(range(steps), desc='description', ncols=100)

            for _ in pbar:
                def closure():
                    global pde_loss, bc_loss, init_loss
                    optimizer.zero_grad()
                    # interior loss
                    #sol = sol_fun(x_i)
                    u = model(x_i)
                    sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
                    sol_D1 = sol_D1_fun(x_i)

                    ut = sol_D1[:,[1]]
                    sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:]
                    lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2)[:,:1], dim=1, keepdim=True)

                    pde = ut - 0.0001*lap + 5*(u**3 - u)

                    # pde loss
                    pde_loss = torch.mean(pde**2)

                    # neuman boundary loss
                    bc_pred_left = model(x_left)
                    bc_pred_right = model(x_right)
                    bcd_pred_left = sol_D1_fun(x_left)[:,[0]]
                    bcd_pred_right = sol_D1_fun(x_right)[:,[0]]
                    bc_loss = (torch.mean((bc_pred_left-bc_pred_right)**2) + torch.mean((bcd_pred_left-bcd_pred_right)**2))/2

                    # initial loss
                    pred_init = model(x_init)
                    true_init = u0_true(x_init)
                    init_loss = torch.mean((true_init-pred_init)**2)

                    loss = lamb_r * pde_loss + lamb_b * bc_loss + lamb_i * init_loss
                    loss.backward()
                    return loss

                optimizer.step(closure)
                #sol = sol_fun(x_i)
                loss = lamb_r * pde_loss + lamb_b * bc_loss + lamb_i * init_loss
                #l2 = torch.mean((model(x_i) - sol)**2)

                if _ % log == 0:
                    pbar.set_description("pde loss: %.2e | bc loss: %.2e | init loss : %.2e " % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), init_loss.cpu().detach().numpy()))

                pde_losses.append(pde_loss.cpu().detach().numpy())
                bc_losses.append(bc_loss.cpu().detach().numpy())
                init_losses.append(init_loss.cpu().detach().numpy())


    train()
    np_i_test = 1001
    np_t_test = 2
    x_mesh = torch.linspace(-1,1,steps=np_i_test).to(device)
    t_mesh = torch.linspace(0,2,steps=np_t_test).to(device)
    X, T = torch.meshgrid(x_mesh, t_mesh, indexing='ij')

    
    if i_deltat == 0:
        x_init = helper(X[:,0], T[:,0])
        np.savetxt('./results/mlprff_t_0.0_width_[2,128,128,128,1]_steps_%d.txt'%steps, model(x_init).cpu().detach().numpy())
    
    x_init = helper(X[:,0], T[:,0]+0.1)
    np.savetxt('./results/mlprff_t_%.1f_width_[2,128,128,128,1]_steps_%d.txt'%(0.1*(i_deltat+1), steps), model(x_init).cpu().detach().numpy())

    #plt.plot(x_mesh.cpu().detach().numpy(), u0_true(x_init).cpu().detach().numpy())
    #plt.plot(x_mesh.cpu().detach().numpy(), model(x_init).cpu().detach().numpy())
    #plt.show()
    
    
end_time = time.time()
np.savetxt('./results/mlprff_walltime_width_[2,128,128,128,1]_steps_%d.txt'%steps, [end_time-start_time])
