import pickle
from models.libs.utils import UnitGaussianNormalizer, get_grids
import numpy as np
import torch

path = '/usr/commondata/public/Neural_Dynamics/CTmixer/dataset/ns_equation/dataset_sr128.pkl'
with open(path, 'rb') as f:
    data = pickle.load(f)


u = torch.from_numpy(data['u'])

def euclid_distance(X, Y=None):
    if Y is None:
        Y = X
    assert(Y.shape[1] == X.shape[1])
    dist = torch.cdist(X, Y)
    return dist  

def get_connectivity(X, Y, eps):
    distance = euclid_distance(X, Y)
    edge_index = torch.vstack(torch.where(distance <= eps))
    return edge_index

original_resolution = u.shape[1:-1]

dimension = len(original_resolution)

x = get_grids(original_resolution)
x = x.reshape(-1, x.shape[-1])

sparse_idx = get_connectivity(x,x,eps=0.05)

data['sparse_idx'] = sparse_idx

with open(path, "wb") as f:
    pickle.dump(data, f, protocol = 4)