

import torch
import numpy as np
import pandas as pd
from config import M, N, n0, n1, n2, numpoints_f2_minus_z
from config import device
from functions import calculate_knn_series


################################################################################-----init---######################################################################

# Define x, y, and theta values
x_values = torch.linspace(-M, N, n0).to(device)
y_values = torch.linspace(-M, N, n0).to(device)
theta_values = torch.linspace(-M, N, n1).to(device)

# Create tensor for calculating f1, f2_minus, and f2_plus values
tensor = torch.stack([
    x_values.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)


KNN = 10

###################################################################################caculate_knn##################################################################
knn_results = calculate_knn_series(KNN, n0, n1, n2, M, N, tensor, x_values, y_values,theta_values, numpoints_f2_minus_z)

print(knn_results)