import torch
import numpy as np
import math
device = 'cuda'

def create_memory(memory_length, interval_indices, data, times, device='cuda'):
    """
    Vectorized: Extracts memory slices ending at interval_indices (inclusive),
    starting from `interval_indices - memory_length + 1`, with left-side padding
    using the first available value and 0s for time.
    Returns (B, 2 * memory_length).
    """
    batch_size, seq_len = data.shape
    interval_indices = interval_indices.to(device)
    mem_start = torch.clamp(interval_indices - memory_length + 1, min=0)
    lengths = interval_indices - mem_start + 1  # actual available length for each row

    # Create index range for right-aligned filling
    index_range = torch.arange(memory_length, device=device).unsqueeze(0).expand(batch_size, -1)  # (B, M)

    # Compute padding offset: how many leftmost positions to fill with padding
    pad_left = memory_length - lengths  # (B,)
    gather_indices = mem_start.unsqueeze(1) + index_range - pad_left.unsqueeze(1)  # shift for right-align

    # Clamp indices to start index to use first available value for padding
    gather_indices = torch.maximum(gather_indices, mem_start.unsqueeze(1))  # no index before mem_start

    # Batch indexing
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(gather_indices)

    # Gather memory and times
    data_mem = data[batch_indices, gather_indices]
    time_mem = times[batch_indices, gather_indices]

    # Padding for time: zero for left-pad positions
    time_pad_mask = gather_indices == mem_start.unsqueeze(1)
    time_mem = torch.where(time_pad_mask & (index_range < pad_left.unsqueeze(1)),
                           torch.zeros_like(time_mem), time_mem)

    # Concatenate along feature dimension
    mem = torch.cat((data_mem, time_mem), dim=1)  # shape (B, 2 * memory_length)
    return mem



def create_memory_sampling(index, times, memory_length, traj, disc_steps):
    indices = times[times <= index]
    indices = indices[-memory_length:]
    mem_cand = traj[:,::disc_steps, :]  # Take memory in steps of size `disc_steps`
    mem = mem_cand[:,indices]
    if mem.shape[1] < memory_length:
        pad_length = memory_length - mem.shape[1]
        padding = torch.ones(traj.shape[0], pad_length, 1, device=traj.device)*mem[:,:1,:]
        mem = torch.cat((padding, mem), dim=1)
        padding = torch.zeros(pad_length, device=traj.device)
        indices = torch.cat((padding, indices),dim=0)
    mem = torch.cat((mem.squeeze(),indices.repeat(traj.shape[0],1)),1)
    return mem


def loss_calc_sde(data, times, net, memory_length, sigma = 0.1, rho = 0.01):
    no_timesteps = times.shape[1]
    batch_size = data.shape[0]
    
    random_times = torch.rand(batch_size, 1,device=device) * (no_timesteps)
    interval_indices = torch.floor(random_times).long()  
    interval_indices = torch.clamp(interval_indices, 0, no_timesteps - 2)
    interval_indices = interval_indices.squeeze(1) 
 
    datap1 = data[torch.arange(batch_size, device=device), interval_indices].unsqueeze(1) 
    datap2 = data[torch.arange(batch_size, device=device), interval_indices + 1].unsqueeze(1)
    t1 = times[torch.arange(batch_size, device=device), interval_indices].unsqueeze(1) 
    t2 = times[torch.arange(batch_size, device=device), interval_indices + 1].unsqueeze(1)  
    t = (random_times % 1).to(device)*(t2-t1)+t1
    
    mem = create_memory(memory_length, interval_indices,data, times)

    xt =  (t2-t)/(t2-t1)*datap1 + (t-t1)/(t2-t1)*datap2 + torch.sqrt(sigma*(t-t1)*(t2-t)/(t2-t1)**2+rho)*torch.randn_like(datap1)
    inpu = torch.cat((t2,t, xt,mem),1)
    vel_out = net(inpu)
    taut = sigma*(t-t1)*(t2-t)/(t2-t1)**2+rho
    mt = (t2-t)/(t2-t1)*datap1 + (t-t1)/(t2-t1)*datap2
    factor_vel = sigma * ((t1+t2-2*t)/(t2-t1)**2-1)/(2*taut)
    vel = (datap2-datap1)/(t2-t1) + factor_vel * (xt-mt)

    loss = torch.sum((vel_out-(vel))**2)/batch_size 

    return loss

def euler(net, no_samples, no_timesteps, times,disc_steps, memory_length, sigma, rho, initial_std = 1., dimension = 1):
    x = torch.randn(no_samples, dimension, device=device)*initial_std
    traj = torch.zeros((no_samples, disc_steps * no_timesteps, dimension), device=device)
    times = torch.cat((torch.zeros(1, device=device), times,torch.ones(1, device=device)*no_timesteps)).to(torch.int64)
    for j in range(no_timesteps):
        t2 = times[(j<times)][0]*torch.ones(no_samples, 1, device=device)
        for i in range(disc_steps):
            traj[:, j * disc_steps + i, :] = x  # Store the current value of `x` in the trajectory

            t = (i / disc_steps) * torch.ones(no_samples, 1, device=device)
            h = 1 / disc_steps * torch.ones(no_samples, 1, device=device)

            mem = create_memory_sampling(j, times, memory_length, traj, disc_steps)
            time = (torch.ones_like(t) * j + (t)) 
            inpu = torch.cat((t2, time,x, mem.reshape(no_samples, -1)), 1)
            x = x + h * net(inpu)+ torch.sqrt(h)*math.sqrt(sigma)*torch.randn_like(x)
            #x=torch.clamp(x,0,1000)

    return traj




def random_times(no_timesteps, subsample_time, random_seed = None, device="cuda", equidist="False"):
    if equidist == "True":
        jump_every = no_timesteps//subsample_time
        sorted_perm = torch.arange(0, no_timesteps, device=device)[::jump_every]
    else:
        if random_seed is not None:
            torch.manual_seed(random_seed)
        perm = torch.randperm(no_timesteps, device=device)[:subsample_time-1]
        sorted_perm = perm.sort().values
    return sorted_perm
