#!/usr/bin/env python

import heapq
from typing import List, Optional, Any
from dataclasses import dataclass, field
from enum import Enum


class NeuronState(Enum):
    RESTING    = "resting"
    INHIBITED  = "inhibited"
    PROCESSING = "processing"
    REFRACTORY = "refractory"
    SPIKING    = "spiking"

class TagState(Enum):
    NORMAL = 'normal'
    TAGGED = 'tagged'

class MessageType(Enum):
    INHIBITORY = "inhibitory"
    EXCITATORY = "excitatory"

@dataclass
class Message:
    source:    int
    type:      MessageType


class EventType(Enum):
    RECV_MSG     = "message_received"
    STATE_CHANGE = "state_change"


@dataclass
class Event:
    neuron_id: int
    type:      EventType
    timestamp: float
    data:      Optional[Any]

    def __lt__(self, other):
        return self.timestamp < other.timestamp


# NOTE: keep separation between simluation event types and record event types.
# TODO: at a later stage, consider merging them if suitable
class RecordEventType(Enum):
    STATE_CHANGE = 'state_change'
    SPIKE        = 'spike'


@dataclass
class SimulationRecord:
    timestamp: float
    neuron_id: int
    event_type: RecordEventType
    data: dict


class SimulationRecorder:
    def __init__(self):
        self.records = []

    def record_event(self, timestamp, neuron_id, event_type, **kwargs):
        self.records.append(SimulationRecord(timestamp, neuron_id, event_type, kwargs))


    def get_first_activation_times(self, event_filter=RecordEventType.SPIKE):
        """Extract first activation time for each neuron"""
        activation_times = {}

        for record in sorted(self.records, key=lambda r: r.timestamp):
            if record.event_type == event_filter:
                if record.neuron_id not in activation_times:
                    activation_times[record.neuron_id] = record.timestamp

        return activation_times

    def get_state_change_times(self, target_state=NeuronState.PROCESSING):
        """Extract first time each neuron enters a specific state"""
        state_times = {}

        for record in sorted(self.records, key=lambda r: r.timestamp):
            if (record.event_type == 'state_change' and
                record.data.get('to_state') == target_state):
                if record.neuron_id not in state_times:
                    state_times[record.neuron_id] = record.timestamp

        return state_times



@dataclass
class Simulation:
    event_queue : List  = field(default_factory=lambda: list())
    timestamp   : float = 0.0
    recorder    : SimulationRecorder = field(default_factory=lambda: SimulationRecorder())


    def schedule_event(self, event):
        heapq.heappush(self.event_queue, event)

    def next_event(self):
        if not self.event_queue:
            return None
        event = heapq.heappop(self.event_queue)
        return event

    def update_timestamp(self, timestamp):
        self.timestamp = timestamp




@dataclass
class Coord:
    x: float
    y: float



@dataclass
class NeuralTiming:
    # all times in milliseconds
    dendritic_delay         : float =  1.0

    axonal_delay_exc        : float =  5.0
    axonal_delay_inh        : float =  2.0

    processing_delay_normal : float = 10.0
    processing_delay_tagged : float =  5.0

    spiking_duration        : float =  0.1
    refractory_duration     : float =  2.0
    inhibition_duration     : float = 10.0


    # total processing delays that are expected when a neuron spikes. That is, the
    # I-E pair for tagging is expected to arrive from normal neurons after the
    # message got deliver (axonal_delay_exc) + the time it takes for the subsequent
    # neuron to generate a corresponding message. If the timing is wrong, i.e. too
    # early, then we sent a message to a tagged neuron that responded faster.
    tau_exc_normal : float = field(init=False)
    tau_inh_normal : float = field(init=False)
    tau_exc_tagged : float = field(init=False)
    tau_inh_tagged : float = field(init=False)


    def __post_init__(self):
        self.tau_exc_normal = self.axonal_delay_exc + self.dendritic_delay + self.processing_delay_normal + self.axonal_delay_exc
        self.tau_inh_normal = self.axonal_delay_exc + self.dendritic_delay + self.processing_delay_normal + self.axonal_delay_inh
        self.tau_exc_tagged = self.axonal_delay_exc + self.dendritic_delay + self.processing_delay_tagged + self.axonal_delay_exc
        self.tau_inh_tagged = self.axonal_delay_exc + self.dendritic_delay + self.processing_delay_tagged + self.axonal_delay_inh



@dataclass
class Neuron:
    # state machine
    id:                int
    state:             NeuronState = NeuronState.RESTING
    tag:               TagState = TagState.NORMAL

    # connectivity
    nbrs:              List = field(default_factory=lambda:list())
    nbrs_inh_local:    List = field(default_factory=lambda:list())
    nbrs_inh_global:   List = field(default_factory=lambda:list())

    # timing and tagging dynamics
    last_spike_time:   Optional[float] = None

    expected_inh_feedback: Optional[float] = None
    expected_exc_feedback: Optional[float] = None
    received_inh_feedback: Optional[float] = None
    received_exc_feedback: Optional[float] = None

    # inhibitory control
    use_local_inhibition : bool = False
    use_global_inhibition: bool = True

    # timings
    timing: NeuralTiming = field(default_factory=NeuralTiming)

    # other information, only used for plotting
    coord:             Coord = field(default_factory=lambda: Coord(0.0, 0.0))

    # debugging stuff
    verbose:           bool = False


    @property
    def is_tagged(self):
        return self.tag == TagState.TAGGED


    def schedule_state_change(self, sim, delay, s_from, s_to):
        ev = Event(self.id, EventType.STATE_CHANGE, sim.timestamp + delay, {'from': s_from, 'to': s_to})
        sim.schedule_event(ev)


    def send_msg(self, sim, target_id, delay, msg_type):
        ev = Event(target_id, EventType.RECV_MSG, sim.timestamp + delay, Message(self.id, msg_type))
        sim.schedule_event(ev)


    # handle arriving messages
    def recv(self, event, sim):
        msg = event.data

        # The following is the tagging dynamics
        if self.expected_inh_feedback is not None and self.expected_exc_feedback is not None:
            if msg.type == MessageType.INHIBITORY:
                self.received_inh_feedback = sim.timestamp
            if msg.type == MessageType.EXCITATORY:
                self.received_exc_feedback = sim.timestamp

            if self.received_inh_feedback is not None and self.received_exc_feedback is not None:
                inh_diff = self.expected_inh_feedback - self.received_inh_feedback
                exc_diff = self.expected_exc_feedback - self.received_exc_feedback
                early_I_E_pair = inh_diff > 0 and inh_diff <= (self.timing.tau_inh_normal - self.timing.tau_inh_tagged) and \
                                 exc_diff > 0 and exc_diff <= (self.timing.tau_exc_normal - self.timing.tau_exc_tagged)
                if early_I_E_pair:
                    self.tag = TagState.TAGGED
                    # reset all expectation variables / eligbility trace stuff
                    self.expected_inh_feedback = None
                    self.expected_exc_feedback = None
                    self.received_inh_feedback = None
                    self.received_exc_feedback = None


        # if we're in either of the following states, we cannot do more
        if self.state in [NeuronState.REFRACTORY, NeuronState.SPIKING]:
            return

        if self.state is NeuronState.INHIBITED and self.is_tagged:
            if msg.type == MessageType.EXCITATORY:
                self.schedule_state_change(sim, self.timing.dendritic_delay, self.state, NeuronState.PROCESSING)

        elif self.state == NeuronState.PROCESSING:
            # go into inhibited state if we receive inhibition. otherwise don't
            # care for now. Maybe we need to tighten the dendritic delay a bit
            if msg.type == MessageType.INHIBITORY:
                self.schedule_state_change(sim, self.timing.dendritic_delay, self.state, NeuronState.INHIBITED)

        elif self.state == NeuronState.RESTING:
            if msg.type == MessageType.EXCITATORY:
                self.schedule_state_change(sim, self.timing.dendritic_delay, self.state, NeuronState.PROCESSING)

            elif msg.type == MessageType.INHIBITORY:
                self.schedule_state_change(sim, self.timing.dendritic_delay, self.state, NeuronState.INHIBITED)


    # handle particular state changes
    def handle_state_change(self, event, sim):
        state_from = event.data['from']
        state_to   = event.data['to']
        if self.verbose:
            print(f"neuron {self.id} state change: {state_from} -> {state_to}")

        # TODO: handle state change mismatch properly. this might happen when we
        # got an inhibition while we're in the middle of processing. but also if
        # we receive multiple excitatory messages:
        #   resting neuron -> dendritic delay -> processing neuron
        # in the brief delay we might receive several other excitatory messages,
        # which then also invoke a state change from resting -> processing.
        # For now, we ignore this.
        if self.state != state_from:
            if self.verbose:
                print(f'possible state mismatch: is {self.state}, expected {state_from}')
            return

        match state_to:
            case NeuronState.INHIBITED:
                self.state = state_to
                self.schedule_state_change(sim, self.timing.inhibition_duration, self.state, NeuronState.RESTING)

            case NeuronState.PROCESSING:
                self.state = state_to
                delay = self.timing.processing_delay_normal if self.tag == TagState.NORMAL else self.timing.processing_delay_tagged
                self.schedule_state_change(sim, delay, self.state, NeuronState.SPIKING)

            case NeuronState.SPIKING:
                # we might have gotten an inhibition in the meantime. if that's the
                # case, don't do anything
                if self.state == NeuronState.INHIBITED and not self.is_tagged:
                    if self.verbose:
                        print("WW: won't send spike when inhibited and not tagged")

                else:
                    sim.recorder.record_event(sim.timestamp, self.id, RecordEventType.SPIKE)
                    self.state = NeuronState.SPIKING
                    self.last_spike_time = sim.timestamp

                    # NOTE: maybe in the future, give the message sending to the
                    # simulation itself, so that we don't know anything about
                    # the true connectivity in here - also this allows to send
                    # local and global inhibitory signals more easily

                    # send E messages to local neighborhood
                    for j in self.nbrs:
                        self.send_msg(sim, j, self.timing.axonal_delay_exc, MessageType.EXCITATORY)

                    # send global I message
                    if self.is_tagged:
                        if self.use_local_inhibition:
                            for j in self.nbrs_inh_local:
                                self.send_msg(sim, j, self.timing.axonal_delay_inh, MessageType.INHIBITORY)

                        if self.use_global_inhibition:
                            for j in self.nbrs_inh_global:
                                self.send_msg(sim, j, self.timing.axonal_delay_inh, MessageType.INHIBITORY)

                    # set time when we expect feedback.
                    # first line is always: delays from this neuro
                    # second line is:       delays from following neuron(s)
                    self.expected_inh_feedback = sim.timestamp + self.timing.tau_inh_normal
                    self.expected_exc_feedback = sim.timestamp + self.timing.tau_exc_normal

                    self.received_inh_feedback, self.received_exc_feedback = None, None

                    # schedule return to resting
                    self.schedule_state_change(sim, self.timing.spiking_duration, self.state, NeuronState.REFRACTORY)


            # NOTE: We should not see a state change for refractory on this
            # level, because we never schedule one (refractory is directly set
            # when we receive a SPIKING state change). Nevertheless, keep it
            # here, to cover potential future bugs that rely on it.
            case NeuronState.REFRACTORY:
                self.state = state_to
                self.schedule_state_change(sim, self.timing.refractory_duration, self.state, NeuronState.RESTING)

        sim.recorder.record_event(sim.timestamp, self.id, RecordEventType.STATE_CHANGE, from_state=state_from, to_state=self.state)


    def handle_event(self, event, sim):
        match event.type:
            case EventType.RECV_MSG:
                self.recv(event, sim)

            case EventType.STATE_CHANGE:
                self.handle_state_change(event, sim)




