import torch
import numpy as np
import pandas as pd
from config import M, N, n0, n1, n2, numpoints_f2_minus_z, device
from functions import f3_minus, f3_plus, f3_u_minus, f3_u_plus


def create_combined_tensor0(x_values, y_values, min_theta_values):
    x_grid, y_grid = torch.meshgrid(x_values, y_values, indexing='ij')
    combined_tensor0 = torch.stack((x_grid, y_grid, min_theta_values), dim=-1)
    return combined_tensor0, x_grid, y_grid 


def create_combined_tensor_f2(combined_tensor0, M, N, numpoints_f2, f2_type='minus'):
    theta_values = combined_tensor0[..., 2].unsqueeze(-1)  # (num_x, num_y, 1)
    expanded_theta = theta_values.expand(-1, -1, numpoints_f2)  # (num_x, num_y, numpoints_f2)
    
    if f2_type == 'minus':
        u_values = torch.linspace(-M, 1, numpoints_f2).to(device)
        u_values_tensor = -M + (u_values - (-M)) * (expanded_theta - (-M)) / (1 - (-M))
        step_size = (expanded_theta - (-M)) / (numpoints_f2 - 1)
    else:
        u_values = torch.linspace(-M, N, numpoints_f2).to(device)
        u_values_tensor = theta_values + (u_values - (-M)) * (N - theta_values) / (N - (-M))
        step_size = (N - expanded_theta) / (numpoints_f2 - 1)
    
    expanded_x = combined_tensor0[..., 0].unsqueeze(-1).expand(-1, -1, numpoints_f2)
    expanded_y = combined_tensor0[..., 1].unsqueeze(-1).expand(-1, -1, numpoints_f2)
    
    combined_tensor = torch.stack((expanded_x, expanded_y, expanded_theta, u_values_tensor), dim=-1)
    return combined_tensor, step_size

def separate_tensor_by_z_theta(x_grid, y_grid, min_theta_values, z_values, n0, n2):
    x_grid_expanded = x_grid.unsqueeze(-1).expand(n0, n0, n2)
    y_grid_expanded = y_grid.unsqueeze(-1).expand(n0, n0, n2)
    min_theta_values_expanded = min_theta_values.unsqueeze(-1).expand(n0, n0, n2)
    z_grid = z_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n2)
    
    combined_tensor1 = torch.stack((x_grid_expanded, y_grid_expanded, min_theta_values_expanded, z_grid), dim=-1)
    z_values_expanded = combined_tensor1[..., 3]
    theta_values_expanded = combined_tensor1[..., 2]
    
    mask_z_leq_theta = z_values_expanded <= theta_values_expanded
    
    grid_1 = combined_tensor1.clone()
    grid_1[~mask_z_leq_theta] = float('inf')
    
    grid_2 = combined_tensor1.clone()
    grid_2[mask_z_leq_theta] = float('inf')
    
    return grid_1, grid_2

def calculate_f_u(f_u_func, combined_tensor, M, N, f2_z_values_expanded, step_size):
    f_u = f_u_func(combined_tensor, M, N, f2_z_values_expanded)
    return torch.cat((f_u, step_size.unsqueeze(-1)), dim=-1)

def compute_result_grid(grid, f_u, compare_func):
    n0, n1, n2 = grid.shape[:3]
    result_grid = torch.zeros(n0, n1, n2).to(grid.device)

    # Extract grid components
    x_grid_expanded = grid[..., 0]  
    y_grid_expanded = grid[..., 1]  
    theta_values_expanded = grid[..., 2]  
    z_values_expanded = grid[..., 3]  

    # Iterate over the grid dimensions
    for i in range(n0):
        for j in range(n1):
            for k in range(n2):
                x = x_grid_expanded[i, j, k]
                y = y_grid_expanded[i, j, k]
                theta = theta_values_expanded[i, j, k]
                z = z_values_expanded[i, j, k]

                # Find matching (x, y, theta) in f_u
                mask = (f_u[..., 0] == x) & (f_u[..., 1] == y) & (f_u[..., 2] == theta)
                matching_f_u = f_u[mask]

                # Filter values based on compare_func
                u_values = matching_f_u[..., 3]
                u_mask = compare_func(u_values, z)
                filtered_f_u = matching_f_u[u_mask]

                # Compute the sum of f_u * delta_u
                result = torch.sum(filtered_f_u[..., 4] * filtered_f_u[..., 5])

                # Store the result
                result_grid[i, j, k] = result

    return result_grid

def calculate_f3_minus_z(f2_z_values_expanded, combined_tensor2, M, N, grid_1, step_size1):
    f_u_1 = calculate_f_u(f3_u_minus, combined_tensor2, M, N, f2_z_values_expanded, step_size1)
    result_grid1 = compute_result_grid(grid_1, f_u_1, lambda u, z: u < z)
    return torch.cat((grid_1, result_grid1.unsqueeze(-1)), dim=-1)

def calculate_f3_plus_z(f2_z_values_expanded, combined_tensor4, M, N, grid_2, step_size2):
    f_u_2 = calculate_f_u(f3_u_plus, combined_tensor4, M, N, f2_z_values_expanded, step_size2)
    result_grid2 = compute_result_grid(grid_2, f_u_2, lambda u, z: u > z)
    return torch.cat((grid_2, result_grid2.unsqueeze(-1)), dim=-1)

def get_f3_z(f3_minus_z, f3_plus_z):
    inf_mask = torch.isinf(f3_minus_z[:, :, :, 0]).unsqueeze(-1)
    f3_z = f3_minus_z.clone()
    f3_z[inf_mask.expand_as(f3_z)] = f3_plus_z[inf_mask.expand_as(f3_z)]
    torch.save(f3_z, 'f3_z.pt')
    return f3_z

# Main function to call the above
def main_function(n0, n1, n2, combined_tensor2, combined_tensor4, M, N, f2_z_values_expanded_1, grid_1, grid_2, step_size1, step_size2):
    # Compute f3_minus_z
    f3_minus_z = calculate_f3_minus_z(f2_z_values_expanded_1, combined_tensor2, M, N, grid_1, step_size1)

    # Compute f3_plus_z
    f3_plus_z = calculate_f3_plus_z(f2_z_values_expanded_1, combined_tensor4, M, N, grid_2, step_size2)

    # Combine f3_minus_z and f3_plus_z into f3_z
    f3_z = get_f3_z(f3_minus_z, f3_plus_z)

    return f3_z


def calculate_knn(f3_z, f2_z):

    x = f3_z[..., 0]  # [n0, n0, n2]
    y = f3_z[..., 1]  # [n0, n0, n2]
    z = f3_z[..., 3]  # [n0, n0, n2]
    f3_z_values = f3_z[..., 4]  


    f2_z_values = f2_z[..., 4]  


    f3_divide_f2 = f3_z_values / f2_z_values
    f3_divide_f2 = torch.nan_to_num(f3_divide_f2, nan=float('-inf'))

    result = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1), f3_divide_f2.unsqueeze(-1)), dim=-1)

    sup_z = torch.max(f3_divide_f2, dim=-1).values

    x_less_than_y = x[..., 0] < y[..., 0]
    filtered_sup_z = torch.where(x_less_than_y, sup_z, torch.tensor(float('inf'), device=sup_z.device))

    inf_x_less_than_y = torch.min(filtered_sup_z)

    knn2 = 1 / inf_x_less_than_y

    return knn2



def calculate_knn_series(KNN, n0, n1, n2, M, N, tensor, x_values, y_values, theta_values, numpoints_f2_minus_z, device):
    knn_results = {}


    f2_z = torch.load('f2_z.pt')
    f2_z_values = f2_z[..., -1]
    f_z_values_expanded = f2_z_values.unsqueeze(-1).expand(-1, -1, -1, f2_z_values.shape[-1])
    f_z_values_expanded_1 = f2_z_values.unsqueeze(0).repeat(n2, 1, 1, 1)

    f_i = f2_z  

    for i in range(2, KNN + 1):

        f_minus_values = f3_minus(tensor, M, N, f_z_values_expanded).to(device)
        f_plus_values = f3_plus(tensor, M, N, f_z_values_expanded).to(device)

        f_theta = torch.abs(f_plus_values - f_minus_values).to(device)
        min_f_theta_values, min_theta_indices = torch.min(f_theta, dim=-1)
        min_theta_values = theta_values[min_theta_indices].to(device)
        

        combined_tensor0, x_grid, y_grid = create_combined_tensor0(x_values, y_values, min_theta_values)
        combined_tensor2, step_size1 = create_combined_tensor_f2(combined_tensor0, M, N, numpoints_f2_minus_z, f2_type='minus')
        combined_tensor4, step_size2 = create_combined_tensor_f2(combined_tensor0, M, N, numpoints_f2_minus_z, f2_type='plus')

        z_values = torch.linspace(-M, N, n2).to(device)
        grid_1, grid_2 = separate_tensor_by_z_theta(x_grid, y_grid, min_theta_values, z_values, n0, n2)

        f_i_plus_1 = main_function(n0, n1, n2, combined_tensor2, combined_tensor4, M, N, f_z_values_expanded_1, grid_1, grid_2, step_size1, step_size2)

        knn_value = calculate_knn(f_i_plus_1, f_i)
        knn_results[f'knn-{i}'] = knn_value.item()  

        f_i = f_i_plus_1
        f_z_values = f_i[...,-1]
        f_z_values_expanded = f_z_values.unsqueeze(-1).expand(-1, -1, -1, f_z_values.shape[-1])
        f_z_values_expanded_1 = f_z_values.unsqueeze(0)
        f_z_values_expanded_1 = f_z_values_expanded_1.repeat(n2, 1, 1, 1)



        print(f'knn-{i}: {knn_value}')

    return knn_results


def save_knn_to_excel(knn_results, file_path):
    df = pd.DataFrame(list(knn_results.items()), columns=['KNN', 'Value'])
    df.to_excel(file_path, index=False)
    print(f'KNN results saved to {file_path}')

################################################################################-----init---######################################################################

# Define x, y, and theta values
x_values = torch.linspace(-M, N, n0).to(device)
y_values = torch.linspace(-M, N, n0).to(device)
theta_values = torch.linspace(-M, N, n1).to(device)

# Create tensor for calculating f1, f2_minus, and f2_plus values
tensor = torch.stack([
    x_values.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)


KNN = 10

###################################################################################caculate_knn##################################################################
knn_results = calculate_knn_series(KNN, n0, n1, n2, M, N, tensor, x_values, y_values,theta_values, numpoints_f2_minus_z, device)

file_path = 'knn_results.xlsx'
save_knn_to_excel(knn_results, file_path)
