import torch
import numpy as np
import pdb
import pandas as pd
import matplotlib.pyplot as plt
import pickle

def helmholtz_2d_exact_u(y, x, a1, a2):
    return torch.sin(a1*torch.pi*y) * torch.sin(a2*torch.pi*x)

# Generate Helmholtz_2d
num_data_per_row = 256

y = torch.linspace(-1, 1, num_data_per_row)
x = torch.linspace(-1, 1, num_data_per_row)
y, x = torch.meshgrid([y, x], indexing='ij')
y_data = y.reshape(-1, 1)
x_data = x.reshape(-1, 1)
grid = torch.cat([y_data, x_data],axis=1)

with open('helmholtz/helmholtz_grid.pkl', 'wb') as f:
    pickle.dump(grid, f)

# start_a = 2
# end_a = 20

# total_data = None
# for a1 in range(start_a, end_a + 1):
#     for a2 in range(start_a, end_a + 1):
#         u_data = helmholtz_2d_exact_u(y_data, x_data, a1*0.5, a2*0.5)
#         if total_data is None:
#             total_data = u_data
#         else:
#             total_data = torch.cat([total_data, u_data], axis=1)

# total_data = total_data.T
# with open(f'helmholtz/helmholtz_data{start_a//2}to{end_a//2}.pkl', 'wb') as f:
#     pickle.dump(total_data, f)


start_a = 1
end_a = 10

total_data = None
for a1 in range(start_a, end_a + 1):
    for a2 in range(start_a, end_a + 1):
        u_data = helmholtz_2d_exact_u(y_data, x_data, a1, a2)
        if total_data is None:
            total_data = u_data
        else:
            total_data = torch.cat([total_data, u_data], axis=1)

total_data = total_data.T
with open(f'helmholtz/helmholtz_data{start_a}to{end_a}_interval1.pkl', 'wb') as f:
    pickle.dump(total_data, f)


