import os
import sys
from setuptools import setup, find_packages

if sys.argv[-1] == 'publish':
    os.system('python setup.py sdist upload')
    sys.exit()

setup(
    name='simcronomicon',
    version='0.1.0',
    description='Event-driven agent-based spread simulation framework for modeling disease spread in realistic spatial environments using geographical data from OpenStreetMap.',
    long_description=open('README.md', encoding='utf-8').read(),
    long_description_content_type='text/markdown',
    author='Warisa Roongaraya',
    author_email='compund555@gmail.com',
    url='https://github.com/warisa-r/simcronomicon',
    packages=find_packages(),
    package_dir={'simcronomicon': 'simcronomicon'},
    include_package_data=True,
    install_requires=[],  # Empty since dependencies are handled by conda environment
    python_requires='>=3.12',
    license='MIT',
    zip_safe=False,
    keywords=['epidemiology', 'agent-based-modeling', 'disease-spread', 'simulation', 'openstreetmap'],
    classifiers=[
        'Development Status :: 2 - Pre-Alpha',
        'Intended Audience :: Developers',
        'Intended Audience :: Science/Research',
        'License :: OSI Approved :: MIT License',
        'Natural Language :: English',
        'Programming Language :: Python :: 3.12',
        'Topic :: Scientific/Engineering :: Medical Science Apps.',
        'Topic :: Scientific/Engineering :: Information Analysis',
    ],
)

# Configuration file for the Sphinx documentation builder.
#
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

# -- Project information -----------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information

import sys
import os
project = 'simcronomicon'
copyright = '2025, Warisa Roongaraya'
author = 'Warisa Roongaraya'
release = '0.1.0'

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

html_theme = 'sphinx_rtd_theme'
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.napoleon']


sys.path.insert(0, os.path.abspath('..'))
templates_path = ['_templates']
exclude_patterns = []


# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output

html_static_path = ['_static']


import json
import random as rd

import h5py
import numpy as np

from .infection_models import EventType


class Simulation:
    """
    Agent-based simulation engine for epidemic modeling in spatial networks.

    The Simulation class implements an agent-based modeling (ABM) framework that applies
    transition rules according to user-defined infection models. Agents move through
    a spatial town network, interact with each other and their environment, and undergo
    state transitions based on the rules defined in the chosen infection model.

    Purpose
    -------
    1. **Initialize Population**: Distribute agents across the town network according to
       user-specified parameters, including initial spreader locations and population size.

    2. **Agent Movement**: Move agents through the city during simulation timesteps based
       on step events that define mobility patterns and destination preferences.

    3. **Agent Interactions**: Enable agent-to-agent and agent-to-environment interactions
       at each location according to the rules defined in the infection model.

    4. **State Transitions**: Apply infection model transition rules (e.g., S→E→I→R)
       based on agent interactions, environmental factors, and time-dependent processes.

    5. **Temporal Dynamics**: Execute simulation in discrete timesteps, where each step
       consists of multiple events, and agents return home after each complete step.

    Simulation Workflow
    -------------------
    Each simulation timestep follows this pattern:

    1. **Event Execution**: For each step event in the current timestep:
       - Reset agent locations (clear previous positions)
       - Execute event-specific movement (DISPERSE) or actions (SEND_HOME)
       - Apply infection model rules for agent interactions
       - Record population status and individual agent states

    2. **Agent Movement**: During DISPERSE events:
       - Agents move to locations within their travel distance
       - Movement considers place type preferences and priority destinations
       - Probability functions can influence destination selection

    3. **Interactions**: At each active location:
       - Agents interact according to infection model rules
       - Environmental factors (place type) influence interaction outcomes
       - State transitions occur based on model-specific probabilities

    4. **Home Reset**: After all events, agents return to their home addresses

    Parameters
    ----------
    town : Town
        The Town object representing the spatial network with nodes (locations)
        and edges (travel routes) where the simulation takes place.
    infection_model : AbstractInfectionModel
        The infection model instance (e.g., SEIRModel, SEIQRDVModel) that
        defines agent states, transition rules, and interaction behaviors.
    timesteps : int
        Number of discrete timesteps to run the simulation.
    seed : bool, optional
        Whether to set random seeds for reproducible results (default: True).
    seed_value : int, optional
        Random seed value for reproducibility (default: 5710).

    Attributes
    ----------
    folks : list
        List of AbstractFolk (agent) objects representing the population.
    town : Town
        The spatial network where agents live and move.
    model : AbstractInfectionModel
        The infection model governing agent behavior and transitions.
    step_events : list
        Sequence of events that occur in each timestep.
    active_node_indices : set
        Set of town nodes currently occupied by agents.
    status_dicts : list
        Historical record of population status counts at each timestep.

    Raises
    ------
    ValueError
        If required place types for the chosen infection model are missing
        in the town data. This ensures model-specific locations (e.g., healthcare
        facilities for medical models) are available in the spatial network.

    Examples
    --------

    >>> # Create town and model
    >>> town = Town.from_point(point=[50.7753, 6.0839], dist=1000,
    ...                        town_params=TownParameters(num_pop=1000, num_init_spreader=10))
    >>> model_params = SEIRModelParameters(max_energy=10, beta=0.3, sigma=5, gamma=7, xi=100)
    >>> model = SEIRModel(model_params)
    >>>
    >>> # Run simulation
    >>> sim = Simulation(town, model, timesteps=100)
    >>> sim.run(hdf5_path="epidemic_simulation.h5")

    Notes
    -----

    - The simulation saves detailed results to HDF5 format, including population
      summaries and individual agent trajectories.

    - Agent energy levels affect movement capability and interaction potential.

    - Movement restrictions (e.g., quarantine) can limit agent mobility while
      still allowing interactions with visiting agents.

    - The simulation automatically terminates early if no infected agents remain.
    """

    def __init__(
            self,
            town,
            infection_model,
            timesteps,
            seed=True,
            seed_value=5710):
        """
        Initialize a Simulation instance.

        Parameters
        ----------
        town : Town
            The Town object representing the simulation environment.
        infection_model : AbstractInfectionModel
            The infection model instance (e.g., SEIRModel) to use for the simulation.
        timesteps : int
            Number of timesteps to run the simulation.
        seed : bool, optional
            Whether to set the random seed for reproducibility (default: True).
        seed_value : int, optional
            The value to use for the random seed (default: 5710).

        Raises
        ------
        ValueError
            If required place types for the model are missing in the town data of the given spatial area.
        """
        self.folks = []
        self.status_dicts = []
        self.town = town
        self.num_pop = town.town_params.num_pop
        self.model = infection_model
        self.model_params = infection_model.model_params
        self.step_events = infection_model.step_events
        self.current_timestep = 0
        self.timesteps = timesteps
        self.active_node_indices = set()
        self.nodes_list = list(self.town.town_graph.nodes)

        missing = [
            ptype for ptype in self.model.required_place_types if ptype not in self.town.found_place_types]
        if missing:
            raise ValueError(
                f"Missing required place types for this model in town data: {missing}. Please increase the radius of your interested area or change it.")

        if seed:
            rd.seed(seed_value)
            np.random.seed(seed_value)

        self.folks, self.household_node_indices, status_dict_t0 = self.model.initialize_sim_population(
            town)
        self.active_node_indices = self.household_node_indices.copy()

        self.status_dicts.append(status_dict_t0)

    def _reset_population_home(self):
        # Reset every person's current address to their home address
        # And reset the town graph
        # In addition, send everyone to sleep as well
        for i in range(self.num_pop):
            self.folks[i].address = self.folks[i].home_address
            self.town.town_graph.nodes[self.folks[i].home_address]["folks"].append(
                self.folks[i])

        self.num_pop = self.model.update_population(
            self.folks, self.town, self.household_node_indices, self.status_dicts[-1])
        # Simple list -> Shallow copy
        self.active_node_indices = self.household_node_indices.copy()

    def _disperse_for_event(self, step_event):
        # Send the town population to the place they are supposed to be,
        # given that they have enough energy
        for person in self.folks:
            if person.movement_restricted == False and person.alive and person.energy > 0:
                current_node = person.address
                candidates = []
                # Get the shortest path lengths from current_node to all other
                # nodes, considering edge weights

                if person.priority_place_type == []:
                    # If this agent doesn't have a place that they prioritize to go to, send them on their normal schedule
                    # like everybody else in the town.
                    # Get the nodes where the shortest path length is less than or
                    # equal to the possible travel distance
                    candidates = [
                        neighbor for neighbor in self.town.town_graph.nodes
                        if neighbor != current_node
                        # check if an edge exists
                        and self.town.town_graph[current_node].get(neighbor)
                        and self.town.town_graph[current_node][neighbor]['weight'] <= step_event.max_distance
                        and self.town.town_graph.nodes[neighbor]['place_type'] in step_event.place_types
                    ]
                else:
                    # If the agent has prioritized place types to go to
                    # Find the closest node with one of those place types,
                    # regardless of max_distance
                    min_dist = float('inf')
                    chosen_node = None
                    chosen_place_type = None
                    for node in self.town.town_graph.nodes:
                        node_place_type = self.town.town_graph.nodes[node]['place_type']
                        if node_place_type in person.priority_place_type:
                            if self.town.town_graph.has_edge(
                                    person.address, node):
                                dist = self.town.town_graph[person.address][node]['weight']
                            else:
                                continue
                            if dist < min_dist:
                                min_dist = dist
                                chosen_node = node
                                chosen_place_type = node_place_type

                    # If there exists a precomputed shortest path from the current location to this place,
                    # move agent to the prioritized place and remove that place
                    # from the priority list.
                    if chosen_node and chosen_place_type:
                        candidates = [chosen_node]
                        # Remove the visited place type from the priority list
                        person.priority_place_type.remove(chosen_place_type)

                if candidates:
                    if step_event.probability_func is not None:
                        distances = [
                            self.town.town_graph[current_node][neighbor]['weight']
                            for neighbor in candidates
                        ]
                        probs = step_event.probability_func(distances, person)
                        new_node = np.random.choice(candidates, p=probs)
                    else:
                        new_node = rd.choice(candidates)
                    # Update person's address
                    person.address = new_node
            self.town.town_graph.nodes[person.address]["folks"].append(person)

        # Reset active_node_indices and update consistently
        self.active_node_indices = set()
        for node in self.town.town_graph.nodes:
            if len(self.town.town_graph.nodes[node]) >= 2:
                self.active_node_indices.add(node)

    def _execute_event(self, step_event):
        # Regardless of the type of events, there are always movements.
        # To consistently update the list we have to
        # reset every house to empty first and fill in the folks at the nodes
        # after their address changes
        for i in range(len(self.town.town_graph.nodes)
                       ):
            self.town.town_graph.nodes[i]["folks"] = []

        for person in self.folks:
            person.clear_previous_event_effect()

        if step_event.event_type == EventType.SEND_HOME:
            for i in range(self.num_pop):
                if not self.folks[i].alive:
                    continue
                # Dummy folks_here and current_place_type since
                # this type of event is meant to relocate people and allow them some time to pass
                # for time-sensitive transition while they do that
                step_event.folk_action(
                    self.folks[i], None, None, self.status_dicts[-1], self.model_params, rd.random())
            if step_event.name == "end_day":
                self._reset_population_home()
        elif step_event.event_type == EventType.DISPERSE:
            # Move people through the town first
            self._disperse_for_event(step_event)
            for node in self.active_node_indices:  # Only iterate through active nodes
                # A person whose movement is restricted can stil be interact with other people who come to their location
                # e.g. delivery service comes into contact with people are
                # quarantined...
                folks_here = [folk for folk in self.town.town_graph.nodes[node]
                              ["folks"] if folk.alive and folk.energy > 0]
                current_place_type = self.town.town_graph.nodes[node]['place_type']
                for folk in folks_here:
                    step_event.folk_action(folk,
                                           folks_here,
                                           current_place_type,
                                           self.status_dicts[-1],
                                           self.model_params,
                                           rd.random())

    def _step(self):
        # Advances the simulation by one timestep, executing all step events for the current round.
        # Updates population status and records each agent's state after every
        # event.
        current_timestep = self.current_timestep + 1
        status_row = None
        indiv_folk_rows = []

        for step_event in self.step_events:
            # Copy and annotate the new state
            self.status_dicts.append(self.status_dicts[-1].copy())
            self.status_dicts[-1]['timestep'] = current_timestep
            self.status_dicts[-1]['current_event'] = step_event.name

            self._execute_event(step_event)

            # Record the latest summary
            status_row = self.status_dicts[-1].copy()

            # Record each individual's state
            for folk in self.folks:
                indiv_folk_rows.append({
                    'timestep': current_timestep,
                    'event': step_event.name,
                    'folk_id': folk.id,
                    'status': folk.status,
                    'address': folk.address
                })
        self.current_timestep = current_timestep

        return status_row, indiv_folk_rows

    def run(self, hdf5_path="simulation_output.h5", silent=False):
        """
        Run the simulation for the specified number of timesteps.

        The simulation results are saved to an HDF5 file with the following structure:

        .. code-block:: text

            simulation_output.h5
            ├── config
            │   ├── simulation_config   (JSON-encoded simulation config)
            │   └── town_config         (JSON-encoded town config)
            ├── status_summary
            │   └── summary               (dataset: structured array with timestep, current_event, and statuses)
            └── individual_logs
                └── log                   (dataset: structured array with timestep, event, folk_id, status, address)

        Parameters
        ----------
        hdf5_path : str
            Path to the output HDF5 file.

        Returns
        -------
        None
        """
        try:
            with h5py.File(hdf5_path, "w") as h5file:
                # Save simulation config
                config_group = h5file.create_group("config")

                # Write simulation configuration information (without town)
                sim_config = {
                    'seed_enabled': hasattr(
                        self,
                        'seed_value'),
                    'seed_value': getattr(
                        self,
                        'seed_value',
                        None),
                    'all_statuses': self.model.all_statuses,
                    'model_parameters': self.model_params.to_config_dict(),
                    'num_locations': len(
                        self.town.town_graph.nodes),
                    'max_timesteps': self.timesteps,
                    'population': self.num_pop,
                    'step_events': [
                        {
                            'name': event.name,
                            'max_distance': event.max_distance,
                            'place_types': event.place_types,
                            'event_type': event.event_type.value,
                            'probability_func': event.probability_func.__name__ if event.probability_func else None,
                        } for event in self.step_events],
                }
                sim_config_json = json.dumps(sim_config)
                config_group.create_dataset(
                    "simulation_config", data=np.bytes_(sim_config_json))

                # Write town configuration information separately
                town_config = {
                    "origin_point": [
                        float(
                            self.town.origin_point[0]),
                        float(
                            self.town.origin_point[1])],
                    "dist": self.town.dist,
                    "epsg_code": self.town.epsg_code,
                    "accommodation_nodes": list(
                        self.town.accommodation_node_ids)}
                town_config_json = json.dumps(town_config)
                config_group.create_dataset(
                    "town_config", data=np.bytes_(town_config_json))

                # Save initial status summary
                status_group = h5file.create_group("status_summary")
                status_dtype = [("timestep", 'i4'), ("current_event", 'S32')] + [
                    (status, 'i4') for status in self.model.all_statuses]
                status_data = []
                initial_status = self.status_dicts[-1]
                row = tuple([initial_status.get("timestep", 0),
                            bytes(str(initial_status.get("current_event", "")), 'utf-8')] +
                            [initial_status[status] for status in self.model.all_statuses])
                status_data.append(row)

                # Save initial individual logs
                indiv_group = h5file.create_group("individual_logs")
                folk_dtype = [("timestep", 'i4'), ("event", 'S32'),
                              ("folk_id", 'i4'), ("status", 'S8'), ("address", 'i4')]
                indiv_data = [
                    (0, b"", folk.id, bytes(folk.status, 'utf-8'), folk.address)
                    for folk in self.folks
                ]

                # Run simulation
                for i in range(1, self.timesteps + 1):
                    status_row, indiv_rows = self._step()

                    # Collect status row
                    row = tuple([
                        status_row["timestep"],
                        bytes(str(initial_status.get("current_event", "")), 'utf-8')
                    ] + [status_row[status] for status in self.model.all_statuses])
                    status_data.append(row)

                    # Collect individual rows
                    for row in indiv_rows:
                        indiv_data.append((
                            row["timestep"],
                            bytes(row["event"], 'utf-8'),
                            row["folk_id"],
                            bytes(row["status"], 'utf-8'),
                            row["address"]
                        ))

                    if not silent:
                        print("Step has been run", i)
                        print(
                            "Status:", {
                                k: v for k, v in status_row.items() if k not in (
                                    'timestep', 'current_event')})

                    if sum(status_row[status]
                            for status in self.model.infected_statuses) == 0:
                        break

                # Store final datasets
                status_group.create_dataset(
                    "summary", data=np.array(
                        status_data, dtype=status_dtype))
                indiv_group.create_dataset(
                    "log", data=np.array(
                        indiv_data, dtype=folk_dtype))
        except IOError as e:
            print(f"Error writing simulation output: {e}")


import copy
import json
import os
import tempfile
import time
import zipfile

import igraph as ig
import networkx as nx
import numpy as np
import osmnx as ox
from scipy.spatial import KDTree
from tqdm import tqdm

# Default place type categories used throughout the simulation
PLACE_TYPES = [
    "accommodation",
    "healthcare_facility",
    "commercial",
    "workplace",
    "education",
    "religious",
    "other"
]

# Default classification criteria for OpenStreetMap tags to place types
PLACE_TYPE_CRITERIA = {
    "accommodation": {
        "building": {
            'residential', 'apartments', 'house', 'detached', 'dormitory', 'terrace',
            'allotment_house', 'bungalow', 'semidetached_house', 'hut'
        }
    },
    "healthcare_facility": {
        "building": {'hospital', 'dentist'},
        "healthcare": {'hospital', 'clinic', 'doctor', 'doctors', 'pharmacy', 'laboratory'},
        "amenity": {'hospital', 'clinic', 'doctors', 'pharmacy', 'dentist'},
        "shop": {'medical_supply'},
        "emergency": {'yes'}
    },
    "commercial": {
        "building": {'commercial', 'retail', 'supermarket', 'shop', 'service', 'sports_centre'},
        "amenity": {'restaurant', 'bar', 'cafe', 'bank', 'fast_food'},
        "landuse": {'commercial'}
    },
    "workplace": {
        "building": {'office', 'factory', 'industrial', 'government'},
        "amenity": {'office', 'factory', 'industry'},
        "landuse": {'industrial', 'office'}
    },
    "education": {
        "building": {'school', 'university', 'kindergarten'},
        "amenity": {'university', 'kindergarten'}
    },
    "religious": {
        "building": {'chapel', 'church', 'temple', 'mosque', 'synagogue'},
        "amenity": {'chapel', 'church', 'temple', 'mosque', 'synagogue'},
        "landuse": {'religious'}
    }
    # "other" is the default fallback when no criteria match
}


def classify_place(row):
    """
    Classify an OpenStreetMap point of interest into a place type category.

    This default function examines OSM tags for buildings, amenities, land use, etc.
    and determines the appropriate functional category for simulation purposes.

    Parameters
    ----------
    row : pandas.Series
        A row from a GeoDataFrame containing OpenStreetMap tags.

    Returns
    -------
    str
        Place type category from PLACE_TYPES list.

    Notes
    -----
    Classification hierarchy:

    1. accommodation*compulsory for every model): Residential buildings where agents live
       - Identified by building tags like 'residential', 'apartments', 'house', etc.

    2. healthcare_facility(compulsory for models with vaccination and/or symptom treatments):
        Hospitals, clinics, pharmacies
       - Identified by building, healthcare, amenity tags, or emergency=yes

    3. commercial: Shops, restaurants, banks
       - Identified by retail/service building types or commercial amenities

    4. workplace: Offices, factories, industrial areas
       - Identified by office/industrial buildings and land use

    5. education: Schools, universities
       - Identified by educational building and amenity types

    6. religious: Churches, mosques, temples
       - Identified by religious building types or land use

    7. other: Default fallback for unclassified locations
       - Any point that doesn't match the above criteria
        We are not going to consider these unclassified locations in general.
        You can, however, with your own custom classification, assign the
        unknown nodes to be any other plaec types randomly if you wish to also
        include them in the simulation.

    Classification uses the PLACE_TYPE_CRITERIA dictionary that maps
    place types to relevant OSM tags.

    It is also important to note that you can also customize your own classification
    function according to OSM taggings in your area of interest if you
    have more place types you want to use in your simulation or that you notice that
    our classification criteria isn't inclusive enough for your area of interest.
    """
    # Extract and normalize OSM tags
    b = str(row.get("building", "")).lower()
    a = str(row.get("amenity", "")).lower()
    l = str(row.get("landuse", "")).lower()
    h = str(row.get("healthcare", "")).lower()
    s = str(row.get("shop", "")).lower()
    e = str(row.get("emergency", "")).lower()

    # Check each place type's criteria
    for place_type, criteria in PLACE_TYPE_CRITERIA.items():
        # Check if the row matches any of the criteria for this place type
        if ("building" in criteria and b in criteria["building"]) or \
           ("amenity" in criteria and a in criteria["amenity"]) or \
           ("landuse" in criteria and l in criteria["landuse"]) or \
           ("healthcare" in criteria and h in criteria["healthcare"]) or \
           ("shop" in criteria and s in criteria["shop"]) or \
           ("emergency" in criteria and e in criteria["emergency"]):
            return place_type

    return "other"


class TownParameters():
    """
    Parameters for town network initialization and agent placement.

    Validates and stores population size, initial spreader count, and optional
    spreader node locations for the simulation set up.

    Parameters
    ----------
    num_pop : int
        Total population size for the simulation. Must be a positive integer (> 0).
        Determines the number of agents that will be created and distributed across
        accommodation nodes in the town network.

    num_init_spreader : int
        Number of initial disease spreaders at simulation start. Must be a positive
        integer (> 0) and cannot exceed num_pop. These agents begin the simulation
        in an infected state to seed the epidemic spread.

    spreader_initial_nodes : list of int, optional
        Specific node IDs where initial spreaders should be placed. This list can be:

        - **Empty** (default): All spreaders will be randomly assigned to accommodation nodes
        - **Partial**: Contains fewer nodes than num_init_spreader; remaining spreaders
          will be randomly assigned to accommodation nodes
        - **Complete**: Contains exactly num_init_spreader nodes for full control
        - **With duplicates**: Same node ID can appear multiple times to place multiple
          spreaders at the same location

        Node IDs must be integers or convertible to integers. The list length must not
        exceed num_init_spreader (len(spreader_initial_nodes) ≤ num_init_spreader).

    Attributes
    ----------
    num_pop : int
        Validated total population size.
    num_init_spreader : int
        Validated number of initial spreaders.
    spreader_initial_nodes : list
        Validated list of spreader node locations.

    Raises
    ------
    TypeError
        If parameters are not of expected types or nodes not convertible to int.
    ValueError
        If values are non-positive, spreaders exceed population, or too many
        node locations specified.

    Examples
    --------
    >>> # Basic configuration
    >>> params = TownParameters(num_pop=1000, num_init_spreader=10)

    >>> # With specific spreader locations
    >>> params = TownParameters(
    ...     num_pop=1000,
    ...     num_init_spreader=3,
    ...     spreader_initial_nodes=[5, 12, 47]
    ... )

    >>> # Partial specification with duplicates
    >>> params = TownParameters(
    ...     num_pop=1000,
    ...     num_init_spreader=5,
    ...     spreader_initial_nodes=[10, 10, 25]  # 3 of 5 spreaders specified
    ... )
    """

    def __init__(self, num_pop, num_init_spreader, spreader_initial_nodes=[]):
        # Validate num_pop
        if not isinstance(num_pop, int):
            raise TypeError(
                f"num_pop must be an integer, got {type(num_pop).__name__}")
        if num_pop <= 0:
            raise ValueError(f"num_pop must be positive, got {num_pop}")

        # Validate num_init_spreader
        if not isinstance(num_init_spreader, int):
            raise TypeError(
                f"num_init_spreader must be an integer, got {
                    type(num_init_spreader).__name__}")
        if num_init_spreader <= 0:
            raise ValueError(
                f"num_init_spreader must be positive, got {num_init_spreader}")
        if num_init_spreader > num_pop:
            raise ValueError(
                f"num_init_spreader ({num_init_spreader}) cannot exceed num_pop ({num_pop})")

        # Validate spreader_initial_nodes
        if not isinstance(spreader_initial_nodes, list):
            raise TypeError(
                f"spreader_initial_nodes must be a list, got {
                    type(spreader_initial_nodes).__name__}")

        if num_init_spreader < len(spreader_initial_nodes):
            raise ValueError(
                "There cannot be more locations of the initial spreaders than the number of initial spreaders")

        # Store validated values
        self.num_init_spreader = num_init_spreader
        self.num_pop = num_pop
        self.spreader_initial_nodes = spreader_initial_nodes


class Town():
    """
    Spatial network representation for agent-based epidemic modeling.

    The Town class represents a spatial network derived from OpenStreetMap data,
    where nodes correspond to places of interest (POIs) and edges represent
    walkable paths between locations. This network serves as the environment
    where agents move, interact, and undergo state transitions during simulation.

    Purpose
    -------
    1. **Spatial Network Creation**: Build a graph representation of urban areas
       from OpenStreetMap data, including roads, buildings, and points of interest.

    2. **Place Classification**: Categorize locations into functional types
       (accommodation, workplace, healthcare, etc.) that influence agent behavior.

    3. **Agent Housing**: Provide accommodation nodes where agents reside and
       return to after each simulation timestep.

    4. **Distance Calculation**: Maintain shortest-path distances between all
       locations to enable realistic agent movement patterns.

    5. **Data Persistence**: Save and load town networks to/from compressed
       files for reuse across multiple simulations.

    Network Structure
    -----------------
    - **Nodes**: Represent places of interest with attributes including:

      - place_type: Functional category of the location

      - x, y: Projected coordinates in the town's coordinate system

      - folks: List of agents currently at this location

    - **Edges**: Represent walkable connections with attributes:

      - weight: Shortest-path distance in meters between connected nodes

    Place Types
    -----------
    The default classification system recognizes:
    - accommodation: Residential buildings where agents live

    - healthcare_facility: Hospitals, clinics, pharmacies

    - commercial: Shops, restaurants, banks

    - workplace: Offices, factories, industrial areas

    - education: Schools, universities

    - religious: Churches, mosques, temples

    - other: Unclassified locations (filtered out by default)

    Attributes
    ----------
    town_graph : networkx.Graph
        The spatial network with nodes representing locations and edges
        representing shortest paths between them.
    town_params : TownParameters
        Configuration parameters including population size and initial spreader locations.
    epsg_code : int
        EPSG coordinate reference system code for spatial projections.
    point : tuple
        Origin point [latitude, longitude] used to center the network.
    dist : float
        Radius in meters defining the network extent from the origin point.
    all_place_types : list
        Complete list of possible place type categories.
    found_place_types : set
        Place types actually present in this town network.
    accommodation_node_ids : list
        Node IDs of all accommodation locations where agents can reside.

    Examples
    --------
    >>> # Create town from geographic coordinates (Aachen, Germany)
    >>> town_params = TownParameters(num_pop=1000, num_init_spreader=10)
    >>> town = Town.from_point(
    ...     point=[50.7753, 6.0839],  # Aachen Dom coordinates
    ...     dist=1000,  # 1km radius
    ...     town_params=town_params
    ... )
    >>>
    >>> # Load previously saved town
    >>> town = Town.from_files(
    ...     config_path="town_config.json",
    ...     town_graph_path="town_graph.graphmlz",
    ...     town_params=town_params
    ... )
    >>>
    >>> # Examine town properties
    >>> print(f"Town has {len(town.town_graph.nodes)} locations")
    >>> print(f"Place types found: {town.found_place_types}")
    >>> print(f"Accommodation nodes: {len(town.accommodation_node_ids)}")

    Notes
    -----
    - The town network uses shortest-path distances calculated from road network
      edges rather than Euclidean distances to provide realistic travel times
      between locations. These distances are computed by finding the shortest
      route along actual roads and pathways connecting places.

    - All shortest paths between every pair of places are pre-calculated during
      town creation, and the resulting simplified graph stores these distances
      as direct edge weights. This optimization dramatically reduces computational
      overhead during simulation steps, as agent movement only requires looking
      up neighboring edge weights rather than performing path searches.

    - This pre-computation approach is especially beneficial when running multiple
      simulations in the same location or simulations with many agents and timesteps,
      as the expensive shortest-path calculations are done once during town creation.

    - Building centroids are mapped to nearest road network nodes to ensure
      all locations are accessible via the street network.

    - Custom place classification functions can be provided to adapt the
      categorization system to specific research needs.

    - Town networks are automatically saved in compressed GraphMLZ format
      along with JSON config_data for efficient storage and reuse. These output
      files serve as input files for the Simulation class, enabling rapid
      simulation setup without re-downloading or re-processing OpenStreetMap data.

    Raises
    ------

    ValueError
        If the specified point coordinates are invalid, if no relevant
        locations remain after filtering, or if initial spreader nodes
        don't exist in the network.
    TypeError
        If the place classification function is not callable or if required
        parameters are missing when using custom classification.
    """

    def __init__(self):
        # Default constructor for flexibility
        pass

    def _validate_inputs(self, point, classify_place_func, all_place_types):
        # Validates the input arguments for town creation.
        # Checks classification function, place types, and geographic point
        # format.

        if not callable(classify_place_func):
            raise TypeError("`classify_place_func` must be a function.")

        if classify_place_func is not classify_place:
            if all_place_types is None:
                raise ValueError(
                    "If you pass a custom `classify_place_func`, you must also provide `all_place_types`."
                )
            elif "accommodation" not in all_place_types:
                raise ValueError(
                    "Your `all_place_types` must include 'accommodation' type buildings."
                )

        if not isinstance(point, (list, tuple)) or len(point) != 2:
            raise ValueError(
                "`point` must be a list or tuple in the format [latitude, longitude].")
        if not (-90 <= point[0] <= 90 and -180 <= point[1] <= 180):
            raise ValueError(
                "`point` values must represent valid latitude and longitude coordinates.")

    def _setup_basic_attributes(
            self,
            point,
            dist,
            town_params,
            classify_place_func,
            all_place_types):
        # Sets up core attributes for the Town object, including spatial parameters and classification settings.
        # Calculates the EPSG code for spatial projection based on the origin
        # point's latitude.

        print("[1/10] Initializing town object and parameters...")
        if all_place_types is None:
            all_place_types = [
                "accommodation", "healthcare_facility", "commercial",
                "workplace", "education", "religious", "other"
            ]

        self.all_place_types = all_place_types
        self.town_params = town_params
        self.classify_place_func = classify_place_func
        self.origin_point = point
        self.dist = dist

        print("[2/10] Calculating EPSG code...")
        utm_zone = int((point[1] + 180) / 6) + 1
        self.epsg_code = int(
            f"326{utm_zone}" if point[0] >= 0 else f"327{utm_zone}")

    def _download_osm_data(self):
        # Downloads OpenStreetMap road network and building data for the specified origin point and radius.
        # Projects the road graph and building geometries to the town's EPSG
        # coordinate system.

        print("[3/10] Downloading OSM road network and building data...")
        G_raw = ox.graph.graph_from_point(
            self.origin_point, network_type="all", dist=self.dist)
        tags = {"building": True}
        self.G_projected = ox.project_graph(G_raw)
        buildings = ox.features.features_from_point(
            self.origin_point, tags, self.dist)
        self.buildings = buildings.to_crs(epsg=self.epsg_code)

    def _process_buildings(self):
        # Processes building geometries to extract centroids and create POIs.
        # Matches each building to the nearest road node and classifies place
        # types.

        print("[4/10] Processing building geometries...")
        is_polygon = self.buildings.geometry.geom_type.isin(
            ['Polygon', 'MultiPolygon'])
        self.buildings.loc[is_polygon,
                           'geometry'] = self.buildings.loc[is_polygon,
                                                            'geometry'].centroid
        self.POI = self.buildings[self.buildings.geometry.geom_type == 'Point']

        print("[5/10] Matching building centroids to nearest road nodes...")
        self._match_buildings_to_roads()

        print("[6/10] Classifying buildings...")
        # Use the classification function passed to from_point
        self.POI['place_type'] = self.POI.apply(
            self.classify_place_func, axis=1)

        print("[7/10] Annotating road graph with place types...")
        place_type_map = self.POI.set_index(
            'nearest_node')['place_type'].to_dict()
        nx.set_node_attributes(self.G_projected, place_type_map, 'place_type')

    def _match_buildings_to_roads(self):
        # Matches each building centroid (POI) to the nearest road network node using KDTree.
        # Updates the POI DataFrame with the nearest node ID for each building.

        # Get projected coordinates of road nodes
        node_xy = {
            node: (data['x'], data['y'])
            for node, data in self.G_projected.nodes(data=True)
        }
        node_ids = list(node_xy.keys())
        node_coords = np.array([node_xy[n] for n in node_ids])

        # Build KDTree for fast nearest-neighbor queries
        tree = KDTree(node_coords)

        # Get POI coords and find nearest road nodes
        poi_coords = np.array([(geom.x, geom.y) for geom in self.POI.geometry])
        _, nearest_indices = tree.query(poi_coords)
        self.POI['nearest_node'] = [node_ids[i] for i in nearest_indices]

    def _build_spatial_network(self):
        # Filters out nodes not assigned to relevant place types and builds the spatial network.
        # Computes shortest-path distances and constructs the final town graph
        # for simulation.

        print("[8/10] Filtering out irrelevant nodes...")
        nodes_to_keep = [n for n, d in self.G_projected.nodes(data=True) if d.get(
            'place_type') is not None and d.get('place_type') != 'other']
        G_filtered = self.G_projected.subgraph(nodes_to_keep).copy()

        if len(G_filtered.nodes) == 0:
            raise ValueError(
                "No relevant nodes remain after filtering. The resulting town graph would be empty.")

        print("[9/10] Building town graph...")
        self._compute_shortest_paths(G_filtered)

    def _compute_shortest_paths(self, G_filtered):
        # Compute the shortest paths between every single pair of locations in the
        # area of interest.

        # Convert G_projected to igraph for fast distance computation
        projected_nodes = list(self.G_projected.nodes)
        node_idx_map = {node: idx for idx, node in enumerate(projected_nodes)}

        g_ig = ig.Graph(directed=False)
        g_ig.add_vertices(len(projected_nodes))

        edges = []
        weights = []

        for u, v, data in self.G_projected.edges(data=True):
            if u in node_idx_map and v in node_idx_map:
                edges.append((node_idx_map[u], node_idx_map[v]))
                weights.append(data.get("length", 1.0))

        g_ig.add_edges(edges)
        g_ig.es["weight"] = weights

        # Compute shortest paths among filtered nodes
        filtered_nodes = list(G_filtered.nodes)
        filtered_indices = [node_idx_map[n] for n in filtered_nodes]

        print("Computing shortest paths between filtered nodes...")
        dist_matrix = g_ig.distances(
            source=filtered_indices,
            target=filtered_indices,
            weights=g_ig.es["weight"])

        # Build final NetworkX town graph
        self._build_final_graph(G_filtered, filtered_nodes, dist_matrix)

    def _build_final_graph(self, G_filtered, filtered_nodes, dist_matrix):
        # Build a simplified town graph from the precalculated distances in the
        # previous step.

        self.town_graph = nx.Graph()
        id_map = {old_id: new_id for new_id,
                  old_id in enumerate(filtered_nodes)}
        self.accommodation_node_ids = []

        # Add nodes with attributes
        for old_id, new_id in id_map.items():
            place_type = G_filtered.nodes[old_id].get("place_type")
            row = self.POI[self.POI['nearest_node'] == old_id]
            x, y = (
                row.iloc[0].geometry.x, row.iloc[0].geometry.y) if not row.empty else (
                None, None)

            if place_type == "accommodation":
                self.accommodation_node_ids.append(new_id)

            self.town_graph.add_node(new_id, place_type=place_type, x=x, y=y)

        # Add edges with shortest path distances
        print("Adding edges to final town graph...")
        for i in tqdm(range(len(filtered_nodes))):
            for j in range(i + 1, len(filtered_nodes)):
                dist = dist_matrix[i][j]
                if dist != float("inf"):
                    self.town_graph.add_edge(
                        id_map[filtered_nodes[i]],
                        id_map[filtered_nodes[j]],
                        weight=dist
                    )

        self.found_place_types = set(nx.get_node_attributes(
            self.town_graph, 'place_type').values())

    def _save_files(self, file_prefix, save_dir):
        # Save the simplified town graph as a zipped file and save the configuration of
        # this area of interest as a .json file

        print("[10/10] Saving a compressed graph and config_data...")
        graphml_name = os.path.join(save_dir, f"{file_prefix}.graphml")
        graphmlz_name = os.path.join(save_dir, f"{file_prefix}.graphmlz")
        config_data_name = os.path.join(save_dir, f"{file_prefix}_config.json")

        nx.write_graphml_lxml(self.town_graph, graphml_name)

        if os.path.exists(graphmlz_name):
            overwrite = input(
                f"The file '{graphmlz_name}' already exists. Overwrite? (y/n): ").strip().lower()
            if overwrite != 'y':
                os.remove(graphml_name)
                print(
                    "Input file saving operation aborted to avoid overwriting the file. Returning town object...")
                return

        with zipfile.ZipFile(graphmlz_name, "w", zipfile.ZIP_DEFLATED) as zf:
            zf.write(graphml_name, arcname="graph.graphml")

        time.sleep(0.1)
        os.remove(graphml_name)

        # Save config_data
        config_data = {
            "origin_point": [float(self.origin_point[0]), float(self.origin_point[1])],
            "dist": self.dist,
            "epsg_code": int(self.epsg_code),
            "all_place_types": self.all_place_types,
            "found_place_types": list(self.found_place_types),
            "accommodation_nodes": list(self.accommodation_node_ids),
        }
        with open(config_data_name, "w") as f:
            json.dump(config_data, f, indent=2)

    def _finalize_town_setup(self):
        # Finalize the town object such that it is ready to be used in the
        # simulation.

        # Initialize folks list for all nodes
        for node in self.town_graph.nodes:
            self.town_graph.nodes[node]["folks"] = []

        # Validate spreader nodes
        missing_nodes = [
            node for node in self.town_params.spreader_initial_nodes
            if node not in self.town_graph.nodes
        ]
        if missing_nodes:
            raise ValueError(
                f"Some spreader_initial_nodes do not exist in the town graph: {missing_nodes}")

    @classmethod
    def from_point(
        cls,
        point,
        dist,
        town_params,
        classify_place_func=classify_place,
        all_place_types=None,
        file_prefix="town_graph",
        save_dir="."
    ):
        """
        Create a town network from OpenStreetMap data centered on a geographic point.

        Downloads road network and building data from OpenStreetMap, processes building
        geometries, classifies places by type, and constructs a simplified graph with
        pre-computed shortest-path distances between all locations.

        Parameters
        ----------
        point : list or tuple
            Geographic coordinates [latitude, longitude] defining the center point
            for data extraction.
        dist : float
            Radius in meters around the point to extract data. Defines the spatial
            extent of the town network.
        town_params : TownParameters
            Configuration object containing population size, initial spreader count,
            and spreader node locations.
        classify_place_func : callable, optional
            Function to classify building types into place categories. Must accept
            a pandas row and return a place type string (default: classify_place).
        all_place_types : list, optional
            List of all possible place type categories. Required when using custom
            classify_place_func (default: None).
        file_prefix : str, optional
            Prefix for output files (default: "town_graph").
        save_dir : str, optional
            Directory to save compressed graph and configuration files (default: ".").

        Returns
        -------
        Town
            Town object with populated spatial network and config_data.

        Raises
        ------
        ValueError
            If point coordinates are invalid, no relevant nodes remain after
            filtering, or spreader nodes don't exist in the network.
        TypeError
            If classify_place_func is not callable or required parameters are missing.

        Examples
        --------
        >>> town_params = TownParameters(num_pop=1000, num_init_spreader=5)
        >>> town = Town.from_point(
        ...     point=[50.7753, 6.0839],  # Aachen Dom
        ...     dist=1000,
        ...     town_params=town_params,
        ...     file_prefix="aachen_dom",
        ...     save_dir="./data"
        ... )
        """
        town = cls()
        town._validate_inputs(point, classify_place_func, all_place_types)
        town._setup_basic_attributes(
            point, dist, town_params, classify_place_func, all_place_types)
        town._download_osm_data()
        town._process_buildings()
        town._build_spatial_network()
        town._save_files(file_prefix, save_dir)
        town._finalize_town_setup()

        print("Town graph successfully built and saved!")
        return town

    @classmethod
    def from_files(cls, config_path, town_graph_path, town_params):
        """
        Load a previously saved town network from compressed files.

        Reconstructs a Town object from GraphMLZ and JSON configuration files created
        by a previous call to from_point(). This method enables rapid simulation
        setup without re-downloading or re-processing OpenStreetMap data.

        Parameters
        ----------
        config_path : str
            Path to the JSON configuration file containing town configuration and
            place type information.
        town_graph_path : str
            Path to the compressed GraphMLZ file containing the spatial network.
        town_params : TownParameters
            Configuration object containing population size, initial spreader count,
            and spreader node locations for the simulation.

        Returns
        -------
        Town
            Town object with loaded spatial network and config_data.

        Raises
        ------
        ValueError
            If spreader nodes specified in town_params don't exist in the
            loaded network.
        FileNotFoundError
            If the specified files don't exist.

        Examples
        --------
        >>> town_params = TownParameters(num_pop=1000, num_init_spreader=5)
        >>> town = Town.from_files(
        ...     config_path="./data/aachen_dom_config.json",
        ...     town_graph_path="./data/aachen_dom.graphmlz",
        ...     town_params=town_params
        ... )
    """
        # 1. Unzip the graphmlz to a temp folder
        print("[1/3] Decompressing the graphmlz file...")
        with tempfile.TemporaryDirectory() as tmpdirname:
            with zipfile.ZipFile(town_graph_path, 'r') as zf:
                zf.extractall(tmpdirname)
                graphml_path = os.path.join(tmpdirname, "graph.graphml")
                G = nx.read_graphml(graphml_path)
                G = nx.relabel_nodes(G, lambda x: int(x))

        # 2. Load config_data
        print("[2/3] Load the config_data...")
        with open(config_path, "r") as f:
            config_data = json.load(f)

        # 3. Rebuild Town object
        print("[3/3] Rebuild the town object...")
        town = cls()
        town.town_graph = G
        town.town_params = town_params
        town.epsg_code = config_data["epsg_code"]
        town.origin_point = config_data["origin_point"]
        town.dist = config_data["dist"]
        town.all_place_types = config_data["all_place_types"]
        town.found_place_types = config_data["found_place_types"]
        town.accommodation_node_ids = config_data["accommodation_nodes"]

        town._finalize_town_setup()

        print("Town graph successfully built from input files!")
        return town

    def save_to_files(self, file_prefix, overwrite=False):
        """
        Save this Town object to GraphML and config files.

        Parameters
        ----------
        file_prefix : str
            Prefix for the output files (will create {prefix}.graphmlz and {prefix}_config.json)
        overwrite : bool, default False
            Whether to overwrite existing files

        Returns
        -------
        tuple[str, str]
            (graphml_path, config_path) - paths to the created files
        """

        # Generate file paths
        graphml_path = f"{file_prefix}.graphmlz"
        config_path = f"{file_prefix}_config.json"

        # Check if files exist and handle overwrite
        if not overwrite:
            if os.path.exists(graphml_path):
                raise FileExistsError(
                    f"GraphMLZ file already exists: {graphml_path}. Set overwrite=True to replace.")
            if os.path.exists(config_path):
                raise FileExistsError(
                    f"Config file already exists: {config_path}. Set overwrite=True to replace.")

        # Save as .graphml first
        temp_graphml_path = f"{file_prefix}.graphml"
        town_graph_copy = copy.deepcopy(self.town_graph)

        # Since we can't save list in GraphML form, we need to remove
        for node in town_graph_copy.nodes:
            del town_graph_copy.nodes[node]["folks"]

        nx.write_graphml(town_graph_copy, temp_graphml_path)

        # Now compress it into a .zip file with .graphmlz extension
        with zipfile.ZipFile(graphml_path, "w", zipfile.ZIP_DEFLATED) as zf:
            zf.write(temp_graphml_path, arcname="graph.graphml")

        # Remove the uncompressed .graphml file
        os.remove(temp_graphml_path)

        config_data = {
            "origin_point": [float(self.origin_point[0]), float(self.origin_point[1])],
            "dist": self.dist,
            "epsg_code": int(self.epsg_code),
            "all_place_types": self.all_place_types,
            "found_place_types": list(self.found_place_types),
            "accommodation_nodes": list(self.accommodation_node_ids),
        }

        # Save config file
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(config_data, f, indent=2, ensure_ascii=False)

        return graphml_path, config_path


__author__ = 'Warisa Roongaraya'
__email__ = 'compund555@gmail.com'
__version__ = '0.1.0'

from .sim import Simulation
from .town import Town, TownParameters
from . import infection_models
from . import visualization


import random as rd

from .step_event import EventType, StepEvent


class AbstractModelParameters:
    """
    Base class for infection model parameters.

    This abstract class defines the common interface for all infection model
    parameter classes. It provides basic energy management and requires subclasses
    to implement configuration serialization for simulation persistence.

    Parameters
    ----------
    max_energy : int
        The maximum energy for an agent. This number limits the maximum number
        of events an agent can attend in a day.

    Attributes
    ----------
    max_energy : int
        Maximum social energy value for agents in the simulation.
    """

    def __init__(self, max_energy):
        """
        Initialize model parameters.

        Parameters
        ----------
        max_energy : int
            The maximum energy for an agent. This number limits the maximum number of events an agent can attend in a day.

        Raises
        ------
        AssertionError
            If max_energy is not a positive integer.
        """
        assert isinstance(
            max_energy, int) and max_energy > 0, "max_energy must be a positive integer!"
        self.max_energy = max_energy

    def to_config_dict(self):
        """
        Convert model parameters to a dictionary for configuration serialization.

        This abstract method must be implemented by subclasses to enable saving
        and loading of simulation configurations. The returned dictionary should
        contain all parameter values needed to reconstruct the model.

        Raises
        ------
        NotImplementedError
            Always raised in the base class. Subclasses must override this method.
        """
        raise NotImplementedError(
            "Subclasses must implement to_config_dict()")


class AbstractFolk:
    """
    Agent class representing individuals in the simulation.

    AbstractFolk objects represent individual agents that move through the town network,
    interact with other agents, and undergo status transitions according to
    infection model rules. Each agent has energy, status, location, and
    behavioral attributes that influence their participation in simulation events.

    Parameters
    ----------
    id : int
        Unique identifier for the agent.
    home_address : int
        Node index of the agent's home location in the town network.
    max_energy : int
        Maximum social energy. Limits the number of events an agent can attend daily.
    status : str
        Initial infection status of the agent (e.g., 'S', 'I', 'R').

    Attributes
    ----------
    id : int
        Unique agent identifier.
    home_address : int
        Home node index in the town network.
    address : int
        Current location node index (initially set to home_address).
    max_energy : int
        Maximum daily social energy.
    energy : int
        Current social energy (randomly initialized between 0 and max_energy).
    status : str
        Current infection status.
    status_step_streak : int
        Number of consecutive timesteps in current status.
    movement_restricted : bool
        Whether agent movement is restricted (e.g., quarantine).
    alive : bool
        Whether the agent is alive and active in the simulation.
    priority_place_type : list
        List of place types the agent prioritizes for visits.
    """

    def __init__(self, id, home_address, max_energy, status):
        """Initialize a AbstractFolk agent."""
        self.id = id
        self.home_address = home_address
        self.address = self.home_address
        self.max_energy = max_energy
        self.energy = rd.randint(0, max_energy)
        self.status = status
        self.status_step_streak = 0
        self.movement_restricted = False
        self.alive = True
        self.priority_place_type = []

    def convert(self, new_stat, status_dict_t):
        """
        Change the agent's status and update the status counts.

        Parameters
        ----------
        new_stat : str
            The new status to assign.
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        """
        assert self.status != new_stat, f"New status cannot be the same as the old status({new_stat})! Please review your transition rules!"
        assert status_dict_t[
            self.status] > 0, f"Attempting to decrement {
            self.status} below zero!"
        status_dict_t[self.status] -= 1
        status_dict_t[new_stat] += 1
        self.status = new_stat
        self.status_step_streak = 0

    def inverse_bernoulli(self, contact_possibility, conversion_prob):
        """
        Calculate the probability of status transition given contact possibility and conversion probability.

        This function is adapted from section 2.2 of
        Eden, M., Castonguay, R., Munkhbat, B., Balasubramanian, H., & Gopalappa, C. (2021).
        Agent-based evolving network modeling: A new simulation method for modeling low prevalence infectious diseases.
        Health Care Management Science, 24, 623–639. https://link.springer.com/article/10.1007/s10729-021-09558-0

        Parameters
        ----------
        contact_possibility : int
            Number of possible contacts.
        conversion_prob : float
            Probability of conversion per contact.

        Returns
        -------
        float
            Probability of at least one successful conversion.
        """
        return 1 - (1 - conversion_prob)**(contact_possibility)

    def sleep(self):
        """Reset the agent's energy and increment the status streak (called at the end of a day)."""
        self.status_step_streak += 1
        self.energy = rd.randint(0, self.max_energy)  # Reset social energy

    def clear_previous_event_effect(self):
        """
        Reset or update agent attributes following step events.

        This method is called at the end of each simulation step to clean up
        temporary state changes caused by events and ensure the agent is
        properly prepared for the next step. Subclasses should implement this
        method to handle model-specific attribute resets.

        Returns
        -------
        None
        """
        pass


class AbstractInfectionModel:
    """
    Abstract base class for all infection epidemic models.

    This class provides the foundation for implementing infection models
    (e.g., SIR, SEIR, SEIQRDV) in agent-based simulations. It handles agent
    creation, step event management, population initialization, and defines
    the interface that all infection models must implement.

    Parameters
    ----------

    model_params : AbstractModelParameters
        Model-specific parameters object containing simulation configuration.

    Attributes
    ----------
    model_params : AbstractModelParameters
        Configuration parameters for the model.
    step_events : list of StepEvent
        Sequence of events that occur during each simulation timestep.
    infected_statuses : list
        List of status strings considered infectious (must be defined by subclasses).
    all_statuses : list
        Complete list of all possible agent statuses (must be defined by subclasses).
    required_place_types : set
        Set of place types required by the model (includes 'accommodation', 'commercial').
    folk_class : class
        The Folk class or subclass used to create agents (must be defined by subclasses).

    Notes
    -----
    - Subclasses must define 'infected_statuses', 'all_statuses', and 'folk_class'
      before calling the parent constructor.
    - An 'end_day' event is automatically appended to step_events if not provided.
    - Default step_events include neighborhood greeting and commercial activities.
    """

    def __init__(self, model_params):
        """
        Initialize the abstract infection model.

        Parameters
        ----------

        model_params : AbstractModelParameters
            Model parameters object.
            Raises

        Raise
        ------

        NotImplementedError
            If subclass doesn't define required attributes (infected_statuses, all_statuses).
        TypeError
            If step_events contains invalid objects or folk_action methods are not callable.
        ValueError
            If the model has fewer than 3 different statuses.
        """
        self.model_params = model_params

        # If step_events is not set, use default events
        if not hasattr(self, "step_events") or self.step_events is None:
            self.step_events = [
                StepEvent(
                    "greet_neighbors",
                    self.folk_class.interact,
                    EventType.DISPERSE,
                    5000,
                    ['accommodation']),
                StepEvent(
                    "chore",
                    self.folk_class.interact,
                    EventType.DISPERSE,
                    19000,
                    [
                        'commercial'])
            ]
        else:
            # Check that step_events is a StepEvent or list of StepEvent
            # objects
            if isinstance(self.step_events, StepEvent):
                self.step_events = [self.step_events]
            elif isinstance(self.step_events, list):
                if not all(isinstance(ev, StepEvent)
                           for ev in self.step_events):
                    raise TypeError(
                        "step_events must be a StepEvent or a list of StepEvent objects")
            else:
                raise TypeError(
                    "step_events must be a StepEvent or a list of StepEvent objects")

            for event in self.step_events:
                if not callable(event.folk_action):
                    raise TypeError(
                        f"folk_action in StepEvent '{
                            event.name}' must be callable")
                # Print folk_class and the class of event.folk_action for debugging
                # Check if the function is a method of self.folk_class
                if not any(
                    event.folk_action is func for name, func in vars(
                        self.folk_class).items() if callable(func)):
                    raise TypeError(
                        f"folk_action in StepEvent '{
                            event.name}' must be a method of the folk_class '{
                            self.folk_class.__name__}'")

        # This is an important check and it will ONLY work when you define
        # some of the attributes before calling the abstract level constructor
        # See SEIsIrR for an example of how to write a constructor.
        required_attrs = {
            'infected_statuses': "Subclasses must define 'infected_statuses'.",
            'all_statuses': "Subclasses must define 'all_statuses' with at least 3 statuses."}

        for attr, message in required_attrs.items():
            if not hasattr(self, attr):
                raise NotImplementedError(message)

        if not hasattr(self, 'required_place_types'):
            self.required_place_types = set()
        self.required_place_types.update(['accommodation', 'commercial'])

        # Status is also actually plural of a status but for clarity that this is plural,
        # the software will stick with the commonly used statuses
        if len(self.all_statuses) < 3:
            raise ValueError(
                "A infection model must consist of at least 3 different statuses.")

        # Append end_day event to the existing day events given by the user
        end_day = StepEvent("end_day", self.folk_class.sleep)
        if not any(
            isinstance(ev, StepEvent) and getattr(
                ev, "name", None) == "end_day"
            for ev in self.step_events
        ):
            self.step_events.append(end_day)

    def create_folk(self, *args, **kwargs):
        """
        Create a new AbstractFolk agent using the model's folk_class.

        Returns
        -------

        AbstractFolk
            A new AbstractFolk agent instance of a given folk_class.
        """
        return self.folk_class(*args, **kwargs)

    def initialize_sim_population(self, town):
        """
        Initialize simulation population data structures and validate spreader configuration.

        This method sets up the basic data structures needed for population initialization
        and validates that the spreader configuration is valid. It prepares containers
        for agent creation and household assignment.

        Parameters
        ----------

        town : Town
            The town network where agents will be placed.

        Returns
        -------
        tuple
            Contains (num_pop, num_init_spreader, num_init_spreader_rd, folks,
            household_node_indices, assignments) where:
            - num_pop : int - Total population size
            - num_init_spreader : int - Total number of initial spreaders
            - num_init_spreader_rd : int - Number of randomly placed spreaders
            - folks : list - Empty list for agent objects
            - household_node_indices : set - Empty set for household node tracking
            - assignments : list - Empty list for agent assignments

        Notes
        -----
        This method only initializes data structures and validates configuration.
        Actual agent creation and placement is handled by the Simulation class.
        """
        num_init_spreader_nodes = len(town.town_params.spreader_initial_nodes)

        num_init_spreader = town.town_params.num_init_spreader
        num_pop = town.town_params.num_pop
        num_init_spreader_rd = num_init_spreader - num_init_spreader_nodes

        folks = []
        household_node_indices = set()
        assignments = []

        return num_pop, num_init_spreader, num_init_spreader_rd, folks, household_node_indices, assignments

    def update_population(
            self,
            folks,
            town,
            household_node_indices,
            status_dict_t):
        """
        Update the simulation population (e.g., add or remove agents).

        This method is called at the end of each day. By default, it does nothing.
        Subclasses can override this method to implement population growth, death, or migration.

        Parameters
        ----------
        folks : list of AbstractFolk
            The current list of AbstractFolk agent objects in the simulation.
        town : Town
            The Town object representing the simulation environment.
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.

        Returns
        -------
        int
            An updated number of overall population
        """
        return len(folks)


"""
This module implements the SEIQRDV infection model for epidemic simulation with vaccination.

The implementation is based on:

Ghostine, R., Gharamti, M., Hassrouny, S., & Hoteit, I. (2021).
An extended SEIR model with vaccination for forecasting the COVID-19 pandemic
in Saudi Arabia using an ensemble Kalman filter.
*Mathematics*, 9(6), 636. https://doi.org/10.3390/math9060636

The SEIQRDV model extends the classic SEIR framework by incorporating three additional
compartments: Quarantine (Q) for isolated infectious individuals, Death (D) for
disease-related mortality, and Vaccination (V) for immunized individuals. This model
is particularly suited for simulating epidemics where quarantine measures, vaccination
campaigns, and mortality dynamics are critical factors in disease progression and
public health intervention strategies.
"""

import random as rd

from .abstract_model import (AbstractInfectionModel, AbstractFolk,
                             AbstractModelParameters)


class SEIQRDVModelParameters(AbstractModelParameters):
    """
    Model parameters for the SEIQRDV infection model.

    This class encapsulates all tunable parameters required for the SEIQRDV
    infection model, including epidemiological rates, probabilities, and
    healthcare system constraints. It validates parameter types and ranges
    upon initialization.
    """

    def __init__(
            self,
            max_energy,
            lam_cap,
            beta,
            alpha,
            gamma,
            delta,
            lam,
            rho,
            kappa,
            mu,
            hospital_capacity=float('inf')):
        """
        Initialize SEIQRDV model parameters and validate all inputs.

        Parameters
        ----------
        max_energy : int
            Maximum energy for each agent (must be a positive integer).
        lam_cap : float
            Rate of new population due to birth or migration (must be between 0 and 1).
        beta : float
            Transmission probability (must be between 0 and 1).
        alpha : float
            Vaccination rate (must be between 0 and 1).
        gamma : int
            Average latent time (must be a positive integer).
        delta : int
            Average days until the infected case is confirmed and quarantined (must be a positive integer).
        lam : int
            Average days until recovery for quarantined agents (must be a positive integer).
        rho : int
            Average days until death for quarantined agents (must be a positive integer).
        kappa : float
            Disease mortality rate (must be between 0 and 1).
        mu : float
            Natural background death rate (must be between 0 and 1).
        hospital_capacity : int or float, optional
            Average number of people a healthcare facility can vaccinate per event.
            Must be a positive integer or float('inf') for unlimited capacity (default: float('inf')).

        Raises
        ------
        TypeError
            If any parameter is not of the correct type or out of valid range.
        """
        for name, value in zip(
            ['lam_cap', 'beta', 'alpha', 'gamma', 'delta', 'lam',
                'rho', 'kappa', 'mu', 'hospital_capacity'],
            [lam_cap, beta, alpha, gamma, delta, lam,
                rho, kappa, mu, hospital_capacity]
        ):
            if name in ['lam_cap', 'beta', 'kappa', 'alpha', 'mu']:
                if not isinstance(
                        value, (float, int)) or not (
                        0 <= value <= 1):
                    raise TypeError(
                        f"{name} must be a float between 0 and 1!")
            elif name == 'hospital_capacity':
                if not isinstance(value, int) and value != float('inf'):
                    raise TypeError(
                        f"{name} must be a positive integer or a value of infinity, got {value}")
            else:
                if not isinstance(value, int) or value <= 0:
                    raise TypeError(
                        f"{name} must be a positive integer, got {value}")

        super().__init__(max_energy)

        # Rate of new population due to birth or migration etc.
        self.lam_cap = lam_cap
        self.beta = beta  # Transmission probability
        self.alpha = alpha  # Vaccination rate
        self.gamma = gamma  # Average latent time
        self.delta = delta  # Average day until the infected case got confirmed and quarantined
        self.lam = lam  # Average day until recovery
        self.rho = rho  # Average day until death
        self.kappa = kappa  # Disease mortality rate
        self.mu = mu  # Natural back ground death rate
        # Average number of people a healthcare facility can contain
        self.hospital_capacity = hospital_capacity

    def to_config_dict(self):
        """
        Convert SEIQRDV model parameters to a dictionary for configuration serialization.

        Returns
        -------
        dict
            Dictionary containing all model parameters as key-value pairs.
        """
        return {
            'max_energy': self.max_energy,
            'lam_cap': self.lam_cap,
            'beta': self.beta,
            'alpha': self.alpha,
            'gamma': self.gamma,
            'delta': self.delta,
            'lam': self.lam,
            'rho': self.rho,
            'kappa': self.kappa,
            'mu': self.mu,
            'hospital_capacity': self.hospital_capacity
        }


class FolkSEIQRDV(AbstractFolk):
    """
    Agent class for the SEIQRDV infection model with vaccination and mortality dynamics.
    FolkSEIQRDV agents extend the basic AbstractFolk with two critical attributes for epidemic modeling:
    `will_die` and `want_vaccine`. The `will_die` attribute is probabilistically set when an agent enters
    quarantine and determines their eventual outcome (recovery or death),
    reflecting the stochastic nature of disease severity. The `want_vaccine` attribute models vaccination-seeking behavior,
    where susceptible agents can spontaneously decide to seek vaccination based on the model's `alpha` parameter,
    creating realistic vaccine demand patterns. These agents exhibit complex behavioral dynamics including healthcare-seeking
    movement (prioritizing healthcare facilities when `want_vaccine` is True), quarantine compliance
    (restricted movement when infectious), and status-dependent interaction patterns.
    The vaccination system implements a queue-based mechanism at healthcare facilities with capacity constraints,
    ensuring fair vaccine distribution while maintaining epidemiological realism.
    Additionally, agents undergo natural aging and mortality processes independent of disease status, allowing for
    comprehensive population dynamics that include births, deaths, migration, and demographic changes throughout the simulation period.
    """

    def __init__(self, id, home_address, max_energy, status):
        """
        Initialize a FolkSEIQRDV agent with 2 more attributes than the standard AbstractFolk.
        The first one being will_die which plays a role in determining if the infected agent
        will pass away or not. The second one, want_vaccine, signifies the agent's will to
        get vaccinated. An agent with this attribute == True will try to get vaccinated at
        their nearest healthcare facility.

        Parameters
        ----------
        id : int
            Unique identifier for the agent.
        home_address : int
            Node index of the agent's home.
        max_energy : int
            Maximum social energy.
        status : str
            Initial status of the agent.
        """
        super().__init__(id, home_address, max_energy, status)
        self.will_die = False
        self.want_vaccine = False

    def inverse_bernoulli(self, folks_here, conversion_prob, stats):
        """
        Calculate the probability of status transition given contact with specific statuses.

        Parameters
        ----------
        folks_here : list of FolkSEIQRDV
            List of FolkSEIQRDV agents present at the same node.
        conversion_prob : float
            Probability of conversion per contact.
        stats : list of str
            List of statuses to consider as infectious.

        Returns
        -------
        float
            Probability of at least one successful conversion.
        """
        # beta * I / N is the non-linear term that defines conversion
        # This inverse bernoulli function is an interpretation of the term
        # in agent-based modeling
        num_contact = len(
            [folk for folk in folks_here if folk != self and folk.status in stats])
        return super().inverse_bernoulli(num_contact, conversion_prob / len(folks_here))

    def interact(
            self,
            folks_here,
            current_place_type,
            status_dict_t,
            model_params,
            dice):
        """
        Perform interaction with other agents in the area and the environment for this agent.

        Transition Rules
        ----------------

        - If the agent is Susceptible ('S'):

            - If the agent comes into contact with at least one Infectious ('I') agent at the same node,
            the probability of becoming Exposed ('E') is calculated using the inverse Bernoulli formula with
            the transmission probability (`beta`). If this probability exceeds the random value `dice`,
            the agent transitions to Exposed ('E').

        - If the agent is Susceptible ('S'), wants a vaccine, and is at a healthcare facility:

            - If the number of agents at the facility wanting a vaccine is less than the hospital capacity,
            the agent transitions to Vaccinated ('V') and `want_vaccine` is set to False.

        Parameters
        ----------
        folks_here : list of FolkSEIQRDV
            List of FolkSEIQRDV agents present at the same node.
        current_place_type : str
            The type of place where the interaction occurs.
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        model_params : SEIQRDVModelParameters
            Model parameters for the simulation.
        dice : float
            Random float for stochastic transitions.

        Returns
        -------
        None
        """
        # When a susceptible person comes into contact with an infectious person,
        # they have a likelihood to become exposed to the disease
        if self.status == 'S' and self.inverse_bernoulli(
                folks_here, model_params.beta, ['I']) > dice:
            self.convert('E', status_dict_t)

        if current_place_type == 'healthcare_facility':
            # Vaccine is only effective for susceptible people but anyone who
            # wants it can queue up
            want_vaccine_list = [
                folk for folk in folks_here if folk.want_vaccine]
            if self in want_vaccine_list and self.status == 'S':
                idx = want_vaccine_list.index(self)
                if idx < model_params.hospital_capacity:
                    self.convert('V', status_dict_t)

    def sleep(
            self,
            folks_here,
            current_place_type,
            status_dict_t,
            model_params,
            dice):
        """
        Perform end-of-day updates and state transitions for this agent.

        This method handles all status progressions and transitions that occur at the end of a simulation day,
        including quarantine outcomes, recovery, death, infection progression, and vaccination planning.

        Transition Rules
        ----------------
        - **If the agent is in Quarantine ('Q'):**

            - If `will_die` is True and the agent has been in quarantine for `rho` days,
            the agent transitions to Dead ('D'), is marked as not alive, and `want_vaccine` is set to False.

            - If `will_die` is False and the agent has been in quarantine for `lam` days,
            the agent transitions to Recovered ('R') and their movement restriction is lifted.

        - **If the agent is Exposed ('E')** and has been exposed for `gamma` days,
        they transition to Infectious ('I').

        - **If the agent is Infectious ('I')** and has been infectious for `delta` days,
        their symptoms are confirmed and they must quarantine. They transition to Quarantine ('Q'),
        their movement is restricted, `want_vaccine` is set to False, and with probability `kappa`
        they are marked to die (`will_die = True`).

        - **If the agent is Susceptible ('S')** and a random draw is less than `alpha`,
        they plan to get vaccinated by setting `want_vaccine` to True.

        - **If the agent is Vaccinated ('V')**, their `want_vaccine` attribute is reset to False
        at the end of the day to ensure correct vaccine queue handling during the next day's events.

        - **For any agent with `want_vaccine = True`**, 'healthcare_facility' is added to their
        priority place types to guide movement toward vaccination sites.

        Parameters
        ----------

        folks_here : list of FolkSEIQRDV
            List of agents present at the same node (not used in this method, for interface compatibility).
        current_place_type : str
            Type of place where the agent is sleeping (not used in this method, for interface compatibility).
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        model_params : SEIQRDVModelParameters
            Model parameters for the simulation.
        dice : float
            Random float for stochastic transitions.

        Returns
        -------

        None

        Notes
        -----

        The `want_vaccine` attribute is reset to False in `sleep()` rather than immediately after
        vaccination in `interact()` to maintain queue integrity. If reset during `interact()`,
        it would modify the vaccination queue while agents are still being processed, potentially
        causing some agents to be skipped or processed incorrectly. Deferring the reset ensures
        fair and consistent vaccination queue processing.
        """
        super().sleep()
        if self.status == 'Q':
            if self.will_die:
                if self.status_step_streak == model_params.rho:
                    self.convert('D', status_dict_t)
                    self.want_vaccine = False
                    self.alive = False
            else:
                if self.status_step_streak == model_params.lam:
                    self.convert('R', status_dict_t)
                    self.movement_restricted = False
        elif self.status == 'E' and self.status_step_streak == model_params.gamma:
            self.convert('I', status_dict_t)
        elif self.status == 'I' and self.status_step_streak == model_params.delta:
            self.convert('Q', status_dict_t)
            self.movement_restricted = True
            self.want_vaccine = False
            if dice < model_params.kappa:
                self.will_die = True
        elif self.status == 'S':
            # We only apply the rate of planning to get vaccination on
            # susceptible agents
            if model_params.alpha > dice:
                self.want_vaccine = True

        if self.want_vaccine:
            self.priority_place_type.append('healthcare_facility')

    def clear_previous_event_effect(self):
        """
        Reset vaccination-related attributes following step events.

        This method updates vaccination-seeking behavior attributes after events
        to maintain consistent state. It performs two key functions:

        1. For vaccinated agents: Clears the 'want_vaccine' flag since
        they've already received vaccination

        2. For other agents seeking vaccination: Ensures 'healthcare_facility'
        remains in their priority places for the next day's movement

        The method specifically handles:
        - Resetting vaccination desire for already-vaccinated agents (status 'V')
        - Maintaining healthcare facility priority for agents still seeking vaccination
        (except those in recovered, dead, or quarantined states who cannot benefit)

        Returns
        -------
        None
        """
        if self.want_vaccine:
            if self.status == 'V':
                self.want_vaccine = False
            elif self.status not in ['R', 'D', 'Q']:
                self.priority_place_type.append('healthcare_facility')


class SEIQRDVModel(AbstractInfectionModel):
    """
    SEIQRDV infection model implementation for epidemic simulation with vaccination.

    The SEIQRDV model extends the classic SEIR model by adding three additional compartments:
    Quarantine (Q), Death (D), and Vaccination (V). This model is particularly suited for
    simulating disease outbreaks where quarantine measures, vaccination campaigns, and
    mortality are important factors.
    """

    def __init__(self, model_params, step_events=None):
        self.folk_class = FolkSEIQRDV
        self.all_statuses = (['S', 'E', 'I', 'Q', 'R', 'D', 'V'])
        self.infected_statuses = ['I', 'E', 'Q']
        self.required_place_types = set(
            ['healthcare_facility'])
        self.step_events = step_events
        super().__init__(model_params)

    def initialize_sim_population(self, town):
        """
        Initialize the simulation population and their assignments.

        This method assigns initial statuses and home locations to all agents in the simulation,
        including initial spreaders (both randomly assigned and those at specified nodes) and susceptible agents.
        It also creates agent objects, updates the town graph with agent assignments, and tracks household nodes.

        Parameters
        ----------
        town : Town
            The Town object representing the simulation environment.

        Returns
        -------
        tuple
            (folks, household_node_indices, status_dict_t0)

            - folks : list of FolkSEIQRDV
                List of all agent objects created for the simulation.

            - household_node_indices : set
                Set of node indices where households are tracked.

            - status_dict_t0 : dict
                Dictionary with the initial count of each status at timestep 0.
        """
        num_pop, num_init_spreader, num_init_spreader_rd, folks, household_node_indices, assignments = super(
        ).initialize_sim_population(town)

        # Randomly assign initial spreaders (not on specified nodes)
        for i in range(num_init_spreader_rd):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'I'))

        # Assign the rest as susceptible
        for i in range(num_pop - num_init_spreader):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'S'))

        # Assign initial spreaders to specified nodes
        for node in town.town_params.spreader_initial_nodes:
            assignments.append((node, 'I'))

        # Create folks and update graph/node information
        for i, (node, status) in enumerate(assignments):
            folk = self.create_folk(
                i, node, self.model_params.max_energy, status)
            if status == 'S' and rd.random() < self.model_params.alpha:
                folk.priority_place_type.append('healthcare_facility')
                folk.want_vaccine = True
            folks.append(folk)
            town.town_graph.nodes[node]["folks"].append(folk)
            if len(town.town_graph.nodes[node]["folks"]) == 2:
                household_node_indices.add(node)

        status_dict_t0 = {
            'current_event': None,
            'timestep': 0,
            'S': num_pop - num_init_spreader,
            'E': 0,
            'Q': 0,
            'I': num_init_spreader,
            'R': 0,
            'D': 0,
            'V': 0
        }
        return folks, household_node_indices, status_dict_t0

    def update_population(
            self,
            folks,
            town,
            household_node_indices,
            status_dict_t):
        """
        Update the simulation population at the end of each day.

        This function performs two main operations:
        1. **Natural Deaths:** Iterates through all currently alive agents and, with probability `mu` (the natural death rate), transitions them to the 'D' (Dead) status and marks them as not alive.

        2. **Population Growth:** Calculates the number of possible new agents to add based on the current alive population and the parameter `lam_cap` (birth/migration rate). For each new agent:
            - Randomly selects an accommodation node as their home.
            - Randomly assigns a status from all possible statuses except 'D' (Dead) and 'Q' (Quarantine).
            - Adds the new agent to the simulation, updates the status count, and tracks their household node.

        Parameters
        ----------
        folks : list of FolkSEIQRDV
            The current list of FolkSEIQRDV agent objects in the simulation.
        town : Town
            The Town object representing the simulation environment.
        household_node_indices : set
            Set of node indices where households are tracked.
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.

        Returns
        -------
        int
            The updated total number of agents in the simulation after deaths and births/migration.
        """

        num_current_pop = len(folks)
        folks_alive = [folk for folk in folks if folk.alive]
        num_current_folks = len(folks_alive)
        # Account for death by natural causes here
        for folk in folks_alive:
            if rd.random() < self.model_params.mu:
                folk.convert('D', status_dict_t)
                folk.alive = False
        num_possible_new_folks = num_current_folks * self.model_params.lam_cap
        if num_possible_new_folks > 1:
            num_possible_new_folks = round(num_possible_new_folks)
            for i in range(num_possible_new_folks):
                node = rd.choice(town.accommodation_node_ids)
                stat = rd.choice(
                    [s for s in self.all_statuses if s not in ('D', 'Q')])
                folk = self.create_folk(
                    num_current_pop + i, node, self.model_params.max_energy, stat)

                status_dict_t[stat] += 1
                folks.append(folk)
                # Account for which folks live where in the graph as well
                town.town_graph.nodes[node]["folks"].append(folk)

                # Track which node has a 'family' living in it
                if len(town.town_graph.nodes[node]["folks"]) == 2:
                    # Add operation and set() data structure ensures that there
                    # is no duplicate
                    household_node_indices.add(node)

        return len(folks)


"""
This module implements the SEIR infection model for epidemic simulation with loss of immunity.

The implementation follows one of the infection models described in:

Kerr, C. C., Stuart, R. M., Mistry, D., Abeysuriya, R. G., Rosenfeld, K., Hart, G. R.,
Núñez, R. C., Cohen, J. A., Selvaraj, P., Hagedorn, B., George, L., Jastrzębski, M.,
Izzo, A. S., Fowler, G., Palmer, A., Delport, D., Scott, N., Kelly, S. L., Bennette, C. S.,
Wagner, B., Chang, S. T., Oron, A. P., Wenger, E. A., Panovska-Griffiths, J.,
Famulare, M., & Klein, D. J. (2021). Covasim: An agent-based model of COVID-19 dynamics
and interventions. *PLOS Computational Biology*, 17(7), e1009149.
https://doi.org/10.1371/journal.pcbi.1009149

The SEIR model represents disease progression through four compartments: Susceptible (S)
individuals who can contract the disease, Exposed (E) individuals who are infected but not
yet infectious, Infectious (I) individuals who can transmit the disease, and Recovered (R)
individuals who have immunity. This implementation includes waning immunity, where recovered
individuals return to susceptible status after a specified duration, making it suitable for
modeling diseases with temporary immunity like influenza or seasonal coronaviruses.
"""

import random as rd

from .abstract_model import (AbstractInfectionModel, AbstractFolk,
                             AbstractModelParameters)


class SEIRModelParameters(AbstractModelParameters):
    """
    Model parameters for the SEIR infection model.

    This class encapsulates all tunable parameters required for the SEIR
    infection model, including transmission rates and duration parameters.
    It validates parameter types and ranges upon initialization.
    """

    def __init__(self, max_energy, beta, sigma, gamma, xi):
        """
        Initialize SEIR model parameters and validate all inputs.

        Parameters
        ----------
        max_energy : int
            Maximum energy for each agent (must be a positive integer).
        beta : float
            Transmission probability (must be between 0 and 1, exclusive).
        sigma : int
            Incubation duration in days (must be a positive integer).
        gamma : int
            Symptom duration in days (must be a positive integer).
        xi : int
            Immune duration in days (must be a positive integer).

        Raises
        ------
        TypeError
            If any parameter is not of the correct type or out of valid range.
        """
        for name, value in zip(
            ['beta', 'sigma', 'gamma', 'xi'],
            [beta, sigma, gamma, xi]
        ):
            if name == 'beta':
                if not isinstance(value, (float, int)) or not (0 < value < 1):
                    raise TypeError(
                        "beta must be a float between 0 and 1 (exclusive)!")
            else:
                if not isinstance(value, int) or value <= 0:
                    raise TypeError(
                        f"{name} must be a positive integer since it is a value that described duration, got {value}")

        super().__init__(max_energy)

        self.beta = beta  # Transmission probability
        self.sigma = sigma  # Incubation duration
        self.gamma = gamma  # Symptom duration
        self.xi = xi  # Immune duration

    def to_config_dict(self):
        """
        Convert SEIR model parameters to a dictionary for configuration serialization.

        Returns
        -------
        dict
            Dictionary containing all model parameters as key-value pairs.
        """
        return {
            'max_energy': self.max_energy,
            'beta': self.beta,
            'sigma': self.sigma,
            'gamma': self.gamma,
            'xi': self.xi
        }


class FolkSEIR(AbstractFolk):
    """
    Agent class for the SEIR model.

    This class represents individual agents in the SEIR infection model,
    handling transitions between Susceptible (S), Exposed (E), Infectious (I),
    and Recovered (R) states based on contact with infectious agents and
    time-based progression rules.
    """

    def __init__(self, id, home_address, max_energy, status):
        """
        Initialize a FolkSEIR agent.

        Parameters
        ----------
        id : int
            Unique identifier for the agent.
        home_address : int
            Node index of the agent's home location.
        max_energy : int
            Maximum social energy for the agent.
        status : str
            Initial infection status ('S', 'E', 'I', or 'R').
        """
        super().__init__(id, home_address, max_energy, status)

    def inverse_bernoulli(self, folks_here, conversion_prob, stats):
        """
        Calculate the probability of status transition given contact with infectious agents.

        This method implements the inverse Bernoulli probability calculation used
        in agent-based modeling to approximate the continuous ODE dynamics of
        infection models. It calculates the probability of infection based
        on the number of infectious contacts and transmission probability.

        Parameters
        ----------
        folks_here : list of FolkSEIR
            List of agents present at the same node.
        conversion_prob : float
            Base transmission probability per contact.
        stats : list of str
            List of infectious status types to consider.

        Returns
        -------
        float
            Probability of at least one successful transmission event.
        """
        num_contact = len(
            [folk for folk in folks_here if folk != self and folk.status in stats])
        # beta * I / N is the non-linear term that defines conversion
        # This inverse bernoulli function is an interpretation of the term
        # in agent-based modeling
        return super().inverse_bernoulli(num_contact, conversion_prob / len(folks_here))

    def interact(
            self,
            folks_here,
            current_place_type,
            status_dict_t,
            model_params,
            dice):
        """
        Perform interactions with other agents and handle potential disease transmission.

        Transition Rules
        ----------------
        - If the agent is Susceptible ('S') and comes into contact with at least one
          Infectious ('I') agent, the probability of becoming Exposed ('E') is calculated
          using the inverse Bernoulli formula with transmission probability (`beta`).
          If this probability exceeds the random value `dice`, the agent transitions to Exposed.

        Parameters
        ----------
        folks_here : list of FolkSEIR
            List of agents present at the same node.
        current_place_type : str
            Type of place where the interaction occurs.
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        model_params : SEIRModelParameters
            Model parameters for the simulation.
        dice : float
            Random float for stochastic transitions.

        Returns
        -------
        None
        """
        # When a susceptible person comes into contact with an infectious person,
        # they have a likelihood to become exposed to the disease
        if self.status == 'S' and self.inverse_bernoulli(
                folks_here, model_params.beta, ['I']) > dice:
            self.convert('E', status_dict_t)

        self.energy -= 1

    def sleep(
            self,
            folks_here,
            current_place_type,
            status_dict_t,
            model_params,
            dice):
        """
        Perform end-of-day status transitions based on disease progression.

        This method handles the deterministic time-based transitions between
        infection states at the end of each simulation day.

        Transition Rules
        ----------------
        - If the agent is Exposed ('E') and has been exposed for `sigma` days,
          they transition to Infectious ('I').

        - If the agent is Infectious ('I') and has been infectious for `gamma` days,
          they transition to Recovered ('R').

        - If the agent is Recovered ('R') and has been recovered for `xi` days,
          they transition back to Susceptible ('S') (waning immunity).

        Parameters
        ----------
        folks_here : list of FolkSEIR
            List of agents present at the same node (not used, for interface compatibility).
        current_place_type : str
            Type of place where the agent is sleeping (not used, for interface compatibility).
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        model_params : SEIRModelParameters
            Model parameters for the simulation.
        dice : float
            Random float for stochastic transitions (not used for deterministic transitions).

        Returns
        -------

        None
        """
        super().sleep()
        if self.status == 'E' and self.status_step_streak == model_params.sigma:
            self.convert('I', status_dict_t)
        elif self.status == 'I' and self.status_step_streak == model_params.gamma:
            self.convert('R', status_dict_t)
        elif self.status == 'R' and self.status_step_streak == model_params.xi:
            self.convert('S', status_dict_t)


class SEIRModel(AbstractInfectionModel):
    """
    SEIR infection model implementation.

    This class implements the Susceptible-Exposed-Infectious-Recovered model
    for epidemic simulation. It includes waning immunity where recovered
    individuals return to susceptible status after a specified duration.
    """

    def __init__(self, model_params, step_events=None):
        """
        Initialize the SEIR model with specified parameters and events.

        Parameters
        ----------
        model_params : SEIRModelParameters
            Configuration parameters for the SEIR model.
        step_events : list of StepEvent, optional
            Custom step events for the simulation. If None, default events are used.
        """
        self.folk_class = FolkSEIR
        self.all_statuses = (['S', 'E', 'I', 'R'])
        self.infected_statuses = ['I', 'E']
        self.required_place_types = set(
            ['workplace', 'education', 'religious'])
        self.step_events = step_events
        super().__init__(model_params)

    def initialize_sim_population(self, town):
        """
        Initialize the simulation population and their initial status assignments.

        This method assigns initial statuses and home locations to all agents in the simulation,
        including initial spreaders (both randomly assigned and those at specified nodes) and
        susceptible agents. It creates agent objects, updates the town graph with agent
        assignments, and tracks household nodes.

        Parameters
        ----------
        town : Town
            The Town object representing the simulation environment.

        Returns
        -------
        tuple
            (folks, household_node_indices, status_dict_t0)

            - folks : list of FolkSEIR
                List of all agent objects created for the simulation.

            - household_node_indices : set
                Set of node indices where households are tracked.

            - status_dict_t0 : dict
                Dictionary with the initial count of each status at timestep 0.
        """
        num_pop, num_init_spreader, num_init_spreader_rd, folks, household_node_indices, assignments = super(
        ).initialize_sim_population(town)

        # Randomly assign initial spreaders (not on specified nodes)
        for i in range(num_init_spreader_rd):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'I'))

        # Assign the rest as susceptible
        for i in range(num_pop - num_init_spreader):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'S'))

        # Assign initial spreaders to specified nodes
        for node in town.town_params.spreader_initial_nodes:
            assignments.append((node, 'I'))

        # Create folks and update graph/node info
        for i, (node, status) in enumerate(assignments):
            folk = self.create_folk(
                i, node, self.model_params.max_energy, status)
            folks.append(folk)
            town.town_graph.nodes[node]["folks"].append(folk)
            if len(town.town_graph.nodes[node]["folks"]) == 2:
                household_node_indices.add(node)

        status_dict_t0 = {
            'current_event': None,
            'timestep': 0,
            'S': num_pop - num_init_spreader,
            'E': 0,
            'I': num_init_spreader,
            'R': 0
        }
        return folks, household_node_indices, status_dict_t0


"""
This module implements the adaptation of rumor spreading dynamics with considerations for
credibility, correlation, and crowd personality-based classification or the so-called
SEIsIrR model in the paper mentioned below.

The implementation is based on:

Chen, X., & Wang, N. (2020).
Rumor spreading model considering rumor credibility, correlation and crowd classification based on personality.
*Scientific Reports*, 10, 5887. https://doi.org/10.1038/s41598-020-62585-9
"""

import random as rd

from .abstract_model import (AbstractInfectionModel, AbstractFolk,
                             AbstractModelParameters)


class SEIsIrRModelParameters(AbstractModelParameters):
    """
    Model parameters for the SEIsIrR rumor spreading model.

    This class encapsulates all tunable parameters required for the SEIsIrR
    rumor spreading model, including rumor credibility, spreading probabilities,
    and population literacy characteristics. It validates parameter types and
    ranges upon initialization.
    """

    def __init__(
            self,
            max_energy,
            literacy,
            gamma,
            alpha,
            lam,
            phi,
            theta,
            mu,
            eta1,
            eta2,
            mem_span=10):
        """
        Initialize SEIsIrR model parameters and validate all inputs.

        Parameters
        ----------
        max_energy : int
            Maximum energy for each agent (must be a positive integer).
        literacy : float
            Fraction of the population that is literate (must be between 0 and 1, affects Is/Ir split).
        gamma : float
            Fraction representing how credible the rumor is (must be between 0 and 1).
        alpha : float
            Fraction representing how relevant the rumor is to a human's life (must be between 0 and 1).
        lam : float
            Rumor spreading probability (must be between 0 and 1).
        phi : float
            Stifling probability parameter for E to R transition (must be between 0 and 1).
        theta : float
            Probability parameter for E to S transition (must be between 0 and 1).
        mu : float
            The spreading desire ratio of individuals in class Is to individuals in class Ir (must be between 0 and 1).
        eta1 : float
            Probability parameter for S to R transition (must be between 0 and 1).
        eta2 : float
            Probability parameter for forgetting (S to R) in sleep (must be between 0 and 1).
        mem_span : int, optional
            Memory span for forgetting mechanism (must be >= 1, default: 10).

        Raises
        ------
        TypeError
            If any parameter is not of the correct type or out of valid range.
        """
        super().__init__(max_energy)
        self.literacy = literacy

        for name, value in zip(
            ['gamma', 'alpha', 'lam', 'phi', 'theta', 'mu', 'eta1', 'eta2'],
            [gamma, alpha, lam, phi, theta, mu, eta1, eta2]
        ):
            if not isinstance(value, (float, int)):
                raise TypeError(
                    f"{name} must be a float or int, got {
                        type(value).__name__}")

        gamma, alpha, lam, phi, theta, mu, eta1, eta2 = map(
            float, [gamma, alpha, lam, phi, theta, mu, eta1, eta2])

        if not isinstance(mem_span, int) or mem_span < 1:
            raise TypeError(
                f"mem_span must be an integer greater or equal to 1, got {mem_span}")

        self.alpha = alpha
        self.gamma = gamma
        self.mu = mu
        self.lam = lam
        gamma_alpha_lam = gamma * alpha * lam

        # We use number 2 to signify transition that happens because of
        # interaction
        self.Is2E = (1 - gamma) * gamma_alpha_lam
        self.Is2S = gamma_alpha_lam * mu
        self.Ir2S = gamma_alpha_lam
        self.E2S = theta
        self.E2R = phi
        self.S2R = eta1
        self.forget = eta2
        self.mem_span = mem_span

    def to_config_dict(self):
        """
        Convert SEIsIrR model parameters to a dictionary for configuration serialization.

        Returns
        -------
        dict
            Dictionary containing all model parameters as key-value pairs.
        """
        return {
            'max_energy': self.max_energy,
            'literacy': self.literacy,
            'lam': self.lam,
            'alpha': self.alpha,
            'gamma': self.gamma,
            'phi': self.E2R,
            'theta': self.E2S,
            'mu': self.mu,
            'eta1': self.S2R,
            'eta2': self.forget,
            'mem_span': self.mem_span,
        }


class FolkSEIsIrR(AbstractFolk):
    """
    Agent class for the SEIsIrR rumor spreading model.

    This class represents individual agents in the SEIsIrR infection model,
    handling transitions between Susceptible (S), Exposed (E), Ignorant spreaders (Is),
    Intelligent spreaders (Ir), and Recovered/Stifler (R) states based on rumor
    credibility, literacy levels, and social interactions.
    """

    def __init__(self, id, home_address, max_energy, status):
        """
        Initialize a FolkSEIsIrR agent.

        Parameters
        ----------

        id : int
            Unique identifier for the agent.
        home_address : int
            Node index of the agent's home location.
        max_energy : int
            Maximum social energy for the agent.
        status : str
            Initial infection status ('S', 'E', 'Is', 'Ir', or 'R').
        """
        super().__init__(id, home_address, max_energy, status)

    def inverse_bernoulli(self, folks_here, conversion_prob, stats):
        """
        Calculate the probability of status transition given contact with specific statuses.

        This method implements an energy-weighted inverse Bernoulli probability calculation
        for rumor spreading dynamics. The probability is scaled by the agent's current
        energy relative to their maximum energy, representing decreased social influence
        when tired.

        Parameters
        ----------
        folks_here : list of FolkSEIsIrR
            List of agents present at the same node.
        conversion_prob : float
            Base transition probability per contact.
        stats : list of str
            List of status types to consider for contact counting.

        Returns
        -------
        float
            Probability of at least one successful transition event.
        """
        num_contact = len(
            [folk for folk in folks_here if folk != self and folk.status in stats])
        return super().inverse_bernoulli(num_contact,
                                         conversion_prob * self.energy / self.max_energy)

    def interact(
            self,
            folks_here,
            current_place_type,
            status_dict_t,
            model_params,
            dice):
        """
        Perform interactions with other agents and handle rumor spreading dynamics.

        Transition Rules
        ----------------
        - **Rule 1:** If the agent is Intelligent spreader ('Ir') and contacts Susceptible ('S') agents,
          they may transition to Susceptible ('S') based on `Ir2S` probability.

        - **Rule 2:** If the agent is Ignorant spreader ('Is') and contacts Susceptible ('S') agents,
          they may transition to either Exposed ('E') or Susceptible ('S') based on `Is2E` and `Is2S`
          probabilities respectively. The transition with higher probability is evaluated first.

        - **Rule 3:** If the agent is Exposed ('E'), they may transition to either Susceptible ('S')
          or Recovered ('R') based on `E2S` and `E2R` probabilities respectively. The transition
          with higher probability is evaluated first.

        - **Rule 4.1:** If the agent is Susceptible ('S'), they may transition to Recovered ('R')
          when contacting any other agents ('S', 'E', 'R') based on `S2R` probability.

        Parameters
        ----------
        folks_here : list of FolkSEIsIrR
            List of agents present at the same node.
        current_place_type : str
            Type of place where the interaction occurs.
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        model_params : SEIsIrRModelParameters
            Model parameters for the simulation.
        dice : float
            Random float for stochastic transitions.

        Returns
        -------

        None
        """
        # The rule numbers below are references to each rule defined in the literature of
        # SEIsIrR model

        # Rule 1
        if self.status == 'Ir' and self.inverse_bernoulli(
                folks_here, model_params.Ir2S, ['S']) > dice:
            self.convert('S', status_dict_t)
        # Rule 2
        elif self.status == 'Is':
            conversion_rate_S = self.inverse_bernoulli(
                folks_here, model_params.Is2S, ['S'])
            conversion_rate_E = self.inverse_bernoulli(
                folks_here, model_params.Is2E, ['S'])

            if conversion_rate_S > conversion_rate_E:
                if conversion_rate_E > dice:
                    self.convert('E', status_dict_t)
                elif conversion_rate_S > dice:
                    self.convert('S', status_dict_t)
            else:
                if conversion_rate_S > dice:
                    self.convert('S', status_dict_t)
                elif conversion_rate_E > dice:
                    self.convert('E', status_dict_t)

        # Rule 3
        elif self.status == 'E':
            conversion_rate_S = self.inverse_bernoulli(
                folks_here, model_params.E2S, ['S'])
            conversion_rate_R = self.inverse_bernoulli(
                folks_here, model_params.E2R, ['R'])

            if conversion_rate_S > conversion_rate_R:
                if conversion_rate_R > dice:
                    self.convert('R', status_dict_t)
                elif conversion_rate_S > dice:
                    self.convert('S', status_dict_t)
            else:
                if conversion_rate_R > dice:
                    self.convert('R', status_dict_t)
                elif conversion_rate_S > dice:
                    self.convert('S', status_dict_t)

        # Rule 4.1
        elif self.status == 'S' and self.inverse_bernoulli(folks_here, model_params.S2R, ['S', 'E', 'R']) > dice:
            self.convert('R', status_dict_t)

        self.energy -= 1

    def sleep(
            self,
            folks_here,
            current_place_type,
            status_dict_t,
            model_params,
            dice):
        """
        Perform end-of-day status transitions and forgetting mechanisms.

        This method handles the forgetting mechanism for Susceptible agents,
        representing the natural tendency to lose interest in rumors over time.

        Transition Rules
        ----------------
        - **Rule 4.2:** If the agent is Susceptible ('S'), they may transition to Recovered ('R')
          through forgetting if either:
          - They have been in 'S' status for longer than `mem_span` days, OR
          - A random draw is less than the forgetting probability `forget`

        Parameters
        ----------
        folks_here : list of FolkSEIsIrR
            List of agents present at the same node (not used, for interface compatibility).
        current_place_type : str
            Type of place where the agent is sleeping (not used, for interface compatibility).
        status_dict_t : dict
            Dictionary tracking the count of each status at the current timestep.
        model_params : SEIsIrRModelParameters
            Model parameters for the simulation.
        dice : float
            Random float for stochastic transitions.

        Returns
        -------
        None
        """
        super().sleep()
        if self.status == 'S':
            # Rule 4.2: Forgetting mechanism
            if model_params.mem_span <= self.status_step_streak or dice < model_params.forget:
                self.convert('R', status_dict_t)


class SEIsIrRModel(AbstractInfectionModel):
    """
    SEIsIrR rumor spreading model implementation.

    This class implements the Susceptible-Exposed-Ignorant spreader-Intelligent spreader-Recovered
    model for rumor spreading dynamics. The model considers rumor credibility, population literacy,
    and personality-based classification of spreaders.
    """

    def __init__(self, model_params, step_events=None):
        """
        Initialize the SEIsIrR model with specified parameters and events.

        Parameters
        ----------

        model_params : SEIsIrRModelParameters
            Configuration parameters for the SEIsIrR model.
        step_events : list of StepEvent, optional
            Custom step events for the simulation. If None, default events are used.
        """
        self.folk_class = FolkSEIsIrR
        self.all_statuses = (['S', 'E', 'Ir', 'Is', 'R'])
        self.infected_statuses = 'S'
        self.step_events = step_events
        super().__init__(model_params)

    def initialize_sim_population(self, town):
        """
        Initialize the simulation population and their initial status assignments.

        This method assigns initial statuses and home locations to all agents in the simulation.
        The population is divided between Ignorant spreaders (Is) and Intelligent spreaders (Ir)
        based on the literacy parameter, with initial rumor spreaders assigned to 'S' status.

        Parameters
        ----------
        town : Town
            The Town object representing the simulation environment.

        Returns
        -------
        tuple
            (folks, household_node_indices, status_dict_t0

            - folks : list of FolkSEIsIrR
                List of all agent objects created for the simulation.

            - household_node_indices : set
                Set of node indices where households are tracked.

            - status_dict_t0 : dict
                Dictionary with the initial count of each status at timestep 0.
        """
        num_pop, num_init_spreader, num_init_spreader_rd, folks, household_node_indices, assignments = super(
        ).initialize_sim_population(town)

        num_IsIr = num_pop - num_init_spreader

        # Divide remaining population between Is and Ir based on literacy
        num_Is = round(self.model_params.literacy * num_IsIr)
        num_Ir = num_IsIr - num_Is

        # Randomly assign initial spreaders (not on specified nodes)
        for _ in range(num_init_spreader_rd):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'S'))

        # Assign the rest as Is and Ir
        for _ in range(num_Is):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'Is'))
        for _ in range(num_Ir):
            node = rd.choice(town.accommodation_node_ids)
            assignments.append((node, 'Ir'))

        # Assign initial spreaders to specified nodes
        for node in town.town_params.spreader_initial_nodes:
            assignments.append((node, 'S'))

        # Create folks and update graph/node info
        for i, (node, status) in enumerate(assignments):
            folk = self.create_folk(
                i, node, self.model_params.max_energy, status)
            folks.append(folk)
            town.town_graph.nodes[node]["folks"].append(folk)
            if len(town.town_graph.nodes[node]["folks"]) == 2:
                household_node_indices.add(node)

        status_dict_t0 = {
            'current_event': None,
            'timestep': 0,
            'S': num_init_spreader,
            'E': 0,
            'Is': num_Is,
            'Ir': num_Ir,
            'R': 0
        }
        return folks, household_node_indices, status_dict_t0


import inspect
from enum import Enum

import numpy as np


def log_normal_mobility(distances, folk, median_distance=2000, sigma=1.0):
    """
    Return probabilities inversely proportional to log-normal PDF of distances. Log-normal PDF has been studied to model
    the human mobility pattern in this following literature and its predecessor:
    Wang, W., & Osaragi, T. (2024). Lognormal distribution of daily travel time and a utility model for its emergence.
    Transportation Research Part A: Policy and Practice, 181, 104058. https://doi.org/10.1016/j.tra.2024.104058

    Parameters
    ----------
    median_distance : float
        Median travel distance in meters. This is where the distribution peaks.
        Default 1100m (1.1km) for typical neighborhood activities.
        Common values:
        - 400m → local/walking activities
        - 1100m → neighborhood activities
        - 3000m → city-wide activities
        - 8000m → regional activities
    sigma : float
        Shape parameter controlling spread around median.
        - sigma=0.5 → narrow distribution, consistent travel patterns
        - sigma=1.0 → moderate distribution (default)
        - sigma=1.5 → wide distribution, highly variable travel patterns
    """
    distances = np.array(distances)
    # Avoid log(0) and negative/zero distances
    distances = np.clip(distances, 1e-6, None)

    # Convert median distance to mu parameter: mu = ln(median)
    mu = np.log(median_distance)

    probs = 1 / (distances * sigma * np.sqrt(2 * np.pi)) * \
        np.exp(- (np.log(distances) - mu) ** 2 / (2 * sigma ** 2))
    probs = np.nan_to_num(probs, nan=0.0, posinf=0.0, neginf=0.0)
    probs = probs / probs.sum() if probs.sum() > 0 else np.ones_like(probs) / len(probs)
    return probs


def energy_exponential_mobility(distances, folk, distance_scale=1000):
    """
    Return probabilities proportional to exponential PDF of distances. With lam = inverse of normalized energy
    as a rate parameter - higher energy = lower decay rate = more willing to travel far.

    Parameters
    ----------
    distance_scale : float
        Scale factor for distances to control decay rate. Higher values = slower decay.
        Default 1000 means distances are scaled to kilometers.
    """
    distances = np.array(distances)

    # Scale distances to control decay rate
    scaled_distances = distances / distance_scale

    # Higher energy = lower lambda (less decay) = more willing to travel far
    # Lower energy = higher lambda (more decay) = prefer nearby locations
    energy_ratio = folk.energy / folk.max_energy  # 0 to 1
    # This gives lambda from 1.0 (high energy) to 2.0 (no energy)
    lam = 2.0 - energy_ratio

    probs = lam * np.exp(-lam * scaled_distances)

    # Normalize probabilities to sum to 1
    probs = probs / probs.sum() if probs.sum() > 0 else np.ones_like(probs) / len(probs)

    return probs


class EventType(Enum):
    """
    Step events are classified into two types.

    - DISPERSE is the type of event
    that send agents around the town graph to specific locations in a given range and allow them to interact with other agents
    who are in the same nodes..

    - SEND_HOME is the type of event that every agents in the simulation back to their home address without any interaction.
    SEND_HOME can represent the end of the day where everybody go home and sleep or an emergency announcement
    that sends everyone around town straight back home.
    """
    SEND_HOME = "send_home"
    DISPERSE = "disperse"


class StepEvent:
    """
    Defines agent activities and movement patterns during simulation timesteps.

    StepEvent objects represent discrete activities that agents perform during
    simulation timesteps. Each event specifies how agents move through the town
    network, what types of locations they visit, and what actions they perform
    when they arrive at destinations.

    Purpose
    -------
    1. **Agent Movement Control**: Define where and how far agents travel during
       specific activities (work, shopping, healthcare visits, etc.).

    2. **Location Targeting**: Specify which types of places agents visit for
       different activities using place type filters.

    3. **Mobility Modeling**: Apply realistic human mobility patterns through
       customizable probability functions based on distance and agent characteristics.

    4. **Agent-Dependent Behavior**: Enable mobility patterns that adapt to individual
       agent properties such as energy levels, status, or other attributes.

    Event Types
    -----------
    - **DISPERSE**: Agents move to locations within specified distance and place
      type constraints. Enables agent-to-agent interactions at destinations.

    - **SEND_HOME**: All agents return directly to their home addresses without
      movement or interaction. Represents end-of-day or emergency scenarios.

    Probability Functions
    --------------------
    Custom probability functions must:

    - Accept exactly 2 non-default arguments: `(distances, agent)`

    - Return probabilities between 0 and 1 (will be normalized automatically)

    - Handle numpy arrays for distances

    - Be robust to edge cases (empty arrays, zero distances)

    Built-in mobility functions include:
    - `log_normal_mobility`: Human mobility based on log-normal distance

    - `energy_exponential_mobility`: Agent energy-dependent exponential decay

    Attributes
    ----------
    name : str
        Event identifier.
    max_distance : int
        Maximum travel distance in meters.
    place_types : list
        Allowed destination place types.
    event_type : EventType
        Movement behavior type.
    folk_action : callable
        Agent interaction function.
    probability_func : callable or None
        Distance and agent-based mobility probability function.

    Examples
    --------
    >>> # End of day event
    >>> end_day = StepEvent("end_day", folk_class.sleep)
    >>>
    >>> # Work event with distance constraints
    >>> work = StepEvent("work", folk_class.interact, EventType.DISPERSE,
    ...                  max_distance=10000, place_types=['workplace'])
    >>>
    >>> # Shopping with log-normal mobility
    >>> shopping = StepEvent("shopping", folk_class.interact, EventType.DISPERSE,
    ...                      max_distance=5000, place_types=['commercial'],
    ...                      probability_func=log_normal_mobility)
    >>>
    >>> # Energy-dependent movement
    >>> leisure = StepEvent("leisure", folk_class.interact, EventType.DISPERSE,
    ...                     max_distance=8000, place_types=['commercial', 'religious'],
    ...                     probability_func=energy_exponential_mobility)
    >>>
    >>> # Custom agent-dependent mobility
    >>> def age_based_mobility(distances, agent):
    ...     import numpy as np
    ...     distances = np.array(distances)
    ...     # Older agents prefer shorter distances
    ...     age_factor = getattr(agent, 'age', 30) / 100.0  # Normalize age
    ...     decay_rate = 0.0001 * (1 + age_factor)  # Higher decay for older agents
    ...     probs = np.exp(-decay_rate * distances)
    ...     return probs / probs.sum() if probs.sum() > 0 else np.ones_like(probs) / len(probs)
    >>>
    >>> custom_event = StepEvent("age_sensitive", folk_class.interact, EventType.DISPERSE,
    ...                          max_distance=15000, place_types=['healthcare'],
    ...                          probability_func=age_based_mobility)
    """

    def __init__(
            self,
            name,
            folk_action,
            event_type=EventType.SEND_HOME,
            max_distance=0,
            place_types=[],
            probability_func=None):
        """
        Initialize a StepEvent for agent activity simulation.

        Parameters
        ----------

        name : str
            Descriptive name for the event (e.g., "work", "shopping", "end_day").
        folk_action : callable
            Function executed for each agent during the event. Must accept arguments:
            (folks_here, current_place_type, status_dict_t, model_params, dice).
        event_type : EventType, optional
            Movement behavior type (default: EventType.SEND_HOME).
        max_distance : int, optional
            Maximum travel distance in meters for DISPERSE events (default: 0).
        place_types : list, optional
            Place type categories agents can visit. Examples: ['commercial', 'workplace']
            (default: []).
        probability_func : callable, optional
            Function taking (distances, agent) and returning movement probabilities [0,1].
            Must have exactly 2 non-default arguments. Cannot be used with SEND_HOME events
            (default: None).

        Raises
        ------

        ValueError
            - If probability_func is specified for SEND_HOME events

            - If probability_func is not callable

            - If probability_func doesn't have exactly 2 non-default arguments

            - If probability_func returns invalid probability values during validation

            - If probability_func fails signature inspection

        """
        self.name = name
        self.max_distance = max_distance
        self.place_types = place_types
        self.event_type = event_type
        self.folk_action = folk_action
        self.probability_func = probability_func
        if event_type == EventType.SEND_HOME and probability_func is not None:
            raise ValueError(
                "You cannot define a mobility probability function for an event that does not disperse people")

        if probability_func is not None:
            if not callable(probability_func):
                raise ValueError(
                    "probability_func must be a callable function")

            # Check function signature only
            try:
                sig = inspect.signature(probability_func)
                non_default_params = [
                    p for p in sig.parameters.values()
                    if p.default == inspect.Parameter.empty
                ]

                if len(non_default_params) != 2:
                    raise ValueError(
                        f"probability_func must have exactly 2 non-default arguments, "
                        f"got {len(non_default_params)}. Expected signature: func(distances, agent, **kwargs)")

            except Exception as e:
                raise ValueError(
                    f"Could not inspect probability_func signature: {e}")


from .abstract_model import AbstractInfectionModel, AbstractFolk, AbstractModelParameters
from .SEIR_model import SEIRModel, SEIRModelParameters, FolkSEIR
from .SEIQRDV_model import SEIQRDVModel, SEIQRDVModelParameters, FolkSEIQRDV
from .SEIsIrR_model import SEIsIrRModel, SEIsIrRModelParameters, FolkSEIsIrR
from .step_event import EventType, StepEvent, log_normal_mobility, energy_exponential_mobility


import json
import warnings
from itertools import product

import h5py
import pandas as pd
import plotly.express as px

from .visualization_util import (_load_node_info_from_graphmlz,
                                 _set_plotly_renderer,
                                 _validate_and_merge_colormap)


def plot_place_types_scatter(town_graph_path, town_config_path, colormap=None):
    """
    Visualizes nodes from a .graphmlz town graph file as colored points with colors representing
    different place types (e.g., accommodation, commercial, education), using Plotly and OpenStreetMap..

    Parameters
    ----------
    town_graph_path : str
        Path to the .graphmlz file containing the town graph with node coordinates
        and place_type classifications.
    town_config_path : str
        Path to the .json file containing town metadata with 'epsg_code' for coordinate
        conversion and 'place_types' list defining valid place types.
    colormap : dict, optional
        Custom color mapping {'place_type': '#HEXCOLOR'}. If None, uses defaults.
        Custom colors override defaults for matching place types.

    Returns
    -------
    None
        Displays an interactive Plotly scatter map with colored points, legend,
        and hover information showing node IDs.

    Raises
    ------
    AssertionError
        If file extensions are incorrect (.graphmlz and .json required).
    KeyError
        If town_config_path doesn't contain required 'epsg_code' field.
    ValueError
        If colormap doesn't provide colors for all place types defined in
        town_config_path's 'place_types' list.
    FileNotFoundError
        If specified file paths don't exist.

    Notes
    -----
    - Default colors provided for: accommodation, commercial, religious, education,
      workplace, healthcare_facility
    - Nodes with undefined place types are colored gray (#CCCCCC)
    - Requires internet connection for OpenStreetMap tiles
    """
    assert town_graph_path.endswith(
        ".graphmlz"), f"Expected a .graphmlz file for town_graph_path, got {town_graph_path}"
    assert town_config_path.endswith(
        ".json"), f"Expected a .json file for town_config_path, got {town_config_path}"

    with open(town_config_path, 'r') as f:
        config = json.load(f)

    # Set a correct render for the environment the script is being run
    _set_plotly_renderer()

    # Get valid place types from config
    valid_place_types = config.get('place_types', [])
    epsg_code = config["epsg_code"]  # Also epsg code

    # Default colormap that supports the place types defined in the
    # default place classification function
    default_colormap = {
        "accommodation": "#FFD700",
        "commercial": "#FFA07A",
        "religious": "#9370DB",
        "education": "#00BFFF",
        "workplace": "#4682B4",
        "healthcare_facility": "#17EEA6",
    }

    # Validate and merge colormaps
    color_map = _validate_and_merge_colormap(
        default_colormap,
        colormap,
        valid_place_types,
        "place type"
    )

    node_positions, node_place_types = _load_node_info_from_graphmlz(
        town_graph_path, epsg_code, return_place_type=True
    )

    # Assemble DataFrame
    node_data_list = []
    for node_id, (lat, lon) in node_positions.items():
        place_type = node_place_types.get(node_id, "other")
        node_data_list.append({
            "node_id": node_id,
            "lat": lat,
            "lon": lon,
            "place_type": place_type,
            # Default gray for 'other' or unknown types
            "color": color_map.get(place_type, "#CCCCCC")
        })

    df = pd.DataFrame(node_data_list)

    fig = px.scatter_map(
        df,
        lat="lat",
        lon="lon",
        color="place_type",
        color_discrete_map=color_map,
        hover_name="node_id",
        zoom=13,
        height=700
    )

    fig.update_layout(
        mapbox_style="open-street-map",
        title="Town Graph Nodes by Place Type",
        legend_title="Place Type",
        margin={"r": 0, "t": 50, "l": 0, "b": 0}
    )
    fig.update_traces(marker=dict(size=9, opacity=0.8))
    fig.show()


def plot_agents_scatter(
        output_hdf5_path,
        town_graph_path,
        time_interval=None):
    """
    Visualize the movement and status of agents over time on a map using simulation output using Plotly and OpenStreetMap.

    Parameters
    ----------
    output_hdf5_path : str
        Path to the HDF5 file containing simulation results.
    town_graph_path : str
        Path to the .graphmlz file containing the town graph.
    time_interval : tuple or list of int, optional
        (start, end) timestep range to visualize. If None, visualize all timesteps.

    Returns
    -------
    None
        Displays an Plotly map with a time slider showing agent locations and statuses over time.

    Notes
    -----
    - Each status is represented by a random color (A colormap parameter is a to-be-implemented)
    - Requires internet connection for OpenStreetMap tiles
    """
    assert output_hdf5_path.endswith(
        ".h5"), f"Expected a .h5 file for output_hdf5_path, got {output_hdf5_path}"
    assert town_graph_path.endswith(
        ".graphmlz"), f"Expected a .graphmlz file for town_graph_path, got {town_graph_path}"

    # Set a correct render for the environment the script is being run
    _set_plotly_renderer()

    # Load HDF5 data
    with h5py.File(output_hdf5_path, "r") as h5:
        town_config_json_bytes = h5["config/town_config"][()]
        town_config = json.loads(town_config_json_bytes.decode("utf-8"))
        epsg_code = town_config["epsg_code"]

        folk_data = h5["individual_logs/log"][:]
        metadata_json_bytes = h5["config/simulation_config"][()]
        metadata = json.loads(metadata_json_bytes.decode("utf-8"))
        all_statuses = metadata["all_statuses"]
        step_events_order = [e['name']
                             for e in metadata.get("step_events", [])]

    # Load node positions
    node_pos = _load_node_info_from_graphmlz(town_graph_path, epsg_code)

    # Validate the user input time_interval
    if time_interval is not None:
        assert isinstance(time_interval, (tuple, list)) and len(
            time_interval) == 2, "time_interval must be a tuple or list of two integers (start, end)"
        assert all(isinstance(x, int)
                   for x in time_interval), "time_interval must contain only integers"
        assert time_interval[0] >= 0 and time_interval[1] > 0, "Timestep values in time_interval cannot be negative."
        assert time_interval[1] >= time_interval[0], "Start timestep cannot be greater than end timestep."

        max_timestep_in_data = int(folk_data["timestep"].max())

        if time_interval[1] > max_timestep_in_data:
            warnings.warn(
                f"Given end timestep {time_interval[1]} exceeds maximum timestep {max_timestep_in_data} in data. "
                f"Plotting will only include timesteps up to {max_timestep_in_data}."
            )
            time_interval = (time_interval[0], max_timestep_in_data)

            # Check again after adjustment - if start > adjusted end, it's an
            # error
            if time_interval[0] > time_interval[1]:
                raise ValueError(
                    f"Start timestep {time_interval[0]} is greater than maximum available timestep {max_timestep_in_data}. "
                    f"Please specify a start timestep <= {max_timestep_in_data}."
                )
    # Aggregate for all (or selected) timesteps
    points = []
    for entry in folk_data:
        timestep = int(entry["timestep"])

        # Filter by time_interval if given
        if time_interval is not None:
            if timestep < time_interval[0] or timestep > time_interval[1]:
                continue

        event = entry["event"].decode("utf-8")
        status = entry["status"].decode("utf-8")
        address = int(entry["address"])
        if address in node_pos:
            lat, lon = node_pos[address]
            frame_label = f"{timestep}: {event}"

            points.append({
                "frame": frame_label,
                "lat": lat,
                "lon": lon,
                "status": status,
                "size": 1
            })

    df_raw = pd.DataFrame(points)
    df_raw["timestep"] = df_raw["frame"].str.extract(r"^(\d+):")[0].astype(int)
    df_raw["event_name"] = df_raw["frame"].str.extract(r": (.*)$")[0]

    # Map event names to their order (within each day)
    event_order_map = {name: i for i, name in enumerate(step_events_order)}
    df_raw["event_order"] = df_raw["event_name"].map(event_order_map)

    # Sort by timestep then by event order
    df_raw.sort_values(by=["timestep", "event_order"], inplace=True)

    # Re-create frame column with the correct order
    df_raw["frame"] = df_raw.apply(
        lambda row: f"{
            row['timestep']}: {
            row['event_name']}",
        axis=1)

    unique_frames = df_raw["frame"].drop_duplicates().tolist()
    unique_coords = df_raw[["lat", "lon"]].drop_duplicates().values.tolist()
    full_index = list(
        product(
            unique_frames, all_statuses, [
                tuple(c) for c in unique_coords]))

    full_df = pd.DataFrame([
        {
            "frame": f,
            "status": s,
            "lat": lat,
            "lon": lon,
            "size": 0
        }
        for f, s, (lat, lon) in full_index
    ])
    df_grouped = df_raw.groupby(
        ["frame", "status", "lat", "lon"], as_index=False).agg({"size": "sum"})

    df_filled = pd.concat([df_grouped, full_df], ignore_index=True).drop_duplicates(
        subset=["frame", "status", "lat", "lon"], keep="first")

    fig = px.scatter_map(
        df_filled,
        lat="lat",
        lon="lon",
        size="size",
        color="status",
        animation_frame="frame",
        category_orders={
            "status": all_statuses,
            "frame": unique_frames
        },
        size_max=20,
        zoom=13,
        height=600,
        hover_data={"size": True}
    )
    fig.update_layout(mapbox_style="open-street-map")
    fig.update_layout(
        title="Population status over time with marker size representing the number of people of that status at each time frame")
    fig.update_traces(marker=dict(opacity=0.7))
    fig.show()


import json

import h5py
import matplotlib.pyplot as plt


def _plot_status_summary_data(
        status_keys,
        timesteps,
        data_dict,
        status_type,
        ylabel="Density"):
    # A helper function to plot simulation status data over time using matplotlib.
    # Selects and validates which status types to plot, then generates the line plot.
    # Validate and select keys to plot

    if status_type is None:
        keys_to_plot = status_keys
    elif isinstance(status_type, str):
        if status_type not in status_keys:
            raise ValueError(
                f"Invalid status_type '{status_type}'. Must be one of {status_keys}.")
        keys_to_plot = [status_type]
    elif isinstance(status_type, list):
        invalid = [k for k in status_type if k not in status_keys]
        if invalid:
            raise ValueError(
                f"Invalid status types {invalid}. Must be from {status_keys}.")
        keys_to_plot = status_type
    else:
        raise TypeError(
            f"status_type must be None, str, or list of str, got {
                type(status_type).__name__}.")

    # Plotting
    plt.figure(figsize=(10, 6))
    for key in keys_to_plot:
        plt.plot(timesteps, data_dict[key], label=key)

    plt.xlabel("Timestep")
    plt.ylabel(ylabel)
    plt.title("Simulation Status Over Timesteps")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()


def plot_status_summary_from_hdf5(output_hdf5_path, status_type=None):
    """
    Plot the normalized status summary (density) over time from a simulation HDF5 output file.

    This function reads the simulation status summary from the specified HDF5 file,
    normalizes each status by the total population, and plots the density of each status
    (or a subset of statuses) over simulation timesteps.

    Parameters
    ----------
    output_hdf5_path : str
        Path to the HDF5 file containing simulation results.
    status_type : str or list of str or None, optional
        If None (default), plot all status types.
        If str, plot only the specified status type.
        If list of str, plot only the specified status types.

    Raises
    ------
    ValueError
        If the HDF5 file contains no status data or if the total population is zero.
        If an invalid status_type is provided.
    TypeError
        If status_type is not None, str, or list of str.

    Returns
    -------
    None
        Displays a matplotlib plot of the status densities over time.
    """
    with h5py.File(output_hdf5_path, "r") as h5file:
        status_ds = h5file["status_summary/summary"]
        if len(status_ds) == 0:
            raise ValueError("No status data found in HDF5 file.")

        # Extract status keys from dtype
        all_keys = [
            name for name in status_ds.dtype.names if name not in (
                "timestep", "current_event")]

        # Extract total population from metadata
        metadata_str = h5file["config/simulation_config"][()
                                                          ].decode("utf-8")
        metadata = json.loads(metadata_str)
        total_population = metadata["population"]
        if total_population == 0:
            raise ValueError("Total population in configurations is zero.")

        # Prepare data dicts
        last_entry_by_timestep = {}
        for row in status_ds:
            timestep = int(row["timestep"])
            # Always keep the last one seen per timestep
            last_entry_by_timestep[timestep] = row

        final_timesteps = sorted(last_entry_by_timestep.keys())
        final_status_data = {key: [] for key in all_keys}

        for ts in final_timesteps:
            row = last_entry_by_timestep[ts]
            for key in all_keys:
                final_status_data[key].append(row[key] / total_population)

    _plot_status_summary_data(
        all_keys,
        final_timesteps,
        final_status_data,
        status_type,
        ylabel="Density")


import os
import re
import tempfile
import warnings
import zipfile

import networkx as nx
import plotly.io as pio
from IPython import get_ipython
from pyproj import Transformer


def _validate_and_merge_colormap(
        default_map,
        user_map,
        valid_keys,
        parameter_name):
    # A helper function used to validate and merge the colormap for plotly visualization
    # if the user gives us a custom color map.

    # Start with the default colormap
    result = default_map.copy()

    # If user map provided, merge it
    # This is for the case where user might not provide all the colors for all the place types
    # in their location of interest. The place type without user specified color will fall back
    # to using default color.
    if user_map is not None:
        # Check user entries
        for key, color in user_map.items():
            if key not in valid_keys:
                warnings.warn(
                    f"Warning: '{key}' is not a valid {parameter_name}. "
                    f"Valid values are: {', '.join(valid_keys)}"
                )

            # Basic validation for hex color codes
            if not isinstance(
                    color,
                    str) or not re.match(
                    r'^#(?:[0-9a-fA-F]{3}){1,2}$',
                    color):
                warnings.warn(
                    f"Warning: '{color}' for {key} is not a valid hex color. "
                    "Expected format: '#RRGGBB' or '#RGB'"
                )

            # Add to result anyway (user's responsibility)
            result[key] = color

    # AFTER merging, check if there are still valid keys without colors
    # If the place type - color mapping doesn't exist in the default map and the custom map,
    # tell the user to provide it.
    missing_colors = set(valid_keys) - set(result.keys())
    if missing_colors:
        raise ValueError(
            f"Missing colors for valid {parameter_name}(s): {', '.join(sorted(missing_colors))}. "
            f"Please provide colors for these in the colormap parameter."
        )

    return result


def _set_plotly_renderer():
    # A helper function to determine on which platform plotly is supposed to
    # render the plot
    try:
        # Check if running in a Jupyter notebook
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            pio.renderers.default = "notebook"
        else:
            pio.renderers.default = "browser"
    except NameError:
        # Not running in IPython/Jupyter
        pio.renderers.default = "browser"


def _load_node_info_from_graphmlz(
        town_graph_path,
        epsg_code,
        return_place_type=False):
    # A helper function that loads the information necessary for plotting
    # from the compressed input graph file
    with tempfile.TemporaryDirectory() as tmpdirname:
        with zipfile.ZipFile(town_graph_path, 'r') as zf:
            zf.extractall(tmpdirname)
            graphml_path = os.path.join(tmpdirname, "graph.graphml")
            G = nx.read_graphml(graphml_path)
            G = nx.relabel_nodes(G, lambda x: int(x))

    transformer = Transformer.from_crs(
        f"EPSG:{epsg_code}", "EPSG:4326", always_xy=True)

    node_positions = {}
    node_place_types = {} if return_place_type else None

    for node, data in G.nodes(data=True):
        x = float(data["x"])
        y = float(data["y"])

        lon, lat = transformer.transform(x, y)
        node_positions[node] = (lat, lon)

        if return_place_type:
            place_type = data.get("place_type", "unknown")
            node_place_types[node] = place_type

    if return_place_type:
        return node_positions, node_place_types
    return node_positions


from .plot_status_summary import plot_status_summary_from_hdf5
from .plot_scatter import plot_agents_scatter, plot_place_types_scatter


import pytest
import tempfile
import os
import glob
import nbformat
from nbconvert.preprocessors import ExecutePreprocessor


class TestNotebookExecution:

    def get_notebook_paths(self):
        examples_dir = os.path.join(
            os.path.dirname(__file__), "..", "examples")
        examples_dir = os.path.abspath(examples_dir)
        notebook_paths = glob.glob(os.path.join(examples_dir, "*.ipynb"))
        return notebook_paths

    @pytest.mark.slow
    def test_all_notebooks_batch(self):
        # Test all notebook in the ../examples directory
        notebook_paths = self.get_notebook_paths()
        failed_notebooks = []

        for notebook_path in notebook_paths:
            notebook_name = os.path.basename(notebook_path)

            try:
                with tempfile.TemporaryDirectory() as tmpdir:
                    # Read and execute notebook
                    with open(notebook_path, 'r', encoding='utf-8') as f:
                        nb = nbformat.read(f, as_version=4)

                    ep = ExecutePreprocessor(
                        timeout=300,
                        kernel_name='python3',
                        allow_errors=False
                    )

                    ep.preprocess(nb, {'metadata': {'path': tmpdir}})

            except Exception as e:
                failed_notebooks.append((notebook_name, str(e)))

        if failed_notebooks:
            error_msg = "The following notebooks failed to execute:\n"
            for name, error in failed_notebooks:
                error_msg += f"  - {name}: {error}\n"
            pytest.fail(error_msg)


from simcronomicon import Town, TownParameters, Simulation, infection_models
from pyproj import Transformer

# Common coordinates
POINT_DOM = (50.7753, 6.0839)
POINT_UNIKLINIK = (50.77583, 6.045277)
COORDS_THERESIENKIRCHE = (50.77809, 6.081859)
COORDS_HAUSARZT = (50.76943, 6.081437)
COORDS_SUPERC = (50.77828, 6.078571)

# Default town parameters for `test_town.py`
DEFAULT_TOWN_PARAMS = TownParameters(100, 10)

DEFAULT_TEST_TOWN_CONFIG = {
    'point': (50.7753, 6.0839),  # POINT_DOM
    'distance': 500,
    'num_pop': 20,
    'num_init_spreader': 2
}


def create_test_town_files(prefix="test_viz", **kwargs):
    # Merge defaults with provided overrides
    config = {**DEFAULT_TEST_TOWN_CONFIG, **kwargs}

    town_params = TownParameters(
        num_pop=config['num_pop'],
        num_init_spreader=config['num_init_spreader']
    )

    town = Town.from_point(
        config['point'],
        config['distance'],
        town_params,
        file_prefix=prefix
    )

    graphml_path = f"{prefix}.graphmlz"
    config_path = f"{prefix}_config.json"

    return graphml_path, config_path, town


def get_nearest_node(town, coords):
    lat, lon = coords
    transformer = Transformer.from_crs(
        "EPSG:4326", f"EPSG:{town.epsg_code}", always_xy=True)
    x, y = transformer.transform(lon, lat)
    min_dist = float("inf")
    closest_node = None
    for node, data in town.town_graph.nodes(data=True):
        dx = float(data["x"]) - x
        dy = float(data["y"]) - y
        dist = dx ** 2 + dy ** 2
        if dist < min_dist:
            min_dist = dist
            closest_node = node
    return closest_node


def get_shortest_path_length(town, node_a, node_b):
    G = town.town_graph
    import networkx as nx
    assert nx.has_path(G, node_a, node_b), \
        "A path between the nodes isn't found!"
    return nx.shortest_path_length(G, node_a, node_b, weight="weight")


MODEL_MATRIX = {
    "seir": (
        infection_models.SEIRModel,
        infection_models.SEIRModelParameters,
        infection_models.FolkSEIR,
        dict(max_energy=5, beta=0.4, sigma=6, gamma=5, xi=20),
        "test/test_data/aachen_dom_500m_config.json",
        "test/test_data/aachen_dom_500m.graphmlz"
    ),
    "seisir": (
        infection_models.SEIsIrRModel,
        infection_models.SEIsIrRModelParameters,
        infection_models.FolkSEIsIrR,
        dict(max_energy=5, literacy=0.5, gamma=0.5, alpha=0.5, lam=0.9,
             phi=0.5, theta=0.8, mu=0.5, eta1=0.5, eta2=0.5, mem_span=10),
        "test/test_data/aachen_dom_500m_config.json",
        "test/test_data/aachen_dom_500m.graphmlz"
    ),
    "seiqrdv": (
        infection_models.SEIQRDVModel,
        infection_models.SEIQRDVModelParameters,
        infection_models.FolkSEIQRDV,
        dict(max_energy=5, lam_cap=0.01, beta=0.4, alpha=0.5, gamma=3,
             delta=2, lam=4, rho=5, kappa=0.2, mu=0.01, hospital_capacity=100),
        "test/test_data/uniklinik_500m_config.json",
        "test/test_data/uniklinik_500m.graphmlz"
    )
}


def default_test_step_events(folk_class):
    return [
        infection_models.StepEvent(
            "greet_neighbors",
            folk_class.interact,
            infection_models.EventType.DISPERSE,
            5000,
            ['accommodation'],
            infection_models.energy_exponential_mobility),
        infection_models.StepEvent(
            "chore",
            folk_class.interact,
            infection_models.EventType.DISPERSE,
            19000,
            [
                'commercial',
                'workplace',
                'education',
                'religious'],
            infection_models.log_normal_mobility)]


def setup_simulation(
        model_key,
        town_params,
        step_events=None,
        timesteps=1,
        seed=None,
        override_params=None):
    model_class, model_params_class, folk_class, base_params, config_path, graphmlz_path = MODEL_MATRIX[
        model_key]
    params = dict(base_params)
    if override_params:
        params.update(override_params)
    model_params = model_params_class(**params)
    model = model_class(model_params, step_events=step_events)
    town = Town.from_files(config_path, graphmlz_path, town_params)
    return Simulation(town, model, timesteps=timesteps, seed=seed), town, model


import pytest
import h5py
import tempfile
import os

from simcronomicon import Town, TownParameters, Simulation
from simcronomicon.infection_models import StepEvent, EventType
from test.test_helper import MODEL_MATRIX, default_test_step_events, setup_simulation


class TestSimulationInitializationGeneralized:
    @pytest.mark.parametrize("model_key,spreader_status", [
        ("seir", b'I'),
        ("seisir", b'S'),
        ("seiqrdv", b'I'),
    ])
    def test_initial_spreaders_placement(self, model_key, spreader_status):
        _, _, _, _, _, _ = MODEL_MATRIX[model_key]
        town_params = TownParameters(num_pop=100, num_init_spreader=10)
        # Use first 5 accommodation nodes, repeated twice
        town = Town.from_files(
            config_path=MODEL_MATRIX[model_key][4],
            town_graph_path=MODEL_MATRIX[model_key][5],
            town_params=town_params
        )
        spreader_nodes = list(town.accommodation_node_ids)[:5] * 2
        town_params.spreader_initial_nodes = spreader_nodes
        sim, town, _ = setup_simulation(model_key, town_params)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "out.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                log = h5file["individual_logs/log"][:]
                spreaders = [row for row in log if row['timestep']
                             == 0 and row['status'] == spreader_status]
                assert len(spreaders) == town_params.num_init_spreader
                spreader_addresses = [row['address'] for row in spreaders]
                assert sorted(spreader_addresses) == sorted(spreader_nodes)

    @pytest.mark.parametrize("model_key", ["seiqrdv"])
    def test_missing_required_place_type(self, model_key):
        # Use test data that does NOT contain 'healthcare_facility' in
        # found_place_types
        config_path = "test/test_data/aachen_dom_500m_config.json"
        graphmlz_path = "test/test_data/aachen_dom_500m.graphmlz"
        town_params = TownParameters(num_pop=10, num_init_spreader=1)
        town = Town.from_files(config_path, graphmlz_path, town_params)
        model_params_class = MODEL_MATRIX[model_key][1]
        model_params = model_params_class(**MODEL_MATRIX[model_key][3])
        model = MODEL_MATRIX[model_key][0](model_params)
        # Should raise ValueError due to missing 'healthcare_facility'
        with pytest.raises(ValueError, match="Missing required place types"):
            Simulation(town, model, timesteps=1)


class TestStepEventFunctionality:
    def test_step_event_invalid_parameters(self):
        # Test SEND_HOME with probability_func (should raise ValueError)
        with pytest.raises(ValueError, match="You cannot define a mobility probability function for an event that does not disperse people"):
            StepEvent(
                "invalid_send_home",
                lambda folk: None,
                EventType.SEND_HOME,
                probability_func=lambda x: 0.5
            )

        # Test non-callable probability_func (should raise ValueError)
        with pytest.raises(ValueError, match="probability_func must be a callable function"):
            StepEvent(
                "invalid_prob_func",
                lambda folk: None,
                EventType.DISPERSE,
                probability_func="not_a_function"
            )

        with pytest.raises(ValueError, match=r"Could not inspect probability_func signature: probability_func must have exactly 2 non-default arguments, got 1\. Expected signature: func\(distances, agent, \*\*kwargs\)"):
            StepEvent(
                "invalid_prob_func_without_folk",
                lambda folk: None,
                EventType.DISPERSE,
                probability_func=lambda x: 0.5
            )

    @pytest.mark.parametrize("model_key", ["seir", "seisir"])
    def test_disperse_and_end_day_events(self, model_key):
        _, _, folk_class, _, _, _ = MODEL_MATRIX[model_key]
        town_params = TownParameters(num_pop=5, num_init_spreader=1)
        step_events = [
            StepEvent(
                "go_to_work",
                folk_class.interact,
                EventType.DISPERSE,
                10000,
                ['workplace']
            )
        ]
        sim, town, _ = setup_simulation(
            model_key, town_params, step_events=step_events, timesteps=1)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "stepevent_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                log = h5file["individual_logs/log"][:]
                # Check 'go_to_work' event
                go_to_work_rows = log[(log['timestep'] == 1) & (
                    log['event'] == b"go_to_work")]
                for row in go_to_work_rows:
                    folk_id = row['folk_id']
                    address = row['address']
                    home_addr = next(
                        folk.home_address for folk in sim.folks if folk.id == folk_id)
                    place_type = town.town_graph.nodes[address]['place_type']
                    assert address == home_addr or place_type == 'workplace', \
                        f"AbstractFolk {folk_id} at address {address} (type {place_type}) is not at home or workplace during go_to_work"
                # Check 'end_day' event that automatically gets appended
                # regardless of the StepEvents input from the user
                end_day_rows = log[(log['timestep'] == 1) &
                                   (log['event'] == b"end_day")]
                for row in end_day_rows:
                    folk_id = row['folk_id']
                    address = row['address']
                    home_addr = next(
                        folk.home_address for folk in sim.folks if folk.id == folk_id)
                    assert address == home_addr, f"AbstractFolk {folk_id} not at home at end_day (address {address}, home {home_addr})"

# For SEIQRDV, the functionality of priority place is tested in its own dedicated tests,
# since agents may prioritize 'healthcare_facility' and bypass typical
# destinations like 'workplace'.


class TestSimulationUpdate:
    @pytest.mark.parametrize("model_key", ["seir", "seisir", "seiqrdv"])
    def test_population_conservation(self, model_key):
        _, _, _, _, _, _ = MODEL_MATRIX[model_key]
        town_params = TownParameters(num_pop=100, num_init_spreader=10)
        sim, _, _ = setup_simulation(model_key, town_params, timesteps=5)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "pop_cons_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                for row in summary:
                    total = sum(
                        row[name] for name in row.dtype.names if name not in (
                            "timestep", "current_event"))
                    assert total == 100, f"Population not conserved at timestep {
                        row['timestep']}: got {total}, expected 100"

    def test_population_migration_and_death(self):
        # Only SEIQRDV truly updates population size after each day
        model_key = "seiqrdv"
        _, _, _, extra_params, _, _ = MODEL_MATRIX[model_key]
        town_params = TownParameters(num_pop=100, num_init_spreader=10)
        # Test migration (lam_cap=1, mu=0)
        params = dict(extra_params)
        params['lam_cap'] = 1
        params['mu'] = 0
        sim, town, _ = setup_simulation(
            model_key, town_params, timesteps=2, override_params=params)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "pop_migration_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                step1 = summary[-2]
                step2 = summary[-1]
                total1 = sum(
                    step1[name] for name in step1.dtype.names if name not in (
                        "timestep", "current_event", "D"))
                total2 = sum(
                    step2[name] for name in step2.dtype.names if name not in (
                        "timestep", "current_event", "D"))
                assert total1 == 200, f"Population should be doubled at timestep {
                    step1['timestep']}: got {total1}, expected 200"
                assert total2 == total1 * 2, f"Population should be doubled at timestep {
                    step2['timestep']}: got {total2}, expected {
                    total1 * 2}"
        # Test death (lam_cap=0, mu=1)
        params['lam_cap'] = 0
        params['mu'] = 1
        sim, town, _ = setup_simulation(
            model_key, town_params, timesteps=1, override_params=params)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "pop_death_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                last_step = summary[-1]
                death_last = last_step["D"]
                assert death_last == 100, f"Population should be all dead at timestep {
                    last_step['timestep']}: got {death_last}, expected 100"
                other_statuses = ["S", "E", "I", "Q", "R", "V"]
                for status in other_statuses:
                    assert last_step[status] == 0, "Population of other statuses should be equal to 0"


class TestSimulationResults:
    def assert_h5_structure(self, h5_path):
        with h5py.File(h5_path, "r") as h5file:
            assert "config" in h5file, "'metadata' group missing in HDF5 file"
            assert "status_summary" in h5file, "'status_summary' group missing in HDF5 file"
            assert "individual_logs" in h5file, "'individual_logs' group missing in HDF5 file"
            assert "simulation_config" in h5file["config"], "'simulation_config' missing in metadata group"
            assert "town_config" in h5file["config"], "'town_config' missing in metadata group"
            assert "summary" in h5file["status_summary"], "'summary' missing in status_summary group"
            assert "log" in h5file["individual_logs"], "'log' missing in individual_logs group"

    @pytest.mark.parametrize("model_key,expected_status", [
        ("seir", {"S": 94, "E": 2, "I": 0, "R": 4}),
        ("seisir", {"S": 0, "E": 0, "Is": 45, "Ir": 42, "R": 13}),
        ("seiqrdv", {"S": 0, "E": 0, "I": 0,
         "Q": 0, "R": 8, "D": 14, "V": 78}),
    ])
    def test_status_summary_last_step(self, model_key, expected_status):
        town_params = TownParameters(num_pop=100, num_init_spreader=10)
        folk_class = MODEL_MATRIX[model_key][2]
        step_events = default_test_step_events(folk_class)
        sim, _, _ = setup_simulation(
            model_key, town_params, step_events=step_events, timesteps=50, seed=True, override_params=None)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "out.h5")
            sim.run(hdf5_path=h5_path, silent=False)
            self.assert_h5_structure(h5_path)
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                last_step = summary[-1]
                for status, expected_value in expected_status.items():
                    assert last_step[status] == expected_value, f"{status} mismatch: got {
                        last_step[status]}, expected {expected_value}"


import pytest
import os
import shutil
import pyproj
import numpy as np
import tempfile

import osmnx as ox

from simcronomicon import Town, TownParameters
from test.test_helper import (
    POINT_DOM,
    POINT_UNIKLINIK,
    COORDS_THERESIENKIRCHE,
    COORDS_HAUSARZT,
    COORDS_SUPERC,
    get_nearest_node,
    get_shortest_path_length,
    DEFAULT_TOWN_PARAMS)


class TestTownParameters:

    @pytest.mark.parametrize(
        "num_pop, num_init_spreader, spreader_nodes, expected_nodes",
        [
            (1000, 10, [], []),
            (1000, 3, [5, 12, 47], [5, 12, 47]),
            (1000, 5, [10, 25], [10, 25]),  # Partial specification
            (100, 4, [7, 7, 15, 15], [7, 7, 15, 15]),  # Duplicates
            (100, 2, ["1", "2"], ["1", "2"]),  # String convertible
            (100, 3, [1, "2", 3.0], [1, "2", 3.0]),  # Mixed types
            (100, 2, [-1, -5], [-1, -5]),  # Negative node IDs
        ]
    )
    def test_valid_parameters(
            self,
            num_pop,
            num_init_spreader,
            spreader_nodes,
            expected_nodes):
        params = TownParameters(
            num_pop=num_pop,
            num_init_spreader=num_init_spreader,
            spreader_initial_nodes=spreader_nodes
        )
        assert params.num_pop == num_pop
        assert params.num_init_spreader == num_init_spreader
        assert params.spreader_initial_nodes == expected_nodes

    @pytest.mark.parametrize(
        "kwargs, error, match",
        [
            # Type Errors
            ({"num_pop": "1000", "num_init_spreader": 10},
             TypeError, "num_pop must be an integer"),
            ({"num_pop": 1000, "num_init_spreader": 10.5},
             TypeError, "num_init_spreader must be an integer"),
            ({"num_pop": 1000, "num_init_spreader": 2, "spreader_initial_nodes": (
                1, 2)}, TypeError, "spreader_initial_nodes must be a list"),

            # Value Errors for num_pop
            ({"num_pop": 0, "num_init_spreader": 1},
             ValueError, "num_pop must be positive, got 0"),
            ({"num_pop": -5, "num_init_spreader": 1},
             ValueError, "num_pop must be positive, got -5"),

            # Value Errors for num_init_spreader - FIXED: Match actual error
            # messages
            ({"num_pop": 100, "num_init_spreader": 0}, ValueError,
             "num_init_spreader must be positive, got 0"),
            ({"num_pop": 100, "num_init_spreader": -1}, ValueError,
             "num_init_spreader must be positive, got -1"),
            ({"num_pop": 100, "num_init_spreader": 150}, ValueError,
             "num_init_spreader \\(150\\) cannot exceed num_pop \\(100\\)"),

            # Too many spreader locations - 4 locations for 2 spreaders should
            # fail
            ({"num_pop": 100, "num_init_spreader": 2, "spreader_initial_nodes": [
             1, 2, 3, 4]}, ValueError, "There cannot be more locations"),
        ]
    )
    def test_invalid_parameters(self, kwargs, error, match):
        with pytest.raises(error, match=match):
            TownParameters(**kwargs)


class TestTown:
    def setup_method(self):
        ox.settings.use_cache = False
        self._cleanup_all_files()

    def teardown_method(self):
        self._cleanup_all_files()

    def _cleanup_all_files(self):
        # OSMnx cache cleanup
        cache_dirs = [
            os.path.expanduser("~/.osmnx"),
            "cache"
        ]
        for cache_dir in cache_dirs:
            if os.path.exists(cache_dir):
                shutil.rmtree(cache_dir)

        # Default town files cleanup
        default_files = [
            "town_graph.graphmlz",
            "town_graph_config.json"
        ]
        for filename in default_files:
            if os.path.exists(filename):
                os.remove(filename)
                print(f"Cleaned up: {filename}")

    def test_town_invalid_inputs(self):
        # Case 1: classify_place_func is not a function
        with pytest.raises(TypeError, match="`classify_place_func` must be a function."):
            Town.from_point(
                POINT_DOM, 500, DEFAULT_TOWN_PARAMS,
                classify_place_func="not_a_function",
                all_place_types=["accommodation", "workplace"]
            )

        # Case 2: custom classify_place_func but all_place_types is None
        def dummy_classify(row):
            return "workplace"

        with pytest.raises(ValueError, match="If you pass a custom `classify_place_func`, you must also provide `all_place_types`."):
            Town.from_point(
                POINT_DOM, 500, DEFAULT_TOWN_PARAMS,
                classify_place_func=dummy_classify,
                all_place_types=None
            )

        # Case 3: custom classify_place_func but "accommodation" missing in
        # all_place_types
        with pytest.raises(ValueError, match="Your `all_place_types` must include 'accommodation' type buildings."):
            Town.from_point(
                POINT_DOM, 500, DEFAULT_TOWN_PARAMS,
                classify_place_func=dummy_classify,
                all_place_types=["workplace", "education"]
            )

        # Edge Case 1: point is not a tuple/list
        with pytest.raises(ValueError, match="`point` must be a list or tuple in the format \\[latitude, longitude\\]."):
            Town.from_point(
                "not_a_tuple", 500, DEFAULT_TOWN_PARAMS
            )

        # Edge Case 2: point is not valid lat/lon
        with pytest.raises(ValueError, match="`point` values must represent valid latitude and longitude coordinates."):
            Town.from_point(
                (200, 500), 500, DEFAULT_TOWN_PARAMS
            )

        # Edge Case 3: "No relevant nodes remain after filtering. The resulting
        # town graph would be empty."
        with pytest.raises(ValueError, match="No relevant nodes remain after filtering. The resulting town graph would be empty."):
            # Use point a bit further off from Dom and decrease the radius to
            # trigger this error
            Town.from_point((50.7853, 6.0839), 100, DEFAULT_TOWN_PARAMS)

    def test_graphmlz_file_saved_and_overwrite_prompt_and_abort(self):
        import builtins
        with tempfile.TemporaryDirectory() as tmpdir:
            file_prefix = "overwrite_test"
            graphmlz_path = os.path.join(tmpdir, f"{file_prefix}.graphmlz")

            # First save: file should be created
            town = Town.from_point(
                POINT_DOM,
                500,
                DEFAULT_TOWN_PARAMS,
                file_prefix=file_prefix,
                save_dir=tmpdir)
            assert os.path.exists(
                graphmlz_path), "GraphMLZ file was not saved in the specified directory"

            # Second save: should prompt for overwrite and handle both 'y' and
            # 'n'
            prompts = []
            printed = []

            # Case 1: User types 'y' (overwrite)
            def fake_input_yes(prompt):
                prompts.append(prompt)
                return "y"

            original_input = builtins.input
            builtins.input = fake_input_yes
            try:
                town2 = Town.from_point(
                    POINT_DOM,
                    500,
                    DEFAULT_TOWN_PARAMS,
                    file_prefix=file_prefix,
                    save_dir=tmpdir)
            finally:
                builtins.input = original_input

            assert any(
                "already exists. Overwrite?" in p for p in prompts), "Overwrite prompt was not shown for 'y'"
            assert isinstance(
                town2, Town), "Town object was not returned after overwrite"

            # Case 2: User types 'n' (abort)
            prompts.clear()
            printed.clear()

            def fake_input_no(prompt):
                prompts.append(prompt)
                return "n"

            def fake_print(msg):
                printed.append(msg)

            builtins.input = fake_input_no
            original_print = builtins.print
            builtins.print = fake_print
            try:
                town3 = Town.from_point(
                    POINT_DOM,
                    500,
                    DEFAULT_TOWN_PARAMS,
                    file_prefix=file_prefix,
                    save_dir=tmpdir)
            finally:
                builtins.input = original_input
                builtins.print = original_print

            assert any(
                "already exists. Overwrite?" in p for p in prompts), "Overwrite prompt was not shown for 'n'"
            assert any("aborted" in str(p).lower()
                       for p in printed), "Abort message was not printed"
            assert isinstance(
                town3, Town), "Town object was not returned after abort"

    def test_spreader_initial_nodes_assertion_error(self):
        test_graphmlz = "test/test_data/aachen_dom_500m.graphmlz"
        test_metadata = "test/test_data/aachen_dom_500m_config.json"

        # Set spreader_initial_nodes to include non-existent nodes (350, 750)
        town_params_spreader = TownParameters(100, 4, [1, 350, 750])

        expected_error_msg = "Some spreader_initial_nodes do not exist in the town graph: \\[350, 750\\]"

        with tempfile.TemporaryDirectory() as tmpdir:
            with pytest.raises(ValueError, match=expected_error_msg):
                Town.from_point(
                    POINT_DOM, 500, town_params_spreader, save_dir=tmpdir)

        # from_files test remains the same (doesn't create new files)
        with pytest.raises(ValueError, match=expected_error_msg):
            Town.from_files(
                config_path=test_metadata,
                town_graph_path=test_graphmlz,
                town_params=town_params_spreader
            )

    def test_healthcare_presence_and_all_types(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            town_dom = Town.from_point(
                POINT_DOM,
                500,
                DEFAULT_TOWN_PARAMS,
                file_prefix="dom",
                save_dir=tmpdir)
            town_uniklinik = Town.from_point(
                POINT_UNIKLINIK,
                500,
                DEFAULT_TOWN_PARAMS,
                file_prefix="uniklinik",
                save_dir=tmpdir)

            assert 'healthcare_facility' not in town_dom.found_place_types, \
                "Expected the area within 0.5km from Aachener Dom to have no healthcare_facility."
            assert 'healthcare_facility' in town_uniklinik.found_place_types, \
                "Expected the area within 0.5km from Uniklinik to have healthcare_facility"

            assert town_dom.all_place_types == town_uniklinik.all_place_types, \
                "Expected both towns to have the same all_place_types list"

    def test_superc_is_classified_as_education(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            town = Town.from_point(
                POINT_DOM,
                750,
                DEFAULT_TOWN_PARAMS,
                file_prefix="dom_750m",
                save_dir=tmpdir)

            # Project lat/lon to same CRS as town graph
            wgs84 = pyproj.CRS("EPSG:4326")
            target_crs = pyproj.CRS(town.epsg_code)
            transformer = pyproj.Transformer.from_crs(
                wgs84, target_crs, always_xy=True)
            x_proj, y_proj = transformer.transform(
                COORDS_SUPERC[1], COORDS_SUPERC[0])

            # Find closest node by Euclidean distance in CRS
            min_dist = float("inf")
            closest_node_id = None

            for node_id, data in town.town_graph.nodes(data=True):
                dx = data['x'] - x_proj
                dy = data['y'] - y_proj
                dist = dx * dx + dy * dy  # squared distance for speed
                if dist < min_dist:
                    min_dist = dist
                    closest_node_id = node_id

            assert closest_node_id is not None, "No closest node found."
            place_type = town.town_graph.nodes[closest_node_id].get(
                "place_type")
            assert place_type == "education", f"Expected 'education', got '{place_type}'"

            # Confirm that the actual coordinate of super C centroid is not far
            actual_x = town.town_graph.nodes[closest_node_id]["x"]
            actual_y = town.town_graph.nodes[closest_node_id]["y"]
            euclidean_distance = np.sqrt(
                (actual_x - x_proj)**2 + (actual_y - y_proj)**2)
            assert euclidean_distance < 50, f"Too far from SuperC (~{
                euclidean_distance:.2f} m)"

    def test_distance_to_landmarks_dom(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            # We have to construct with from_point since we always want to make sure that
            # our algorithm of shortest path construction works with the most recent
            # open street map information
            town_2000 = Town.from_point(
                POINT_DOM,
                2000,
                DEFAULT_TOWN_PARAMS,
                file_prefix="dom_2000m",
                save_dir=tmpdir)
            town_750 = Town.from_point(
                POINT_DOM,
                750,
                DEFAULT_TOWN_PARAMS,
                file_prefix="dom_750m",
                save_dir=tmpdir)

            node_theresienkirche_2000 = get_nearest_node(
                town_2000, COORDS_THERESIENKIRCHE)
            node_hausarzt_2000 = get_nearest_node(town_2000, COORDS_HAUSARZT)
            node_superC_2000 = get_nearest_node(town_2000, COORDS_SUPERC)
            node_theresienkirche_750 = get_nearest_node(
                town_750, COORDS_THERESIENKIRCHE)
            node_hausarzt_750 = get_nearest_node(town_750, COORDS_HAUSARZT)
            node_superC_750 = get_nearest_node(town_750, COORDS_SUPERC)

            # Expected distances (meters)
            expected_theresienkirche = 335
            expected_hausarzt = 1245
            tolerance = 50

            # Calculate distances
            dist_theresienkirche_2000 = get_shortest_path_length(
                town_2000, node_superC_2000, node_theresienkirche_2000)
            dist_theresienkirche_750 = get_shortest_path_length(
                town_750, node_superC_750, node_theresienkirche_750)
            dist_hausarzt_2000 = get_shortest_path_length(
                town_2000, node_superC_2000, node_hausarzt_2000)
            dist_hausarzt_750 = get_shortest_path_length(
                town_750, node_superC_750, node_hausarzt_750)

            # Assert that 2000m town gives shorter or equal distances than 750m
            # town
            assert dist_theresienkirche_2000 <= dist_theresienkirche_750, "Distance to Theresienkirche should be shorter in 2000m town"
            assert dist_hausarzt_2000 <= dist_hausarzt_750, "Distance to Hausarzt should be shorter in 2000m town"

            # Assert that distances do not deviate from expected values by more
            # than 50m
            assert abs(dist_theresienkirche_2000 - expected_theresienkirche) < tolerance, \
                f"Distance to Theresienkirche deviates by more than {tolerance}m (got {dist_theresienkirche_2000:.2f}m)"
            assert abs(dist_hausarzt_2000 - expected_hausarzt) < tolerance, \
                f"Distance to Hausarzt deviates by more than {tolerance}m (got {dist_hausarzt_2000:.2f}m)"

    def test_save_to_files(self):
        # Create a temporary directory for the test
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create a town using from_point
            original_town = Town.from_point(
                POINT_DOM,
                500,
                DEFAULT_TOWN_PARAMS,
                file_prefix="original",
                save_dir=tmpdir)

            # Modify some node attributes (first accommodation node)
            first_node = original_town.accommodation_node_ids[0]
            original_town.town_graph.nodes[first_node]["custom_node_attr"] = 123
            original_town.town_graph.nodes[first_node]["modified"] = True

            # Swap place types of two nodes if we have different types
            place_types = list(original_town.found_place_types)
            if len(place_types) >= 2 and "accommodation" in place_types:
                other_type = next(
                    pt for pt in place_types if pt != "accommodation")

                # Find one node of each type
                acc_node = original_town.accommodation_node_ids[0]
                other_nodes = [
                    n for n, d in original_town.town_graph.nodes(
                        data=True) if d.get("place_type") == other_type]

                if other_nodes:
                    other_node = other_nodes[0]
                    # Swap place types
                    original_place_types = {
                        acc_node: original_town.town_graph.nodes[acc_node]["place_type"],
                        other_node: original_town.town_graph.nodes[other_node]["place_type"]}
                    original_town.town_graph.nodes[acc_node]["place_type"] = other_type
                    original_town.town_graph.nodes[other_node]["place_type"] = "accommodation"

                    # Update accommodation_node_ids
                    original_town.accommodation_node_ids.remove(acc_node)
                    original_town.accommodation_node_ids.append(other_node)

            # Set a specific file prefix for save_to_files
            custom_prefix = os.path.join(tmpdir, "custom_town")

            # Save the town using save_to_files
            graphml_path, config_path = original_town.save_to_files(
                custom_prefix)

            # Test loading the saved files
            loaded_town = Town.from_files(
                config_path=config_path,
                town_graph_path=graphml_path,
                town_params=DEFAULT_TOWN_PARAMS
            )

            # Verify basic structure
            assert len(loaded_town.town_graph.nodes) == len(
                original_town.town_graph.nodes)
            assert len(loaded_town.town_graph.edges) == len(
                original_town.town_graph.edges)

            # Check node attribute modifications were preserved
            assert "custom_node_attr" in loaded_town.town_graph.nodes[first_node]
            assert loaded_town.town_graph.nodes[first_node]["custom_node_attr"] == 123
            assert loaded_town.town_graph.nodes[first_node]["modified"]

            # Check place type swapping if we did it
            if 'original_place_types' in locals():
                for node, original_type in original_place_types.items():
                    assert loaded_town.town_graph.nodes[node]["place_type"] != original_type

            # Verify accommodation_node_ids reflect our changes
            assert set(loaded_town.accommodation_node_ids) == set(
                original_town.accommodation_node_ids)


import pytest
import h5py
import tempfile
import os
import json
import warnings
import matplotlib
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from unittest.mock import patch, MagicMock

from simcronomicon import Town, TownParameters, Simulation
from simcronomicon.visualization import plot_status_summary_from_hdf5, plot_scatter
from test.test_helper import MODEL_MATRIX, default_test_step_events, setup_simulation, create_test_town_files

# Use non-interactive backend for matplotlib in tests
matplotlib.use('Agg')


class TestPlotStatusSummary:

    @pytest.mark.parametrize("model_key", ["seir", "seisir", "seiqrdv"])
    def test_plot_status_summary_all_statuses(self, model_key):
        with tempfile.TemporaryDirectory() as tmpdir:
            town_params = TownParameters(num_pop=50, num_init_spreader=5)
            folk_class = MODEL_MATRIX[model_key][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                model_key, town_params, step_events=step_events, timesteps=10, seed=True)

            h5_path = os.path.join(tmpdir, "test_plot.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            # Mock plt.show() to prevent actual display during testing
            with patch('matplotlib.pyplot.show') as mock_show:
                plot_status_summary_from_hdf5(h5_path)
                mock_show.assert_called_once()

    @pytest.mark.parametrize("model_key,status", [
        ("seir", "S"),
        ("seir", "I"),
        ("seisir", "Is"),
        ("seiqrdv", "V")
    ])
    def test_plot_status_summary_single_status(self, model_key, status):
        with tempfile.TemporaryDirectory() as tmpdir:
            town_params = TownParameters(num_pop=50, num_init_spreader=5)
            folk_class = MODEL_MATRIX[model_key][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                model_key, town_params, step_events=step_events, timesteps=10, seed=True)

            h5_path = os.path.join(tmpdir, "test_single_status.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            with patch('matplotlib.pyplot.show') as mock_show:
                plot_status_summary_from_hdf5(h5_path, status_type=status)
                mock_show.assert_called_once()

    @pytest.mark.parametrize("model_key,status_list", [
        ("seir", ["S", "I"]),
        ("seisir", ["Is", "Ir"]),
        ("seiqrdv", ["S", "V", "D"])
    ])
    def test_plot_status_summary_multiple_statuses(
            self, model_key, status_list):
        with tempfile.TemporaryDirectory() as tmpdir:
            town_params = TownParameters(num_pop=50, num_init_spreader=5)
            folk_class = MODEL_MATRIX[model_key][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                model_key, town_params, step_events=step_events, timesteps=10, seed=True)

            h5_path = os.path.join(tmpdir, "test_multiple_status.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            with patch('matplotlib.pyplot.show') as mock_show:
                plot_status_summary_from_hdf5(
                    h5_path, status_type=status_list)
                mock_show.assert_called_once()

    def test_plot_status_summary_invalid_status_type(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            town_params = TownParameters(num_pop=50, num_init_spreader=5)
            folk_class = MODEL_MATRIX["seir"][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                "seir", town_params, step_events=step_events, timesteps=10, seed=True)

            h5_path = os.path.join(tmpdir, "test_invalid.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            # Test invalid string
            with pytest.raises(ValueError, match="Invalid status_type 'INVALID'"):
                plot_status_summary_from_hdf5(
                    h5_path, status_type="INVALID")

            # Test invalid list
            with pytest.raises(ValueError, match="Invalid status types"):
                plot_status_summary_from_hdf5(
                    h5_path, status_type=["S", "INVALID"])

            # Test invalid type
            with pytest.raises(TypeError, match="status_type must be None, str, or list of str"):
                plot_status_summary_from_hdf5(h5_path, status_type=123)

    def test_plot_status_summary_empty_data(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            empty_h5_path = os.path.join(tmpdir, "empty.h5")

            # Create empty HDF5 file with proper structure but no data
            with h5py.File(empty_h5_path, "w") as h5file:
                status_group = h5file.create_group("status_summary")
                # Create empty dataset with proper dtype
                dt = [('timestep', 'i4'), ('S', 'i4'), ('I', 'i4')]
                status_group.create_dataset("summary", (0,), dtype=dt)

                # Add required config
                config_group = h5file.create_group("config")
                sim_config = {"population": 100}
                config_group.create_dataset(
                    "simulation_config",
                    data=json.dumps(sim_config).encode("utf-8")
                )

            with pytest.raises(ValueError, match="No status data found in HDF5 file"):
                plot_status_summary_from_hdf5(empty_h5_path)

    def test_plot_status_summary_zero_population(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            zero_pop_h5_path = os.path.join(tmpdir, "zero_pop.h5")

            with h5py.File(zero_pop_h5_path, "w") as h5file:
                status_group = h5file.create_group("status_summary")
                dt = [('timestep', 'i4'), ('S', 'i4'), ('I', 'i4')]
                data = [(0, 10, 5)]
                status_group.create_dataset("summary", data=data, dtype=dt)

                # Zero population config -> corrupt simulation!
                config_group = h5file.create_group("config")
                sim_config = {"population": 0}
                config_group.create_dataset(
                    "simulation_config",
                    data=json.dumps(sim_config).encode("utf-8")
                )

            with pytest.raises(ValueError, match="Total population in configurations is zero"):
                plot_status_summary_from_hdf5(zero_pop_h5_path)

    @pytest.mark.parametrize("model_key", ["seir", "seisir", "seiqrdv"])
    def test_plot_status_summary_has_data(self, model_key):
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create simulation output
            town_params = TownParameters(num_pop=50, num_init_spreader=5)
            folk_class = MODEL_MATRIX[model_key][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                model_key, town_params, step_events=step_events, timesteps=10, seed=True)

            h5_path = os.path.join(tmpdir, "test_plot.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            # Capture the current figure instead of just mocking show
            with patch('matplotlib.pyplot.show'):
                plot_status_summary_from_hdf5(h5_path)

                # Get the current figure
                fig = plt.gcf()
                axes = fig.get_axes()

                # Verify figure has content
                assert len(axes) > 0, "Plot should have at least one axis"

                # Check that axes have data
                for ax in axes:
                    lines = ax.get_lines()
                    assert len(
                        lines) > 0, f"Axis should have at least one line plot"

                    # Check that lines actually have data points
                    for line in lines:
                        xdata, ydata = line.get_data()
                        assert len(xdata) > 0, "Line should have x-data"
                        assert len(ydata) > 0, "Line should have y-data"

                # Verify labels exist
                assert fig.get_suptitle() or any(ax.get_title()
                                                 for ax in axes), "Plot should have a title"

                plt.close(fig)  # Clean up


class TestValidateAndMergeColormap:

    @pytest.mark.parametrize("default_map,user_map,valid_keys,parameter_name,expected_missing", [
        # Test case: Default map missing colors for valid keys
        ({"type1": "#FF0000", "type2": "#00FF00"}, None,
         ["type1", "type2", "type3", "type4"], "test param", ["type3", "type4"]),

        # Test case: Even after merging user map, there are still some missing
        # colors
        ({"type1": "#FF0000"}, {"type2": "#00FF00"},
         ["type1", "type2", "type3"], "test param", ["type3"]),

        # Test case: User map provided but doesn't cover all missing
        ({"x": "#FF0000"}, {"y": "#00FF00", "z": "#0000FF"},
         ["x", "y", "z", "w", "v"], "category", ["v", "w"]),
    ])
    def test_missing_colors_raises_value_error(
            self,
            default_map,
            user_map,
            valid_keys,
            parameter_name,
            expected_missing):
        from simcronomicon.visualization.visualization_util import _validate_and_merge_colormap

        # Sort expected missing colors to match function behavior
        expected_missing_sorted = sorted(expected_missing)
        expected_error_pattern = rf"Missing colors for valid {parameter_name}\(s\): {
            ', '.join(expected_missing_sorted)}"

        with pytest.raises(ValueError, match=expected_error_pattern):
            _validate_and_merge_colormap(
                default_map, user_map, valid_keys, parameter_name)

    @pytest.mark.parametrize("default_map,user_map,valid_keys,parameter_name,expected_result", [
        # All these test cases shouldn't produce any error
        # Test case: All valid keys have colors.
        ({"type1": "#FF0000", "type2": "#00FF00"}, {"type3": "#0000FF"},
         ["type1", "type2", "type3"], "test param",
         {"type1": "#FF0000", "type2": "#00FF00", "type3": "#0000FF"}),

        # Test case: User map overrides defaults
        ({"type1": "#FF0000", "type2": "#00FF00"}, {"type1": "#FFFFFF"},
         ["type1", "type2"], "test param",
         {"type1": "#FFFFFF", "type2": "#00FF00"}),

        # Test case: Only defaults, all covered
        ({"a": "#FF0000", "b": "#00FF00"}, None,
         ["a", "b"], "status",
         {"a": "#FF0000", "b": "#00FF00"}),

        # Test case: Only user map, all covered
        ({}, {"x": "#FF0000", "y": "#00FF00"},
         ["x", "y"], "category",
         {"x": "#FF0000", "y": "#00FF00"}),

        # Test case: Complex merge with overrides
        ({"a": "#FF0000", "b": "#00FF00", "c": "#0000FF"}, {"b": "#FFFFFF", "d": "#FFFF00"},
         ["a", "b", "c", "d"], "type",
         {"a": "#FF0000", "b": "#FFFFFF", "c": "#0000FF", "d": "#FFFF00"}),
    ])
    def test_successful_color_mapping(
            self,
            default_map,
            user_map,
            valid_keys,
            parameter_name,
            expected_result):
        from simcronomicon.visualization.visualization_util import _validate_and_merge_colormap

        # Should not raise an error
        result = _validate_and_merge_colormap(
            default_map, user_map, valid_keys, parameter_name)
        assert result == expected_result

    @pytest.mark.parametrize("default_map,user_map,valid_keys,parameter_name,expected_warnings", [
        # Test case: User provides invalid key
        ({"type1": "#FF0000"}, {"invalid_key": "#00FF00", "type1": "#FFFFFF"},
         ["type1"], "test param",
         ["Warning: 'invalid_key' is not a valid test param"]),

        # Test case: Invalid color format
        ({"type1": "#FF0000"}, {"type1": "not_a_color"},
         ["type1"], "test param",
         ["Warning: 'not_a_color' for type1 is not a valid hex color"]),

        # Test case: Both invalid key and invalid color
        ({"type1": "#FF0000"}, {"invalid_key": "bad_color", "type1": "#FFFFFF"},
         ["type1"], "test param",
         ["Warning: 'invalid_key' is not a valid test param",
          "Warning: 'bad_color' for invalid_key is not a valid hex color"]),
    ])
    def test_warnings_for_invalid_inputs(
            self,
            default_map,
            user_map,
            valid_keys,
            parameter_name,
            expected_warnings):
        from simcronomicon.visualization.visualization_util import _validate_and_merge_colormap

        with warnings.catch_warnings(record=True) as w:
            warnings.simplefilter("always")

            # This should generate warnings but not raise an error (if all
            # valid keys are covered)
            try:
                _validate_and_merge_colormap(
                    default_map, user_map, valid_keys, parameter_name)
            except ValueError:
                pass  # Ignore ValueError for this test, we're only testing warnings

            # Check that expected warnings were issued
            warning_messages = [str(warning.message) for warning in w]

            for expected_warning in expected_warnings:
                assert any(expected_warning in msg for msg in warning_messages), \
                    f"Expected warning '{expected_warning}' not found in {warning_messages}"


class TestVisualizeMap:
    def test_set_plotly_renderer_no_ipython_nameerror(self):
        from simcronomicon.visualization.visualization_util import _set_plotly_renderer
        import plotly.io as pio

        # Store original renderer to restore later
        original_renderer = pio.renderers.default

        try:
            # Import the function first to make the patch target clear
            from simcronomicon.visualization.visualization_util import _set_plotly_renderer

            # Mock NameError when get_ipython is not available
            with patch('simcronomicon.visualization.visualization_util.get_ipython', side_effect=NameError("name 'get_ipython' is not defined")):
                _set_plotly_renderer()
                # Should set browser renderer when not in IPython
                assert pio.renderers.default == "browser"
        finally:
            # Restore original renderer
            pio.renderers.default = original_renderer

    @pytest.mark.parametrize("shell_name,expected_renderer", [
        ('ZMQInteractiveShell', 'notebook'),      # Jupyter notebook
        ('TerminalInteractiveShell', 'browser'),   # IPython terminal
        ('google.colab._shell', 'browser'),        # Google Colab
        ('SpyderShell', 'browser'),                # Spyder IDE
    ])
    def test_set_plotly_renderer_different_shells(
            self, shell_name, expected_renderer):
        from simcronomicon.visualization.visualization_util import _set_plotly_renderer
        import plotly.io as pio

        # Mock different IPython shell environments
        mock_ipython = MagicMock()
        mock_ipython.__class__.__name__ = shell_name

        # Store original renderer
        original_renderer = pio.renderers.default

        try:
            # Patch get_ipython in the correct module where
            # _set_plotly_renderer is defined
            with patch('simcronomicon.visualization.visualization_util.get_ipython', return_value=mock_ipython):
                _set_plotly_renderer()
                assert pio.renderers.default == expected_renderer
        finally:
            # Restore original renderer
            pio.renderers.default = original_renderer

    def test_plot_place_types_scatter(self):
        graphml_path, config_path, town = create_test_town_files()

        try:
            # Mock plotly show to prevent actual display
            with patch('plotly.graph_objects.Figure.show') as mock_show:
                plot_scatter.plot_place_types_scatter(
                    graphml_path, config_path)
                mock_show.assert_called_once()
        finally:
            # Cleanup
            for file_path in [graphml_path, config_path]:
                if os.path.exists(file_path):
                    os.remove(file_path)

    def test_visualize_place_types_invalid_file_extensions(self):
        with pytest.raises(AssertionError, match="Expected a .graphmlz file"):
            plot_scatter.plot_place_types_scatter(
                "wrong.txt", "config.json")

        with pytest.raises(AssertionError, match="Expected a .json file"):
            plot_scatter.plot_place_types_scatter(
                "graph.graphmlz", "wrong.txt")

    def test_plot_agents_scatter(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create simulation output
            town_params = TownParameters(num_pop=20, num_init_spreader=2)
            folk_class = MODEL_MATRIX["seir"][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                "seir", town_params, step_events=step_events, timesteps=3, seed=True)

            h5_path = os.path.join(tmpdir, "test_folks.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            graphml_path, config_path, _ = create_test_town_files()

            try:
                # Mock plotly show to prevent actual display
                with patch('plotly.graph_objects.Figure.show') as mock_show:
                    plot_scatter.plot_agents_scatter(
                        h5_path, graphml_path)
                    mock_show.assert_called_once()
            finally:
                # Cleanup
                for file_path in [graphml_path, config_path]:
                    if os.path.exists(file_path):
                        os.remove(file_path)

    def test_visualize_folks_invalid_file_extensions(self):
        with pytest.raises(AssertionError, match="Expected a .h5 file"):
            plot_scatter.plot_agents_scatter(
                "wrong.txt", "graph.graphmlz")

        with pytest.raises(AssertionError, match="Expected a .graphmlz file"):
            plot_scatter.plot_agents_scatter("sim.h5", "wrong.txt")

    @pytest.mark.parametrize("time_interval,should_pass,expected_error", [
        ((0, 2), True, None),
        ((1, 3), True, None),
        ([-1, 2], False, AssertionError),  # Negative start
        ((0, -1), False, AssertionError),  # Negative end
        ((2, 1), False, AssertionError),   # Start > end
        ("invalid", False, AssertionError),  # Wrong type
        ((0, 1, 2), False, AssertionError),  # Too many values
    ])
    def test_visualize_folks_time_interval_validation(
            self, time_interval, should_pass, expected_error):
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create simulation output
            town_params = TownParameters(num_pop=10, num_init_spreader=1)
            folk_class = MODEL_MATRIX["seir"][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                "seir", town_params, step_events=step_events, timesteps=5, seed=True)

            h5_path = os.path.join(tmpdir, "test_time_interval.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            graphml_path, config_path, _ = create_test_town_files()

            try:
                if should_pass:
                    with patch('plotly.graph_objects.Figure.show') as mock_show:
                        plot_scatter.plot_agents_scatter(
                            h5_path, graphml_path, time_interval=time_interval
                        )
                        mock_show.assert_called_once()
                else:
                    with pytest.raises(expected_error):
                        plot_scatter.plot_agents_scatter(
                            h5_path, graphml_path, time_interval=time_interval
                        )
            finally:
                # Cleanup
                for file_path in [graphml_path, config_path]:
                    if os.path.exists(file_path):
                        os.remove(file_path)

    def test_visualize_folks_time_interval_exceeds_data(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create simulation with only 3 timesteps
            town_params = TownParameters(num_pop=10, num_init_spreader=1)
            folk_class = MODEL_MATRIX["seir"][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                "seir", town_params, step_events=step_events, timesteps=3, seed=True)

            h5_path = os.path.join(tmpdir, "test_exceed.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            graphml_path, config_path, _ = create_test_town_files()

            try:
                with patch('plotly.graph_objects.Figure.show') as mock_show:
                    # Request timesteps beyond what's available
                    import warnings
                    with warnings.catch_warnings(record=True) as w:
                        warnings.simplefilter("always")
                        plot_scatter.plot_agents_scatter(
                            h5_path, graphml_path, time_interval=(0, 10)
                        )

                        # Check that warning was there
                        assert len(w) == 1
                        assert "exceeds maximum timestep" in str(w[0].message)

                    mock_show.assert_called_once()
            finally:
                # Cleanup
                for file_path in [graphml_path, config_path]:
                    if os.path.exists(file_path):
                        os.remove(file_path)

    def test_visualize_folks_no_data_in_interval(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create simulation output
            town_params = TownParameters(num_pop=10, num_init_spreader=1)
            folk_class = MODEL_MATRIX["seir"][2]
            step_events = default_test_step_events(folk_class)
            sim, _, _ = setup_simulation(
                "seir", town_params, step_events=step_events, timesteps=3, seed=True)

            h5_path = os.path.join(tmpdir, "test_no_data.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            graphml_path, config_path, _ = create_test_town_files()

            try:
                # Request time interval with start > max timestep should raise
                # ValueError
                with pytest.raises(ValueError, match="Start timestep .* is greater than maximum available timestep"):
                    plot_scatter.plot_agents_scatter(
                        h5_path, graphml_path, time_interval=(100, 200)
                    )
            finally:
                # Cleanup
                for file_path in [graphml_path, config_path]:
                    if os.path.exists(file_path):
                        os.remove(file_path)


class TestVisualizationUtilities:

    def test_visualize_folks_has_data_flexible(self):
        with tempfile.TemporaryDirectory() as tmpdir:
            # Create simulation and town data
            town_params = TownParameters(num_pop=20, num_init_spreader=2)
            folk_class = MODEL_MATRIX["seir"][2]
            step_events = default_test_step_events(folk_class)
            sim, town, _ = setup_simulation(
                "seir", town_params, step_events=step_events, timesteps=3, seed=True)

            h5_path = os.path.join(tmpdir, "test_folks.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            graphml_path, config_path, _ = create_test_town_files()

            try:
                captured_fig = None

                def capture_figure(self):
                    nonlocal captured_fig
                    captured_fig = self

                with patch.object(go.Figure, 'show', capture_figure):
                    plot_scatter.plot_agents_scatter(
                        h5_path, graphml_path)

                # Verify the figure has data
                assert captured_fig is not None, "Figure should be created"
                assert len(
                    captured_fig.data) > 0, "Figure should have data traces"

                # Look for any traces with coordinate data
                coordinate_traces = []
                for trace in captured_fig.data:
                    has_coords = False

                    # Check for different coordinate systems
                    if hasattr(trace, 'lon') and hasattr(trace, 'lat'):
                        if (trace.lon is not None and trace.lat is not None and
                                len(trace.lon) > 0 and len(trace.lat) > 0):
                            has_coords = True
                    elif hasattr(trace, 'x') and hasattr(trace, 'y'):
                        if (trace.x is not None and trace.y is not None and
                                len(trace.x) > 0 and len(trace.y) > 0):
                            has_coords = True

                    if has_coords:
                        coordinate_traces.append(trace)

                assert len(coordinate_traces) > 0, \
                    f"Should have traces with coordinate data. Found {len(captured_fig.data)} traces of types: " \
                    f"{[trace.type for trace in captured_fig.data]}"

                # Verify the figure has a layout
                assert hasattr(
                    captured_fig, 'layout'), "Figure should have layout"

            finally:
                # Cleanup
                for file_path in [graphml_path, config_path]:
                    if os.path.exists(file_path):
                        os.remove(file_path)




import numpy as np
import h5py
from scipy.integrate import solve_ivp
import tempfile
import os
import pytest
from ..test_helper import default_test_step_events

from simcronomicon import Town, TownParameters, Simulation
from simcronomicon.infection_models import StepEvent, EventType
from simcronomicon.infection_models.SEIQRDV_model import SEIQRDVModel, SEIQRDVModelParameters, FolkSEIQRDV


class TestSEIQRDVModel:
    @classmethod
    def setup_class(cls):
        cls.town_graph_path = "test/test_data/uniklinik_500m.graphmlz"
        cls.town_config_path = "test/test_data/uniklinik_500m_config.json"

    def test_invalid_seiqrdv_model_parameters(self):
        # lam_cap out of range
        with pytest.raises(TypeError, match="lam_cap must be a float between 0 and 1!"):
            SEIQRDVModelParameters(
                max_energy=10,
                lam_cap=1.5,
                beta=0.1,
                alpha=0.1,
                gamma=4,
                delta=5,
                lam=7,
                rho=7,
                kappa=0.2,
                mu=0.01)

        # beta negative
        with pytest.raises(TypeError, match="beta must be a float between 0 and 1!"):
            SEIQRDVModelParameters(
                max_energy=10,
                lam_cap=0.1,
                beta=-0.1,
                alpha=0.1,
                gamma=4,
                delta=5,
                lam=7,
                rho=7,
                kappa=0.2,
                mu=0.01)

        # gamma not positive integer
        with pytest.raises(TypeError, match="gamma must be a positive integer, got -4"):
            SEIQRDVModelParameters(
                max_energy=10,
                lam_cap=0.1,
                beta=0.1,
                alpha=0.1,
                gamma=-4,
                delta=5,
                lam=7,
                rho=7,
                kappa=0.2,
                mu=0.01)

        # hospital_capacity not int or inf
        with pytest.raises(TypeError, match="hospital_capacity must be a positive integer or a value of infinity"):
            SEIQRDVModelParameters(
                max_energy=10,
                lam_cap=0.1,
                beta=0.1,
                alpha=0.1,
                gamma=4,
                delta=5,
                lam=7,
                rho=7,
                kappa=0.2,
                mu=0.01,
                hospital_capacity="a lot")

    def test_seiqrdv_abm_vs_ode_error(self):
        # ODE solution
        model_params = SEIQRDVModelParameters(
            max_energy=2,
            lam_cap=0.01,
            beta=0.7,
            alpha=0.1,
            gamma=4,
            delta=5,
            lam=7,
            rho=7,
            kappa=0.2,
            mu=0.002,
            hospital_capacity=float('Inf'))

        def rhs_func(t, y):
            S, E, I, Q, R, D, V = y
            N = S + E + I + Q + R + V
            rhs = np.zeros(7)
            rhs[0] = model_params.lam_cap / 5 * N + model_params.beta * \
                S * I / N - model_params.alpha * S - model_params.mu * S
            rhs[1] = model_params.lam_cap / 5 * N + model_params.beta * \
                S * I / N - 1 / model_params.gamma * E - model_params.mu * E
            rhs[2] = model_params.lam_cap / 5 * N + 1 / model_params.gamma * \
                E - 1 / model_params.delta * I - model_params.mu * I
            rhs[3] = 1 / model_params.delta * I - (1 - model_params.kappa) / model_params.lam * \
                Q - model_params.kappa / model_params.rho * Q - model_params.mu * Q
            rhs[4] = model_params.lam_cap / 5 * N + \
                (1 - model_params.kappa) / \
                model_params.lam * Q - model_params.mu * R
            rhs[5] = model_params.kappa / model_params.rho * Q
            rhs[6] = model_params.lam_cap / 5 * N + \
                model_params.alpha * S - model_params.mu * V
            return rhs

        t_end = 100
        t_span = (0, t_end)
        y0 = [1980, 0, 20, 0, 0, 0, 0]
        t_eval = np.arange(0, t_end + 1)

        sol = solve_ivp(
            rhs_func,
            t_span,
            y0,
            method='RK45',
            t_eval=t_eval
        )

        # Perform ABM simulation
        town_params = TownParameters(num_pop=2000, num_init_spreader=20)
        town = Town.from_files(
            config_path=self.town_config_path,
            town_graph_path=self.town_graph_path,
            town_params=town_params
        )

        model = SEIQRDVModel(
            model_params, default_test_step_events(FolkSEIQRDV))
        sim = Simulation(town, model, t_end)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "abm_vs_ode_test_seiqrdv.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            # Extract ABM results
            import h5py
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                abm_S = summary['S']
                abm_E = summary['E']
                abm_I = summary['I']
                abm_Q = summary['Q']
                abm_R = summary['R']
                abm_D = summary['D']
                abm_V = summary['V']
                abm_total = abm_S + abm_E + abm_I + abm_Q + abm_R + abm_D + abm_V

            # Normalize ODE results for comparison
            ode_S = sol.y[0]
            ode_E = sol.y[1]
            ode_I = sol.y[2]
            ode_Q = sol.y[3]
            ode_R = sol.y[4]
            ode_D = sol.y[5]
            ode_V = sol.y[6]
            ode_total = ode_S + ode_E + ode_I + ode_Q + ode_R + ode_V

            # Normalize both to initial total population for fair comparison
            abm_S = abm_S / abm_total
            abm_E = abm_E / abm_total
            abm_I = abm_I / abm_total
            abm_Q = abm_Q / abm_total
            abm_R = abm_R / abm_total
            abm_D = abm_D / abm_total
            abm_V = abm_V / abm_total

            ode_S = ode_S / ode_total
            ode_E = ode_E / ode_total
            ode_I = ode_I / ode_total
            ode_Q = ode_Q / ode_total
            ode_R = ode_R / ode_total
            ode_D = ode_D / ode_total
            ode_V = ode_V / ode_total

            # Compute average per time step 2-norm error for each compartment
            # over all time points
            err_S = np.linalg.norm(abm_S - ode_S) / t_end
            err_E = np.linalg.norm(abm_E - ode_E) / t_end
            err_I = np.linalg.norm(abm_I - ode_I) / t_end
            err_Q = np.linalg.norm(abm_Q - ode_Q) / t_end
            err_R = np.linalg.norm(abm_R - ode_R) / t_end
            err_D = np.linalg.norm(abm_D - ode_D) / t_end
            err_V = np.linalg.norm(abm_V - ode_V) / t_end

            assert err_S < 0.03, f"Susceptible compartment error too high: {
                err_S:.4f}"
            assert err_E < 0.03, f"Exposed compartment error too high: {
                err_E:.4f}"
            assert err_I < 0.03, f"Infectious compartment error too high: {
                err_I:.4f}"
            assert err_Q < 0.03, f"Quarantined compartment error too high: {
                err_Q:.4f}"
            assert err_R < 0.03, f"Recovered compartment error too high: {
                err_R:.4f}"
            assert err_D < 0.03, f"Dead compartment error too high: {
                err_D:.4f}"
            assert err_V < 0.03, f"Vaccinated compartment error too high: {
                err_V:.4f}"

    def test_vaccination(self):
        model_params = SEIQRDVModelParameters(
            max_energy=10,
            lam_cap=0,
            beta=0,
            alpha=1.0,
            gamma=4,
            delta=5,
            lam=7,
            rho=7,
            kappa=0.2,
            mu=0,
            hospital_capacity=float('Inf'))

        town_params = TownParameters(num_pop=10, num_init_spreader=1)
        town = Town.from_files(
            config_path=self.town_config_path,
            town_graph_path=self.town_graph_path,
            town_params=town_params
        )

        step_event = StepEvent(
            "chore", FolkSEIQRDV.interact, EventType.DISPERSE, 19000, [
                'commercial', 'workplace', 'education', 'religious'])

        model = SEIQRDVModel(
            model_params, step_event)
        sim = Simulation(town, model, 1)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "pop_vaccination_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                last_step = summary[-1]
                # Since everyone wants vaccines and the hospitle capacity is
                # infinite, they should all get it
                vaccinated_last = last_step["V"]
                assert vaccinated_last == 9, f"Every former susceptible person should be vaccinated at timestep {
                    last_step['timestep']}: got {vaccinated_last}, expected 9"

        model_params = SEIQRDVModelParameters(
            max_energy=10,
            lam_cap=0,
            beta=0,
            alpha=1.0,
            gamma=4,
            delta=5,
            lam=7,
            rho=7,
            kappa=0.2,
            mu=0,
            hospital_capacity=5)

        town_params = TownParameters(num_pop=21, num_init_spreader=1)
        town = Town.from_files(
            config_path=self.town_config_path,
            town_graph_path=self.town_graph_path,
            town_params=town_params
        )

        model = SEIQRDVModel(
            model_params, step_event)
        sim = Simulation(town, model, 1)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "pop_vaccination_cap_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                log = h5file["individual_logs/log"][:]
                # Filter for the first step event where current_event is
                # "greet_neighbors" and timestep == 1
                first_step = log[(log['timestep'] == 1) & (
                    log['event'] == b"chore")]
                # Count number of people at each healthcare node of interest
                node_counts = {node: 0 for node in [26, 32, 40, 53]}
                for row in first_step:
                    # Only susceptible people can want vaccines. In this case there is no transmission
                    # so there will be no unaware infected person who wants
                    # vaccination
                    if row['address'] in node_counts and row['status'] != b'I':
                        node_counts[row['address']] += 1
                expected = {26: 3, 32: 1, 40: 10, 53: 4}
                print("Actual node counts at timestep 2, chore:", node_counts)
                for node, count in node_counts.items():
                    assert count == expected[node], f"Node {node} has {count} people, expected {
                        expected[node]}"

                summary = h5file["status_summary/summary"][:]
                next_step = summary[-1]
                # There are 4 healthcare_facility type nodes in the graph
                # In this test case, they got allocated 3, 1, 5, 10
                # Therefore the amount of vaccination they should get is 3 + 1
                # + 5 + 4 = 13
                vaccinated_last = next_step["V"]
                assert vaccinated_last == 13, f"Every former susceptible person should be vaccinated at timestep {
                    next_step['timestep']}: got {vaccinated_last}, expected 14"

    def test_quarantine_and_dead_address_stable(self):
        # All agents start as spreaders, delta=1 so all go to quarantine after
        # 1 day, no deaths or births
        model_params = SEIQRDVModelParameters(
            max_energy=10,
            lam_cap=0,
            beta=0,
            alpha=0,
            gamma=4,
            delta=1,
            lam=7,
            rho=2,
            kappa=1,
            mu=0,
            hospital_capacity=5)
        town_params = TownParameters(num_pop=10, num_init_spreader=10)
        town = Town.from_files(
            config_path=self.town_config_path,
            town_graph_path=self.town_graph_path,
            town_params=town_params
        )
        model = SEIQRDVModel(
            model_params, default_test_step_events(FolkSEIQRDV))
        sim = Simulation(town, model, 10)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "quarantine_address_stable.h5")
            sim.run(hdf5_path=h5_path, silent=True)
            with h5py.File(h5_path, "r") as h5file:
                log = h5file["individual_logs/log"][:]
                # For each folk, track the address when they first become Q or
                # D
                first_q_address = {}
                first_d_address = {}
                for row in log:
                    folk_id = row['folk_id']
                    timestep = row['timestep']
                    status = row['status']
                    address = row['address']
                    if status == b'Q':
                        if folk_id not in first_q_address:
                            first_q_address[folk_id] = address
                        else:
                            assert address == first_q_address[folk_id], (
                                f"AbstractFolk {folk_id} changed address after quarantine at timestep {timestep}!"
                            )
                    if status == b'D':
                        if folk_id not in first_d_address:
                            first_d_address[folk_id] = address
                        else:
                            assert address == first_d_address[folk_id], (
                                f"AbstractFolk {folk_id} changed address after death at timestep {timestep}!"
                            )


import numpy as np
from scipy.integrate import solve_ivp
import tempfile
import os
import pytest
from ..test_helper import default_test_step_events

from simcronomicon import Town, TownParameters, Simulation
from simcronomicon.infection_models import StepEvent, EventType
from simcronomicon.infection_models.SEIR_model import SEIRModel, SEIRModelParameters, FolkSEIR


class TestSEIRModel:
    def test_invalid_seir_model_parameters(self):
        # beta out of range
        with pytest.raises(TypeError, match="beta must be a float between 0 and 1 \\(exclusive\\)!"):
            SEIRModelParameters(
                max_energy=5, beta=1.2, sigma=6, gamma=5, xi=200
            )

        # beta negative
        with pytest.raises(TypeError, match="beta must be a float between 0 and 1 \\(exclusive\\)!"):
            SEIRModelParameters(
                max_energy=5, beta=-0.1, sigma=6, gamma=5, xi=200
            )

        # sigma not positive integer
        with pytest.raises(TypeError, match="sigma must be a positive integer since it is a value that described duration, got 0"):
            SEIRModelParameters(
                max_energy=5, beta=0.4, sigma=0, gamma=5, xi=200
            )

        # gamma not positive integer
        with pytest.raises(TypeError, match="gamma must be a positive integer since it is a value that described duration, got -2"):
            SEIRModelParameters(
                max_energy=5, beta=0.4, sigma=6, gamma=-2, xi=200
            )

        # xi not positive integer
        with pytest.raises(TypeError, match="xi must be a positive integer since it is a value that described duration, got 0"):
            SEIRModelParameters(
                max_energy=5, beta=0.4, sigma=6, gamma=5, xi=0
            )

    def test_seir_abm_vs_ode_error(self):
        # ODE solution
        model_params = SEIRModelParameters(
            max_energy=5, beta=0.4, sigma=6, gamma=5, xi=200)

        def rhs_func(t, y):
            S, E, I, R = y
            N = S + E + I + R
            rhs = np.zeros(4)
            rhs[0] = -model_params.beta * S * I / N + 1 / model_params.xi * R
            rhs[1] = model_params.beta * S * I / N - 1 / model_params.sigma * E
            rhs[2] = 1 / model_params.sigma * E - 1 / model_params.gamma * I
            rhs[3] = 1 / model_params.gamma * I - 1 / model_params.xi * R
            return rhs

        t_end = 70  # Number of steps before termination for the simulation of the default seed
        t_span = (0, t_end)
        y0 = [0.99, 0, 0.01, 0]  # 1000 pop, 10 infected, 990 susceptible
        t_eval = np.arange(0, t_end + 1)

        sol = solve_ivp(
            rhs_func,
            t_span,
            y0,
            method='RK45',
            t_eval=t_eval
        )

        # Perform ABM simulation
        total_pop = 1000
        town_params = TownParameters(total_pop, 10)
        town_graph_path = "test/test_data/aachen_dom_500m.graphmlz"
        town_config_path = "test/test_data/aachen_dom_500m_config.json"
        town = Town.from_files(
            config_path=town_config_path,
            town_graph_path=town_graph_path,
            town_params=town_params
        )
        model = SEIRModel(
            model_params, default_test_step_events(FolkSEIR))
        sim = Simulation(town, model, t_end)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "abm_vs_ode_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            # Extract ABM results
            import h5py
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                abm_S = summary['S'] / total_pop
                abm_E = summary['E'] / total_pop
                abm_I = summary['I'] / total_pop
                abm_R = summary['R'] / total_pop

            # Align and compare
            ode_S = sol.y[0]
            ode_E = sol.y[1]
            ode_I = sol.y[2]
            ode_R = sol.y[3]

            # Compute average per time step 2-norm error for each compartment
            # over all time points
            err_S = np.linalg.norm(abm_S - ode_S) / t_end
            err_E = np.linalg.norm(abm_E - ode_E) / t_end
            err_I = np.linalg.norm(abm_I - ode_I) / t_end
            err_R = np.linalg.norm(abm_R - ode_R) / t_end

            assert err_S < 0.05, f"Susceptible compartment error too high: {
                err_S:.4f}"
            assert err_E < 0.05, f"Exposed compartment error too high: {
                err_E:.4f}"
            assert err_I < 0.05, f"Infectious compartment error too high: {
                err_I:.4f}"
            assert err_R < 0.05, f"Recovered compartment error too high: {
                err_R:.4f}"


import numpy as np
import simcronomicon as scon
from scipy.integrate import solve_ivp
import tempfile
import os
import pytest
from ..test_helper import default_test_step_events
from simcronomicon import Town, TownParameters, Simulation
from simcronomicon.infection_models import StepEvent, EventType
from simcronomicon.infection_models.SEIsIrR_model import SEIsIrRModel, SEIsIrRModelParameters, FolkSEIsIrR


class TestSEIsIrRModel:
    def test_invalid_seisir_model_parameters(self):
        # gamma not a float or int
        with pytest.raises(TypeError, match="gamma must be a float or int"):
            SEIsIrRModelParameters(
                max_energy=4,
                literacy=0.7,
                gamma="bad",
                alpha=0.5,
                lam=0.5,
                phi=0.5,
                theta=0.7,
                mu=0.62,
                eta1=0.1,
                eta2=0.1)

        # alpha out of range
        with pytest.raises(TypeError, match="alpha must be a float or int"):
            SEIsIrRModelParameters(
                max_energy=4,
                literacy=0.7,
                gamma=0.9,
                alpha="bad",
                lam=0.5,
                phi=0.5,
                theta=0.7,
                mu=0.62,
                eta1=0.1,
                eta2=0.1)

        # lam negative
        with pytest.raises(TypeError, match="lam must be a float or int"):
            SEIsIrRModelParameters(
                max_energy=4,
                literacy=0.7,
                gamma=0.9,
                alpha=0.5,
                lam="bad",
                phi=0.5,
                theta=0.7,
                mu=0.62,
                eta1=0.1,
                eta2=0.1)

        # mem_span not int > 1
        with pytest.raises(TypeError, match="mem_span must be an integer greater or equal to 1, got 1.03"):
            SEIsIrRModelParameters(
                max_energy=4,
                literacy=0.7,
                gamma=0.9,
                alpha=0.5,
                lam=0.5,
                phi=0.5,
                theta=0.7,
                mu=0.62,
                eta1=0.1,
                eta2=0.1,
                mem_span=1.03)

    def test_SEIsIrR_abm_vs_ode_error(self):
        # ODE solution
        model_params = SEIsIrRModelParameters(
            4, 0.7, 0.9, 0.5, 0.5, 0.5, 0.7, 0.62, 0.1, 0.1)

        def rhs_func(t, y):
            S, E, Is, Ir, R = y
            rhs = np.zeros(5)

            rhs[0] = (
                S * (model_params.mu * Is + Ir) * model_params.gamma *
                model_params.alpha * model_params.lam
                + S * E * model_params.E2S
                - S * (R + S + E) * model_params.S2R
                - S * model_params.forget
            )
            rhs[1] = (Is *
                      S *
                      model_params.gamma *
                      (1 -
                       model_params.gamma) *
                      model_params.alpha *
                      model_params.lam -
                      S *
                      E *
                      model_params.E2S -
                      R *
                      E *
                      model_params.E2R)

            rhs[2] = (
                -Is * S * (model_params.gamma * model_params.alpha * model_params.lam * model_params.mu +
                           model_params.gamma * (1 - model_params.gamma) * model_params.alpha * model_params.lam)
            )
            rhs[3] = -Ir * S * model_params.gamma * \
                model_params.alpha * model_params.lam
            rhs[4] = S * (R + S + E) * model_params.S2R + S * \
                model_params.forget + R * E * model_params.E2R
            return rhs

        t_end = 12  # Number of steps before termination for the simulation of the default seed
        t_span = (0, t_end)
        # 10 spreader, 690 Steady ignorant, 300 Radical ignorant
        # The ODE is adapted from the original paper due to lack of information regarding average degree of the homogeneous networks
        # , therefore, some coefficients have been dropped
        # This results in a system that is no longer based on density
        # We have to use this scaling instead to get a sensible result
        # (yield similar result and trends to the reference literature numerical result)
        y0 = [0.1, 0, 6.9, 3.0, 0]
        t_eval = np.arange(0, t_end + 1)

        sol = solve_ivp(
            rhs_func,
            t_span,
            y0,
            method='RK45',
            t_eval=t_eval
        )

        # Perform ABM simulation
        total_pop = 2000
        town_params = TownParameters(total_pop, 20)
        town_graph_path = "test/test_data/aachen_dom_500m.graphmlz"
        town_config_path = "test/test_data/aachen_dom_500m_config.json"
        town = Town.from_files(
            config_path=town_config_path,
            town_graph_path=town_graph_path,
            town_params=town_params
        )
        model = SEIsIrRModel(
            model_params, default_test_step_events(FolkSEIsIrR))
        sim = Simulation(town, model, t_end)
        with tempfile.TemporaryDirectory() as tmpdir:
            h5_path = os.path.join(tmpdir, "abm_vs_ode_test.h5")
            sim.run(hdf5_path=h5_path, silent=True)

            # Extract ABM results
            import h5py
            with h5py.File(h5_path, "r") as h5file:
                summary = h5file["status_summary/summary"][:]
                abm_S = summary['S'] / total_pop
                abm_E = summary['E'] / total_pop
                abm_Is = summary['Is'] / total_pop
                abm_Ir = summary['Ir'] / total_pop
                abm_R = summary['R'] / total_pop

            # Normalize, align and compare
            ode_S = sol.y[0] / 10
            ode_E = sol.y[1] / 10
            ode_Is = sol.y[2] / 10
            ode_Ir = sol.y[3] / 10
            ode_R = sol.y[4] / 10

            # Compute average per time step 2-norm error for each compartment
            # over all time points
            err_S = np.linalg.norm(abm_S - ode_S) / t_end
            err_E = np.linalg.norm(abm_E - ode_E) / t_end
            err_Is = np.linalg.norm(abm_Is - ode_Is) / t_end
            err_Ir = np.linalg.norm(abm_Ir - ode_Ir) / t_end
            err_R = np.linalg.norm(abm_R - ode_R) / t_end

            assert err_S < 0.05, f"Spreader compartment error too high: {
                err_S:.4f}"
            assert err_E < 0.05, f"Exposed compartment error too high: {
                err_E:.4f}"
            assert err_Is < 0.05, f"Steady ignorant compartment error too high: {
                err_Is:.4f}"
            assert err_Ir < 0.05, f"Radical ignorant compartment error too high: {
                err_Is:.4f}"
            assert err_R < 0.05, f"Stifler compartment error too high: {
                err_R:.4f}"


