# main.py

import torch
import numpy as np
import pickle
import time
import pandas as pd
import xlsxwriter
import openpyxl
from scipy.interpolate import lagrange
from config import M, N, n0, n1, n2, numpoints_f2_minus_z, device,KNN
from functions import f1, f2_minus, f2_plus, f_u_plus, f_u_minus,f3_minus, f3_plus, f3_u_minus, f3_u_plus
from functions import calculate_knn_series



# GPU CHOOSE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#####################################################################################################################################################   
# Define x, y, and theta values
x_values = torch.linspace(-M, N, n0).to(device)
y_values = torch.linspace(-M, N, n0).to(device)
theta_values = torch.linspace(-M, N, n1).to(device)

# Create tensor for calculating f1, f2_minus, and f2_plus values
tensor = torch.stack([
    x_values.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)


# Calculate f1, f2_minus, and f2_plus values and print the time taken
f2_minus_values = f2_minus(tensor, M, N).to(device)
f2_plus_values = f2_plus(tensor, M, N).to(device)

# Compute the absolute value of f2_theta and find the minimum value and corresponding theta
f2_theta = torch.abs(f2_plus_values - f2_minus_values).to(device)
min_f2_theta_values, min_theta_indices = torch.min(f2_theta, dim=-1)
min_theta_values = theta_values[min_theta_indices].to(device)

#-----------------------------------------------------------------------------------------------------------------------------------
# Combine x_values, y_values, and min_theta_values into one tensor
u_values = torch.linspace(-M, N, n2).to(device)
x_grid, y_grid = torch.meshgrid(x_values, y_values, indexing='ij') 
combined_tensor0 = torch.stack((x_grid, y_grid, min_theta_values), dim=-1)

#combine tensor1
x_grid_expanded = x_grid.unsqueeze(-1).expand(n0, n0, n2)
y_grid_expanded = y_grid.unsqueeze(-1).expand(n0, n0, n2)
min_theta_values_expanded = min_theta_values.unsqueeze(-1).expand(n0, n0, n2)
z_values = torch.linspace(-M, N, n2).to(device)
z_grid = z_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n2)
combined_tensor1 = torch.stack((x_grid_expanded, y_grid_expanded, min_theta_values_expanded,z_grid), dim=-1)


#combine the tensor to be used in caculating f1-u
num_x, num_y, _ = combined_tensor0.shape
u_values = torch.linspace(-M, 1, numpoints_f2_minus_z).to(device)
theta_values = combined_tensor0[..., 2].unsqueeze(-1)  # Shape: (num_x, num_y, 1)
expanded_theta = theta_values.expand(-1, -1, numpoints_f2_minus_z)  # Shape: (num_x, num_y, numpoints_f2_minus_z)
u_values_tensor = -M + (u_values - (-M)) * (expanded_theta - (-M)) / (1 - (-M))



#caculate step size
step_size1 = (expanded_theta - (-M)) / (numpoints_f2_minus_z - 1)
#combine the tensor
expanded_x = combined_tensor0[..., 0].unsqueeze(-1).expand(-1, -1, numpoints_f2_minus_z)
expanded_y = combined_tensor0[..., 1].unsqueeze(-1).expand(-1, -1, numpoints_f2_minus_z)
combined_tensor2 = torch.stack((expanded_x, expanded_y, expanded_theta, u_values_tensor), dim=-1)

#combine the tensor to be used in caculating f1+u
num_x, num_y, _ = combined_tensor0.shape
u_values = torch.linspace(-M, N, numpoints_f2_minus_z).to(device)
theta_values = combined_tensor0[..., 2].unsqueeze(-1)  # Shape: (num_x, num_y, 1)
expanded_theta = theta_values.expand(-1, -1, numpoints_f2_minus_z)   # Shape: (num_x, num_y, numpoints_f2_minus_z)
u_values_tensor = theta_values + (u_values - (-M)) * (N - theta_values) / (N - (-M))
#print(u_values_tensor )
#caculate step size
step_size2 = (N - expanded_theta) / (numpoints_f2_minus_z - 1) 
#combine the tensor
expanded_x = combined_tensor0[..., 0].unsqueeze(-1).expand(-1, -1, numpoints_f2_minus_z)
expanded_y = combined_tensor0[..., 1].unsqueeze(-1).expand(-1, -1, numpoints_f2_minus_z)
combined_tensor4 = torch.stack((expanded_x, expanded_y, expanded_theta, u_values_tensor), dim=-1)
#print(combined_tensor4)



#seperate tensor1
# Given combined_tensor with shape [20, 20, 20, 4]
# Where dim=-1 corresponds to x, y, theta, z respectively
x_grid_expanded = combined_tensor1[..., 0]
y_grid_expanded = combined_tensor1[..., 1]
theta_values_expanded = combined_tensor1[..., 2]
z_values_expanded = combined_tensor1[..., 3]

# Create a boolean mask to check grid points where z <= theta
mask_z_leq_theta = z_values_expanded <= theta_values_expanded

# Create the first new grid: keep points where z <= theta, set others to inf
grid_1 = combined_tensor1.clone()
grid_1[~mask_z_leq_theta] = float('inf')  # Invert the mask to set z > theta points to inf

# Create the second new grid: keep points where z > theta, set others to inf
grid_2 = combined_tensor1.clone()
grid_2[mask_z_leq_theta] = float('inf')  # Use the mask to keep z > theta points

# Print the results
#print(grid_1)
#print(grid_2)

f_u = f_u_minus(combined_tensor2,M,N)
#combine the step size
f_u = torch.cat((f_u, step_size1.unsqueeze(-1)), dim=-1)



#caculate f2-z, get tensor(x,y,theta,z,f2-(z))
# Initialize the resulting tensor for grid1 with an additional dimension for the aggregated results
result_grid1 = torch.zeros(n0, n0, n2).to(device)

# Extract components from grid1
x_grid_expanded = grid_1[..., 0]  # [n0, n0, n2]
y_grid_expanded = grid_1[..., 1]  # [n0, n0, n2]
theta_values_expanded = grid_1[..., 2]  # [n0, n0, n2]
z_values_expanded = grid_1[..., 3]  # [n0, n0, n2]

# Iterate over the grid1 dimensions
for i in range(n0):
    for j in range(n0):
        for k in range(n2):
            x = x_grid_expanded[i, j, k]
            y = y_grid_expanded[i, j, k]
            theta = theta_values_expanded[i, j, k]
            z = z_values_expanded[i, j, k]

            # Find the matching (x, y, theta) in f_u
            mask = (f_u[..., 0] == x) & (f_u[..., 1] == y) & (f_u[..., 2] == theta)
            matching_f_u = f_u[mask]
            
            # Filter the values where u < z
            u_values = matching_f_u[..., 3]
            u_mask = u_values < z
            filtered_f_u = matching_f_u[u_mask]
            
            # Compute the sum of f_u * delta_u for u < z
            result = torch.sum(filtered_f_u[..., 4] * filtered_f_u[..., 5])
            
            # Store the result in the corresponding position in result_grid1
            result_grid1[i, j, k] = result
            
            f2_minus_z= torch.cat((grid_1, result_grid1.unsqueeze(-1)), dim=-1)




f_u = f_u_plus(combined_tensor4,M,N)
#combine the step size
f_u = torch.cat((f_u, step_size2.unsqueeze(-1)), dim=-1)
#print(f_u.shape)


#caculate f2-z, get tensor(x,y,theta,z,f2-(z))
# Initialize the resulting tensor for grid1 with an additional dimension for the aggregated results
result_grid2 = torch.zeros(n0, n0, n2).to(device)

# Extract components from grid1
x_grid_expanded = grid_2[..., 0]  # [n0, n0, n2]
y_grid_expanded = grid_2[..., 1]  # [n0, n0, n2]
theta_values_expanded = grid_2[..., 2]  # [n0, n0, n2]
z_values_expanded = grid_2[..., 3]  # [n0, n0, n2]

# Iterate over the grid1 dimensions
for i in range(n0):
    for j in range(n0):
        for k in range(n2):
            x = x_grid_expanded[i, j, k]
            y = y_grid_expanded[i, j, k]
            theta = theta_values_expanded[i, j, k]
            z = z_values_expanded[i, j, k]

            # Find the matching (x, y, theta) in f_u
            mask = (f_u[..., 0] == x) & (f_u[..., 1] == y) & (f_u[..., 2] == theta)
            matching_f_u = f_u[mask]
            
            # Filter the values where u > z
            u_values = matching_f_u[..., 3]
            u_mask = u_values > z
            filtered_f_u = matching_f_u[u_mask]
            
            # Compute the sum of f_u * delta_u for u > z
            result = torch.sum(filtered_f_u[..., 4] * filtered_f_u[..., 5])
            
            # Store the result in the corresponding position in result_grid1
            result_grid2[i, j, k] = result
            
            f2_plus_z= torch.cat((grid_2, result_grid2.unsqueeze(-1)), dim=-1)



inf_mask = torch.isinf(f2_minus_z[:, :, :, 0]) 

inf_mask = inf_mask.unsqueeze(-1) 

f2_z = f2_minus_z.clone()  
f2_z[inf_mask.expand_as(f2_z)] = f2_plus_z[inf_mask.expand_as(f2_z)] 
torch.save(f2_z, 'f2_z.pt')
#print(f2_z.shape)

#get f1_z to caculate knn-1-----------------------------------------------------------------------------------------
x = f2_z[..., 0]  # [n0, n0, n2]
y = f2_z[..., 1]  # [n0, n0, n2]
theta = f2_z[..., 2]  # [n0, n0, n2]
z = f2_z[..., 3]  # [n0, n0, n2]
f2_z_values = f2_z[..., 4]

f1_z_values = f1(z,x,y,M,N)
f2_divide_f1 = f2_z_values / f1_z_values
f2_divide_f1 = torch.nan_to_num(f2_divide_f1, nan=float('-inf'))
result = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1), f2_divide_f1.unsqueeze(-1)), dim=-1)


sup_z = torch.max(f2_divide_f1, dim=-1).values 
x_less_than_y = x[..., 0] < y[..., 0] 
filtered_sup_z = torch.where(x_less_than_y, sup_z, torch.tensor(float('inf'), device=sup_z.device))
inf_x_less_than_y = torch.min(filtered_sup_z)
knn1 = 1/inf_x_less_than_y
print(knn1)# this is the first knn


#knn1_value = knn1.item() 

#workbook = xlsxwriter.Workbook('knn_results.xlsx')
#worksheet = workbook.add_worksheet()

#worksheet.write(0, 0, 'KNN1')
#worksheet.write(1, 0, knn1_value)

###################################################################################caculate_knn##################################################################

# Define x, y, and theta values
x_values_0 = torch.linspace(-M, N, n0).to(device)
y_values_0 = torch.linspace(-M, N, n0).to(device)
theta_values_0 = torch.linspace(-M, N, n1).to(device)

# Create tensor for calculating f1, f2_minus, and f2_plus values
tensor_0 = torch.stack([
    x_values_0.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values_0.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values_0.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)


knn_results = calculate_knn_series(KNN, n0, n1, n2, M, N, tensor_0, x_values_0, y_values_0,theta_values_0, numpoints_f2_minus_z)


# Open existing Excel file
file_path = 'knn_results.xlsx'
workbook = openpyxl.load_workbook(file_path)
worksheet = workbook.active

# Find the first empty row
def find_first_empty_row(sheet):
    for row in range(1, sheet.max_row + 1):
        if all(cell.value is None for cell in sheet[row]):
            return row
    return sheet.max_row + 1

empty_row = find_first_empty_row(worksheet)

# Write the new results in the first empty row
worksheet.cell(row=empty_row, column=1, value='KNN1')
worksheet.cell(row=empty_row, column=2, value=knn1.item())

# Write the KNN results in the next available rows
for i, (key, value) in enumerate(knn_results.items()):
    worksheet.cell(row=empty_row + i + 1, column=1, value=f"KNN{key}")
    worksheet.cell(row=empty_row + i + 1, column=2, value=value)

# Save the workbook
workbook.save(file_path)


