import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

seed = 47
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)


def initialize_T(N, scale=1.0):
    # Define the positions of the neurons and the astrocyte processes in a 2D grid
    positions = torch.linspace(0, 1, N).repeat(N, 1)
    x_positions = positions.flatten().unsqueeze(1)
    y_positions = positions.t().flatten().unsqueeze(1)
    coordinates = torch.cat((x_positions, y_positions), dim=1)

    # Calculate the Euclidean distance between each pair of neurons
    distances = torch.cdist(coordinates, coordinates, p=2).to(device)

    # Initialize T with values that decrease with increasing distance
    T = torch.exp(-distances * scale).reshape(N, N, N, N).to(device)
    return T


def synaptic_dynamics(s, tau_s, alpha, phi_i, phi_j, psi_ij, c_ij, p):
    dsdt = (-alpha * s + torch.tensordot(phi_i, phi_j, dims=0) + psi_ij(p) + c_ij) / tau_s
    return dsdt


def neural_dynamics_torch(x, t, tau_n, lam, g, b, phi, threshold, reset_voltage, s):
    dxdt = torch.empty_like(x)
    spikes = torch.zeros_like(x)
    for i in range(x.shape[0]):
        sum_term = torch.sum(g(s[i]) * phi(x))
        dxdt[i] = (-lam * (x[i]) + sum_term + b[i]) / tau_n
        if x[i] >= threshold:  # Spike condition
            x[i] = reset_voltage  # Reset voltage after spike
            spikes[i] = 1
    return dxdt, spikes


def astrocyte_process_dynamics_stp(p, tau_p, gamma, T, eta, s, kappa, d_ij, spikes):
    dpdt = (-gamma * p + torch.tensordot(T, kappa(p), dims=2) * spikes + d_ij) / tau_p
    return dpdt


def astrocyte_process_dynamics_ltp(p, tau_p, gamma, T, eta, s, kappa, d_ij, spikes):
    dpdt = (-gamma * p + eta(s) + d_ij) / tau_p
    return dpdt


# Example parameters (converted to PyTorch tensors and moved to GPU)
N = 3  # Number of neurons
seg = 6  # Number of segments
timescale = 300  # total time for simulation
dt = 0.04  # Time step size

# Initialize T with spatial distance consideration
T = initialize_T(N, scale=2)
# T = 0.1 * torch.ones(N, N, N, N, device=device)

tau_n = torch.tensor(0.5, device=device)  # Characteristic time scale
lam = torch.tensor(0.2, device=device)  # Leak rate
b = 0.000 * torch.ones(N, device=device)  # Input bias for each neuron
g = torch.relu  # Synaptic weights (sigmoid activation of s_ij)
threshold = 1  # Firing threshold for neurons
reset_voltage = -1  # Reset voltage for neurons
phi = torch.sigmoid

# Example parameters for synaptic dynamics (converted to PyTorch tensors and moved to GPU)
tau_s = torch.tensor(0.75, device=device)  # Timescale of synaptic dynamics
alpha = torch.tensor(0.25, device=device)  # Leak rate of synaptic facilitation
phi_i = torch.tanh  # Example nonlinearity function phi_i
phi_j = torch.tanh  # Example nonlinearity function phi_j
psi_ij = torch.tanh  # Concentration of intracellular Ca2+ ions
c_ij = torch.zeros(N, N, device=device)  # Bias controlling baseline rate of synaptic facilitation

# Example parameters for astrocyte process dynamics (converted to PyTorch tensors and moved to GPU)
tau_p_stp = torch.tensor(1, device=device)  # Timescale of astrocyte process dynamics in short term potentiation
tau_p_ltp = torch.tensor(6, device=device)  # Timescale of astrocyte process dynamics in long term potentiation
gamma_stp = torch.tensor(0.2, device=device)  # Decay rate of intracellular Ca2+ concentration in STP
gamma_ltp = torch.tensor(0.1, device=device)  # Decay rate of intracellular Ca2+ concentration in LTP
eta = torch.tanh  # Influx of Ca2+ from synaptic cleft
kappa = torch.sigmoid  # Non-linearity induced to astrocytic process
d_ij = 0.0 * torch.randn(N, N,
                         device=device)  # A constant bias term: sets the “tone” of the astrocyte (Neuromodulation).

# Time vector and step size (converted to PyTorch tensors and moved to GPU)
t_span = torch.tensor([0.0, timescale], device=device)  # Simulation time span
reset_time = timescale / seg
num_steps = int((t_span[1] - t_span[0]) / dt)

# Initial conditions (converted to PyTorch tensors and moved to GPU)
x0 = 0.1 * torch.randn(N, device=device, requires_grad=True)  # Initial membrane voltages
s0 = 0.1 * torch.randn(N, N, device=device)  # Initial level of synaptic facilitation
p0 = 0.1 * torch.randn(N, N, device=device)  # Initial intracellular concentration of Ca2+ inside processes

# Integrate synaptic, neural, and astrocyte process dynamics using Euler's method
s_solution = torch.zeros((num_steps, N, N), device=device)
x_solution = torch.zeros((num_steps, N), device=device)
p_solution_stp = 0.1 * torch.randn((num_steps, N, N), device=device)
p_solution_ltp = p_solution_stp.clone()
spike_counts = torch.zeros(N, device=device)  # Spike counts for each neuron

s_solution[0] = s0
x_solution[0] = x0
p_solution_stp[0] = p0
p_solution_ltp[0] = p0

for i in range(1, num_steps):
    # Check reset time
    t = t_span[0] + i * dt
    '''if t == 0:
        dpdt_ltp = astrocyte_process_dynamics_ltp(p_solution_ltp[i - 1], tau_p_ltp, gamma, T, eta, s_solution[i], kappa,
                                                  d_ij, spikes)
        p_solution_ltp[i] = p_solution_ltp[i - 1] + dpdt_ltp * dt
        p_solution_stp[i] = 0
        continue'''
    reset_times = torch.arange(1, seg, 1.0, device=device) * reset_time
    if t in reset_times:
        torch.manual_seed(t)
        torch.cuda.manual_seed(t)
        s_solution[i - 1] = 0.1 * torch.randn(N, N, device=device)
        p_solution_stp[i - 1] = 0.1 * torch.randn(N, N, device=device)
        x_solution[i - 1] = 0.1 * torch.randn(N, device=device, requires_grad=True)

    # Compute synaptic dynamics
    dsdt = synaptic_dynamics(s_solution[i - 1], tau_s, alpha, phi_i(x_solution[i - 1]), phi_j(x_solution[i - 1]),
                             psi_ij, c_ij, p_solution_stp[i - 1]  # + p_solution_ltp[i - 1]
                             )
    s_solution[i] = s_solution[i - 1] + dsdt * dt

    # Compute neural dynamics
    dxdt, spikes = neural_dynamics_torch(x_solution[i - 1], t_span[0] + i * dt, tau_n, lam, g, b, phi, threshold,
                                         reset_voltage, s_solution[i])
    x_solution[i] = x_solution[i - 1] + dxdt * dt
    spike_counts += spikes

    # Compute astrocyte process dynamics (stp)
    dpdt_stp = astrocyte_process_dynamics_stp(p_solution_stp[i - 1], tau_p_stp, gamma_stp, T, eta, s_solution[i], kappa,
                                              d_ij, spikes)
    p_solution_stp[i] = p_solution_stp[i - 1] + dpdt_stp * dt

    # Compute astrocyte process dynamics (ltp)
    dpdt_ltp = astrocyte_process_dynamics_ltp(p_solution_ltp[i - 1], tau_p_ltp, gamma_ltp, T, eta, s_solution[i],
                                              kappa, d_ij, spikes)
    p_solution_ltp[i] = p_solution_ltp[i - 1] + dpdt_ltp * dt

# Calculate Average Spiking Rate (ASR) for each neuron
ASR = spike_counts / (t_span[1] - t_span[0])  # Mean spike count divided by simulation duration

# Visualize synaptic facilitation dynamics
plt.figure(figsize=(10, 6))
plt.subplot(3, 2, 1)
for i in range(N):
    for j in range(N):
        plt.plot(torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
                 s_solution[:, i, j].cpu().detach(),
                 label=f'Synapse ({i}, {j})')
plt.title('Synaptic Facilitation Dynamics')
plt.xlabel('Time')
plt.ylabel('Synaptic Facilitation')
# plt.legend()

# Visualize membrane voltage dynamics
plt.subplot(3, 2, 2)
for i in range(N):
    plt.plot(torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(), x_solution[:, i].cpu().detach(),
             label=f'Neuron {i}')
plt.title('Membrane Voltage Dynamics of Neurons')
plt.xlabel('Time')
plt.ylabel('Membrane Voltage')
plt.legend()

# Visualize astrocyte process dynamics
plt.subplot(3, 2, 3)
for i in range(N):
    for j in range(N):
        plt.plot(torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
                 p_solution_stp[:, i, j].cpu().detach(),
                 label=f'Process ({i}, {j})')
plt.title('STP Astrocyte Processes')
plt.xlabel('Time')
plt.ylabel('STP Response')
# plt.legend()

# Visualize astrocyte process dynamics
plt.subplot(3, 2, 4)
for i in range(N):
    for j in range(N):
        plt.plot(torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
                 p_solution_ltp[:, i, j].cpu().detach(),
                 label=f'Process ({i}, {j})')
plt.title('LTP Astrocyte Processes')
plt.xlabel('Time')
plt.ylabel('LTP Response')
# plt.legend()

# Visualize spike counts
plt.subplot(3, 2, 5)
plt.bar(torch.arange(N).cpu(), spike_counts.cpu().detach(), label='Spike Counts')
plt.title('Spike Counts of Neurons')
plt.xlabel('Neuron Index')
plt.ylabel('Spike Count')
plt.xticks(torch.arange(0, N, 10))
plt.axhline(y=spike_counts.mean().item(), color='r', linestyle='--', label='Average Spiking Rate')
plt.legend()
plt.show()

'''# Combined Plot for Synaptic Facilitation and LTP Astrocyte Dynamics
plt.figure(figsize=(10, 6))
for i in range(N):
    for j in range(N):
        plt.plot(torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
                 s_solution[:, i, j].cpu().detach(), linestyle='-', label=f'Synapse ({i}, {j})')
        plt.plot(torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
                 p_solution_ltp[:, i, j].cpu().detach(), linestyle='--', label=f'LTP Astrocyte Process ({i}, {j})')
plt.title('Synaptic Facilitation and LTP Astrocyte Dynamics', fontsize=16)
plt.xlabel('Time', fontsize=14)
plt.ylabel('Dynamics', fontsize=14)
plt.legend(fontsize=10)
plt.grid(True)

plt.tight_layout()
plt.show()'''

import matplotlib.pyplot as plt

# Combined Plot for Synaptic Facilitation and LTP Astrocyte Dynamics
plt.figure(figsize=(10, 6))

# Use a colormap for different colors
colors = plt.cm.viridis(torch.linspace(0, 1, N * N))  # Viridis is a good default colormap

for i in range(N):
    for j in range(N):
        color = colors[i * N + j]  # Choose a different color for each synapse and process

        # Plot Synaptic Facilitation (s_{ij})
        plt.plot(
            torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
            s_solution[:, i, j].cpu().detach(),
            color=color, linestyle='-', linewidth=2, label=fr'$s_{{{i}{j}}}$'
        )

        # Plot LTP Astrocyte Process (p_{ij}^l)
        plt.plot(
            torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu(),
            p_solution_ltp[:, i, j].cpu().detach(),
            color=color, linestyle='--', linewidth=2, label=fr'$p_{{{i}{j}}}^l$'
        )

# Titles and labels with bigger fonts
plt.title(r'Synaptic Facilitation $s_{ij}$ and LTP Astrocyte Dynamics $p_{ij}^l$', fontsize=18, fontweight='bold')
plt.xlabel('Time (s)', fontsize=14, fontweight='bold')
plt.ylabel('Arbitrary unit', fontsize=14, fontweight='bold')

# Customize tick labels to make them bold and larger
plt.xticks(fontsize=12, fontweight='bold')
plt.yticks(fontsize=12, fontweight='bold')

# Add a grid for better readability
plt.grid(True, linestyle='--', alpha=0.6)

# Add arrow for LTP from 0s to 300s (with a gap from the line)
plt.annotate('', xy=(300, 9), xytext=(0, 9),  # Moved the arrow above the line
             arrowprops=dict(arrowstyle='<->', color='red', linewidth=2, shrinkA=5, shrinkB=5))
plt.text(150, 9.2, 'LTP', fontsize=14, fontweight='bold', color='red', ha='center')

# Add arrows for STP for each 50s interval (with a gap from the line)
for start, end in zip(range(0, 300, 50), range(50, 350, 50)):  # Create intervals 0-50, 50-100, etc.
    plt.annotate('', xy=(end, -0.5), xytext=(start, -0.5),  # Moved the arrow above the line
                 arrowprops=dict(arrowstyle='<->', color='blue', linewidth=2, shrinkA=5, shrinkB=5))
    plt.text((start + end) / 2, -0.3, f'STP', fontsize=12, fontweight='bold', color='blue', ha='center')

# Adjust the legend for better visualization
plt.legend(fontsize=10, loc='upper right', bbox_to_anchor=(1.1, 1))  # Adjusted position

# Make the layout tight and well-spaced
plt.tight_layout()

# Save the plot as a PDF
plt.savefig('synaptic-and-ltp-astrocyte-dynamics-1.pdf', format='pdf')

# Show the plot
plt.show()

import matplotlib.pyplot as plt

# Assuming `t_span[1]` is the total time and `dt` is the time step size
time_vector = torch.linspace(t_span[0], t_span[1], num_steps, device=device).cpu()

# Get the indices corresponding to the time window from 8 seconds to 10 seconds
start_time = 26
end_time = 30
time_indices = (time_vector >= start_time) & (time_vector <= end_time)  # Boolean mask for the time range 8-10 seconds

# Plot the first 50 seconds of neural dynamics
plt.figure(figsize=(10, 6))

# Loop through each neuron and plot its dynamics for the first 50 seconds
for i in range(N):
    plt.plot(time_vector[time_indices], x_solution[time_indices, i].cpu().detach(), label=f'Neuron {i}')

# Add titles and labels
plt.title('Neural Dynamics', fontsize=18, fontweight='bold')
plt.xlabel('Time (s)', fontsize=14, fontweight='bold')
plt.ylabel('Membrane Voltage', fontsize=14, fontweight='bold')

# Customize tick labels to make them bold and larger
plt.xticks(fontsize=12, fontweight='bold')
plt.yticks(fontsize=12, fontweight='bold')

# Add a grid for better readability
plt.grid(True, linestyle='--', alpha=0.6)

# Add legend
plt.legend(fontsize=10)

# Make the layout tight and well-spaced
plt.tight_layout()

# Show the plot
# plt.show()
