"""
Utilities for the jump-diffusion model, adapted for image data.

This file contains the core mathematical logic for:
1.  IKL Loss Calculation: Computes the loss between the network's predictions
    and the ground truth distribution parameters, all on a per-pixel basis.
2.  Memory Creation: Constructs the memory tensor by stacking previous frames.
3.  Euler-Maruyama Sampler: Generates new video trajectories from the model.
"""
import torch
import torch.nn.functional as F
import torch.nn as nn

class LossIKL(nn.Module):
    """Computes the IKL loss entrywise for image data."""
    def __init__(self, tol=1e-6):
        super().__init__()
        self.tol = tol

    def forward(self, lambda_cond, lambda_net, mean_gt, mean_net, sig_gt, sig_net):
        var_gt = sig_gt**2
        var_net = sig_net**2
        
        loss_lambd = -lambda_cond + lambda_net + lambda_cond * (torch.log(lambda_cond.clamp(min=self.tol)) - torch.log(lambda_net.clamp(min=self.tol)))
        loss_var = lambda_cond * 0.5 * (torch.log(var_net.clamp(min=self.tol)) - torch.log(var_gt.clamp(min=self.tol)) + (var_gt - var_net) / (var_net + self.tol))
        loss_mean = lambda_cond * 0.5 * ((mean_gt - mean_net)**2 / (var_net + self.tol))
        
        return (loss_lambd + loss_var + loss_mean).mean()

class LossNaive(nn.Module):
    """Computes the Naive MSE loss entrywise for image data."""
    def __init__(self):
        super().__init__()
    
    def forward(self, lambda_cond, lambda_net, mean_gt, mean_net, sig_gt, sig_net):
        loss = F.mse_loss(lambda_net, lambda_cond)
        loss += F.mse_loss(mean_net, mean_gt)
        loss += F.mse_loss(sig_net, sig_gt)
        return loss

def create_memory(memory_length, interval_indices, data, times, device='cuda'):
    """
    Creates a memory tensor from a batch of image sequences using vectorized
    advanced indexing. This is a fast and loop-free approach.
    """
    batch_size, _, channels, H, W = data.shape

    # Calculate the start index for the memory window for each item in the batch
    mem_start = torch.clamp(interval_indices - memory_length + 1, min=0)

    # Create a range of indices for the memory window [0, 1, ..., memory_length-1]
    index_range = torch.arange(memory_length, device=device).unsqueeze(0)

    # Calculate how much padding is needed on the left for each item
    lengths = interval_indices - mem_start + 1
    pad_left = memory_length - lengths

    # Create the indices to gather from the data tensor's time dimension.
    # This cleverly handles padding by repeatedly selecting the first available frame.
    gather_indices = torch.clamp(mem_start.unsqueeze(1) + index_range - pad_left.unsqueeze(1), min=0)

    # Create batch indices to select from the batch dimension
    # Shape becomes [batch_size, 1] to enable broadcasting
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)

    # Use advanced indexing to gather all memory frames.
    # The result might be non-contiguous.
    mem_frames = data[batch_indices, gather_indices]

    # Force the tensor to be contiguous in memory to prevent shape errors later.
    return mem_frames.contiguous()


def loss_calc_jump(data, times, net, memory_length, sigma=0.1, rho=0.1, loss_function='ikl', device='cuda', return_debug_tensors=False):
    """Calculates the loss for a batch of image jump data."""
    loss_func = LossIKL() if loss_function == 'ikl' else LossNaive()

    batch_size, no_sub_timesteps, channels, H, W = data.shape
    
    random_times_ = torch.rand(batch_size, 1, 1, 1, device=device) * (no_sub_timesteps - 1)
    interval_indices = torch.clamp(torch.floor(random_times_).long().squeeze(), 0, no_sub_timesteps - 2)

    idx_range = torch.arange(batch_size, device=device)
    a = data[idx_range, interval_indices]
    b = data[idx_range, interval_indices + 1]

    # Correctly extract scalar time values and reshape for broadcasting
    t1 = times[idx_range, interval_indices, 0, 0, 0].view(-1, 1, 1, 1)
    t2 = times[idx_range, interval_indices + 1, 0, 0, 0].view(-1, 1, 1, 1)
    
    t = (random_times_.view(-1, 1, 1, 1) % 1) * (t2 - t1) + t1

    mem = create_memory(memory_length, interval_indices, data, times, device)
    
    xt_noise = torch.sqrt(sigma * (t - t1) * (t2 - t) / (t2 - t1)**2 + rho) * torch.randn_like(a)
    xt = (t2 - t) / (t2 - t1) * a + (t - t1) / (t2 - t1) * b + xt_noise

    # Reshape the 5D memory tensor [B, T_mem, C, H, W] into a 4D tensor
    # by collapsing the memory and channel dimensions: [B, T_mem * C, H, W]
    mem_reshaped = mem.reshape(batch_size, -1, H, W)
    inpu = torch.cat((xt, mem_reshaped), 1)
    time_in = torch.cat((t2.squeeze(), t.squeeze()), dim=0).view(-1)
    #print("inpu,time", inpu.shape, time_in.shape)
    out = net(inpu, time_in)
    lambda_net_raw, mean_net, sig_net_raw = out.split(1, dim=1)
    lambda_net = torch.exp(lambda_net_raw)
    sig_net = torch.exp(sig_net_raw)

# --- Ground Truth Calculations (now entrywise) ---
    tau = sigma * (t - t1) * (t2 - t) / (t2 - t1)**2 + rho
    m = (t2 - t) / (t2 - t1) * a + (t - t1) / (t2 - t1) * b

    zw = (b - a) * (t2-t1) * torch.sqrt(tau) / (sigma * (2 * t - (t2 + t1)) + 1e-8)
    sq_zw_1 = torch.sqrt(zw**2 + 1)
    u_p, u_n = zw + sq_zw_1, zw - sq_zw_1

    # Integrals I_0, I_1, I_2, etc. are now N-dimensional
    sq_two_pi = torch.sqrt(torch.tensor(2.0 * torch.pi))
# Compute I_0, I_1, and I_2
    I_0 = 0.5 * (torch.erf(u_p / 2**0.5) - torch.erf(u_n / 2**0.5))
    I_1 = 1.0 / (sq_two_pi) * (torch.exp(-u_n**2 / 2) - torch.exp(-u_p**2 / 2))
    I_2 = I_0 - 1.0 / (sq_two_pi) * (u_p * torch.exp(-u_p**2 / 2) - u_n * torch.exp(-u_n**2 / 2))
    I_3 = 2*I_1 - 1.0 / (sq_two_pi) * (u_p**2 * torch.exp(-u_p**2 / 2) - u_n**2 * torch.exp(-u_n**2 / 2))
    I_4 = 3*I_2 - 1.0 / (sq_two_pi) * (u_p**3 * torch.exp(-u_p**2 / 2) - u_n**3 * torch.exp(-u_n**2 / 2))

    
    # Apply conditional transformations
    mid_time = (t2+t1)/2
    I_0 = torch.where(t > mid_time , I_0, 1- I_0)
    I_1 = torch.where(t > mid_time, I_1, -I_1)
    I_2 = torch.where(t > mid_time, I_2,  1-I_2)
    I_3 = torch.where(t > mid_time, I_3, - I_3)
    I_4 = torch.where(t > mid_time, I_4, 3- I_4)
    sig_term = sigma*(2*t-(t2+t1))/((t2-t1)*2)
    # Ground truth lambda, mean, and sigma
    term = sigma * (t2 + t1 - 2 * t) / (2 * (t2 - t1)**2) - 0.5 * sigma * (t2 + t1 - 2 * t) * (xt - m)**2 / ((t2 - t1)**2 * tau) - (xt - m) * (b - a) / (t2 - t1)
    lambda_cond = F.relu(term) / tau + 1e-10
    nenner = (b-a)*torch.sqrt(tau) *I_1- sig_term*(I_2-I_0)
    if torch.where(nenner<1e-8,1,0).sum() >0: 
        print("\33[31m OMG nenner<1e-8", "\33[0m",end="")
        nenner += 1e-8
    mean_gt = m + torch.sqrt(tau) * ((b-a)*torch.sqrt(tau) *I_2- sig_term*(I_3-I_1))/nenner 
    #print("mean_gt",mean_gt.sum())
    # Variance calculation is complex, so we approximate or simplify for stability
    # A simple stable choice is to use a fraction of the bridge variance
    var_gt = tau * (((b-a)*torch.sqrt(tau) *I_3- sig_term*(I_4-I_2))/nenner)-(m-mean_gt)**2
    var_gt_clamp = var_gt.squeeze(1).clamp(0,1000)
    sig_gt = torch.sqrt(var_gt_clamp)
    loss = loss_func(lambda_cond, lambda_net, mean_gt, mean_net, sig_gt, sig_net)

    if return_debug_tensors:
        return loss, mem, b
    
    return loss


def euler(net, no_samples, no_timesteps, times, disc_steps, memory_length, image_size, sigma=0.1, rho=0.1,rounding = False, device='cuda'):
    """
    Euler-Maruyama sampler for image trajectories. Correctly handles memory
    by keeping it fixed during the simulation between observed time points.
    """
    #print("times 0 and end", times[0], times[-1])
    # Initialize the starting frame by mimicking the data generation process
    cube_size = 3 # As defined in config.yaml
    x = torch.zeros(no_samples, 1, image_size, image_size, device=device)
    for i in range(no_samples):
        rand_x = torch.randint(0, image_size - cube_size, (1,)).item()
        rand_y = torch.randint(0, image_size - cube_size, (1,)).item()
        x[i, 0, rand_y:rand_y + cube_size, rand_x:rand_x + cube_size] = 1.0

    # Trajectory tensor to store all intermediate simulation steps
    total_sim_steps = len(times) * disc_steps
    traj = torch.zeros((no_samples, total_sim_steps, 1, image_size, image_size), device=device)
    
    subsampled_times = times
    subsampled_times = torch.cat((subsampled_times,torch.tensor([len(times)], device=device)))
    
    # Outer loop over the observed time intervals
    for j in range(len(subsampled_times)-1):
        #print("j", j)
        t2_val = subsampled_times[j+1]
        t1_val = subsampled_times[j] if j > 0 else 0
        
        t2 = t2_val * torch.ones(no_samples, device=device)

        # --- FIX: Create memory ONCE per interval ---
        # Get the history of previously generated *observed* frames
        history_observed = traj[:, disc_steps-1:j*disc_steps:disc_steps]
        num_history = history_observed.shape[1]
        
        # Get the last `memory_length` frames from history
        start_idx = max(0, num_history - memory_length)
        mem = history_observed[:, start_idx:]
        
        # Pad if necessary at the beginning of the sequence
        num_mem_frames = mem.shape[1]
        if num_mem_frames < memory_length:
            # --- FIX: Unsqueeze x to 5D to match mem tensor ---
            padding_frame = x.unsqueeze(1) if num_history == 0 else history_observed[:, 0:1]
            padding = padding_frame.repeat(1, memory_length - num_mem_frames, 1, 1, 1)
            mem = torch.cat([padding, mem], dim=1)

        mem_reshaped = mem.reshape(no_samples, -1, image_size, image_size)
        if rounding:
            #print("rounding",j)
            reshaped = mem_reshaped.reshape(mem_reshaped.shape[0], memory_length, -1)
            topk = torch.topk(reshaped, 9, dim=2).values
            thresholds = topk[:,:, -1]
            thresholds_reshaped = thresholds.reshape(mem_reshaped.shape[0], memory_length, 1, 1)
            mem_reshaped = (mem_reshaped >= thresholds_reshaped).float()
            #print("mem_reshaped", mem_reshaped[0,-1])
        # Inner loop for simulating the path within the interval
        for i in range(disc_steps):
            idx = j * disc_steps + i
            if idx >= traj.shape[1]: continue
            # Store the current state `x` before updating it
            traj[:, idx] = x
            
            t_val = t1_val + (i+1 / disc_steps) * (t2_val - t1_val)
            t = t_val * torch.ones(no_samples, device=device)
            h = 1 / disc_steps

            # Use the fixed memory for all steps in this interval
            inpu = torch.cat([x, mem_reshaped], dim=1)
            time_in = torch.cat([t2, t], dim=0)

            out = net(inpu, time_in)
            lambda_net_raw, mean_net, sig_net_raw = out.split(1, dim=1)
            lambda_net = torch.exp(lambda_net_raw)
            sig_net = torch.exp(sig_net_raw)
            #print("mean_net", mean_net.sum()) 
            z = mean_net + sig_net * torch.randn_like(mean_net)
            rt = torch.exp(-lambda_net * h).clamp(0, 1)
            #print("rt", rt.sum())
            m = torch.bernoulli(1 - rt)
            x = (1 - m) * x + m * z
            
    # Return the final trajectory, subsampled at the observed time points
    return traj[:, :]


