import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tick.hawkes import SimuHawkesExpKernels
from tick.plot import plot_point_process




# Function to create a custom alpha matrix based on a causal graph
# def create_custom_alpha(l, non_zero_indices, alpha_min=0.4, alpha_max=0.8):
def create_custom_alpha(l, non_zero_indices, alpha_min=0.4, alpha_max=0.8):
    """
    l: Number of processes.
    non_zero_indices: List of tuples [(i, j), ...] where alpha[i, j] is non-zero, that is, j->i exists.
    alpha_min, alpha_max: Range for random non-zero values.
    """
    alpha = np.zeros((l, l))  # Initialize all elements to zero
    for i, j in non_zero_indices:
        alpha[i, j] = np.round(np.random.uniform(alpha_min, alpha_max),3)  # Assign random non-zero values
        # alpha[i,j] = ToBij()
        # alpha[0,7] = 0.3
        print("The coefficient alpha["+str(i)+","+str(j)+"] is",str(alpha[i, j]))
    return alpha


# Function to calculate Phi and check spectral radius
def compute_Phi_and_check(alpha, beta):
    Phi = alpha / beta  # Compute Phi matrix
    eigvals = np.linalg.eigvals(Phi)  # Eigenvalues of Phi
    spectral_radius = max(np.abs(eigvals))  # Spectral radius
    return Phi, spectral_radius



# Function to simulate the multivariate Hawkes process
def simulate_hawkes_process(mu, alpha, beta, l, run_time=100):

    n_nodes = l  # dimension of the Hawkes process
    adjacency = alpha
    decays = beta * np.ones((n_nodes, n_nodes))
    baseline = mu
    hawkes = SimuHawkesExpKernels(adjacency=adjacency, decays=decays,
                                baseline=baseline, verbose=False)
    hawkes.end_time = run_time
    dt = 0.01
    hawkes.reset()
    hawkes.track_intensity(dt)
    hawkes.simulate()

    return hawkes.timestamps





def get_discrete_time_series_from_hawkes(hawkes_timestamps, time_interval, total_time):
    """
    Discretize multi-dimensional continuous-time point data into a discrete time series.

    Parameters:
    hawkes_timestamps (list of lists): A list where each sublist contains timestamps for a specific dimension.
    time_interval (float): The fixed time interval for discretization.
    total_time (float): The total duration for the time series.

    Returns:
    np.ndarray: A 2D numpy array where each row corresponds to a dimension and each column represents a time interval.
                Each entry contains the count of events in that interval.
    """
    # Calculate the number of intervals
    num_intervals = int(np.ceil(total_time / time_interval))
    
    # Initialize a 2D numpy array to store the discrete time series
    num_dimensions = len(hawkes_timestamps)
    discrete_time_series = np.zeros((num_dimensions, num_intervals), dtype=int)
    
    # Iterate over each dimension
    for dim, timestamps in enumerate(hawkes_timestamps):
        for timestamp in timestamps:
            if timestamp < total_time:
                # Determine the corresponding interval index
                interval_index = int(timestamp // time_interval)
                discrete_time_series[dim, interval_index] += 1
    
    data = pd.DataFrame(discrete_time_series.T, columns=[f'O_{i+1}' for i in range(len(hawkes_timestamps))])
    
    return data







def generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices):   
    """
    Parameters:
    l is the dimension of the Hawkes process
    """
    # Create the custom alpha matrix
    alpha = create_custom_alpha(l, non_zero_indices)
    # Calculate Phi and check the spectral radius
    Phi, spectral_radius = compute_Phi_and_check(alpha, beta)
    # Ensure stability condition is met
    while spectral_radius >= 1:
        # Regenerate alpha while keeping the same causal graph
        alpha = create_custom_alpha(l, non_zero_indices)
        Phi, spectral_radius = compute_Phi_and_check(alpha, beta)
    
  
    # Simulate the Hawkes process
    hawkes_timestamps = simulate_hawkes_process(mu, alpha, beta, l, run_time)
    # print(len(hawkes_timestamps))
    observed_data = [hawkes_timestamps[i] for i in observed_indices]
    
    data = get_discrete_time_series_from_hawkes(observed_data, time_interval, run_time)
    data = (data-data.mean())/data.std()
    # data = data-data.mean()
        
    return data  
  
  
  
def Case1(run_time=5000):
    # Set parameters
    # run_time = 100
    time_interval = 0.1
    l = 3  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2: non_zero_indices = [(0, 1), (2, 1)]
    non_zero_indices = [(0, 1), (2, 1), (1,1)]
    # list of indices of observed component processes
    observed_indices = [0,1,2] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data  

  

def Case2(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 5  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (3,2), (3,3), (4, 2), (4,4)]
    # list of indices of observed component processes
    observed_indices = [0,1,3,4] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data  
  






def Case6(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 5  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3), (4,3), (4, 4)]
    # non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3)]
    # list of indices of observed component processes
    observed_indices = [0,1,3,4] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data 



def Case7(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 5  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (4,2), (3,3), (4,3), (4, 4)]
    # non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3)]
    # list of indices of observed component processes
    observed_indices = [0,1,3,4] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data 




def CaseFig1b(run_time=5000):
    # Set parameters
    # run_time = 100
    time_interval = 0.1
    l = 3  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2: non_zero_indices = [(0, 1), (2, 1)]
    non_zero_indices = [(0,0), (0, 1), (1,1), (1,2), (2, 1), (2,2)]
    # list of indices of observed component processes
    observed_indices = [0,1,2] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data  



def CaseFig2a(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 5  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,2), (1, 2), (2,2), (3,2), (3,3), (4, 2), (4,4)]
    # list of indices of observed component processes
    observed_indices = [0,1,3,4] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data 



def CaseFig4a(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 5  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3), (4,3), (4, 4)]
    # non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3)]
    # list of indices of observed component processes
    observed_indices = [0,1,3,4] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data 





def CaseFig4b(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 5  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (4,2), (3,3), (4,3), (4, 4)]
    # non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3)]
    # list of indices of observed component processes
    observed_indices = [0,1,3,4] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data 




def CaseFig4c(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 8  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    non_zero_indices = [(0,0), (0,2), (2,4), (4,4), (1,1), (1,3), (3,4), (4,4), (6,5), (6,6), (7,7), (7,5), (5,4)]
    # non_zero_indices = [(0,2), (0,0), (1, 2), (1,1), (2,2), (2,3), (3,3)]
    # list of indices of observed component processes
    observed_indices = [0,1,6,7] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data 




def CaseFig4d(run_time=5000):
    # Set parameters
    time_interval = 0.1
    l = 8  # Number of processes
    beta = 1.0  # Decay rate of the exponential excitation function
    mu = np.random.uniform(15, 25, l)  # Background intensities
    # Define a custom causal graph (indices of non-zero elements in the alpha matrix)
    # Example: Process 1 influences 0, and Process 1 influences 2
    # non_zero_indices = [(0,0), (0,2), (1,1), (1,2), (2,3), (3,3), (3,4), (4,4), (5,3), (6,5), (6,6), (7,5), (7,7)]
    non_zero_indices = [(0,0), (0,2), (1,1), (1,2), (2,3), (3,3), (3,4), (4,4), (5,3), (6,5), (6,6), (7,5), (7,7)]
    # list of indices of observed component processes
    observed_indices = [0,1,6,7] # Example: [0,2] the first and third component processes are observed.
    data = generate_simulation(run_time, time_interval, l, beta, mu, non_zero_indices, observed_indices)
    return data











# # Example usage
# if __name__ == "__main__":
#     hawkes_timestamps = [
#         [0.1, 0.3, 1.5, 2.0],  # Timestamps for dimension 1
#         [0.2, 0.4, 1.7],       # Timestamps for dimension 2
#     ]
#     time_interval = 0.5
#     total_time = 3.0

#     discrete_time_series = get_discrete_time_series_from_hawkes(hawkes_timestamps, time_interval, total_time)
#     print(discrete_time_series)








