import torch
import numpy as np
import pickle
from scipy.interpolate import lagrange
from config import M, N, n0, n1, n2, numpoints_f2_minus_z, device,KNN
import torch
import numpy as np
from scipy.integrate import quad


# Define functions c(u) and a(u)
#def c(u):
#    return -u**2 / 2
def a(u):
    return 1/2
    

def c(u_values, M=3, N=3, num_points=100):
    device = u_values.device
    original_shape = u_values.shape
    u_values_flat = u_values.view(-1, u_values.size(-1))  # Flatten to (num_u_values, n1)
    
    t_values = torch.linspace(-M, N, num_points).to(device).unsqueeze(0)  # [1, num_points]
    b = -( 3*t_values) /2
    a = 1 / 2
    tensor_0 = (b / a).to(device)  # [1, num_points]
    dt = (M + N) / num_points
    
    u_values_expanded = u_values_flat.unsqueeze(2)  # [num_u_values, n1, 1]
    
    positive_mask = u_values_expanded >= 0
    negative_mask = u_values_expanded < 0
    
    mask_positive = (t_values > 0) & (t_values <= u_values_expanded)  # [num_u_values, n1, num_points]
    mask_negative = (t_values <= 0) & (t_values > u_values_expanded)  # [num_u_values, n1, num_points]
    
    cumsum_positive = torch.cumsum(tensor_0 * mask_positive.float(), dim=-1) * dt
    cumsum_negative = torch.cumsum(tensor_0 * mask_negative.float(), dim=-1) * dt
    
    result_positive = cumsum_positive[:, :, -1]  # Get the last column
    result_negative = -cumsum_negative[:, :, -1]  # Negative part result is negated
    
    result = torch.where(positive_mask[:, :, -1], result_positive, result_negative)
    
    result = result.view(original_shape)
    return result



def f1(t, x, y, M, N, num_points=200):#100->200
    device = t.device
    u_values = torch.linspace(-M, N, num_points, device=device)
    delta_u = u_values[1] - u_values[0]
    integrand = torch.exp(c(u_values)) / a(u_values)

    def compute_integral(start, end):
        start = start.unsqueeze(-1)
        end = end.unsqueeze(-1)
        mask = (u_values >= start) & (u_values <= end)
        integral = torch.sum(integrand * mask, dim=-1) * delta_u
        return integral

    integral_y_N = compute_integral(y, torch.tensor(N, device=device))
    integral_M_x = compute_integral(torch.tensor(-M, device=device), x)
    integral_M_t = compute_integral(torch.tensor(-M, device=device), t)
    integral_t_N = compute_integral(t, torch.tensor(N, device=device))

    result = torch.zeros_like(t)
    condition1 = t <= x
    condition2 = (x < t) & (t < y)
    condition3 = y <= t

    result[condition1] = (torch.sqrt(integral_y_N[condition1] / integral_M_x[condition1]) * 
                          torch.sqrt(integral_M_t[condition1]))
    result[condition2] = torch.sqrt(integral_y_N[condition2])
    result[condition3] = torch.sqrt(integral_t_N[condition3])
    
    # Optionally clear more memory if necessary
    del integral_y_N, integral_M_x, integral_M_t, integral_t_N
    torch.cuda.empty_cache()  # Free up memory again if needed

    return result


def f2_minus(tensor, M, N, num_points=200, chunk_size=10, u_chunk_size=50):#numpoints100->200
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    u_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)
    t_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]
    results = []

    for i in range(0, x_tensor.shape[0], chunk_size):
        x_chunk = x_tensor[i:i + chunk_size]
        y_chunk = y_tensor[i:i + chunk_size]
        theta_chunk = theta_tensor[i:i + chunk_size]
        
        chunk_results = []
        for u_start in range(0, num_points, u_chunk_size):
            for t_start in range(0, num_points, u_chunk_size):
                # Create chunks for u_values and t_values
                u_chunk = u_values[u_start:u_start + u_chunk_size]
                t_chunk = t_values[t_start:t_start + u_chunk_size]
                
                inner1 = (torch.exp(c(u_chunk)) / a(u_chunk)) 
                inner2 =  torch.exp(-c(t_chunk))
                inner1_grid,inner2_grid = torch.meshgrid(inner1,inner2,indexing='ij')
                inner1_grid = inner1_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                inner2_grid = inner2_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)

                # Generate meshgrid for the chunked u and t values
                u_grid, t_grid = torch.meshgrid(u_chunk, t_chunk, indexing='ij')
                u_grid = u_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                t_grid = t_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                inner_integral_mask = (t_grid >= u_grid) & (t_grid <= theta_chunk.unsqueeze(-1).unsqueeze(-1))
                integrand_inner = inner1_grid * inner2_grid
                integrand_inner_masked = integrand_inner * inner_integral_mask
                
                # Compute f1 values for the chunked grid
                f1_values = f1(t_grid[:x_chunk.shape[0]], x_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), 
                               y_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), M, N, num_points)

                # Perform inner sum for this chunk
                inner_sum = torch.sum(integrand_inner_masked * f1_values, dim=(-2, -1)) * delta_u * delta_t
                chunk_results.append(inner_sum)


                # Clear GPU memory for the chunk
                del u_chunk, t_chunk, u_grid, t_grid, inner_integral_mask, integrand_inner, integrand_inner_masked, f1_values, inner_sum
                torch.cuda.empty_cache()
        
        # Concatenate results for this chunk of x, y, theta
        results.append(torch.sum(torch.stack(chunk_results), dim=0))
        
        # Clear memory for x, y, theta chunks
        del x_chunk, y_chunk, theta_chunk, chunk_results
        torch.cuda.empty_cache()

    return torch.cat(results, dim=0)
    
    
def f2_minus_re(tensor, M, N, num_points=200, chunk_size=10, u_chunk_size=50):#numpoints100->200
    x_tensor, y_tensor, theta_tensor = tensor[..., 0], tensor[..., 1], tensor[..., 2]
    device = tensor.device
    u_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)
    t_values = torch.linspace(-M, theta_tensor.max().item(), num_points, device=device)

    delta_u = u_values[1] - u_values[0]
    delta_t = t_values[1] - t_values[0]
    results = []

    for i in range(0, x_tensor.shape[0], chunk_size):
        x_chunk = x_tensor[i:i + chunk_size]
        y_chunk = y_tensor[i:i + chunk_size]
        theta_chunk = theta_tensor[i:i + chunk_size]
        
        chunk_results = []
        for u_start in range(0, num_points, u_chunk_size):
            for t_start in range(0, num_points, u_chunk_size):
                # Create chunks for u_values and t_values
                u_chunk = u_values[u_start:u_start + u_chunk_size]
                t_chunk = t_values[t_start:t_start + u_chunk_size]
                
                inner1 = (torch.exp(c(u_chunk)) / a(u_chunk)) 
                inner2 =  torch.exp(-c(t_chunk))
                inner1_grid,inner2_grid = torch.meshgrid(inner1,inner2,indexing='ij')
                inner1_grid = inner1.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], u_chunk_size)
                inner2_grid = inner2_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)

                # Generate meshgrid for the chunked u and t values
                u_grid, t_grid = torch.meshgrid(u_chunk, t_chunk, indexing='ij')
                u_grid = u_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                t_grid = t_grid.unsqueeze(0).unsqueeze(0).expand(x_chunk.shape[0], y_chunk.shape[1], theta_chunk.shape[2], -1, -1)
                inner_integral_mask = (t_grid >= u_grid) & (t_grid <= theta_chunk.unsqueeze(-1).unsqueeze(-1))
                integrand_inner_masked = inner2_grid * inner_integral_mask#inner mask
                
                # Compute f1 values for the chunked grid
                f1_values = f1(t_grid[:x_chunk.shape[0]], x_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), 
                               y_chunk.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, u_chunk_size, u_chunk_size), M, N, num_points)

                # Perform inner sum for this chunk
                inner_sum = torch.sum(integrand_inner_masked * f1_values, dim=-1) * delta_t

                inner_sum_result = torch.sum(inner_sum*inner1_grid,dim=-1)*delta_u
                chunk_results.append(inner_sum_result)
                
                # Clear GPU memory for the chunk
                del u_chunk, t_chunk, u_grid, t_grid, inner_integral_mask, integrand_inner_masked, f1_values, inner_sum
                torch.cuda.empty_cache()

        
        # Concatenate results for this chunk of x, y, theta
        results.append(torch.sum(torch.stack(chunk_results), dim=0))
       
        # Clear memory for x, y, theta chunks
        del x_chunk, y_chunk, theta_chunk, chunk_results
        torch.cuda.empty_cache()
        

    return torch.cat(results, dim=0)
    
    
    
# GPU CHOOSE
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#####################################################################################################################################################   
# Define x, y, and theta values
x_values = torch.linspace(-M, N, n0).to(device)
y_values = torch.linspace(-M, N, n0).to(device)
theta_values = torch.linspace(-M, N, n1).to(device)

# Create tensor for calculating f1, f2_minus, and f2_plus values
tensor = torch.stack([
    x_values.unsqueeze(-1).unsqueeze(-1).expand(n0, n0, n1),
    y_values.unsqueeze(0).unsqueeze(-1).expand(n0, n0, n1),
    theta_values.unsqueeze(0).unsqueeze(0).expand(n0, n0, n1)
], dim=-1).to(device)

import time

# Ensure synchronization with GPU before starting
torch.cuda.synchronize()

# Record start time for f2_minus
start_time = time.time()
old = f2_minus(tensor, M, N).to(device)
torch.cuda.synchronize()  # Ensure all CUDA operations are complete
end_time = time.time()

# Print f2_minus computation time
print(f"f2_minus computation time: {end_time - start_time} seconds")

# Record start time for f2_minus_re
start_time = time.time()
new = f2_minus_re(tensor, M, N).to(device)
torch.cuda.synchronize()  # Ensure all CUDA operations are complete
end_time = time.time()

# Print f2_minus_re computation time
print(f"f2_minus_re computation time: {end_time - start_time} seconds")

are_equal = torch.allclose(old,new)
print(f"Are the results equal? {are_equal}")