# -*- coding: utf-8 -*-
import torch
import numpy as np
from scipy.integrate import quad
import torch
import time
import pickle
import pandas as pd
import xlsxwriter
import openpyxl
from config import M, N, n0, n1, n2, numpoints_f2_minus_z, device,KNN


def a(u):
    return 1/2

def c(u_values, M=3, N=3, num_points=100):

    device = u_values.device  
    original_shape = u_values.shape
    u_values_flat = u_values.reshape(-1, u_values.size(-1))  

    t_values = torch.linspace(-M, N, num_points).to(device).unsqueeze(0)  
    b = -(t_values) / 2  
    a = 1 / 2  
    tensor_0 = (b / a).to(device)  
    dt = (M + N) / num_points 

    u_values_expanded = u_values_flat.unsqueeze(2)  

    positive_mask = u_values_expanded >= 0    
    negative_mask = u_values_expanded < 0     

    mask_positive = (t_values > 0) & (t_values <= u_values_expanded)  
    mask_negative = (t_values <= 0) & (t_values > u_values_expanded) 

    cumsum_positive = torch.cumsum(tensor_0 * mask_positive.float(), dim=-1) * dt 
    cumsum_negative = torch.cumsum(tensor_0 * mask_negative.float(), dim=-1) * dt

    result_positive = cumsum_positive[:, :, -1]  # Get the last column
    result_negative = -cumsum_negative[:, :, -1]  # Negative part result is negated

    result = torch.where(positive_mask[:, :, -1], result_positive, result_negative)
    result = result.view(original_shape)
    return result

def f1(t, x, y, M, N, num_points=200):
    device = t.device
    u_values = torch.linspace(-M, N, num_points, device=device)
    delta_u = u_values[1] - u_values[0]   
    integrand = torch.exp(c(u_values)) / a(u_values)  

    def compute_integral(start, end):
        start = start.unsqueeze(-1)
        end = end.unsqueeze(-1)#####
        mask = (u_values >= start) & (u_values <= end)
        integral = torch.sum(integrand * mask, dim=-1) * delta_u
        return integral

    integral_y_N = compute_integral(y, torch.tensor(N, device=device))
    integral_M_x = compute_integral(torch.tensor(-M, device=device), x)
    integral_M_t = compute_integral(torch.tensor(-M, device=device), t)
    integral_t_N = compute_integral(t, torch.tensor(N, device=device))

    result = torch.zeros_like(t)
    condition1 = t <= x
    condition2 = (x < t) & (t < y)
    condition3 = y <= t

    result[condition1] = (torch.sqrt(integral_y_N[condition1] / integral_M_x[condition1]) *
                          torch.sqrt(integral_M_t[condition1]))
    result[condition2] = torch.sqrt(integral_y_N[condition2])
    result[condition3] = torch.sqrt(integral_t_N[condition3])

    # Optionally clear more memory if necessary
    del integral_y_N, integral_M_x, integral_M_t, integral_t_N
    torch.cuda.empty_cache()  # Free up memory again if needed

    return result

def f2_minus(tensor, M, N, num_points=200, chunk_size=10, u_chunk_size=50):  
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    u_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)
    t_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]
    results = []

    for i in range(0, x_tensor.shape[0], chunk_size):
        x_chunk = x_tensor[i:i + chunk_size]
        y_chunk = y_tensor[i:i + chunk_size]
        theta_chunk = theta_tensor[i:i + chunk_size]

        chunk_results = []
        for u_start in range(0, num_points, u_chunk_size):
            for t_start in range(0, num_points, u_chunk_size):
                # Create chunks for u_values and t_values
                u_chunk = u_values[u_start:u_start + u_chunk_size]
                t_chunk = t_values[t_start:t_start + u_chunk_size]

                inner1 = (torch.exp(c(u_chunk)) / a(u_chunk))
                inner2 = torch.exp(-c(t_chunk))
                inner1_grid, inner2_grid = torch.meshgrid(inner1, inner2, indexing='ij')
                inner1_grid = inner1_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                           theta_chunk.shape[2], -1, -1)
                inner2_grid = inner2_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                           theta_chunk.shape[2], -1, -1)

                # Generate meshgrid for the chunked u and t values
                u_grid, t_grid = torch.meshgrid(u_chunk, t_chunk, indexing='ij')
                u_grid = u_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                 theta_chunk.shape[2], -1, -1)
                t_grid = t_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                 theta_chunk.shape[2], -1, -1)
                inner_integral_mask = (t_grid >= u_grid) & (t_grid <= theta_chunk.unsqueeze(-1).unsqueeze(-1))
                integrand_inner = inner1_grid * inner2_grid
                integrand_inner_masked = integrand_inner * inner_integral_mask

                # Compute f1 values for the chunked grid
                f1_values = f1(t_grid[:x_chunk.shape[0]],
                               x_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size),
                               y_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), M, N,
                               num_points)

                # Perform inner sum for this chunk
                inner_sum = torch.sum(integrand_inner_masked * f1_values, dim=(-2, -1)) * delta_u * delta_t
                chunk_results.append(inner_sum)

                # Clear GPU memory for the chunk
                del u_chunk, t_chunk, u_grid, t_grid, inner_integral_mask, integrand_inner, integrand_inner_masked, f1_values, inner_sum
                torch.cuda.empty_cache()

        # Concatenate results for this chunk of x, y, theta
        results.append(torch.sum(torch.stack(chunk_results), dim=0))

        # Clear memory for x, y, theta chunks
        del x_chunk, y_chunk, theta_chunk, chunk_results
        torch.cuda.empty_cache()

    return torch.cat(results, dim=0)


def f2_plus(tensor, M, N, num_points=200, chunk_size=5, u_chunk_size=10):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    u_values = torch.linspace(theta_tensor.min().item(), N, num_points, device=device)
    t_values = torch.linspace(theta_tensor.min().item(), N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]
    results = []

    for i in range(0, x_tensor.shape[0], chunk_size):
        x_chunk = x_tensor[i:i + chunk_size]
        y_chunk = y_tensor[i:i + chunk_size]
        theta_chunk = theta_tensor[i:i + chunk_size]

        chunk_results = []

        for u_start in range(0, num_points, u_chunk_size):
            for t_start in range(0, num_points, u_chunk_size):
                # Create chunks for u_values and t_values
                u_chunk = u_values[u_start:u_start + u_chunk_size]
                t_chunk = t_values[t_start:t_start + u_chunk_size]

                inner1 = (torch.exp(c(u_chunk)) / a(u_chunk))
                inner2 = torch.exp(-c(t_chunk))
                inner1_grid, inner2_grid = torch.meshgrid(inner1, inner2, indexing='ij')
                inner1_grid = inner1_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                           theta_chunk.shape[2], -1, -1)
                inner2_grid = inner2_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                           theta_chunk.shape[2], -1, -1)

                # Generate meshgrid for the chunked u and t values
                u_grid, t_grid = torch.meshgrid(u_chunk, t_chunk, indexing='ij')
                u_grid = u_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                 theta_chunk.shape[2], -1, -1)
                t_grid = t_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1],
                                                                 theta_chunk.shape[2], -1, -1)

                inner_integral_mask = (t_grid >= theta_chunk.unsqueeze(-1).unsqueeze(-1)) & (t_grid <= u_grid)
                integrand_inner = inner1_grid * inner2_grid
                integrand_inner_masked = integrand_inner * inner_integral_mask

                # Compute f1 values for the chunked grid
                f1_values = f1(t_grid[:x_chunk.shape[0]],
                               x_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size),
                               y_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), M, N,
                               num_points)

                # Perform inner sum for this chunk
                inner_sum = torch.sum(integrand_inner_masked * f1_values, dim=(-2, -1)) * delta_u * delta_t
                chunk_results.append(inner_sum)

                # Clear GPU memory for the chunk
                del u_chunk, t_chunk, u_grid, t_grid, inner_integral_mask, integrand_inner, integrand_inner_masked, f1_values, inner_sum
                torch.cuda.empty_cache()

        # Concatenate results for this chunk of x, y, theta
        chunk_result = torch.sum(torch.stack(chunk_results), dim=0)
        results.append(chunk_result)

        # Clear memory for x, y, theta chunks
        del x_chunk, y_chunk, theta_chunk, chunk_results
        torch.cuda.empty_cache()

    # Concatenate results and ensure the final shape is [n0, n0, n1]
    return torch.cat(results, dim=0).squeeze(-1)



def f2_minus_z(result_tensor, grid_1, M, N, step_size1, num_intervals=1000, chunk_size=10):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]

    # Get device information
    device = result_tensor.device

    # Generate t_values from 0 to 1 for intervals
    t_values = torch.linspace(0, 1, num_intervals).view(-1, 1, 1, 1).to(device)

    # Use broadcasting to generate t_values within the range [u, theta]
    u_tensor = u.unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    t_values = u_tensor + (theta.unsqueeze(0) - u_tensor) * t_values  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Calculate dt
    dt = (theta - u) / num_intervals  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Calculate exp(c(u)) and exp(-c(t))
    exp_c_u = torch.exp(c(u)).unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)

    # Initialize an empty list to store the chunked results
    integrand_results = []

    # Chunk processing along the first dimension of t_values
    for t_chunk in torch.split(t_values, chunk_size, dim=0):
        # Ensure gradients are not required for these computations
        with torch.no_grad():
            exp_minus_c_t = torch.exp(-c(t_chunk))  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Expand x and y to match the shape of t_chunk
            f1_chunk = f1(t_chunk,
                          x=x.unsqueeze(0).expand(t_chunk.shape),
                          y=y.unsqueeze(0).expand(t_chunk.shape),
                          M=M, N=N)  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Compute the integrand for the chunk
            integrand_chunk = exp_minus_c_t * f1_chunk  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)
            integrand_results.append(integrand_chunk)

    # Concatenate all chunks along the interval dimension
    integrand = torch.cat(integrand_results, dim=0)  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    
    # Concatenate all chunks along the interval dimension
    integrand = torch.cat(integrand_results, dim=0)  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Clean up to release memory
    del t_values, exp_minus_c_t, integrand_results
    torch.cuda.empty_cache()

    # Sum over the t dimension and multiply by dt to get the integral value
    integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Finally, compute f_u
    f_u = (exp_c_u / a(u)) * integral_value.squeeze()  # Shape: (num_x, num_y, numpoints_f2_minus_z)
    # Remove the dimension with size 1
    f_u = f_u.squeeze(0)  # Now shape: [num_x, num_y, numpoints_f2_minus_z]


    result_tensor_extended = torch.cat((result_tensor, f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x, num_y, numpoints_f2_minus_z, 5)


    f_2_minus = torch.cat((result_tensor_extended, step_size1.unsqueeze(-1)), dim=-1)  # Shape: (num_x, num_y, numpoints_f2_minus_z, 5)
    
    result_grid1 = torch.zeros(n0, n0, n2).to(device)

    # Extract components from grid1
    x_grid_expanded = grid_1[..., 0]  # [n0, n0, n2]
    y_grid_expanded = grid_1[..., 1]  # [n0, n0, n2]
    theta_values_expanded = grid_1[..., 2]  # [n0, n0, n2]
    z_values_expanded = grid_1[..., 3]  # [n0, n0, n2]


    # Iterate over the grid1 dimensions
    for i in range(n0):
        for j in range(n0):
            for k in range(n2):
                x = x_grid_expanded[i, j, k]
                y = y_grid_expanded[i, j, k]
                theta = theta_values_expanded[i, j, k]
                z = z_values_expanded[i, j, k]

                # Find the matching (x, y, theta) in f_u_with_step_size
                mask = (f_2_minus[..., 0] == x) & (f_2_minus[..., 1] == y) & (f_2_minus[..., 2] == theta)
                matching_f_u = f_2_minus[mask]

                # Filter the values where u < z
                u_values = matching_f_u[..., 3]
                u_mask = u_values < z
                filtered_f_u = matching_f_u[u_mask]

                # Compute the sum of f_u * delta_u for u < z
                result = torch.sum(filtered_f_u[..., 4] * filtered_f_u[..., 5])

                # Store the result in the corresponding position in result_grid1
                result_grid1[i, j, k] = result
                
                # Combine the results with the grid
                f2_minus_z = torch.cat((grid_1, result_grid1.unsqueeze(-1)), dim=-1)

    return f2_minus_z


def f2_plus_z(result_tensor, grid_2, M, N, step_size2, num_intervals=1000, chunk_size=10):
    # Extract (x, y, theta, u) from result_tensor
    x = result_tensor[..., 0]
    y = result_tensor[..., 1]
    theta = result_tensor[..., 2]
    u = result_tensor[..., 3]

    # Get device information
    device = result_tensor.device

    # Generate t_values from 0 to 1 for intervals
    t_values = torch.linspace(0, 1, num_intervals).view(-1, 1, 1, 1).to(device)

    # Use broadcasting to generate t_values within the range [theta, u]
    theta_tensor = theta.unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    t_values = theta_tensor + (u.unsqueeze(0) - theta_tensor) * t_values  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Calculate dt
    dt = (u - theta) / num_intervals  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Calculate exp(c(u)) and exp(-c(t))
    exp_c_u = torch.exp(c(u)).unsqueeze(0)  # Shape: (1, num_x, num_y, numpoints_f2_minus_z)
    
    # Initialize an empty list to store the chunked results
    integrand_results = []

    # Chunk processing along the first dimension of t_values
    for t_chunk in torch.split(t_values, chunk_size, dim=0):
        # Ensure gradients are not required for these computations
        with torch.no_grad():
            exp_minus_c_t = torch.exp(-c(t_chunk))  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Expand x and y to match the shape of t_chunk
            f1_chunk = f1(t_chunk, 
                          x=x.unsqueeze(0).expand(t_chunk.shape), 
                          y=y.unsqueeze(0).expand(t_chunk.shape), 
                          M=M, N=N)  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)

            # Compute the integrand for the chunk
            integrand_chunk = exp_minus_c_t * f1_chunk  # Shape: (chunk_size, num_x, num_y, numpoints_f2_minus_z)
            integrand_results.append(integrand_chunk)
            

    
    # Concatenate all chunks along the interval dimension
    integrand = torch.cat(integrand_results, dim=0)  # Shape: (num_intervals, num_x, num_y, numpoints_f2_minus_z)

    # Clean up to release memory
    del t_values, exp_minus_c_t, integrand_results
    torch.cuda.empty_cache()

    # Sum over the t dimension and multiply by dt to get the integral value
    integral_value = torch.sum(integrand, dim=0) * dt.unsqueeze(0)  # Shape: (num_x, num_y, numpoints_f2_minus_z)

    # Finally, compute f_u
    f_u = (exp_c_u / a(u)) * integral_value.squeeze()  # Shape: (num_x, num_y, numpoints_f2_minus_z)
    # Remove the dimension with size 1
    f_u = f_u.squeeze(0)  # Now shape: [num_x, num_y, numpoints_f2_minus_z]

    
    # Combine the results, adding f_u as the fifth dimension
    result_tensor_extended = torch.cat((result_tensor, f_u.unsqueeze(-1)), dim=-1)  # Shape: (num_x, num_y, numpoints_f2_minus_z, 5)
  
    # Combine f_u with step_size2
    f_u = torch.cat((result_tensor_extended, step_size2.unsqueeze(-1)), dim=-1)

    # Step 2: Calculate f2_plus_z using grid_2 and f_u
    result_grid2 = torch.zeros(grid_2.shape[:-1], device=device)  # [n0, n0, n2]

    x_grid_expanded = grid_2[..., 0]
    y_grid_expanded = grid_2[..., 1]
    theta_values_expanded = grid_2[..., 2]
    z_values_expanded = grid_2[..., 3]

    # Iterate over the grid_2 dimensions
    for i in range(grid_2.shape[0]):
        for j in range(grid_2.shape[1]):
            for k in range(grid_2.shape[2]):
                x = x_grid_expanded[i, j, k]
                y = y_grid_expanded[i, j, k]
                theta = theta_values_expanded[i, j, k]
                z = z_values_expanded[i, j, k]

                # Find the matching (x, y, theta) in f_u
                mask = (f_u[..., 0] == x) & (f_u[..., 1] == y) & (f_u[..., 2] == theta)
                matching_f_u = f_u[mask]

                # Filter the values where u > z
                u_values = matching_f_u[..., 3]
                u_mask = u_values > z
                filtered_f_u = matching_f_u[u_mask]

                # Compute the sum of f_u * delta_u for u > z
                result = torch.sum(filtered_f_u[..., 4] * filtered_f_u[..., 5])

                # Store the result in the corresponding position in result_grid2
                result_grid2[i, j, k] = result

                # Combine the grid_2 with result_grid2 into the final f2_plus_z tensor
                f2_plus_z = torch.cat((grid_2, result_grid2.unsqueeze(-1)), dim=-1)

    return f2_plus_z
    
def f3_minus(tensor, M, N, f2_z_values_expanded, num_points=200):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device

    u_values = torch.linspace(-M, N, num_points, device=device)
    t_values = torch.linspace(-M, N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]

    u_grid, t_grid = torch.meshgrid(u_values, t_values, indexing='ij')
    f2_z_values_expanded = f2_z_values_expanded.unsqueeze(2).expand(-1, -1, theta_tensor.shape[2], -1, -1)

    # Initialize result tensor
    result = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], theta_tensor.shape[2], device=device)

    # Compute inner mask and integrand
    inner_integral_mask = (t_grid >= u_grid) & (t_grid <= theta_tensor.unsqueeze(-1).unsqueeze(-1))
    integrand_inner = (torch.exp(c(u_grid)) / a(u_grid)) * torch.exp(-c(t_grid))
    integrand_inner_masked = integrand_inner * inner_integral_mask
    outter_integral_mask = (u_values >= -M) & (u_values <= theta_tensor.unsqueeze(-1))

    inner_sum = torch.sum(integrand_inner_masked * f2_z_values_expanded, dim=-1) * delta_t
    outter_sum = torch.sum(inner_sum*outter_integral_mask,dim=-1)*delta_u

    result += outter_sum

    return result

def f3_plus(tensor, M, N, f2_z_values_expanded, num_points=200):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device

    u_values = torch.linspace(-M, N, num_points, device=device)
    t_values = torch.linspace(-M, N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]

    u_grid, t_grid = torch.meshgrid(u_values, t_values, indexing='ij')
    f2_z_values_expanded = f2_z_values_expanded.unsqueeze(2).expand(-1, -1, theta_tensor.shape[2], -1, -1)

    # Initialize result tensor
    result = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], theta_tensor.shape[2], device=device)

    # Compute inner mask and integrand
    inner_integral_mask = (t_grid <= u_grid) & (t_grid > theta_tensor.unsqueeze(-1).unsqueeze(-1))
    integrand_inner = (torch.exp(c(u_grid)) / a(u_grid)) * torch.exp(-c(t_grid))
    integrand_inner_masked = integrand_inner * inner_integral_mask
    outter_integral_mask = (u_values < N) & (u_values > theta_tensor.unsqueeze(-1))
    
    inner_sum = torch.sum(integrand_inner_masked * f2_z_values_expanded, dim=-1) * delta_t
    outter_sum = torch.sum(inner_sum*outter_integral_mask,dim=-1)*delta_u

    result += outter_sum

    return result


#-----------------------------------
def f3_u_minus(tensor, M, N, f2_z_values_expanded, num_points=200):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device

    u_values = torch.linspace(-M, N, num_points, device=device)
    t_values = torch.linspace(-M, N, num_points, device=device)
    z_values = torch.linspace(-M, N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]

    u_grid, t_grid = torch.meshgrid(u_values, t_values, indexing='ij')

    result = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], num_points, device=device)


    inner_integral_mask = (t_grid >= u_grid) & (t_grid <= theta_tensor.unsqueeze(-1).unsqueeze(-1))
    integrand_inner = (torch.exp(c(u_grid)) / a(u_grid)) * torch.exp(-c(t_grid))
    integrand_inner_masked = integrand_inner * inner_integral_mask
    
    for i, z_val in enumerate(z_values):
        outter_integral_mask = (u_values >= -M) & (u_values <= z_val)
        outter_integral_mask = outter_integral_mask.unsqueeze(0).unsqueeze(0)  
        outter_integral_mask = outter_integral_mask.expand(n0, n0, num_points)
        inner_sum = torch.sum(integrand_inner_masked * f2_z_values_expanded, dim=-1) * delta_t
        outter_sum = torch.sum(inner_sum * outter_integral_mask, dim=-1) * delta_u

        result[..., i] = outter_sum 
        

    theta = theta_tensor.unsqueeze(-1).expand(n0,n0,num_points)
    z = z_values.unsqueeze(0).unsqueeze(1).expand(n0,n0,num_points)
    mask = z<=theta
    result0 = torch.full_like(z, float('inf'))
    result0[mask] = result[mask]  
    return result0



def f3_u_plus(tensor, M, N, f2_z_values_expanded, num_points=200):
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device

    u_values = torch.linspace(-M, N, num_points, device=device)
    t_values = torch.linspace(-M, N, num_points, device=device)
    z_values = torch.linspace(-M, N, num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]

    u_grid, t_grid = torch.meshgrid(u_values, t_values, indexing='ij')

    result = torch.zeros(x_tensor.shape[0], x_tensor.shape[1], num_points, device=device)


    inner_integral_mask = (t_grid <= u_grid) & (t_grid >= theta_tensor.unsqueeze(-1).unsqueeze(-1))
    integrand_inner = (torch.exp(c(u_grid)) / a(u_grid)) * torch.exp(-c(t_grid))
    integrand_inner_masked = integrand_inner * inner_integral_mask
    
    for i, z_val in enumerate(z_values):
        outter_integral_mask = (u_values <= N) & (u_values >= z_val) 
        outter_integral_mask = outter_integral_mask.unsqueeze(0).unsqueeze(0)  
        outter_integral_mask = outter_integral_mask.expand(n0, n0, num_points)
        inner_sum = torch.sum(integrand_inner_masked * f2_z_values_expanded, dim=-1) * delta_t
        outter_sum = torch.sum(inner_sum * outter_integral_mask, dim=-1) * delta_u

        result[..., i] = outter_sum 
         
    theta = theta_tensor.unsqueeze(-1).expand(n0,n0,num_points)
    z = z_values.unsqueeze(0).unsqueeze(1).expand(n0,n0,num_points)
    mask = z>theta
    result0 = torch.full_like(z, float('inf'))
    result0[mask] = result[mask]  
    return result0


def calculate_knn(f3_z, f2_z):

    x = f3_z[..., 0]  # [n0, n0, n2]
    y = f3_z[..., 1]  # [n0, n0, n2]
    z = f3_z[..., 3]  # [n0, n0, n2]
    f3_z_values = f3_z[..., 4]  


    f2_z_values = f2_z[..., 4]  


    f3_divide_f2 = f3_z_values / f2_z_values
    f3_divide_f2 = torch.nan_to_num(f3_divide_f2, nan=float('-inf'))

    result = torch.cat((x.unsqueeze(-1), y.unsqueeze(-1), z.unsqueeze(-1), f3_divide_f2.unsqueeze(-1)), dim=-1)

    sup_z = torch.max(f3_divide_f2, dim=-1).values

    x_less_than_y = x[..., 0] < y[..., 0]
    filtered_sup_z = torch.where(x_less_than_y, sup_z, torch.tensor(float('inf'), device=sup_z.device))

    inf_x_less_than_y = torch.min(filtered_sup_z)

    knn2 = 1 / inf_x_less_than_y

    return knn2
#-------------------------------------------------------------
def calculate_knn_series(KNN, n0, n1, n2, M, N, tensor, x_values, y_values, theta_values, numpoints_f2_minus_z):
    device = tensor.device
    knn_results = {}

    f2_z = torch.load('f2_z.pt')
    f2_z_values = f2_z[..., -1]
    f_z_values_expanded = f2_z_values.unsqueeze(-1).expand(-1, -1, -1, f2_z_values.shape[-1])

    x_values = x_values.unsqueeze(-1)  
    y_values = y_values.unsqueeze(0)   
    xy_tensor = torch.stack(torch.meshgrid(x_values.squeeze(), y_values.squeeze()), dim=-1)

    f_i = f2_z

    for i in range(2, KNN + 1):
        f_minus_values = f3_minus(tensor, M, N, f_z_values_expanded).to(device)
        f_plus_values = f3_plus(tensor, M, N, f_z_values_expanded).to(device)

        f_theta = torch.abs(f_plus_values - f_minus_values).to(device)
        min_f_theta_values, min_theta_indices = torch.min(f_theta, dim=-1)
        min_theta_values = theta_values[min_theta_indices].to(device)


        tensor0= torch.cat((xy_tensor, min_theta_values.unsqueeze(-1)), dim=-1)
        f3_minus_z = f3_u_minus(tensor0, M, N, f_z_values_expanded, num_points=200)
        f3_plus_z = f3_u_plus(tensor0, M, N, f_z_values_expanded, num_points=200)

        mask_minus = f3_minus_z != float('inf')  
        mask_plus = f3_plus_z != float('inf')    
        final_result = torch.full_like(f3_minus_z, float('inf'))
        final_result[mask_minus] = f3_minus_z[mask_minus]
        final_result[mask_plus] = f3_plus_z[mask_plus]
        f3_z = final_result

        z_values = torch.linspace(-M, N, numpoints_f2_minus_z, device=device)  
        z_values_expanded = z_values.unsqueeze(0).unsqueeze(0) 
        z_values_expanded = z_values_expanded.expand(n0, n0, numpoints_f2_minus_z)  


        x_values_expanded = x_values.unsqueeze(-1).expand(n0, n0, numpoints_f2_minus_z)  
        y_values_expanded = y_values.unsqueeze(-1).expand(n0, n0, numpoints_f2_minus_z)  
        min_theta_expanded = min_theta_values.unsqueeze(2).expand(n0, n0, numpoints_f2_minus_z) 

        f_i_plus_1 = torch.stack((x_values_expanded, y_values_expanded, min_theta_expanded, z_values_expanded, f3_z), dim=-1)

        knn_value = calculate_knn(f_i_plus_1, f_i)
        knn_results[f'knn-{i}'] = knn_value.item()

        f_i = f_i_plus_1
        f_z_values = f_i[..., -1]
        f_z_values_expanded = f_z_values.unsqueeze(-1).expand(-1, -1, -1, f_z_values.shape[-1])

        print(f'knn-{i}: {knn_value}')

    return knn_results

############################################################################################################################################################3



# Define x, y, and theta values
x_values_0 = torch.linspace(-M, N, n0).to(device)
y_values_0 = torch.linspace(-M, N, n0).to(device)
theta_values_0 = torch.linspace(-M, N, n1).to(device)

# Create tensor for calculating f1, f2_minus, and f2_plus values
tensor_0 = torch.stack([
    x_values_0.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values_0.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values_0.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)


knn_results = calculate_knn_series(KNN, n0, n1, n2, M, N, tensor_0, x_values_0, y_values_0,theta_values_0, numpoints_f2_minus_z)


# Open existing Excel file
file_path = 'knn_results.xlsx'
workbook = openpyxl.load_workbook(file_path)
worksheet = workbook.active

# Find the first empty column
def find_first_empty_column(sheet):
    for col in range(1, sheet.max_column + 1):
        if all(sheet.cell(row=row, column=col).value is None for row in range(1, sheet.max_row + 1)):
            return col
    return sheet.max_column + 1

empty_col = find_first_empty_column(worksheet)

# Write the new results in the first empty column
worksheet.cell(row=1, column=empty_col, value='KNN1')
worksheet.cell(row=2, column=empty_col, value=knn1.item())

# Write the KNN results in the next available columns
for i, (key, value) in enumerate(knn_results.items()):
    worksheet.cell(row=1, column=empty_col + i + 1, value=f"KNN{key}")
    worksheet.cell(row=2, column=empty_col + i + 1, value=value)

# Save the workbook
workbook.save(file_path)
