#!/usr/bin/env python

from collections import deque

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.gridspec import GridSpec

from scipy.spatial import Delaunay

from arena import BoxArena, CircleArena, TMaze, AMaze
from model import \
        Neuron, NeuronState, TagState, Simulation, Coord, Event, EventType, Message, \
        MessageType, SimulationRecorder

from plotting import make_contour_from_recorder


def remove_diagonal(M):
    return M[~np.eye(M.shape[0],dtype=bool)].reshape(M.shape[0],-1)


def pairwise_sqdist(xs0, ys0, xs1, ys1):
    return (xs0[:, np.newaxis] - xs1)**2 \
         + (ys0[:, np.newaxis] - ys1)**2


def check_pairwise_distance(xs, ys, min_dist):
    sq_dist = (xs[:, np.newaxis] - xs)**2 \
            + (ys[:, np.newaxis] - ys)**2
    sq_dist = remove_diagonal(sq_dist)
    return np.all(sq_dist >= min_dist**2, axis=1)


def gen_random_coordinates(arena, N, min_dist=0.0):
    bbox = arena.get_bbox()

    accepted_xs = np.array([])
    accepted_ys = np.array([])

    while len(accepted_xs) < N:
        n = N - len(accepted_xs)
        xs = np.random.uniform(low=bbox[0][0], high=bbox[0][1], size=(n,))
        ys = np.random.uniform(low=bbox[1][0], high=bbox[1][1], size=(n,))

        dist_ok = check_pairwise_distance(np.append(accepted_xs, xs), np.append(accepted_ys, ys), min_dist)
        dist_ok = dist_ok[len(accepted_xs):]

        inside = arena.is_inside(xs, ys)
        valid = inside & dist_ok

        accepted_xs = np.append(accepted_xs, xs[valid])
        accepted_ys = np.append(accepted_ys, ys[valid])

    return np.column_stack((accepted_xs, accepted_ys))


def get_transition_matrix(locs, center_dist, surround_dist):
    # get all pairwise distances to build connectivity/transition matrix
    dists = np.sqrt(pairwise_sqdist(locs[:, 0], locs[:, 1], locs[:, 0], locs[:, 1]))
    mask  = (center_dist < dists) & (dists < surround_dist)

    # build transition matrix (with 0 on diagonal)
    T = np.zeros_like(dists)
    T[mask] = 1.0
    T[np.eye(T.shape[0], dtype=bool)] = 0
    return T


def delaunay_connectivity(coordinates, max_dist = None):
    """Create Delaunay triangulation connectivity."""
    tri = Delaunay(coordinates)

    # Extract edges from triangulation
    edges = set()
    for simplex in tri.simplices:
        for i in range(3):
            edge = tuple(sorted([simplex[i], simplex[(i+1)%3]]))

            if max_dist:
                d = np.sqrt(np.sum((coordinates[edge[0], :] - coordinates[edge[1], :])**2))
                if d > max_dist:
                    continue

            edges.add(edge)

    return list(edges)


def gabriel_graph(coordinates, max_dist = None):
    """Create Gabriel graph connectivity."""
    n = len(coordinates)

    # Start with Delaunay triangulation
    tri = Delaunay(coordinates)
    delaunay_edges = set()
    for simplex in tri.simplices:
        for i in range(3):
            edge = tuple(sorted([simplex[i], simplex[(i+1)%3]]))
            delaunay_edges.add(edge)

    # Filter to Gabriel graph edges
    gabriel_edges = []
    for i, j in delaunay_edges:
        # Check if any other point is inside the circle with diameter i-j
        center = (coordinates[i] + coordinates[j]) / 2
        radius_sq = np.sum((coordinates[i] - coordinates[j])**2) / 4

        is_gabriel = True
        for k in range(n):
            if k != i and k != j:
                dist_to_center_sq = np.sum((coordinates[k] - center)**2)
                if dist_to_center_sq < radius_sq:
                    is_gabriel = False
                    break

        if is_gabriel:
            if max_dist:
                d = np.sqrt(np.sum((coordinates[edge[0], :] - coordinates[edge[1], :])**2))
                if d > max_dist:
                    continue

            gabriel_edges.append((i, j))

    return gabriel_edges


def bfs_distances(T, source):
    """Run BFS on adjacency matrix T, return shortest distances from source."""
    n = T.shape[0]
    dist = np.full(n, np.inf)
    dist[source] = 0
    q = deque([source])

    while q:
        u = q.popleft()
        for v in np.where(T[u] > 0)[0]:
            if dist[v] == np.inf:  # not visited
                dist[v] = dist[u] + 1
                q.append(v)
    return dist


def shortest_path_vertices(T, start, end):
    """
    Return the set of vertices lying on *some* shortest path
    between start and end in unweighted graph T.
    """
    dist_from_start = bfs_distances(T, start)
    dist_from_end   = bfs_distances(T, end)
    shortest_length = dist_from_start[end]

    if np.isinf(shortest_length):
        return set()  # no path exists

    vertices_on_paths = [
        v for v in range(T.shape[0])
        if dist_from_start[v] + dist_from_end[v] == shortest_length
    ]
    return vertices_on_paths


def match_quality(ground_truth, predicted):
    """
    Compute match quality between two lists of elements.
    Ignores duplicates and order.
    Returns precision, recall, f1, jaccard.
    """
    gt_set = set(ground_truth)
    pred_set = set(predicted)

    intersection = gt_set & pred_set
    union = gt_set | pred_set

    precision = len(intersection) / len(pred_set) if pred_set else 0.0
    recall    = len(intersection) / len(gt_set) if gt_set else 0.0
    f1        = (2 * precision * recall / (precision + recall)) if (precision + recall) else 0.0
    jaccard   = len(intersection) / len(union) if union else 0.0

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "jaccard": jaccard
    }

def edges_to_transition_matrix(edges, n_nodes):
    """
    Convert edge list to transition matrix.

    Parameters:
    -----------
    edges : list of tuples
        List of (i, j) edge pairs from delaunay_connectivity() or gabriel_graph()
    n_nodes : int
        Number of nodes (len(coordinates))

    Returns:
    --------
    T : numpy.ndarray
        Transition matrix with 1.0 for connected nodes, 0.0 elsewhere
        Diagonal is 0 (no self-connections)
    """
    T = np.zeros((n_nodes, n_nodes))

    for i, j in edges:
        T[i, j] = 1.0
        T[j, i] = 1.0  # Make symmetric (undirected graph)

    # Ensure diagonal is 0 (no self-connections)
    np.fill_diagonal(T, 0.0)

    return T


def plot_graph_edges(ax, coordinates, edges, color='blue', alpha=0.8, linewidth=1, label=None, show_points=True, point_size=50, point_color='red'):
    """
    Plot graph edges on the given axis.

    Parameters:
    -----------
    ax : matplotlib.axes.Axes
        The axis object to plot on
    coordinates : numpy.ndarray
        Array of shape (N, 2) with x, y coordinates
    edges : list of tuples
        List of (i, j) edge pairs
    color : str
        Color for the edges
    alpha : float
        Transparency of edges
    linewidth : float
        Width of edge lines
    label : str or None
        Label for legend
    show_points : bool
        Whether to show the coordinate points
    point_size : float
        Size of the points if shown
    point_color : str
        Color of the points
    """
    # Plot edges
    for idx, (i, j) in enumerate(edges):
        ax.plot([coordinates[i, 0], coordinates[j, 0]],
                [coordinates[i, 1], coordinates[j, 1]],
                color=color, alpha=alpha, linewidth=linewidth,
                label=label if idx == 0 and label is not None else "")

    # Plot points on top
    if show_points:
        ax.scatter(coordinates[:, 0], coordinates[:, 1],
                  c=point_color, s=point_size, zorder=5)


def plot_connectivity_scatter(ax, coords, T, selected_idx=0, figsize=(8, 6), xlims=None, ylims=None, xpad=0.1, ypad=0.1):
    """
    Plot locations with color-coded connectivity from a selected location.

    Args:
        locs: array of shape (N, 2) with x, y coordinates
        T: transition matrix of shape (N, N)
        selected_idx: index of the location to show connections from
        figsize: figure size tuple
    """
    x = coords[:, 0]
    y = coords[:, 1]

    if xlims is None:
        x_min, x_max = np.min(x), np.max(x)
    else:
        x_min, x_max = xlims

    if ylims is None:
        y_min, y_max = np.min(y), np.max(y)
    else:
        y_min, y_max = ylims

    print(x_min, x_max, y_min, y_max)

    # Add some padding around the data
    x_pad = (x_max - x_min) * xpad
    y_pad = (y_max - y_min) * ypad

    # Get connections from the selected location
    connections = T[selected_idx, :]
    connected_mask = connections > 0

    # Create colors: red for selected, blue for connected, gray for unconnected
    colors = np.full(len(coords), 'lightgray', dtype=object)
    colors[connected_mask] = 'blue'
    colors[selected_idx] = 'red'

    # Create sizes: larger for selected and connected locations
    sizes = np.full(len(coords), 20)
    sizes[connected_mask] = 40
    sizes[selected_idx] = 80

    # Plot scatter
    ax.scatter(coords[:, 0], coords[:, 1], c=colors, s=sizes, alpha=0.7, edgecolors='black', linewidths=0.5)
    ax.set_xlim(x_min - x_pad, x_max + x_pad)
    ax.set_ylim(y_min - y_pad, y_max + y_pad)

    # Add title and legend
    ax.set_title(f'Connectivity from Location {selected_idx} (red)\nBlue: Connected, Gray: Not connected')
    ax.set_aspect('equal')


def plot_connectivity_with_lines(locs, T, selected_idx=0, figsize=(8, 6)):
    """
    Plot locations with lines showing connections from a selected location.

    Args:
        locs: array of shape (N, 2) with x, y coordinates
        T: transition matrix of shape (N, N)
        selected_idx: index of the location to show connections from
        figsize: figure size tuple
    """
    fig, ax = plt.subplots(figsize=figsize, layout='constrained')

    # Get connections from the selected location
    connections = T[selected_idx, :]
    connected_indices = np.where(connections > 0)[0]

    # Plot all locations in light gray
    ax.scatter(locs[:, 0], locs[:, 1], c='lightgray', s=30, alpha=0.5, zorder=1)

    # Plot connected locations in blue
    if len(connected_indices) > 0:
        ax.scatter(locs[connected_indices, 0], locs[connected_indices, 1],
                  c='blue', s=50, alpha=0.8, zorder=2)

        # Draw lines from selected location to connected locations
        for idx in connected_indices:
            ax.plot([locs[selected_idx, 0], locs[idx, 0]],
                   [locs[selected_idx, 1], locs[idx, 1]],
                   'blue', alpha=0.3, linewidth=1, zorder=0)

    # Plot selected location in red (on top)
    ax.scatter(locs[selected_idx, 0], locs[selected_idx, 1],
              c='red', s=100, alpha=1.0, zorder=3, edgecolors='black', linewidths=1)

    ax.set_title(f'Connectivity from Location {selected_idx}\nRed: Selected, Blue: Connected (with lines)')
    ax.set_aspect('equal')

    return fig, ax


def plot_connectivity_heatmap(locs, T, selected_idx=0, figsize=(8, 6)):
    """
    Plot locations with color intensity based on connectivity strength.
    Useful if T contains weighted connections rather than just 0s and 1s.

    Args:
        locs: array of shape (N, 2) with x, y coordinates
        T: transition matrix of shape (N, N)
        selected_idx: index of the location to show connections from
        figsize: figure size tuple
    """
    fig, ax = plt.subplots(figsize=figsize, layout='constrained')

    # Get connection strengths from the selected location
    connection_strengths = T[selected_idx, :]

    # Create a colormap where 0 = white/light, higher values = darker blue
    scatter = ax.scatter(locs[:, 0], locs[:, 1], c=connection_strengths,
                        s=50, cmap='Blues', alpha=0.8, edgecolors='black', linewidths=0.5)

    # Highlight the selected location in red
    ax.scatter(locs[selected_idx, 0], locs[selected_idx, 1],
              c='red', s=100, alpha=1.0, zorder=3, edgecolors='black', linewidths=1)

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label('Connection Strength')

    ax.set_title(f'Connectivity Heatmap from Location {selected_idx} (red)')
    ax.set_aspect('equal')

    return fig, ax



def pick_random(coords, unit_position, k = 10):
    coords = np.array(coords)

    # Convert unit coordinates to actual coordinate space
    x_min, x_max = np.min(coords[:, 0]), np.max(coords[:, 0])
    y_min, y_max = np.min(coords[:, 1]), np.max(coords[:, 1])

    # Map unit position to actual coordinates
    target_x = x_min + unit_position[0] * (x_max - x_min)
    target_y = y_min + unit_position[1] * (y_max - y_min)

    # Calculate distances to all points
    distances = np.sqrt((coords[:, 0] - target_x)**2 + (coords[:, 1] - target_y)**2)

    # Find k closest points
    k = min(k, len(coords))  # Don't ask for more points than we have
    closest_indices = np.argsort(distances)[:k]

    # Randomly pick one of the k closest
    chosen_idx = np.random.choice(closest_indices)

    return coords[chosen_idx], chosen_idx


#def angle_between_points(center, point1, point2, degrees=True):
#    """
#    Compute the angle between two points relative to a central point.
#
#    Parameters:
#    -----------
#    center : array-like, shape (2,)
#        Coordinates of the central point
#    point1 : array-like, shape (2,)
#        Coordinates of the first point
#    point2 : array-like, shape (2,)
#        Coordinates of the second point
#    degrees : bool, default True
#        If True, return angle in degrees; if False, return in radians
#
#    Returns:
#    --------
#    angle : float
#        Angle between the two vectors (0 to 180 degrees/π radians)
#    """
#
#    center = np.array(center)
#    point1 = np.array(point1)
#    point2 = np.array(point2)
#
#    # Compute vectors from center to each point
#    vector1 = point1 - center
#    vector2 = point2 - center
#
#    # Compute dot product
#    dot_product = np.dot(vector1, vector2)
#
#    # Compute magnitudes
#    magnitude1 = np.linalg.norm(vector1)
#    magnitude2 = np.linalg.norm(vector2)
#
#    # Handle edge cases
#    if magnitude1 == 0 or magnitude2 == 0:
#        return 0.0  # If any vector has zero length, angle is undefined/zero
#
#    # Compute cosine of angle
#    cos_angle = dot_product / (magnitude1 * magnitude2)
#
#    # Clamp to valid range to handle floating point errors
#    cos_angle = np.clip(cos_angle, -1.0, 1.0)
#
#    # Compute angle
#    angle_rad = np.arccos(cos_angle)
#
#    if degrees:
#        return np.degrees(angle_rad)
#    else:
#        return angle_rad


def main(arena='amaze', transition_mode = 'tss', figsize = (13, 3), **kwargs):
    match arena:
        case 'box':    arena = BoxArena([-1, 1], [-1, 1])
        case 'circle': arena = CircleArena([0.0, 0.0], 1.0)
        case 'tmaze':  arena = TMaze()
        case 'amaze':  arena = AMaze()
        case _:        raise ValueError(f"Unknown arena type '{arena}'")

    n_neurons     = 1000
    min_dist      = 0.01
    center_dist   = 0.05
    surround_dist = 0.15

    # get random locations, required to build a transition matrix, i.e. local
    # connectivity from place cell to place cell. This is effectively what the
    # TSS says would be represented by grid cells.
    coords = gen_random_coordinates(arena, n_neurons, min_dist)
    match transition_mode:
        case 'tss':
            T = get_transition_matrix(coords, center_dist, surround_dist)

        case 'delaunay':
            edges = delaunay_connectivity(coords, 2 * surround_dist)
            T = edges_to_transition_matrix(edges, n_neurons).T

        case 'gabriel':
            edges = gabriel_graph(coords, surround_dist)
            T = edges_to_transition_matrix(edges, n_neurons)

        case _:
            raise ValueError(f"Unknown transition mode {transition_mode}")


    # general neuron setup
    use_local_inhibition  = kwargs.get('use_local_inhibition',  False)
    use_global_inhibition = kwargs.get('use_global_inhibition', True)

    # construct neurons, and establish connectivity (both excitatory and
    # inhibitory

    neurons = [Neuron(i,
                      coord=Coord(coords[i,0], coords[i,1]),
                      use_local_inhibition=use_local_inhibition,
                      use_global_inhibition=use_global_inhibition) for i in range(n_neurons)]
    for i in range(n_neurons):
        for j in range(n_neurons):
            if j == i:
                continue
            if T[i, j] > 0:
                neurons[i].nbrs.append(j)
                neurons[i].nbrs_inh_local.append(j)

        neurons[i].nbrs_inh_global = list(range(n_neurons))
        neurons[i].nbrs_inh_local.append(i)

        neurons[i].nbrs = list(set(neurons[i].nbrs))
        neurons[i].nbrs_inh_local = list(set(neurons[i].nbrs_inh_local))
        neurons[i].nbrs_inh_global = list(set(neurons[i].nbrs_inh_global))

    n_targets = kwargs.get('n_targets', 1)

    # default targets
    _, start_idx  = pick_random(coords, [0.1, 0.1])
    _, target_idx0 = pick_random(coords, [0.9, 0.9])
    targets = [target_idx0]

    if n_targets >= 2:
        _, target_idx1 = pick_random(coords, [0.9, 0.2])
        targets.append(target_idx1)

    if n_targets >= 3:
        _, target_idx2 = pick_random(coords, [0.3, 0.78])
        targets.append(target_idx2)


    # NOTE: manually tag the target neuron. we hypothesize that this could be
    # achieved via PFC interactions
    # neurons[target_idx0].tag = TagState.TAGGED
    for t in targets:
        neurons[t].tag = TagState.TAGGED

    sim = Simulation()

    # prevent infinite simulation in case of rebound activity
    max_time = 10000.0

    max_rounds = 71
    n_rounds = 0
    recorders               = []
    neuron_last_spike_times = []
    neuron_tags             = []
    target_spike_times      = []
    while not neurons[start_idx].is_tagged and n_rounds < max_rounds:
        # we'll let the selected neuron fire after 10 ms of the simulation

        sim.schedule_event(Event(start_idx, EventType.RECV_MSG, 10.0, Message(-1, MessageType.EXCITATORY)))
        sim.timestamp = 0

        target_stimes = {idx: [] for idx in targets}

        # reset all neurons
        for n in neurons:
            n.state                 = NeuronState.RESTING
            n.last_spike_time       = None
            n.expected_inh_feedback = None
            n.expected_exc_feedback = None
            n.received_inh_feedback = None
            n.received_exc_feedback = None


        # TODO: while start-not-tagged or maxtime not exceeded
        while (event := sim.next_event()) and sim.timestamp <= max_time:
            # update the simulation time to the next event time
            sim.update_timestamp(event.timestamp)

            # each event is assigned to a neuron, so process that neuron
            n = neurons[event.neuron_id]
            n.handle_event(event, sim)

            # check if this is the goal neuron and if yes, initiate global
            # inhibition to all neurons, and reschedule the algorithm for the next
            # round
            if n.id in targets and n.state == NeuronState.SPIKING:
                print(f"Target (index {n.id}, coord {coords[n.id]}) active {sim.timestamp}ms after iteration start")
                target_stimes[n.id].append(sim.timestamp)

        n_rounds += 1

        # swap out the recorder for the next iteration
        recorders.append(sim.recorder)
        sim.recorder = SimulationRecorder()

        # also record neuron state information that's relevant
        neuron_last_spike_times.append([n.last_spike_time for n in neurons])
        neuron_tags.append([n.tag for n in neurons])
        target_spike_times.append(target_stimes)


    # ground truth: Dijkstra
    gt_neurons = shortest_path_vertices(T, start_idx, targets[0])
    print("Ground truth/Dijkstra neurons on shortest path to target 0: \n", gt_neurons)
    tagged_neurons = [n.id for n in neurons if n.tag == TagState.TAGGED]
    print("Algorithm neurons on shortest path (NOTE: only works for single target evaluation): \n", tagged_neurons)
    print("Match quality", match_quality(gt_neurons, tagged_neurons))


    max_plots = 5
    n_plots   = np.unique(np.round(np.linspace(0, len(recorders) - 1, max_plots))).astype(int)

    fig    = plt.figure(figsize=figsize, layout='constrained')
    gspec  = GridSpec(1, len(n_plots), figure=fig)


    for k, idx in enumerate(n_plots):

        #                start    goal    normal    tagged, not active
        colors        = ['blue', 'green', 'orange', 'red', 'gray']
        alphas        = np.empty(n_neurons)
        color_indices = np.empty(n_neurons, dtype=int)
        for i in range(n_neurons):
            # standard color behavior -> there's a neuron
            color_indices[i] = 4
            alphas[i] = 0.15

            # if this neuron was active, great, then we'll show it differently
            if neuron_last_spike_times[idx][i]:
                color_indices[i] = 2
                alphas[i] = 0.45

            # tagged colors
            if neuron_tags[idx][i] == TagState.TAGGED: #and neurons[i].last_spike_time:
                color_indices[i] = 3


        ax = fig.add_subplot(gspec[k])
        #make_contour_from_recorder2(ax, neurons, recorders[idx], coords, nlevels='auto', method='rbf_gaussian', min_levels=n_rounds, fill_in_times=True, start_idx=start_idx)
        make_contour_from_recorder(ax, neurons, recorders[idx], coords, nlevels=10, method='rbf_gaussian', min_levels=n_rounds, fill_in_times=True, start_idx=start_idx)


        # figure out which target was the actual (winning) target
        tid = -1
        k = 0
        while k < len(targets):
            if len(target_spike_times[idx][targets[k]]) > 0:
                tid = k
                break
            k = k + 1
        if tid >= 0 and tid < len(targets):
            tst = target_spike_times[idx][targets[tid]][0]
            ax.set_title(f"Iteration {idx+1}, {tst}ms TTT")
        ax.scatter(coords[:, 0], coords[:, 1], color=np.array(colors)[color_indices], alpha=alphas, s=15, linewidth=1.5, zorder=10)
        # start and stop indicators
        patch = Rectangle(tuple(coords[start_idx, :] - 0.05), 0.1, 0.1, lw=2, edgecolor='blue', facecolor='None', zorder=5)
        ax.add_patch(patch)
        for tidx in targets:
            patch = Rectangle(tuple(coords[tidx, :] - 0.05), 0.1, 0.1, lw=2, edgecolor='beige', facecolor='None', zorder=5)
            ax.add_patch(patch)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_xticks([])
        ax.set_yticks([])


    # connectivity plot and field distribution
    fig    = plt.figure(figsize=(5, 3), layout='constrained')
    gspec  = GridSpec(1, 1, figure=fig)
    ax     = fig.add_subplot(gspec[-1])
    plot_connectivity_scatter(ax, coords, T, start_idx)
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_xticks([])
    ax.set_yticks([])

    plt.show()


if __name__ == "__main__":
    # force random numbers to follow a distinct sequence, so that we get
    # reproducible figures
    np.random.seed(71)

    # Paper, Figure 2 (main results) ###########################################
    main('box', 'tss', (13, 3))
    main('amaze', 'tss', (13, 3.5))

    # Paper, Additional Results (Appendix) #####################################
    # reset seed
    np.random.seed(71)
    # other environments
    main('circle', 'tss', (13, 3))
    main('tmaze', 'tss', (13, 3))

    # local inhibition only
    main('amaze', 'tss', (13, 3.5), use_local_inhibition=True, use_global_inhibition=False)

    # no inhibition -> algorihm fails
    main('amaze', 'tss', (13, 3.5), use_local_inhibition=False, use_global_inhibition=False)

    # multi target, only local inhibition -> multiple candidate trajectories
    # still available
    main('amaze', 'tss', (13, 3.5), n_targets=2, use_global_inhibition=False, use_local_inhibition=True)

    # multi target, full inhibition -> only one solution
    main('amaze', 'tss', (13, 3.5), n_targets=2, use_global_inhibition=True, use_local_inhibition=True)

    # multi target, local inhibition only -> two competing solutions remain
    main('amaze', 'tss', (13, 3.5), n_targets=3, use_global_inhibition=False, use_local_inhibition=True)
