"""
Utilities for the SDE-based model.

Contains the loss function and Euler-Maruyama sampler for a continuous
stochastic differential equation process, adapted for image data.
"""
import torch
import torch.nn.functional as F
import math

# Memory creation is generic and can be re-used.
# In a real project, this might be moved to a shared 'utils.py'.
def create_memory(memory_length, interval_indices, data, device='cuda'):
    """
    Creates a memory tensor from a batch of image sequences using vectorized
    advanced indexing.
    """
    batch_size, _, channels, H, W = data.shape
    mem_start = torch.clamp(interval_indices - memory_length + 1, min=0)
    index_range = torch.arange(memory_length, device=device).unsqueeze(0)
    lengths = interval_indices - mem_start + 1
    pad_left = memory_length - lengths
    gather_indices = torch.clamp(mem_start.unsqueeze(1) + index_range - pad_left.unsqueeze(1), min=0)
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1)
    mem_frames = data[batch_indices, gather_indices]
    return mem_frames.contiguous()

def loss_calc_sde(data, times, net, memory_length, sigma=0.1, rho=0.01, device='cuda'):
    """
    Calculates the score-matching loss for the SDE model on image data.
    """
    batch_size, no_sub_timesteps, channels, H, W = data.shape
    
    random_times_ = torch.rand(batch_size, 1, 1, 1, device=device) * (no_sub_timesteps - 1)
    interval_indices = torch.clamp(torch.floor(random_times_).long().squeeze(), 0, no_sub_timesteps - 2)

    idx_range = torch.arange(batch_size, device=device)
    a = data[idx_range, interval_indices]
    b = data[idx_range, interval_indices + 1]

    t1 = times[idx_range, interval_indices, 0, 0, 0].view(-1, 1, 1, 1)
    t2 = times[idx_range, interval_indices + 1, 0, 0, 0].view(-1, 1, 1, 1)
    
    t = (random_times_.view(-1, 1, 1, 1) % 1) * (t2 - t1) + t1

    mem = create_memory(memory_length, interval_indices, data, device=device)
    
    xt = (t2 - t) / (t2 - t1) * a + (t - t1) / (t2 - t1) * b + torch.sqrt(sigma * (t - t1) * (t2 - t) / (t2 - t1)**2 + rho) * torch.randn_like(a)
    
    # Prepare inputs for the U-Net
    mem_reshaped = mem.reshape(batch_size, -1, H, W)
    inpu = torch.cat((xt, mem_reshaped), 1)
    time_in = torch.cat((t2.squeeze(), t.squeeze()), dim=0).view(-1)
    #time_in = torch.cat((t.squeeze(), t.squeeze()), dim=0).view(-1)
    # Get the network's predicted velocity/drift
    x1_pred= net(inpu, time_in)
    
    # Calculate the ground truth velocity/drift
    #taut = sigma * (t - t1) * (t2 - t) / (t2 - t1)**2 + rho
    #mt = (t2 - t) / (t2 - t1) * a + (t - t1) / (t2 - t1) * b
    #factor_vel = sigma * ((t1 + t2 - 2 * t) / (t2 - t1)**2 - 1) / (2 * taut)
    #vel = (b - a) / (t2 - t1) + factor_vel * (xt - mt)

    loss = F.mse_loss(x1_pred, b)
    return loss

def euler_sde(net, no_samples, no_timesteps, times_eval, disc_steps, memory_length, image_size, sigma, device='cuda'):
    """
    Euler-Maruyama sampler for the SDE model on image data.
    """
    # Initialize with the same "random cube" distribution as the data
    cube_size = 3
    x = torch.zeros(no_samples, 1, image_size, image_size, device=device)
    for i in range(no_samples):
        rand_x = torch.randint(0, image_size - cube_size, (1,)).item()
        rand_y = torch.randint(0, image_size - cube_size, (1,)).item()
        x[i, 0, rand_y:rand_y + cube_size, rand_x:rand_x + cube_size] = 1.0

    traj = torch.zeros((no_samples, len(times_eval), 1, image_size, image_size), device=device)
    full_history = torch.zeros((no_samples, no_timesteps * disc_steps, 1, image_size, image_size), device=device)

    observed_idx = 0
    for j in range(no_timesteps * disc_steps):
        # Store current state in full history
        full_history[:, j] = x

        # If current step is an observed time point, store it in the final trajectory
        if observed_idx < len(times_eval) and j / disc_steps >= times_eval[observed_idx]:
             traj[:, observed_idx] = x
             observed_idx += 1

        # Determine current interval t1 and t2
        t1_val = j / disc_steps
        t2_val = no_timesteps # The final endpoint
        
        t = t1_val * torch.ones(no_samples, device=device)
        t2 = t2_val * torch.ones(no_samples, device=device)
        h = 1 / disc_steps

        # Create memory from the history of observed frames
        # This part requires careful handling of indices
        start_mem = max(0, j - memory_length)
        mem_frames = full_history[:, start_mem:j] if j > 0 else torch.zeros(no_samples, 0, 1, image_size, image_size, device=device)
        
        num_past_frames = mem_frames.shape[1]
        if num_past_frames < memory_length:
            padding = torch.zeros(no_samples, memory_length - num_past_frames, 1, image_size, image_size, device=device)
            mem_frames = torch.cat([padding, mem_frames], dim=1)
        
        mem_reshaped = mem_frames.reshape(no_samples, -1, image_size, image_size)
        time_in = torch.cat([t2, t], dim=0)
        inpu = torch.cat([x, mem_reshaped], dim=1)
        
        # SDE Euler-Maruyama step
        x1_pred = net(inpu, time_in)
        drift = (x1_pred-x)/(torch.abs(t2-t))
        noise = torch.sqrt(torch.tensor(h)) * torch.sqrt(torch.tensor(sigma)) * torch.randn_like(x)
        x = x + h * drift + noise
        
    return traj

