from kan import *
import torch
import matplotlib.pyplot as plt
from torch import autograd
from tqdm import tqdm
import time
from experiments.baselines.MLP import MLP_RFF


# darcy flow, input (x, y)
widths = [[2,128,128,128,1]]
ss = [3,30]

steps = 1000

for width in widths:
    for s in ss:

        seed = 0
        torch.manual_seed(seed)

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(device)

        dim = 2
        np_i = 51 # number of interior points (along each dimension)
        np_b = 51 # number of boundary points (along each dimension)
        ranges = [-1, 1]

        N_a = 20
        N_u = 20

        x_a = torch.rand(N_a).to(device) * 2 - 1
        y_a = torch.rand(N_a).to(device) * 2 - 1
        sigma_a = 0.2 * torch.rand(N_a).to(device) + 0.1
        x_u = torch.rand(N_u).to(device) * 2 - 1
        y_u = torch.rand(N_u).to(device) * 2 - 1
        sigma_u = 0.2 * torch.rand(N_u).to(device) + 0.1

        def a(x):
            # x shape: (B, d)
            # x_a shape: (N_a)
            return 1 + torch.sum(torch.exp(-((x[:,[0],None] - x_a[None,None,:])**2+(x[:,[1],None] - y_a[None,None,:])**2)/(2*sigma_a[None,None,:]**2)), dim=-1)

        def u(x):
            # x shape: (B, d)
            # x_a shape: (N_a)
            return torch.sum(torch.exp(-((x[:,[0],None] - x_u[None,None,:])**2+(x[:,[1],None] - y_u[None,None,:])**2)/(2*sigma_u[None,None,:]**2)), dim=-1)


        def batch_jacobian(func, x, create_graph=False):
            # x in shape (Batch, Length)
            def _func_sum(x):
                return func(x).sum(dim=0)
            return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

        # define solution
        def source_fun(x):
            def adu(x):
                du = batch_jacobian(u, x, create_graph=True)[:,0,:]
                adu_ = a(x) * du
                return adu_


            lap = batch_jacobian(adu, x)
            lhs = torch.sum(torch.diagonal(lap, dim1=1, dim2=2),dim=1,keepdim=True)
            return lhs


        # interior
        sampling_mode = 'mesh' # 'radnom' or 'mesh'

        x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)
        y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)
        X, Y = torch.meshgrid(x_mesh, y_mesh, indexing="ij")
        if sampling_mode == 'mesh':
            #mesh
            x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
        else:
            #random
            x_i = torch.rand((np_i**2,2))*2-1

        x_i = x_i.to(device)

        # boundary, 4 sides
        helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
        xb1 = helper(X[0], Y[0])
        xb2 = helper(X[-1], Y[0])
        xb3 = helper(X[:,0], Y[:,0])
        xb4 = helper(X[:,0], Y[:,-1])
        x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0)

        x_b = x_b.to(device)

        alpha = 0.01
        log = 1


        pde_losses = []
        bc_losses = []
        l2_losses = []


        model = MLP_RFF(width=width, seed=1, device=device, s=s)


        def train():
            #optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)
            optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
            #optimizer = torch.optim.LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe")

            pbar = tqdm(range(steps), desc='description', ncols=100)

            for _ in pbar:
                def closure():
                    global pde_loss, bc_loss
                    optimizer.zero_grad()
                    # interior loss
                    sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
                    adu = lambda x: a(x) * sol_D1_fun(x)
                    sol_D2 = batch_jacobian(adu, x_i, create_graph=True)[:,:,:]
                    lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True)

                    #print(sol_D2)
                    source = source_fun(x_i)
                    pde_loss = torch.mean((lap - source)**2)

                    # boundary loss
                    bc_true = u(x_b)
                    bc_pred = model(x_b)
                    bc_loss = torch.mean((bc_pred-bc_true)**2)

                    loss = alpha * pde_loss + bc_loss
                    loss.backward()
                    return loss


                optimizer.step(closure)
                closure()
                sol = u(x_i)
                loss = alpha * pde_loss + bc_loss
                l2 = torch.mean((model(x_i) - sol)**2)

                if _ % log == 0:
                    pbar.set_description("pde loss: %.2e | bc loss: %.2e | l2: %.2e " % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.cpu().detach().numpy()))

                pde_losses.append(pde_loss.cpu().detach().numpy())
                bc_losses.append(bc_loss.cpu().detach().numpy())
                l2_losses.append(l2.cpu().detach().numpy())


        start_time = time.time()
        train()
        end_time = time.time()

        wall_time = end_time - start_time

        # wall time
        # l2_loss trajectory
        # final profile, true profile

        np.savetxt(f'./results/mlprff_l2_width_{width}_seed_{seed}_steps_{steps}_s_{s}', np.array(l2_losses))
        np.savetxt(f'./results/mlprff_walltime_width_{width}_seed_{seed}_steps_{steps}_s_{s}', np.array([wall_time]))

        x = torch.linspace(-1,1,steps=101)
        y = torch.linspace(-1,1,steps=101)
        xx, yy = torch.meshgrid(x, y, indexing='ij')
        inputs = torch.stack([xx.reshape(-1,), yy.reshape(-1,)]).permute(1,0).to(device)
        np.savetxt(f'./results/mlprff_sol_width_{width}_seed_{seed}_steps_{steps}_s_{s}', model(inputs).reshape(101,101).cpu().detach().numpy())
        np.savetxt(f'./results/a_sol_seed_{seed}_steps_{steps}', a(inputs).reshape(101,101).cpu().detach().numpy())
        np.savetxt(f'./results/u_sol_seed_{seed}_steps_{steps}', u(inputs).reshape(101,101).cpu().detach().numpy())
