import torch
import numpy as np
import torch.nn.functional as F
import torch.distributions as D
import torch.nn as nn
device = 'cuda'


def create_memory(memory_length, interval_indices, data, times, device='cuda'):
  
    batch_size, seq_len = data.shape
    interval_indices = interval_indices.to(device)
    mem_start = torch.clamp(interval_indices - memory_length + 1, min=0)
    lengths = interval_indices - mem_start + 1  # actual available length for each row

    # Create index range for right-aligned filling
    index_range = torch.arange(memory_length, device=device).unsqueeze(0).expand(batch_size, -1)  # (B, M)

    # Compute padding offset: how many leftmost positions to fill with padding
    pad_left = memory_length - lengths  # (B,)
    gather_indices = mem_start.unsqueeze(1) + index_range - pad_left.unsqueeze(1)  # shift for right-align

    # Clamp indices to start index to use first available value for padding
    gather_indices = torch.maximum(gather_indices, mem_start.unsqueeze(1))  # no index before mem_start

    # Batch indexing
    batch_indices = torch.arange(batch_size, device=device).unsqueeze(1).expand_as(gather_indices)

    # Gather memory and times
    data_mem = data[batch_indices, gather_indices]
    time_mem = times[batch_indices, gather_indices]

    # Padding for time: zero for left-pad positions
    time_pad_mask = gather_indices == mem_start.unsqueeze(1)
    time_mem = torch.where(time_pad_mask & (index_range < pad_left.unsqueeze(1)),
                           torch.zeros_like(time_mem), time_mem)

    # Concatenate along feature dimension
    mem = torch.cat((data_mem, time_mem), dim=1)  # shape (B, 2 * memory_length)
    return mem



def create_memory_sampling(index, times, memory_length, traj, disc_steps):
        
    indices = times[times <= index]
    indices = indices[-memory_length:]
    mem_cand = traj[:,::disc_steps, :]  # Take memory in steps of size `disc_steps`
    mem = mem_cand[:,indices]
    if mem.shape[1] < memory_length:
        pad_length = memory_length - mem.shape[1]
        padding = torch.ones(traj.shape[0], pad_length, 1, device=traj.device)*mem[:,:1,:]
        mem = torch.cat((padding, mem), dim=1)
        padding = torch.zeros(pad_length, device=traj.device)
        indices = torch.cat((padding, indices),dim=0)
    mem = torch.cat((mem.squeeze(),indices.repeat(traj.shape[0],1)),1)
    return mem
# Loss comparing lambda, mu, sigma
class loss_naive(nn.Module):
    def __init__(self):
        super().__init__()
        

    def forward(self, lambd_t_cond, lambd_t_net, mean_gt, mean_net, sig_gt, sig_net):
        loss = ((lambd_t_cond-lambd_t_net)**2).mean()
        
        loss += torch.mean((mean_net-mean_gt)**2)
        
        loss += torch.mean((sig_net-sig_gt)**2)
        return loss


#IKL Loss
class loss_ikl(nn.Module):
    def __init__(self, tol=1e-6):
        super().__init__()
        self.tol = tol

    def forward(self, lambda_t_cond, lambd_t_net, mean_gt, mean_net, sig_gt, sig_net):

        var_gt = sig_gt**2
        var_net = sig_net**2
        loss_lambd = -lambda_t_cond + lambd_t_net + lambda_t_cond * (torch.log(lambda_t_cond.clamp(min=self.tol)) - torch.log(lambd_t_net.clamp(min=self.tol)))  
        loss_lambd = loss_lambd.mean()
        loss_var = lambda_t_cond * 0.5 * (torch.log(var_net.clamp(min=self.tol)) - torch.log(var_gt.clamp(min=self.tol))  +  (var_gt-var_net)/(var_net + self.tol)) 
        loss_var = loss_var.mean()
        loss_mean = lambda_t_cond * 0.5 * ( (mean_gt-mean_net)**2/(var_net + self.tol))
        loss_mean = loss_mean.mean()
        loss = (loss_lambd + loss_var + loss_mean).mean()
        return loss



def loss_calc_jump(data, times, net, memory_length, sigma = 0.1,rho=0.1, stack_number = 50,start=-40,end=40, loss_function = 'naive'):
   
    if loss_function == 'naive':
        loss_func = loss_naive()
    elif loss_function == 'ikl':
        loss_func = loss_ikl()

    no_timesteps = times.shape[1]
    batch_size = data.shape[0]
    random_times = torch.rand(batch_size, 1,device=device) * (no_timesteps)
    interval_indices = torch.floor(random_times).long()  
    interval_indices = torch.clamp(interval_indices, 0, no_timesteps - 2)
    interval_indices = interval_indices.squeeze(1) 
 
    a = data[torch.arange(batch_size, device=device), interval_indices].unsqueeze(1) 
    b = data[torch.arange(batch_size, device=device), interval_indices + 1].unsqueeze(1)
    t1 = times[torch.arange(batch_size, device=device), interval_indices].unsqueeze(1) 
    t2 = times[torch.arange(batch_size, device=device), interval_indices + 1].unsqueeze(1)  
    t = (random_times % 1).to(device)*(t2-t1)+t1  # Map random_times to [0, 1] for interpolation
    mem = create_memory(memory_length, interval_indices,data, times)

    #print(mem[0],a[0])
    xt =  (t2-t)/(t2-t1)*a + (t-t1)/(t2-t1)*b + torch.sqrt(sigma*(t-t1)*(t2-t)/(t2-t1)**2+rho)*torch.randn_like(a)
    sq_two_pi = torch.sqrt(torch.tensor(2.0 * torch.pi, dtype=torch.float32))
    


    time = t # Time normalized within [0, no_timesteps]
    time_diff = t2-t1

    # Create network output
    inpu = torch.cat((t2,time, xt,mem),1)
    out = net(inpu)
    lambd_t = torch.exp(out[:,0]) #jump intensit
    mean = out[:,1] #mean
    sig = torch.exp(out[:,2])

    ##############################
    # Create gt lambda, mu, sigma
    
    # Auxiliary values
    tau = sigma*(t-t1)*(t2-t)/(t2-t1)**2+rho
    m = (t2-t)/(t2-t1)*a + (t-t1)/(t2-t1)*b
    
    
    # Compute the Is
    zw= (b-a)*(time_diff)*torch.sqrt(tau)/(sigma*(2*t-(t2+t1))+1e-8)

    # Compute u_p and u_n
    u_p = zw + torch.sqrt(zw**2 + 1 )
    u_n = zw - torch.sqrt(zw**2 + 1)
    
    # Compute I_0, I_1, and I_2
    I_0 = 0.5 * (torch.erf(u_p / 2**0.5) - torch.erf(u_n / 2**0.5))
    I_1 = 1.0 / (sq_two_pi) * (torch.exp(-u_n**2 / 2) - torch.exp(-u_p**2 / 2))
    I_2 = I_0 - 1.0 / (sq_two_pi) * (u_p * torch.exp(-u_p**2 / 2) - u_n * torch.exp(-u_n**2 / 2))
    I_3 = 2*I_1 - 1.0 / (sq_two_pi) * (u_p**2 * torch.exp(-u_p**2 / 2) - u_n**2 * torch.exp(-u_n**2 / 2))
    I_4 = 3*I_2 - 1.0 / (sq_two_pi) * (u_p**3 * torch.exp(-u_p**2 / 2) - u_n**3 * torch.exp(-u_n**2 / 2))

    
    # Apply conditional transformations
    mid_time = (t2+t1)/2
    I_0 = torch.where(t > mid_time , I_0, 1- I_0)
    I_1 = torch.where(t > mid_time, I_1, -I_1)
    I_2 = torch.where(t > mid_time, I_2,  1-I_2)
    I_3 = torch.where(t > mid_time, I_3, - I_3)
    I_4 = torch.where(t > mid_time, I_4, 3- I_4)

    #if torch.isnan(I_0).any() or torch.isnan(I_1).any() or torch.isnan(I_2).any() or torch.isnan(I_3).any() or torch.isnan(I_4).any():
    #    print("\33[31m OMG I_0 nan", "\33[0m")
    #    print("\33[31m OMG I_1 nan", "\33[0m")
    #    print("\33[31m OMG I_2 nan", "\33[0m")
    #    print("\33[31m OMG I_3 nan", "\33[0m")
    #    print("\33[31m OMG I_4 nan", "\33[0m")
    #    exit()
    
    # Compute gt mu, sigma, i.e. mean_gt and var_gt
    sig_term = sigma*(2*t-(t2+t1))/((t2-t1)*2)
    #if torch.max(sig_term.abs()) >1000:
    #    print("\33[31m OMG sig_term > 1000", "\33[0m")
    

    nenner = (b-a)*torch.sqrt(tau) *I_1- sig_term*(I_2-I_0)
    if torch.where(nenner<1e-8,1,0).sum() >0: 
        print("\33[31m OMG nenner<1e-8", "\33[0m",end="")
        nenner += 1e-8
    mean_gt = m + torch.sqrt(tau) * ((b-a)*torch.sqrt(tau) *I_2- sig_term*(I_3-I_1))/nenner 
    #if torch.max(mean_gt.abs()) >1000:
    #    print("\33[31m OMG mean_gt > 1000", "\33[0m")
    #if torch.isnan(mean_gt).any():
    #    print("\33[31m OMG mean_gt nan", "\33[0m")
    #    exit()

    
    var_gt = tau * (((b-a)*torch.sqrt(tau) *I_3- sig_term*(I_4-I_2))/nenner)-(m-mean_gt)**2
    var_gt_clamp = var_gt.squeeze(1).clamp(0,1000)
    sig_gt_clamp = torch.sqrt(var_gt_clamp)
    
    #if torch.isnan(sig_gt_clamp).any():
    #    print("\33[31m OMG sig_gt_clamp nan", "\33[0m")
    #    exit()
    #if torch.max(var_gt.abs()) >1000:
    #    print("\33[31m OMG var_gt > 1000",var_gt, "\33[0m")
    #    exit()
    
  
    term = sigma*(t2+t1-2*t)/(2*(t2-t1)**2)  -.5*sigma*(t2+t1-2*t)*(xt-m)**2/((t2-t1)**2*tau )-(xt-m)*(b-a)/(t2-t1)
    lambd_t_cond = F.relu(term) /(tau)+1e-10 
    #if torch.isnan(lambd_t_cond).any():
    #    print("\33[31m OMG lambd_t_cond nan",lambd_t_cond, "\33[0m")

    loss = loss_func(lambd_t_cond.squeeze(1), lambd_t, mean_gt.squeeze(1), mean, sig_gt_clamp, sig)
    
    return loss


def euler(net, no_samples, no_timesteps,times, disc_steps, memory_length,sigma = 0.1,rho=0.1,initial_std = 1., dimension = 1):
    x = torch.randn(no_samples, dimension, device=device)
    #x+=0.01*torch.randn_like(x)
    traj = torch.zeros((no_samples, disc_steps * no_timesteps, dimension), device=device)
    
    times = torch.cat((torch.zeros(1, device=device), times,torch.ones(1, device=device)*no_timesteps)).to(torch.int64)
    for j in range(no_timesteps):
        t2 = times[(j<times)][0]*torch.ones(no_samples, 1, device=device)
        for i in range(disc_steps):
            traj[:, j * disc_steps + i, :] = x  

            t =  (i / disc_steps) * torch.ones(no_samples, 1, device=device)
            h = 1 / disc_steps * torch.ones(no_samples, 1, device=device)

            
            # Get memory from previous steps
            #start_idx = max(0, j * disc_steps  - disc_steps * memory_length)
            mem = create_memory_sampling(j, times, memory_length, traj, disc_steps)
           
            time = (torch.ones_like(t) * j + (t))
            inpu = torch.cat((t2,time, x, mem.reshape(no_samples, -1)), 1)
            out = net(inpu)
            lambd_t = torch.exp(out[:,0].unsqueeze(1)) #jump intensity
            lambd_t.clamp_(0.,1000)
            #lambd_t=  F.relu(lambd_t.unsqueeze(1)) /(sigma*t*(1-t)+rho) 

            mean = out[:,1] #mean
            sig = torch.exp(out[:,2])
            
           
            z = mean.unsqueeze(1) + sig.unsqueeze(1)*torch.randn(no_samples,1, device = device)
            rt = torch.exp(-lambd_t*h)
            
            m = torch.bernoulli(1 - rt)
            x = (1 - m) * x + m * z
            
            #x=torch.clamp(x,0,1000)
    return traj

def random_times(no_timesteps, subsample_time, random_seed = None, device="cuda", equidist="False"):
    if equidist=="True":
        jump_every = no_timesteps//subsample_time
        sorted_perm = torch.arange(0, no_timesteps, device=device)[::jump_every]
    else:
        if random_seed is not None:
            torch.manual_seed(random_seed)
        perm = torch.randperm(no_timesteps, device=device)[:subsample_time-1]#[:subsample_time-1]
        sorted_perm = perm.sort().values
    return sorted_perm
