"""
Modules containing class that
    1. the most abstract circuit connection information
"""


class AbsSegment:
    """class for abstract segment

    Attributes:
        cell: AbsCell class, the cell that contains the segment
        index: int, segment index in the cell, starts from 0
        name: str, segment name generated by Neuron
        pre_connections, post_connections: list of AbsConnections, the segment acts as postsynaptic segment or presynaptic segment
    """
    def __init__(self, index, cell, name):
        assert isinstance(cell, AbsCell)

        self.cell, self.index, self.name = cell, int(index), name
        self.pre_connections, self.post_connections = [], []

    def add_pre_connection(self, new_connection):
        """the new connection points to self segment"""
        assert isinstance(new_connection, AbsConnection)

        if self.index == new_connection.post_segment.index:
            self.pre_connections.append(new_connection)
        else:
            raise ValueError("AbsSegment Index Incompatible!")

    def add_post_connection(self, new_connection):
        """the new connection points from self segment"""
        assert isinstance(new_connection, AbsConnection)

        if self.index == new_connection.pre_segment.index:
            self.post_connections.append(new_connection)
        else:
            raise ValueError("AbsSegment Index Incompatible!")


class AbsConnection:
    """class for abstract connection
    either be synapse (positive weight for excitatory and negative for inhibitory) or gap-junction

    Attributes:
        pre_segment, post_segment: AbsSegment class, pre-synaptic segment or post-synaptic segment
        pre_cell, post_cell: AbsCell class, pre-synaptic cell or post-synaptic cell
        category: str, either 'syn' or 'gj'
        weight: float
        pair_key: int or None, two gj belongs to one pair has the same unique key in the circuit, syn's key is None
    """
    def __init__(self, pre_segment, post_segment, category, weight, pair_key=None):
        assert category in ('syn', 'gj')
        assert isinstance(pre_segment, (AbsSegment, type(None))) and isinstance(post_segment, (AbsSegment, type(None)))

        self.pre_segment, self.post_segment = pre_segment, post_segment
        self.pre_cell = pre_segment.cell if pre_segment is not None else None
        self.post_cell = post_segment.cell if post_segment is not None else None
        self.category, self.weight, self.pair_key = category, weight, pair_key

    def update_info(self, new_weight=None):
        """change the weight or category of the connection"""
        assert new_weight is not None
        self.weight = new_weight


class AbsCell:
    """class for abstract cell

    Attributes:
        index: int, cell index
        name: str, cell name
        segments: list of AbsSegment object
        pre_connections, post_connections: list of connection, the Cell acts as post-synaptic cell or pre-synaptic cell respectively
    """
    def __init__(self, index, name):
        self.index, self.name = int(index), name
        self.segments = []
        self.pre_connections, self.post_connections = [], []

    def add_segment(self, new_segment):
        """adding AbsSegment into self.segments"""
        assert isinstance(new_segment, AbsSegment)
        assert new_segment.cell.index == self.index

        for segment in self.segments:
            if segment.index == new_segment.index:
                raise ValueError("Segment already inserted!")
        self.segments.append(new_segment)

    def add_pre_connection(self, new_connection):
        """the new connection points to the cell"""
        assert isinstance(new_connection, AbsConnection)

        if new_connection.post_cell.index == self.index:
            self.pre_connections.append(new_connection)
        else:
            raise ValueError("Connection post-synaptic cell incompatible!")

    def add_post_connection(self, new_connection):
        """the new connection points out of the cell"""
        assert isinstance(new_connection, AbsConnection)

        if new_connection.pre_cell.index == self.index:
            self.post_connections.append(new_connection)
        else:
            raise ValueError("Connection pre-synaptic cell incompatible!")

    def segment(self, segment_index=None, segment_name=None):
        """obtain the segment object according to its index or name"""
        assert (segment_index is not None) or (segment_name is not None)
        for segment in self.segments:
            if ((segment.index == segment_index) or (segment_index is None)) and \
                    ((segment.name == segment_name) or (segment_name is None)):
                return segment
        raise ValueError("Segment not found!")

    def update_connections(self, new_weights):
        """updating pre_connections weight"""
        assert len(self.pre_connections) == len(new_weights)
        for connection, new_weight in zip(self.pre_connections, new_weights):
            connection.update_info(new_weight)


class AbstractCircuit:
    """class for abstract circuit

    Attributes:
        cells: list of all AbsCell object in the circuit
        connections: list of AbsConnection, all connections include input and output connections
        input_connections: list of AbsConnection, input connections
        output_connections: list of AbsConnection, output connections
    """
    def __init__(self):
        self.cells, self.connections = [], []
        self.input_connections, self.output_connections = [], []

    def add_cell(self, new_cell):
        """cell append interface"""
        assert isinstance(new_cell, AbsCell)
        self.cells.append(new_cell)

    def add_connection(self, new_connection):
        """connection append interface"""
        assert isinstance(new_connection, AbsConnection)
        self.connections.append(new_connection)
        if new_connection.pre_segment is None:
            self.input_connections.append(new_connection)
        else:
            new_connection.pre_segment.add_post_connection(new_connection)
            new_connection.pre_cell.add_post_connection(new_connection)

        if new_connection.post_segment is None:
            self.output_connections.append(new_connection)
        else:
            new_connection.post_segment.add_pre_connection(new_connection)
            new_connection.post_cell.add_pre_connection(new_connection)

    def cell(self, cell_index=None, cell_name=None):
        """obtain the cell object according to its index or name"""
        assert (cell_index is not None) or (cell_name is not None)
        for cell in self.cells:
            if ((cell.index == cell_index) or (cell_index is None)) and \
                    ((cell.name == cell_name) or (cell_name is None)):
                return cell
        raise ValueError("Cell not found!")

    def update_connections(self, new_weights):
        """updating connects weight"""
        assert len(self.connections) == len(new_weights)
        for connection, new_weight in zip(self.connections, new_weights):
            connection.update_info(new_weight)
