# -*- coding: utf-8 -*-
from setuptools import setup, find_packages, Extension
from setuptools.command.install import install
import subprocess
import os
from os.path import abspath, dirname, join
from glob import glob

this_dir = abspath(dirname(__file__))
with open(join(this_dir, "LICENSE")) as f:
    license = f.read()

with open(join(this_dir, "README.md"), encoding="utf-8") as file:
    long_description = file.read()

with open(join(this_dir, "requirements.txt")) as f:
    requirements = f.read().split("\n")

scripts = glob("scripts/*.py") + glob("scripts/*.sh")

setup(
    name="june",
    version="1.2.0",
    description="A framework for high resolution Agent Based Modelling.",
    url="https://github.com/idas-durham/june",
    long_description_content_type="text/markdown",
    long_description=long_description,
    scripts=scripts,
    author="IDAS-Durham",
    author_email="arnauq@protonmail.com",
    license="GPLv3 license",
    install_requires=requirements,
    packages=find_packages(exclude=["docs"]),
    include_package_data=True,
)


from june.geography import Geography
from june.groups import Hospitals, Schools, Companies, CareHomes, Universities
from june.groups.leisure import (
    Pubs,
    Cinemas,
    Groceries,
    Gyms,
    generate_leisure_for_config,
)
from june.groups.travel import Travel
from june.world import generate_world_from_geography
import time
import numpy as np

# load london super areas
london_areas = np.loadtxt("./london_areas.txt", dtype=np.str_)[40:60]

# add King's cross area for station
if "E00004734" not in london_areas:
    london_areas = np.append(london_areas, "E02000187")

# add some people commuting from Cambridge
london_areas = np.concatenate((london_areas, ["E02003719", "E02003720", "E02003721"]))
#
# add Bath as well to have a city with no stations
london_areas = np.concatenate(
    (london_areas, ["E02002988", "E02002989", "E02002990", "E02002991", "E02002992"])
)

t1 = time.time()

# default config path
config_path = "./config_simulation.yaml"

# define geography, let's run the first 20 super areas of london
geography = Geography.from_file({"super_area": london_areas})
# geography = Geography.from_file({"region": ["North East"]})

# add buildings
geography.hospitals = Hospitals.for_geography(geography)
geography.companies = Companies.for_geography(geography)
geography.schools = Schools.for_geography(geography)
geography.universities = Universities.for_geography(geography)
geography.care_homes = CareHomes.for_geography(geography)
# generate world
world = generate_world_from_geography(geography, include_households=True)

# some leisure activities
world.pubs = Pubs.for_geography(geography)
world.cinemas = Cinemas.for_geography(geography)
world.groceries = Groceries.for_geography(geography)
world.gyms = Gyms.for_geography(geography)
leisure = generate_leisure_for_config(world, config_filename=config_path)
leisure.distribute_social_venues_to_areas(
    areas=world.areas, super_areas=world.super_areas
)  # this assigns possible social venues to people.
travel = Travel()
travel.initialise_commute(world)
t2 = time.time()
print(f"Took {t2 -t1} seconds to run.")
# save the world to hdf5 to load it later
world.to_hdf5("tests.hdf5")
print("Done :)")


import time
import logging
import numpy as np
import numba as nb
import random
import json
from pathlib import Path
from mpi4py import MPI
import h5py
import sys
import cProfile
import argparse
import yaml

from june.hdf5_savers import generate_world_from_hdf5, load_population_from_hdf5
from june.interaction import Interaction
from june.epidemiology.infection import (
    Infection,
    InfectionSelector,
    InfectionSelectors,
    HealthIndexGenerator,
    SymptomTag,
    ImmunitySetter,
    Covid19,
    B16172,
)
from june.groups import Hospitals, Schools, Companies, Households, CareHomes, Cemeteries
from june.groups.travel import Travel
from june.groups.leisure import Cinemas, Pubs, Groceries, generate_leisure_for_config
from june.simulator import Simulator
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection_seed import (
    InfectionSeed,
    Observed2Cases,
    InfectionSeeds,
)
from june.policy import Policies
from june.event import Events
from june import paths
from june.records import Record
from june.records.records_writer import combine_records
from june.domains import Domain, DomainSplitter
from june.mpi_setup import mpi_comm, mpi_rank, mpi_size

from june.tracker.tracker import Tracker
from june.tracker.tracker_plots import PlotClass
from june.tracker.tracker_merger import MergerClass


from collections import defaultdict


def set_random_seed(seed=999):
    """
    Sets global seeds for testing in numpy, random, and numbaized numpy.
    """

    @nb.njit(cache=True)
    def set_seed_numba(seed):
        random.seed(seed)
        return np.random.seed(seed)

    np.random.seed(seed)
    set_seed_numba(seed)
    random.seed(seed)
    return


set_random_seed(0)

# disable logging for ranks
if mpi_rank > 0:
    logging.disable(logging.CRITICAL)


def keys_to_int(x):
    return {int(k): v for k, v in x.items()}


# =============== Argparse =========================#

parser = argparse.ArgumentParser(description="Full run of the England")

parser.add_argument(
    "-w",
    "--world_path",
    help="path to saved world file",
    required=False,
    default="/cosma5/data/do010/dc-walk3/world.hdf5",
)

parser.add_argument(
    "-c",
    "--comorbidities",
    help="True to include comorbidities",
    required=False,
    default="True",
)
parser.add_argument(
    "-con",
    "--config",
    help="Config file",
    required=False,
    default=paths.configs_path / "config_example.yaml",
)
parser.add_argument(
    "-p",
    "--parameters",
    help="Parameter file",
    required=False,
    default=paths.configs_path / "defaults/interaction/interaction.yaml",
)

parser.add_argument(
    "-tr",
    "--tracker",
    help="Activate Tracker for CM tracing",
    required=False,
    default="False",
)

parser.add_argument(
    "-ro", "--region_only", help="Run only one region", required=False, default="False"
)

parser.add_argument(
    "-hb", "--household_beta", help="Household beta", required=False, default=0.25
)
parser.add_argument(
    "-nnv",
    "--no_vaccines",
    help="Implement no vaccine policies",
    required=False,
    default="False",
)
parser.add_argument(
    "-v",
    "--vaccines",
    help="Implement vaccine policies",
    required=False,
    default="False",
)
parser.add_argument(
    "-nv", "--no_visits", help="No shelter visits", required=False, default="False"
)
parser.add_argument(
    "-ih",
    "--indoor_beta_ratio",
    help="Indoor/household beta ratio scaling",
    required=False,
    default=0.55,
)
parser.add_argument(
    "-oh",
    "--outdoor_beta_ratio",
    help="Outdoor/household beta ratio scaling",
    required=False,
    default=0.05,
)
parser.add_argument(
    "-inf",
    "--infectiousness_path",
    help="path to infectiousness parameter file",
    required=False,
    default="nature",
)
parser.add_argument(
    "-cs",
    "--child_susceptibility",
    help="Reduce child susceptibility for under 12s",
    required=False,
    default="False",
)
parser.add_argument(
    "-u",
    "--isolation_units",
    help="True to include isolation units",
    required=False,
    default="False",
)
parser.add_argument(
    "-t", "--isolation_testing", help="Mean testing time", required=False, default=3
)
parser.add_argument(
    "-i", "--isolation_time", help="Ouput file name", required=False, default=7
)
parser.add_argument(
    "-ic",
    "--isolation_compliance",
    help="Isolation unit self reporting compliance",
    required=False,
    default=0.6,
)
parser.add_argument(
    "-m",
    "--mask_wearing",
    help="True to include mask wearing",
    required=False,
    default="False",
)
parser.add_argument(
    "-mc",
    "--mask_compliance",
    help="Mask wearing compliance",
    required=False,
    default="False",
)
parser.add_argument(
    "-mb",
    "--mask_beta_factor",
    help="Mask beta factor reduction",
    required=False,
    default=0.5,
)

parser.add_argument(
    "-s",
    "--save_path",
    help="Path of where to save logger",
    required=False,
    default="results",
)

parser.add_argument(
    "--n_seeding_days", help="number of seeding days", required=False, default=10
)
parser.add_argument(
    "--n_seeding_case_per_day",
    help="number of seeding cases per day",
    required=False,
    default=10,
)

args = parser.parse_args()
args.save_path = Path(args.save_path)

if mpi_rank == 0:
    counter = 1
    OG_save_path = args.save_path
    while args.save_path.is_dir() is True:
        args.save_path = Path(str(OG_save_path) + "_%s" % counter)
        counter += 1
    args.save_path.mkdir(parents=True, exist_ok=False)

mpi_comm.Barrier()
args.save_path = mpi_comm.bcast(args.save_path, root=0)
mpi_comm.Barrier()


if args.tracker == "True":
    args.tracker = True
else:
    args.tracker = False

if args.comorbidities == "True":
    args.comorbidities = True
else:
    args.comorbidities = False

if args.child_susceptibility == "True":
    args.child_susceptibility = True
else:
    args.child_susceptibility = False

if args.no_vaccines == "True":
    args.no_vaccines = True
else:
    args.no_vaccines = False

if args.vaccines == "True":
    args.vaccines = True
else:
    args.vaccines = False

if args.no_visits == "True":
    args.no_visits = True
else:
    args.no_visits = False

if args.isolation_units == "True":
    args.isolation_units = True
else:
    args.isolation_units = False

if args.mask_wearing == "True":
    args.mask_wearing = True
else:
    args.mask_wearing = False


if args.infectiousness_path == "nature":
    transmission_config_path = paths.configs_path / "defaults/transmission/nature.yaml"
elif args.infectiousness_path == "correction_nature":
    transmission_config_path = (
        paths.configs_path / "defaults/transmission/correction_nature.yaml"
    )
elif args.infectiousness_path == "nature_larger":
    transmission_config_path = (
        paths.configs_path
        / "defaults/transmission/nature_larger_presymptomatic_transmission.yaml"
    )
elif args.infectiousness_path == "nature_lower":
    transmission_config_path = (
        paths.configs_path
        / "defaults/transmission/nature_lower_presymptomatic_transmission.yaml"
    )
elif args.infectiousness_path == "xnexp":
    transmission_config_path = paths.configs_path / "defaults/transmission/XNExp.yaml"
else:
    raise NotImplementedError

if mpi_rank == 0:
    print("Comorbidities set to: {}".format(args.comorbidities))
    print("Parameters path set to: {}".format(args.parameters))
    print("Indoor beta ratio is set to: {}".format(args.indoor_beta_ratio))
    print("Outdoor beta ratio set to: {}".format(args.outdoor_beta_ratio))
    print("Infectiousness path set to: {}".format(args.infectiousness_path))
    print("Child susceptibility change set to: {}".format(args.child_susceptibility))

    print("Isolation units set to: {}".format(args.isolation_units))
    print("Household beta set to: {}".format(args.household_beta))
    if args.isolation_units:
        print("Testing time set to: {}".format(args.isolation_testing))
        print("Isolation time set to: {}".format(args.isolation_time))
        print("Isolation compliance set to: {}".format(args.isolation_compliance))

    print("Mask wearing set to: {}".format(args.mask_wearing))
    if args.mask_wearing:
        print("Mask compliance set to: {}".format(args.mask_compliance))
        print("Mask beta factor set up: {}".format(args.mask_beta_factor))

    print("World path set to: {}".format(args.world_path))
    print("Save path set to: {}".format(args.save_path))

    print("\n", args.__dict__, "\n")


# =============== world creation =========================#
CONFIG_PATH = args.config


def generate_simulator():
    record = Record(
        record_path=args.save_path, record_static_data=True, mpi_rank=mpi_rank
    )
    if mpi_rank == 0:
        with h5py.File(args.world_path, "r") as f:
            super_area_ids = f["geography"]["super_area_id"]
            super_area_names = f["geography"]["super_area_name"]
            super_area_name_to_id = {
                name.decode(): id for name, id in zip(super_area_names, super_area_ids)
            }
        super_areas_per_domain, score_per_domain = DomainSplitter.generate_world_split(
            number_of_domains=mpi_size, world_path=args.world_path
        )
        super_area_names_to_domain_dict = {}
        super_area_ids_to_domain_dict = {}
        for domain, super_areas in super_areas_per_domain.items():
            for super_area in super_areas:
                super_area_names_to_domain_dict[super_area] = domain
                super_area_ids_to_domain_dict[
                    int(super_area_name_to_id[super_area])
                ] = domain
        with open("super_area_ids_to_domain.json", "w") as f:
            json.dump(super_area_ids_to_domain_dict, f)
        with open("super_area_names_to_domain.json", "w") as f:
            json.dump(super_area_names_to_domain_dict, f)
    print(f"mpi_rank {mpi_rank} waiting")
    mpi_comm.Barrier()
    if mpi_rank > 0:
        with open("super_area_ids_to_domain.json", "r") as f:
            super_area_ids_to_domain_dict = json.load(f, object_hook=keys_to_int)
    print(f"mpi_rank {mpi_rank} loading domain")
    domain = Domain.from_hdf5(
        domain_id=mpi_rank,
        super_areas_to_domain_dict=super_area_ids_to_domain_dict,
        hdf5_file_path=args.world_path,
        interaction_config=args.parameters,
    )
    print(f"mpi_rank {mpi_rank} has loaded domain")
    # regenerate lesiure
    leisure = generate_leisure_for_config(domain, CONFIG_PATH)
    #
    selector = InfectionSelector.from_file()
    selectors = InfectionSelectors([selector])

    infection_seed = InfectionSeed.from_uniform_cases(
        world=domain,
        infection_selector=selector,
        cases_per_capita=0.01,
        date="2020-03-02 9:00",
        seed_past_infections=False,
    )
    infection_seeds = InfectionSeeds([infection_seed])

    epidemiology = Epidemiology(
        infection_selectors=selectors, infection_seeds=infection_seeds
    )

    interaction = Interaction.from_file(config_filename=args.parameters)

    policies = Policies.from_file(
        paths.configs_path / "defaults/policy/policy.yaml",
        base_policy_modules=("june.policy", "camps.policy"),
    )

    # events
    events = Events.from_file()

    # create simulator

    travel = Travel()

    group_types = []
    domainVenues = {}
    if domain.households is not None:
        if len(domain.households) > 0:
            group_types.append(domain.households)
            domainVenues["households"] = {
                "N": len(domain.households),
                "bins": domain.households[0].subgroup_bins,
            }
        else:
            domainVenues["households"] = {"N": 0, "bins": "NaN"}

    if domain.care_homes is not None:
        if len(domain.care_homes) > 0:
            group_types.append(domain.care_homes)
            domainVenues["care_homes"] = {
                "N": len(domain.care_homes),
                "bins": domain.care_homes[0].subgroup_bins,
            }
        else:
            domainVenues["care_homes"] = {"N": 0, "bins": "NaN"}

    if domain.schools is not None:
        if len(domain.schools) > 0:
            group_types.append(domain.schools)
            domainVenues["schools"] = {
                "N": len(domain.schools),
                "bins": domain.schools[0].subgroup_bins,
            }
        else:
            domainVenues["schools"] = {"N": 0, "bins": "NaN"}

    if domain.hospitals is not None:
        if len(domain.hospitals) > 0:
            group_types.append(domain.hospitals)
            domainVenues["hospitals"] = {"N": len(domain.hospitals)}
        else:
            domainVenues["hospitals"] = {"N": 0, "bins": "NaN"}

    if domain.companies is not None:
        if len(domain.companies) > 0:
            group_types.append(domain.companies)
            domainVenues["companies"] = {
                "N": len(domain.companies),
                "bins": domain.companies[0].subgroup_bins,
            }
        else:
            domainVenues["companies"] = {"N": 0, "bins": "NaN"}

    if domain.universities is not None:
        if len(domain.universities) > 0:
            group_types.append(domain.universities)
            domainVenues["universities"] = {
                "N": len(domain.universities),
                "bins": domain.universities[0].subgroup_bins,
            }
        else:
            domainVenues["universities"] = {"N": 0, "bins": "NaN"}

    if domain.pubs is not None:
        if len(domain.pubs) > 0:
            group_types.append(domain.pubs)
            domainVenues["pubs"] = {
                "N": len(domain.pubs),
                "bins": domain.pubs[0].subgroup_bins,
            }
        else:
            domainVenues["pubs"] = {"N": 0, "bins": "NaN"}

    if domain.groceries is not None:
        if len(domain.groceries) > 0:
            group_types.append(domain.groceries)
            domainVenues["groceries"] = {
                "N": len(domain.groceries),
                "bins": domain.groceries[0].subgroup_bins,
            }
        else:
            domainVenues["groceries"] = {"N": 0, "bins": "NaN"}

    if domain.cinemas is not None:
        if len(domain.cinemas) > 0:
            group_types.append(domain.cinemas)
            domainVenues["cinemas"] = {
                "N": len(domain.cinemas),
                "bins": domain.cinemas[0].subgroup_bins,
            }
        else:
            domainVenues["cinemas"] = {"N": 0, "bins": "NaN"}

    if domain.gyms is not None:
        if len(domain.gyms) > 0:
            group_types.append(domain.gyms)
            domainVenues["gyms"] = {
                "N": len(domain.gyms),
                "bins": domain.gyms[0].subgroup_bins,
            }
        else:
            domainVenues["gyms"] = {"N": 0, "bins": "NaN"}

    if domain.city_transports is not None:
        if len(domain.city_transports) > 0:
            group_types.append(domain.city_transports)
            domainVenues["city_transports"] = {"N": len(domain.city_transports)}
        else:
            domainVenues["city_transports"] = {"N": 0, "bins": "NaN"}

    if domain.inter_city_transports is not None:
        if len(domain.inter_city_transports) > 0:
            group_types.append(domain.inter_city_transports)
            domainVenues["inter_city_transports"] = {
                "N": len(domain.inter_city_transports)
            }
        else:
            domainVenues["inter_city_transports"] = {"N": 0, "bins": "NaN"}

    # print(mpi_rank, domainVenues)

    # ==================================================================================#

    # =================================== tracker ===============================#
    if args.tracker:
        tracker = Tracker(
            world=domain,
            record_path=args.save_path,
            group_types=group_types,
            load_interactions_path=args.parameters,
            contact_sexes=["unisex", "male", "female"],
            MaxVenueTrackingSize=100000,
        )
    else:
        tracker = None

    simulator = Simulator.from_file(
        world=domain,
        policies=policies,
        events=events,
        interaction=interaction,
        leisure=leisure,
        travel=travel,
        epidemiology=epidemiology,
        config_filename=CONFIG_PATH,
        record=record,
        tracker=tracker,
    )
    return simulator


# ==================================================================================#

# =================================== simulator ===============================#

print(f"mpi_rank {mpi_rank} generate simulator")
simulator = generate_simulator()
simulator.run()

# ==================================================================================#

# =================================== read logger ===============================#

mpi_comm.Barrier()

if mpi_rank == 0:
    combine_records(args.save_path)

mpi_comm.Barrier()

# ==================================================================================#

# =================================== tracker figures ===============================#

if args.tracker:
    if mpi_rank == 0:
        print("Tracker stuff now")

    simulator.tracker.contract_matrices("AC", np.array([0, 18, 100]))
    simulator.tracker.contract_matrices(
        "Paper",
        [0, 5, 10, 13, 15, 18, 20, 22, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 100],
    )
    simulator.tracker.post_process_simulation(save=True)

    mpi_comm.Barrier()

    if mpi_rank == 0:
        print("Combine Tracker results")
        Merger = MergerClass(record_path=args.save_path)
        Merger.Merge()


import numpy as np
import random
import numba as nb
import pandas as pd
import time
import matplotlib.pyplot as plt
from datetime import datetime, timedelta
import sys
import argparse
from pathlib import Path
import yaml

from collections import defaultdict

from june import World
from june.geography import Geography
from june.demography import Demography
from june.interaction import Interaction
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection import Infection, InfectionSelector, InfectionSelectors
from june.epidemiology.infection.health_index import Data2Rates
from june.epidemiology.infection.health_index.health_index import HealthIndexGenerator
from june.epidemiology.infection.transmission import TransmissionConstant
from june.groups import (
    Hospitals,
    Schools,
    Companies,
    Households,
    CareHomes,
    Cemeteries,
    Universities,
)
from june.groups.leisure import (
    generate_leisure_for_config,
    Cinemas,
    Pubs,
    Groceries,
    Gyms,
)
from june.groups.travel import Travel
from june.groups.travel.transport import (
    CityTransport,
    CityTransports,
    InterCityTransport,
    InterCityTransports,
)
from june.simulator import Simulator
from june.epidemiology.infection_seed import InfectionSeed, InfectionSeeds
from june.policy import Policy, Policies
from june import paths
from june.hdf5_savers import load_geography_from_hdf5
from june.records import Record, RecordReader

from june.world import generate_world_from_geography
from june.hdf5_savers import generate_world_from_hdf5

from june.tracker.tracker import Tracker
from june.tracker.tracker_plots import PlotClass

from june.activity import ActivityManager


def set_random_seed(seed=999):
    """
    Sets global seeds for testing in numpy, random, and numbaized numpy.
    """

    @nb.njit(cache=True)
    def set_seed_numba(seed):
        random.seed(seed)
        return np.random.seed(seed)

    np.random.seed(seed)
    set_seed_numba(seed)
    random.seed(seed)
    return


set_random_seed(0)

# =============== Argparse =========================#

parser = argparse.ArgumentParser(description="Full run of the England")

parser.add_argument(
    "-w",
    "--world_path",
    help="path to saved world file",
    required=False,
    default="/cosma5/data/do010/dc-walk3/world.hdf5",
)

parser.add_argument(
    "-c",
    "--comorbidities",
    help="True to include comorbidities",
    required=False,
    default="True",
)
parser.add_argument(
    "-con",
    "--config",
    help="Config file",
    required=False,
    default=paths.configs_path / "config_example.yaml",
)
parser.add_argument(
    "-p",
    "--parameters",
    help="Parameter file",
    required=False,
    default=paths.configs_path / "defaults/interaction/interaction.yaml",
)

parser.add_argument(
    "-tr",
    "--tracker",
    help="Activate Tracker for CM tracing",
    required=False,
    default="False",
)

parser.add_argument(
    "-ro", "--region_only", help="Run only one region", required=False, default="False"
)

parser.add_argument(
    "-hb", "--household_beta", help="Household beta", required=False, default=0.25
)
parser.add_argument(
    "-nnv",
    "--no_vaccines",
    help="Implement no vaccine policies",
    required=False,
    default="False",
)
parser.add_argument(
    "-v",
    "--vaccines",
    help="Implement vaccine policies",
    required=False,
    default="False",
)
parser.add_argument(
    "-nv", "--no_visits", help="No shelter visits", required=False, default="False"
)
parser.add_argument(
    "-ih",
    "--indoor_beta_ratio",
    help="Indoor/household beta ratio scaling",
    required=False,
    default=0.55,
)
parser.add_argument(
    "-oh",
    "--outdoor_beta_ratio",
    help="Outdoor/household beta ratio scaling",
    required=False,
    default=0.05,
)
parser.add_argument(
    "-inf",
    "--infectiousness_path",
    help="path to infectiousness parameter file",
    required=False,
    default="nature",
)
parser.add_argument(
    "-cs",
    "--child_susceptibility",
    help="Reduce child susceptibility for under 12s",
    required=False,
    default="False",
)
parser.add_argument(
    "-u",
    "--isolation_units",
    help="True to include isolation units",
    required=False,
    default="False",
)
parser.add_argument(
    "-t", "--isolation_testing", help="Mean testing time", required=False, default=3
)
parser.add_argument(
    "-i", "--isolation_time", help="Ouput file name", required=False, default=7
)
parser.add_argument(
    "-ic",
    "--isolation_compliance",
    help="Isolation unit self reporting compliance",
    required=False,
    default=0.6,
)
parser.add_argument(
    "-m",
    "--mask_wearing",
    help="True to include mask wearing",
    required=False,
    default="False",
)
parser.add_argument(
    "-mc",
    "--mask_compliance",
    help="Mask wearing compliance",
    required=False,
    default="False",
)
parser.add_argument(
    "-mb",
    "--mask_beta_factor",
    help="Mask beta factor reduction",
    required=False,
    default=0.5,
)

parser.add_argument(
    "-s",
    "--save_path",
    help="Path of where to save logger",
    required=False,
    default="results",
)

parser.add_argument(
    "--n_seeding_days", help="number of seeding days", required=False, default=10
)
parser.add_argument(
    "--n_seeding_case_per_day",
    help="number of seeding cases per day",
    required=False,
    default=10,
)

args = parser.parse_args()
args.save_path = Path(args.save_path)

counter = 1
OG_save_path = args.save_path
while args.save_path.is_dir() is True:
    args.save_path = Path(str(OG_save_path) + "_%s" % counter)
    counter += 1
args.save_path.mkdir(parents=True, exist_ok=False)


if args.tracker == "True":
    args.tracker = True
else:
    args.tracker = False

if args.comorbidities == "True":
    args.comorbidities = True
else:
    args.comorbidities = False

if args.child_susceptibility == "True":
    args.child_susceptibility = True
else:
    args.child_susceptibility = False

if args.no_vaccines == "True":
    args.no_vaccines = True
else:
    args.no_vaccines = False

if args.vaccines == "True":
    args.vaccines = True
else:
    args.vaccines = False

if args.no_visits == "True":
    args.no_visits = True
else:
    args.no_visits = False

if args.isolation_units == "True":
    args.isolation_units = True
else:
    args.isolation_units = False

if args.mask_wearing == "True":
    args.mask_wearing = True
else:
    args.mask_wearing = False


if args.infectiousness_path == "nature":
    transmission_config_path = paths.configs_path / "defaults/transmission/nature.yaml"
elif args.infectiousness_path == "correction_nature":
    transmission_config_path = (
        paths.configs_path / "defaults/transmission/correction_nature.yaml"
    )
elif args.infectiousness_path == "nature_larger":
    transmission_config_path = (
        paths.configs_path
        / "defaults/transmission/nature_larger_presymptomatic_transmission.yaml"
    )
elif args.infectiousness_path == "nature_lower":
    transmission_config_path = (
        paths.configs_path
        / "defaults/transmission/nature_lower_presymptomatic_transmission.yaml"
    )
elif args.infectiousness_path == "xnexp":
    transmission_config_path = paths.configs_path / "defaults/transmission/XNExp.yaml"
else:
    raise NotImplementedError

print("Comorbidities set to: {}".format(args.comorbidities))
print("Parameters path set to: {}".format(args.parameters))
print("Indoor beta ratio is set to: {}".format(args.indoor_beta_ratio))
print("Outdoor beta ratio set to: {}".format(args.outdoor_beta_ratio))
print("Infectiousness path set to: {}".format(args.infectiousness_path))
print("Child susceptibility change set to: {}".format(args.child_susceptibility))

print("Isolation units set to: {}".format(args.isolation_units))
print("Household beta set to: {}".format(args.household_beta))
if args.isolation_units:
    print("Testing time set to: {}".format(args.isolation_testing))
    print("Isolation time set to: {}".format(args.isolation_time))
    print("Isolation compliance set to: {}".format(args.isolation_compliance))

print("Mask wearing set to: {}".format(args.mask_wearing))
if args.mask_wearing:
    print("Mask compliance set to: {}".format(args.mask_compliance))
    print("Mask beta factor set up: {}".format(args.mask_beta_factor))

print("World path set to: {}".format(args.world_path))
print("Save path set to: {}".format(args.save_path))

print("\n", args.__dict__, "\n")


time.sleep(10)

# =============== world creation =========================#
CONFIG_PATH = args.config


world = generate_world_from_hdf5(args.world_path, interaction_config=args.parameters)

leisure = generate_leisure_for_config(world, CONFIG_PATH)
travel = Travel()

# ==================================================================================#

# =================================== Infection ===============================#


selector = InfectionSelector.from_file()
selectors = InfectionSelectors([selector])

infection_seed = InfectionSeed.from_uniform_cases(
    world=world,
    infection_selector=selector,
    cases_per_capita=0.01,
    date="2020-03-02 9:00",
    seed_past_infections=False,
)
infection_seeds = InfectionSeeds([infection_seed])

epidemiology = Epidemiology(
    infection_selectors=selectors, infection_seeds=infection_seeds
)

interaction = Interaction.from_file(config_filename=args.parameters)
# ============================================================================#

# =================================== policies ===============================#


policies = Policies.from_file(
    paths.configs_path / "defaults/policy/policy.yaml",
    base_policy_modules=("june.policy", "camps.policy"),
)

print(
    "Policy path set to: {}".format(paths.configs_path / "defaults/policy/policy.yaml")
)

record = Record(record_path=args.save_path, record_static_data=True)

# ==================================================================================#

# =================================== tracker ===============================#
if args.tracker:
    group_types = [
        world.households,
        world.care_homes,
        world.schools,
        world.hospitals,
        world.companies,
        world.universities,
        world.pubs,
        world.groceries,
        world.cinemas,
        world.gyms,
        world.city_transports,
        world.inter_city_transports,
    ]

    tracker = Tracker(
        world=world,
        record_path=args.save_path,
        group_types=group_types,
        load_interactions_path=args.parameters,
        contact_sexes=["unisex", "male", "female"],
        MaxVenueTrackingSize=10000,
    )
else:
    tracker = None

# ==================================================================================#

# =================================== simulator ===============================#


simulator = Simulator.from_file(
    world=world,
    epidemiology=epidemiology,
    interaction=interaction,
    config_filename=CONFIG_PATH,
    leisure=leisure,
    travel=travel,
    record=record,
    policies=policies,
    tracker=tracker,
)

simulator.run()

# ==================================================================================#

# =================================== read logger ===============================#

read = RecordReader(args.save_path)

infections_df = read.get_table_with_extras("infections", "infected_ids")

locations_df = infections_df.groupby(["location_specs", "timestamp"]).size()

locations_df.to_csv(args.save_path / "locations.csv")

# ==================================================================================#

# =================================== tracker figures ===============================#

if args.tracker:
    print("Tracker stuff now")
    simulator.tracker.contract_matrices("AC", np.array([0, 18, 100]))
    simulator.tracker.contract_matrices(
        "Paper",
        [0, 5, 10, 13, 15, 18, 20, 22, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 100],
    )
    simulator.tracker.post_process_simulation(save=True)


class GroupException(Exception):
    pass


class PolicyError(BaseException):
    pass


class HospitalError(BaseException):
    pass


class SimulatorError(BaseException):
    pass


class InteractionError(BaseException):
    pass


# -*- coding: utf-8 -*-
"""
https://gist.github.com/chengdi123000/42ec8ed2cbef09ee050766c2f25498cb#file-mpifilehandler-py
Created on Wed Feb 14 16:17:38 2018
This handler is used to deal with logging with mpi4py in Python3.

@author: cheng

@reference:
    https://cvw.cac.cornell.edu/python/logging
    https://groups.google.com/forum/#!topic/mpi4py/SaNzc8bdj6U
    https://gist.github.com/JohnCEarls/8172807
"""

# %% mpi4py logging handler
from mpi4py import MPI
import logging
from os.path import abspath


class MPIFileHandler(logging.FileHandler):
    def __init__(
        self,
        filename,
        mode=MPI.MODE_WRONLY | MPI.MODE_CREATE | MPI.MODE_APPEND,
        encoding="utf-8",
        delay=False,
        comm=MPI.COMM_WORLD,
    ):
        self.baseFilename = abspath(filename)
        self.mode = mode
        self.encoding = encoding
        self.comm = comm
        if delay:
            # We don't open the stream, but we still need to call the
            # Handler constructor to set level, formatter, lock etc.
            logging.Handler.__init__(self)
            self.stream = None
        else:
            logging.StreamHandler.__init__(self, self._open())

    def _open(self):
        stream = MPI.File.Open(self.comm, self.baseFilename, self.mode)
        stream.Set_atomicity(True)
        return stream

    def emit(self, record):
        """
        Emit a record.

        If a formatter is specified, it is used to format the record.
        The record is then written to the stream with a trailing newline.  If
        exception information is present, it is formatted using
        traceback.print_exception and appended to the stream.  If the stream
        has an 'encoding' attribute, it is used to determine how to do the
        output to the stream.

        Modification:
            stream is MPI.File, so it must use `Write_shared` method rather
            than `write` method. And `Write_shared` method only accept
            bytestring, so `encode` is used. `Write_shared` should be invoked
            only once in each all of this emit function to keep atomicity.
        """
        try:
            msg = self.format(record)
            stream = self.stream
            stream.Write_shared((msg + self.terminator).encode(self.encoding))
            # self.flush()
        except Exception:
            self.handleError(record)

    def close(self):
        if self.stream:
            self.stream.Sync()
            self.stream.Close()
            self.stream = None


# %% example code
if __name__ == "__main__":
    comm = MPI.COMM_WORLD
    logger = logging.getLogger("rank[%i]" % comm.rank)
    logger.setLevel(logging.DEBUG)

    mh = MPIFileHandler("logfile.log")
    formatter = logging.Formatter("%(asctime)s:%(name)s:%(levelname)s:%(message)s")
    mh.setFormatter(formatter)

    logger.addHandler(mh)
    # 'application' code
    logger.debug("debug message")
    logger.info("info message")
    logger.warning("warn message")
    logger.error("error message")
    logger.critical("critical message")


from mpi4py import MPI
import numpy as np


mpi_comm = MPI.COMM_WORLD
mpi_rank = mpi_comm.Get_rank()
mpi_size = mpi_comm.Get_size()


class MovablePeople:
    """
    Holds information about people who might be present in a domain, but may or may not be be,
    given circumstances. They have skinny profiles, which only have their id, infection probability,
    susceptibility, home domain, and whether active or not. For now, we mimic the original structure,
    but with an additional interface.
    """

    def __init__(self):
        self.skinny_out = {}
        self.skinny_in = {}
        self.index = {}

    def add_person(self, person, external_subgroup):
        """Add or update a person to the outward facing group"""
        domain_id = external_subgroup.domain_id
        group_spec = external_subgroup.spec
        group_id = external_subgroup.group_id
        subgroup_type = external_subgroup.subgroup_type

        if domain_id not in self.skinny_out:
            self.skinny_out[domain_id] = {}  # allocate domain id
        if group_spec not in self.skinny_out[domain_id]:
            self.skinny_out[domain_id][group_spec] = {}
        if group_id not in self.skinny_out[domain_id][group_spec]:
            self.skinny_out[domain_id][group_spec][group_id] = {}
        if subgroup_type not in self.skinny_out[domain_id][group_spec][group_id]:
            self.skinny_out[domain_id][group_spec][group_id][subgroup_type] = {}

        if person.infected:
            view = [
                person.id,
                person.infection.transmission.probability,
                person.infection.infection_id(),
                False,
                np.array([], dtype=np.int64),
                np.array([], dtype=np.float64),
                mpi_rank,
                True,
            ]
        else:
            (
                susceptibility_inf_ids,
                susceptibility_inf_suscs,
            ) = person.immunity.serialize()
            view = [
                person.id,
                0.0,
                0,
                True,
                np.array(susceptibility_inf_ids, dtype=np.int64),
                np.array(susceptibility_inf_suscs, dtype=np.float64),
                mpi_rank,
                True,
            ]

        self.skinny_out[domain_id][group_spec][group_id][subgroup_type][
            person.id
        ] = view

    def delete_person(self, person, external_subgroup):
        """Remove a person from the external subgroup. For now we actually do it. Later
        we may flag them."""
        domain_id = external_subgroup.domain_id
        group_spec = external_subgroup.spec
        group_id = external_subgroup.group_id
        subgroup_type = external_subgroup.subgroup_type
        try:
            del self.skinny_out[domain_id][group_spec][group_id][subgroup_type][
                person.id
            ]
            return 0
        except KeyError:
            return 1

    def serialise(self, rank):
        """Hopefully more efficient than standard pickle"""
        keys, data = [], []
        if rank not in self.skinny_out:
            return None, None, 0
        for group_spec in self.skinny_out[rank]:
            for group_id in self.skinny_out[rank][group_spec]:
                for subgroup_type in self.skinny_out[rank][group_spec][group_id]:
                    keys.append(
                        (
                            group_spec,
                            group_id,
                            subgroup_type,
                            len(
                                self.skinny_out[rank][group_spec][group_id][
                                    subgroup_type
                                ]
                            ),
                        )
                    )
                    data += [
                        view
                        for pid, view in self.skinny_out[rank][group_spec][group_id][
                            subgroup_type
                        ].items()
                    ]
        outbound = np.array(data, dtype=object)
        return keys, outbound, outbound.shape[0]

    def update(self, rank, keys, rank_data):
        """Update the information we have about people coming into our domain
        :param rank: domain of origin
        :param keys: dictionary keys for the group structure
        :param rank_data: numpy array of all the person data
        """
        index = 0

        for key in keys:
            group_spec, group_id, subgroup_type, n_data = key
            if group_spec not in self.skinny_in:
                self.skinny_in[group_spec] = {}
            if group_id not in self.skinny_in[group_spec]:
                self.skinny_in[group_spec][group_id] = {}
            if subgroup_type not in self.skinny_in[group_spec][group_id]:
                self.skinny_in[group_spec][group_id][subgroup_type] = {}
            data = rank_data[index : index + n_data]
            index += n_data

            try:
                self.skinny_in[group_spec][group_id][subgroup_type].update(
                    {
                        int(k): {
                            "inf_prob": i,
                            "inf_id": t,
                            "susc": s,
                            "immunity_inf_ids": iids,
                            "immunity_suscs": iis,
                            "dom": d,
                            "active": a,
                        }
                        for k, i, t, s, iids, iis, d, a in data
                    }
                )
            except Exception:
                print("failing", rank, "f-done")
                raise


def move_info(info2move):
    """
    Send a list of arrays of uint32 integers to all ranks,
    and receive arrays from all ranks.

    """
    # flatten list of uneven vectors of data, ensure correct type
    assert len(info2move) == mpi_size
    buffer = np.concatenate(info2move)
    assert buffer.dtype == np.uint32

    n_sending = len(buffer)
    count = np.array([len(x) for x in info2move])
    displ = np.array([sum(count[:p]) for p in range(len(info2move))])

    # send my count to all processes
    values = mpi_comm.alltoall(count)

    n_receiving = sum(values)

    # now all processes know how much data they will get,
    # and how much from each rank

    r_buffer = np.zeros(n_receiving, dtype=np.uint32)
    rdisp = np.array([sum(values[:p]) for p in range(len(values))])

    mpi_comm.Alltoallv(
        [buffer, count, displ, MPI.UINT32_T], [r_buffer, values, rdisp, MPI.UINT32_T]
    )

    return r_buffer, n_sending, n_receiving


import logging
import os
import subprocess
from pathlib import Path
from sys import argv

logger = logging.getLogger(__name__)

project_directory = Path(os.path.abspath(__file__)).parent

working_directory = Path(os.getcwd())

working_directory_parent = working_directory.parent


def find_default(name: str, look_in_package=True) -> Path:
    """
    Get a default path when no command line argument is passed.

    - First attempt to find the folder in the current working directory.
    - If it is not found there then try the directory in which June lives.
    - Finally, try the directory above the current working directory. This
    is for the build pipeline.

    This means that tests will find the configuration regardless of whether
    they are run together or individually.

    Parameters
    ----------
    name
        The name of some folder

    Returns
    -------
    The full path to that directory
    """
    directories_to_look = [working_directory, working_directory_parent]
    if look_in_package:
        directories_to_look.append(project_directory)
        directories_to_look.append(project_directory.parent)
    for directory in directories_to_look:
        path = directory / name
        if os.path.exists(path):
            return path
    raise FileNotFoundError(f"Could not find a default path for {name}")


def path_for_name(name: str, look_in_package=True) -> Path:
    """
    Get a path input using a flag when the program is run.

    If no such argument is given default to the directory above
    the june with the name of the flag appended.

    e.g. --data indicates where the data folder is and defaults
    to june/../data

    Parameters
    ----------
    name
        A string such as "data" which corresponds to the flag --data

    Returns
    -------
    A path
    """
    flag = f"--{name}"
    try:
        path = Path(argv[argv.index(flag) + 1])
        if not path.exists():
            raise FileNotFoundError(f"No such folder {path}")
    except (IndexError, ValueError):
        path = find_default(name, look_in_package=look_in_package)
        logger.warning(f"No {flag} argument given - defaulting to:\n{path}")

    return path


try:
    data_path = path_for_name("data", look_in_package=True)
except FileNotFoundError:
    answer = input(
        "I couldn't find any data folder, do you want me to download it for you? (y/N) "
    )
    if answer == "y":
        script_path = Path(__file__).parent.parent / "scripts" / "get_june_data.sh"
        with open(script_path, "rb") as file:
            script = file.read()
        rc = subprocess.call(script, shell=True)
    data_path = path_for_name("data", look_in_package=True)

configs_path = path_for_name("configs")


import logging
import datetime
import yaml
from typing import Optional, List
from pathlib import Path
from time import perf_counter
from time import time as wall_clock

from june import paths
from june.activity import ActivityManager
from june.exc import SimulatorError
from june.groups.leisure import Leisure
from june.groups.travel import Travel
from june.epidemiology.epidemiology import Epidemiology
from june.interaction import Interaction
from june.tracker import Tracker
from june.policy import Policies
from june.event import Events
from june.time import Timer
from june.records import Record
from june.world import World
from june.mpi_setup import mpi_comm, mpi_size, mpi_rank

default_config_filename = paths.configs_path / "config_example.yaml"

output_logger = logging.getLogger("simulator")
mpi_logger = logging.getLogger("mpi")
rank_logger = logging.getLogger("rank")
mpi_logger.propagate = False
if mpi_rank > 0:
    output_logger.propagate = False
    mpi_logger.propagate = False


def enable_mpi_debug(results_folder):
    from june.logging import MPIFileHandler

    logging_file = Path(results_folder) / "mpi.log"
    with open(logging_file, "w"):
        pass
    mh = MPIFileHandler(logging_file)
    rank_logger.addHandler(mh)


def _read_checkpoint_dates_from_file(config_filename):
    with open(config_filename) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return _read_checkpoint_dates(config.get("checkpoint_save_dates", None))


def _read_checkpoint_dates(checkpoint_dates):
    if isinstance(checkpoint_dates, datetime.date):
        return (checkpoint_dates,)
    elif type(checkpoint_dates) == str:
        return (datetime.datetime.strptime(checkpoint_dates, "%Y-%m-%d"),)
    elif type(checkpoint_dates) in [list, tuple]:
        ret = []
        for date in checkpoint_dates:
            if type(date) == str:
                dd = datetime.datetime.strptime(date, "%Y-%m-%d").date()
            else:
                dd = date
            ret.append(dd)
        return tuple(ret)
    else:
        return ()


class Simulator:
    ActivityManager = ActivityManager

    def __init__(
        self,
        world: World,
        interaction: Interaction,
        timer: Timer,
        activity_manager: ActivityManager,
        epidemiology: Epidemiology,
        tracker: Tracker,
        events: Optional[Events] = None,
        record: Optional[Record] = None,
        checkpoint_save_dates: List[datetime.date] = None,
        checkpoint_save_path: str = None,
    ):
        """
        Class to run an epidemic spread simulation on the world.

        Parameters
        ----------
        world:
            instance of World class
        """
        self.activity_manager = activity_manager
        self.world = world
        self.interaction = interaction
        self.events = events
        self.timer = timer
        self.epidemiology = epidemiology
        if self.epidemiology:
            self.epidemiology.set_medical_care(
                world=world, activity_manager=activity_manager
            )
            self.epidemiology.set_immunity(self.world)
            self.epidemiology.set_past_vaccinations(
                people=self.world.people, date=self.timer.date, record=record
            )
        self.tracker = tracker
        if self.events is not None:
            self.events.init_events(world=world)
        # self.comment = comment
        self.checkpoint_save_dates = _read_checkpoint_dates(checkpoint_save_dates)
        if self.checkpoint_save_dates:
            if not checkpoint_save_path:
                checkpoint_save_path = "results/checkpoints"
            self.checkpoint_save_path = Path(checkpoint_save_path)
            self.checkpoint_save_path.mkdir(parents=True, exist_ok=True)
        self.record = record
        if self.record is not None and self.record.record_static_data:
            self.record.static_data(world=world)

    @classmethod
    def from_file(
        cls,
        world: World,
        interaction: Interaction,
        policies: Optional[Policies] = None,
        events: Optional[Events] = None,
        epidemiology: Optional[Epidemiology] = None,
        tracker: Optional[Tracker] = None,
        leisure: Optional[Leisure] = None,
        travel: Optional[Travel] = None,
        config_filename: str = default_config_filename,
        checkpoint_save_path: str = None,
        record: Optional[Record] = None,
    ) -> "Simulator":

        """
        Load config for simulator from world.yaml

        Parameters
        ----------
        leisure
        policies
        interaction
        world
        config_filename
            The path to the world yaml configuration
        comment
            A brief description of the purpose of the run(s)

        Returns
        -------
        A Simulator
        """
        checkpoint_save_dates = _read_checkpoint_dates_from_file(config_filename)
        timer = Timer.from_file(config_filename=config_filename)
        activity_manager = cls.ActivityManager.from_file(
            config_filename=config_filename,
            world=world,
            leisure=leisure,
            travel=travel,
            policies=policies,
            timer=timer,
            record=record,
        )
        return cls(
            world=world,
            interaction=interaction,
            timer=timer,
            events=events,
            activity_manager=activity_manager,
            epidemiology=epidemiology,
            tracker=tracker,
            record=record,
            checkpoint_save_dates=checkpoint_save_dates,
            checkpoint_save_path=checkpoint_save_path,
        )

    @classmethod
    def from_checkpoint(
        cls,
        world: World,
        checkpoint_load_path: str,
        interaction: Interaction,
        epidemiology: Optional[Epidemiology] = None,
        tracker: Optional[Tracker] = None,
        policies: Optional[Policies] = None,
        leisure: Optional[Leisure] = None,
        travel: Optional[Travel] = None,
        config_filename: str = default_config_filename,
        record: Optional[Record] = None,
        events: Optional[Events] = None,
        reset_infections=False,
    ):
        from june.hdf5_savers.checkpoint_saver import generate_simulator_from_checkpoint

        return generate_simulator_from_checkpoint(
            world=world,
            checkpoint_path=checkpoint_load_path,
            interaction=interaction,
            policies=policies,
            epidemiology=epidemiology,
            tracker=tracker,
            leisure=leisure,
            travel=travel,
            config_filename=config_filename,
            record=record,
            events=events,
            reset_infections=reset_infections,
        )

    def clear_world(self):
        """
        Removes everyone from all possible groups, and sets everyone's busy attribute
        to False.
        """
        for super_group_name in self.activity_manager.all_super_groups:
            if "visits" in super_group_name:
                continue
            grouptype = getattr(self.world, super_group_name)
            if grouptype is not None:
                for group in grouptype.members:
                    group.clear()

        for person in self.world.people.members:
            person.busy = False
            person.subgroups.leisure = None

    def do_timestep(self):
        """
        Perform a time step in the simulation. First, ActivityManager is called
        to send people to the corresponding subgroups according to the current daytime.
        Then we iterate over all the groups and create an InteractiveGroup object, which
        extracts the relevant information of each group to carry the interaction in it.
        We then pass the interactive group to the interaction module, which returns the ids
        of the people who got infected. We record the infection locations, update the health
        status of the population, and distribute scores among the infectors to calculate R0.
        """
        output_logger.info("==================== timestep ====================")
        tick_s, tickw_s = perf_counter(), wall_clock()
        tick, tickw = perf_counter(), wall_clock()
        if self.activity_manager.policies is not None:
            self.activity_manager.policies.interaction_policies.apply(
                date=self.timer.date, interaction=self.interaction
            )
            self.activity_manager.policies.regional_compliance.apply(
                date=self.timer.date, regions=self.world.regions
            )
        activities = self.timer.activities
        # apply events
        if self.events is not None:
            self.events.apply(
                date=self.timer.date,
                world=self.world,
                activities=activities,
                day_type=self.timer.day_type,
                simulator=self,
            )
        if not activities or len(activities) == 0:
            output_logger.info("==== do_timestep(): no active groups found. ====")
            return
        (
            people_from_abroad_dict,
            n_people_from_abroad,
            n_people_going_abroad,
            to_send_abroad,  # useful for knowing who's MPI-ing, so can send extra info as needed.
        ) = self.activity_manager.do_timestep(record=self.record)
        tick_interaction = perf_counter()

        # get the supergroup instances that are active in this time step:
        active_super_groups = self.activity_manager.active_super_groups
        super_group_instances = []
        for super_group_name in active_super_groups:
            if "visits" not in super_group_name:
                super_group_instance = getattr(self.world, super_group_name)
                if super_group_instance is None or len(super_group_instance) == 0:
                    continue
                super_group_instances.append(super_group_instance)

        # for checking that people is conserved
        n_people = 0
        # count people in the cemetery
        for cemetery in self.world.cemeteries.members:
            n_people += len(cemetery.people)

        output_logger.info(
            f"Info for rank {mpi_rank}, "
            f"Date = {self.timer.date}, "
            f"number of deaths =  {n_people}, "
            f"number of infected = {len(self.world.people.infected)}"
        )

        # main interaction loop
        infected_ids = []  # ids of the newly infected people
        infection_ids = []  # ids of the viruses they got

        for super_group in super_group_instances:
            for group in super_group:
                if group.external:
                    continue
                else:
                    people_from_abroad = people_from_abroad_dict.get(
                        group.spec, {}
                    ).get(group.id, None)
                    (
                        new_infected_ids,
                        new_infection_ids,
                        group_size,
                    ) = self.interaction.time_step_for_group(
                        group=group,
                        people_from_abroad=people_from_abroad,
                        delta_time=self.timer.duration,
                        record=self.record,
                    )

                    infected_ids += new_infected_ids
                    infection_ids += new_infection_ids
                    n_people += group_size

        tock_interaction = perf_counter()
        rank_logger.info(
            f"Rank {mpi_rank} -- interaction -- {tock_interaction-tick_interaction}"
        )

        tick_tracker = perf_counter()
        # Loop in here
        if isinstance(self.tracker, type(None)):
            pass
        else:
            self.tracker.trackertimestep(
                self.activity_manager.all_super_groups, self.timer
            )
        tock_tracker = perf_counter()
        rank_logger.info(f"Rank {mpi_rank} -- tracker -- {tock_tracker-tick_tracker}")

        self.epidemiology.do_timestep(
            world=self.world,
            timer=self.timer,
            record=self.record,
            infected_ids=infected_ids,
            infection_ids=infection_ids,
            people_from_abroad_dict=people_from_abroad_dict,
        )

        tick, tickw = perf_counter(), wall_clock()
        mpi_comm.Barrier()
        tock, tockw = perf_counter(), wall_clock()
        rank_logger.info(f"Rank {mpi_rank} -- interaction_waiting -- {tock-tick}")

        # recount people active to check people conservation
        people_active = (
            len(self.world.people) + n_people_from_abroad - n_people_going_abroad
        )
        if n_people != people_active:

            raise SimulatorError(
                f"Number of people active {n_people} does not match "
                f"the total people number {people_active}.\n"
                f"People in the world {len(self.world.people)}\n"
                f"People going abroad {n_people_going_abroad}\n"
                f"People coming from abroad {n_people_from_abroad}\n"
                f"Current rank {mpi_rank}\n"
            )

        # remove everyone from their active groups
        self.clear_world()
        tock, tockw = perf_counter(), wall_clock()
        output_logger.info(
            f"CMS: Timestep for rank {mpi_rank}/{mpi_size} - {tock - tick_s},"
            f"{tockw-tickw_s} - {self.timer.date}\n"
        )
        mpi_logger.info(f"{self.timer.date},{mpi_rank},timestep,{tock-tick_s}")

    def run(self):
        """
        Run simulation with n_seed initial infections
        """
        output_logger.info(
            f"Starting simulation for {self.timer.total_days} days at day {self.timer.date},"
            f"to run for {self.timer.total_days} days"
        )
        self.clear_world()
        if self.record is not None:
            self.record.parameters(
                interaction=self.interaction,
                epidemiology=self.epidemiology,
                activity_manager=self.activity_manager,
            )
        while self.timer.date < self.timer.final_date:
            if self.epidemiology:
                self.epidemiology.infection_seeds_timestep(
                    self.timer, record=self.record
                )
            mpi_comm.Barrier()
            if mpi_rank == 0:
                rank_logger.info("Next timestep")
            self.do_timestep()
            if (
                self.timer.date.date() in self.checkpoint_save_dates
                and (self.timer.now + self.timer.duration).is_integer()
            ):  # this saves in the last time step of the day
                saving_date = self.timer.date.date()
                # we can resume consistenly
                output_logger.info(
                    f"Saving simulation checkpoint at {self.timer.date.date()}"
                )
                self.save_checkpoint(saving_date)
            next(self.timer)

    def save_checkpoint(self, saving_date):
        from june.hdf5_savers.checkpoint_saver import save_checkpoint_to_hdf5

        if mpi_size == 1:
            save_path = self.checkpoint_save_path / f"checkpoint_{saving_date}.hdf5"
        else:
            save_path = (
                self.checkpoint_save_path / f"checkpoint_{saving_date}.{mpi_rank}.hdf5"
            )
        save_checkpoint_to_hdf5(
            population=self.world.people,
            date=str(saving_date),
            hdf5_file_path=save_path,
        )


import calendar
import datetime
from turtle import home
import yaml
from typing import List

SECONDS_PER_DAY = 24 * 60 * 60


class Timer:
    def __init__(
        self,
        initial_day: str = "2020-03-01 9:00",
        total_days: int = 10,
        weekday_step_duration: List[int] = (12, 12),
        weekend_step_duration: List[int] = (24,),
        weekday_activities: List[List[str]] = (
            ("primary_activity", "residence"),
            ("residence",),
        ),
        weekend_activities: List[List[str]] = (("residence",),),
        day_types=None,
    ):

        day_i = datetime.datetime(
            *[int(value) for value in initial_day.split(" ")[0].split("-")]
        )
        hour_i = 0
        if len(initial_day.split(" ")) > 1:
            hour_i = int(initial_day.split(" ")[1].split(":")[0])
        self.initial_date = day_i + datetime.timedelta(hours=hour_i)

        if day_types is None:
            self.day_types = {
                "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
                "weekend": ["Saturday", "Sunday"],
            }
        else:
            self.day_types = day_types

        self.total_days = total_days
        self.weekday_step_duration = weekday_step_duration
        self.weekend_step_duration = weekend_step_duration
        self.weekday_activities = weekday_activities
        self.weekend_activities = weekend_activities

        self.previous_date = self.initial_date
        self.final_date = (
            self.initial_date
            + datetime.timedelta(days=total_days)
            # + datetime.timedelta(hours=24 - hour_i)
        )
        self.date = self.initial_date
        self.shift = 0
        self.delta_time = datetime.timedelta(hours=self.shift_duration)

    @classmethod
    def from_file(cls, config_filename):
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        time_config = config["time"]
        if "weekday" in config.keys() and "weekend" in config.keys():
            day_types = {"weekday": config["weekday"], "weekend": config["weekend"]}
        else:
            day_types = None

        return cls(
            initial_day=time_config["initial_day"],
            total_days=time_config["total_days"],
            weekday_step_duration=time_config["step_duration"]["weekday"],
            weekend_step_duration=time_config["step_duration"]["weekend"],
            weekday_activities=time_config["step_activities"]["weekday"],
            weekend_activities=time_config["step_activities"]["weekend"],
            day_types=day_types,
        )

    @property
    def is_weekend(self):
        if self.day_of_week in self.day_types["weekend"]:
            return True
        else:
            return False

    @property
    def day_type(self):
        if self.day_of_week in self.day_types["weekend"]:
            return "weekend"
        else:
            return "weekday"

    @property
    def now(self):
        difference = self.date - self.initial_date
        return difference.total_seconds() / SECONDS_PER_DAY

    @property
    def date_str(self):
        return self.date.date().strftime("%Y-%m-%d")

    @property
    def duration(self):
        return self.delta_time.total_seconds() / SECONDS_PER_DAY

    @property
    def day(self):
        return int(self.now)

    @property
    def day_of_week(self):
        return calendar.day_name[self.date.weekday()]

    @property
    def activities(self):
        type_day = "weekend" if self.is_weekend else "weekday"
        return getattr(self, type_day + "_activities")[self.shift]

    @property
    def shift_duration(self):
        type_day = "weekend" if self.is_weekend else "weekday"
        return getattr(self, type_day + "_step_duration")[self.shift]

    def reset(self):
        self.date = self.initial_date
        self.shift = 0
        self.delta_time = datetime.timedelta(hours=self.shift_duration)
        self.previous_date = self.initial_date

    def reset_to_new_date(self, date):
        self.date = date
        self.shift = 0
        self.delta_time = datetime.timedelta(hours=self.shift_duration)
        self.previous_date = self.initial_date

    def __next__(self):
        self.previous_date = self.date
        self.date += self.delta_time
        self.shift += 1
        if self.previous_date.day != self.date.day:
            self.shift = 0
        self.delta_time = datetime.timedelta(hours=self.shift_duration)
        return self.date


import logging
from typing import Optional
from june.demography import Demography, Population
from june.distributors import (
    SchoolDistributor,
    HospitalDistributor,
    HouseholdDistributor,
    CareHomeDistributor,
    WorkerDistributor,
    CompanyDistributor,
    UniversityDistributor,
)
from june.geography import Geography, Areas
from june.groups import Supergroup, Cemeteries

logger = logging.getLogger("world")

possible_groups = [
    "households",
    "care_homes",
    "schools",
    "hospitals",
    "companies",
    "universities",
    "pubs",
    "groceries",
    "cinemas",
]


def _populate_areas(areas: Areas, demography, ethnicity=True, comorbidity=True):
    logger.info("Populating areas")
    people = Population()
    for area in areas:
        area.populate(demography, ethnicity=ethnicity, comorbidity=comorbidity)
        people.extend(area.people)
    n_people = len(people)
    logger.info(f"Areas populated. This world's population is: {n_people}")
    return people


class World:
    """
    This Class creates the world that will later be simulated.
    The world will be stored in pickle, but a better option needs to be found.

    """

    def __init__(self):
        """
        Initializes a world given a geography and a demography. For now, households are
        a special group because they require a mix of both groups (we need to fix
        this later).
        """
        self.areas = None
        self.super_areas = None
        self.regions = None
        self.people = None
        self.households = None
        self.care_homes = None
        self.schools = None
        self.companies = None
        self.hospitals = None
        self.pubs = None
        self.groceries = None
        self.cinemas = None
        self.cemeteries = None
        self.universities = None
        self.cities = None
        self.stations = None

    def __iter__(self):
        ret = []
        for attr_name, attr_value in self.__dict__.items():
            if isinstance(attr_value, Supergroup):
                ret.append(attr_value)
        return iter(ret)

    def distribute_people(self, include_households=True):
        """
        Distributes people to buildings assuming default configurations.
        """
        if (
            self.companies is not None
            or self.hospitals is not None
            or self.schools is not None
            or self.care_homes is not None
        ):
            worker_distr = WorkerDistributor.for_super_areas(
                area_names=[super_area.name for super_area in self.super_areas]
            )  # atm only for_geography()
            worker_distr.distribute(
                areas=self.areas, super_areas=self.super_areas, population=self.people
            )
        if self.care_homes is not None:
            carehome_distr = CareHomeDistributor.from_file()
            carehome_distr.populate_care_homes_in_super_areas(
                super_areas=self.super_areas
            )

        if include_households:
            household_distributor = HouseholdDistributor.from_file()

            self.households = (
                household_distributor.distribute_people_and_households_to_areas(
                    self.areas
                )
            )

        if self.schools is not None:
            school_distributor = SchoolDistributor(self.schools)
            school_distributor.distribute_kids_to_school(self.areas)
            school_distributor.limit_classroom_sizes()
            school_distributor.distribute_teachers_to_schools_in_super_areas(
                self.super_areas
            )

        if self.universities is not None:
            uni_distributor = UniversityDistributor(self.universities)
            uni_distributor.distribute_students_to_universities(
                areas=self.areas, people=self.people
            )
        if self.care_homes is not None:
            # this goes after unis to ensure students go to uni
            carehome_distr.distribute_workers_to_care_homes(
                super_areas=self.super_areas
            )

        if self.hospitals is not None:
            hospital_distributor = HospitalDistributor.from_file(self.hospitals)
            hospital_distributor.distribute_medics_to_super_areas(self.super_areas)
            hospital_distributor.assign_closest_hospitals_to_super_areas(
                self.super_areas
            )

        # Companies last because need hospital and school workers first
        if self.companies is not None:
            company_distributor = CompanyDistributor()
            company_distributor.distribute_adults_to_companies_in_super_areas(
                self.super_areas
            )

    def to_hdf5(self, file_path: str, chunk_size=100000):
        """
        Saves the world to an hdf5 file. All supergroups and geography
        are stored as groups. Class instances are substituted by ids of the
        instances. To load the world back, one needs to call the
        generate_world_from_hdf5 function.

        Parameters
        ----------
        file_path
            path of the hdf5 file
        chunk_size
            how many units of supergroups to process at a time.
            It is advise to keep it around 1e5
        """
        from june.hdf5_savers import save_world_to_hdf5

        save_world_to_hdf5(world=self, file_path=file_path, chunk_size=chunk_size)


def generate_world_from_geography(
    geography: Geography,
    demography: Optional[Demography] = None,
    include_households=True,
    ethnicity=True,
    comorbidity=True,
):
    """
    Initializes the world given a geometry. The demography is calculated
    with the default settings for that geography.
    """
    world = World()
    if demography is None:
        demography = Demography.for_geography(geography)
    world.areas = geography.areas
    world.super_areas = geography.super_areas
    world.regions = geography.regions
    world.people = _populate_areas(world.areas, demography)
    for possible_group in possible_groups:
        geography_group = getattr(geography, possible_group)
        if geography_group is not None:
            setattr(world, possible_group, geography_group)
    world.distribute_people(include_households=include_households)
    world.cemeteries = Cemeteries()
    return world


import logging.config
import os

import yaml

from june import paths
from . import demography
from . import distributors
from . import groups
from . import interaction
from . import simulator
from . import activity
from .demography import Person
from .exc import GroupException
from .time import Timer
from .world import World

default_logging_config_filename = paths.configs_path / "logging.yaml"

if os.path.isfile(default_logging_config_filename):
    with open(default_logging_config_filename, "rt") as f:
        log_config = yaml.safe_load(f.read())
        logging.config.dictConfig(log_config)
else:
    print("The logging config file does not exist.")
    log_file = os.path.join("./", "world_creation.log")
    logging.basicConfig(filename=log_file, level=logging.DEBUG)


import logging
import yaml
from datetime import datetime
from itertools import chain
from typing import List, Optional
from time import perf_counter
from time import time as wall_clock

from june.demography import Person
from june.exc import SimulatorError
from june.groups import Subgroup
from june.groups.leisure import Leisure
from june.groups.travel import Travel
from june.mpi_setup import mpi_comm, mpi_size, mpi_rank, MovablePeople
from june.records import Record

logger = logging.getLogger("activity_manager")
mpi_logger = logging.getLogger("mpi")
rank_logger = logging.getLogger("rank")
if mpi_rank > 0:
    logger.propagate = True

activity_hierarchy = [
    "medical_facility",
    "rail_travel_out",
    "rail_travel_back",
    "commute",
    "primary_activity",
    "leisure",
    "residence",
]


class ActivityManager:
    def __init__(
        self,
        world,
        policies,
        timer,
        all_activities,
        activity_to_super_groups: dict,
        record: Optional[Record] = None,
        leisure: Optional[Leisure] = None,
        travel: Optional[Travel] = None,
    ):
        self.policies = policies
        if self.policies is not None:
            self.policies.init_policies(world=world, date=timer.date, record=record)
        self.world = world
        self.timer = timer
        self.leisure = leisure
        self.travel = travel
        self.all_activities = all_activities

        self.activity_to_super_group_dict = {
            "medical_facility": activity_to_super_groups.get("medical_facility", []),
            "primary_activity": activity_to_super_groups.get("primary_activity", []),
            "leisure": activity_to_super_groups.get("leisure", []),
            "residence": activity_to_super_groups.get("residence", []),
            "commute": activity_to_super_groups.get("commute", []),
            "rail_travel": activity_to_super_groups.get("rail_travel", []),
        }

    @classmethod
    def from_file(
        cls,
        config_filename,
        world,
        policies,
        timer,
        record: Optional[Record] = None,
        leisure: Optional[Leisure] = None,
        travel: Optional[Travel] = None,
    ):
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        try:
            activity_to_super_groups = config["activity_to_super_groups"]
        except KeyError:
            logger.warning(
                "Activity to groups in config is deprecated"
                "please change it to activity_to_super_groups"
            )
            activity_to_super_groups = config["activity_to_groups"]
        time_config = config["time"]

        cls.check_inputs(time_config)
        weekday_activities = [
            activity for activity in time_config["step_activities"]["weekday"].values()
        ]
        weekend_activities = [
            activity for activity in time_config["step_activities"]["weekend"].values()
        ]
        all_activities = set(
            chain.from_iterable(weekday_activities + weekend_activities)
        )
        return cls(
            world=world,
            policies=policies,
            timer=timer,
            all_activities=all_activities,
            activity_to_super_groups=activity_to_super_groups,
            leisure=leisure,
            travel=travel,
            record=record,
        )

    @staticmethod
    def check_inputs(time_config: dict):
        """
        Check that the iput time configuration is correct, i.e., activities are among allowed activities
        and days have 24 hours.

        Parameters
        ----------
        time_config:
            dictionary with time steps configuration
        """

        try:
            assert sum(time_config["step_duration"]["weekday"].values()) == 24
            assert sum(time_config["step_duration"]["weekend"].values()) == 24
        except AssertionError:
            raise SimulatorError(
                "Daily activity durations in config do not add to 24 hours."
            )

        # Check that all groups given in time_config file are in the valid group hierarchy
        all_super_groups = activity_hierarchy
        try:
            for step, activities in time_config["step_activities"]["weekday"].items():
                assert all(group in all_super_groups for group in activities)

            for step, activities in time_config["step_activities"]["weekend"].items():
                assert all(group in all_super_groups for group in activities)
        except AssertionError:
            raise SimulatorError("Config file contains unsupported activity name.")

    @property
    def all_super_groups(self):
        return self.activities_to_super_groups(self.all_activities)

    @property
    def active_super_groups(self):
        return self.activities_to_super_groups(self.timer.activities)

    @staticmethod
    def apply_activity_hierarchy(activities: List[str]) -> List[str]:
        """
        Returns a list of activities with the right order, obeying the permanent activity hierarcy
        and shuflling the random one.

        Parameters
        ----------
        activities:
            list of activities that take place at a given time step
        Returns
        -------
        Ordered list of activities according to hierarchy
        """
        activities.sort(key=lambda x: activity_hierarchy.index(x))
        return activities

    def activities_to_super_groups(self, activities: List[str]) -> List[str]:
        """
        Converts activities into Supergroups, the interaction will run over these Groups.

        Parameters
        ---------
        activities:
            list of activities that take place at a given time step
        Returns
        -------
        List of groups that are active.
        """
        return list(
            chain.from_iterable(
                self.activity_to_super_group_dict[activity] for activity in activities
            )
        )

    def get_personal_subgroup(self, person: "Person", activity: str):
        return getattr(person, activity)

    def do_timestep(self, record=None):
        # get time data
        tick_interaction_timestep = perf_counter()
        date = self.timer.date
        day_type = self.timer.day_type
        activities = self.apply_activity_hierarchy(self.timer.activities)
        delta_time = self.timer.duration
        # apply leisure policies
        if self.leisure is not None:
            if self.policies is not None:
                self.policies.leisure_policies.apply(date=date, leisure=self.leisure)
            self.leisure.generate_leisure_probabilities_for_timestep(
                delta_time=delta_time,
                date=date,
                working_hours="primary_activity" in activities,
            )
        # move people to subgroups and get going abroad people
        to_send_abroad = self.move_people_to_active_subgroups(
            activities=activities,
            date=date,
            days_from_start=self.timer.now,
            record=record,
        )
        tock_interaction_timestep = perf_counter()
        rank_logger.info(
            f"Rank {mpi_rank} -- move_people -- {tock_interaction_timestep-tick_interaction_timestep}"
        )
        tick_waiting = perf_counter()
        mpi_comm.Barrier()
        tock_waiting = perf_counter()
        rank_logger.info(
            f"Rank {mpi_rank} -- move_people_waiting -- {tock_waiting-tick_waiting}"
        )
        (
            people_from_abroad,
            n_people_from_abroad,
            n_people_going_abroad,
        ) = self.send_and_receive_people_from_abroad(to_send_abroad)
        return (
            people_from_abroad,
            n_people_from_abroad,
            n_people_going_abroad,
            to_send_abroad,
        )

    def move_people_to_active_subgroups(
        self,
        activities: List[str],
        date: datetime = datetime(2020, 2, 2),
        days_from_start=0,
        record=None,
    ):
        """
        Sends every person to one subgroup. If a person has a mild illness,
        they stay at home

        Parameters
        ----------

        """
        tick = perf_counter()
        active_individual_policies = self.policies.individual_policies.get_active(
            date=date
        )
        to_send_abroad = MovablePeople()
        counter = 0
        for person in self.world.people:
            counter += 1
            if person.dead or person.busy:
                continue
            allowed_activities = self.policies.individual_policies.apply(
                active_policies=active_individual_policies,
                person=person,
                activities=activities,
                days_from_start=days_from_start,
            )
            external_subgroup = self.move_to_active_subgroup(
                allowed_activities, person, to_send_abroad
            )
            if external_subgroup is not None:
                to_send_abroad.add_person(person, external_subgroup)

        tock = perf_counter()
        mpi_logger.info(f"{self.timer.date},{mpi_rank},activity,{tock-tick}")
        return to_send_abroad

    def move_to_active_subgroup(
        self, activities: List[str], person: Person, to_send_abroad=None
    ) -> Optional["Subgroup"]:
        """
        Given the hierarchy of activities and a person, decide what subgroup
        should they go to

        Parameters
        ----------
        activities:
            list of activities that take place at a given time step
        person:
            person that is looking for a subgroup to go to
        Returns
        -------
        Subgroup to which person has to go, given the hierarchy of activities
        """
        for activity in activities:
            if activity == "leisure" and person.leisure is None:
                subgroup = self.leisure.get_subgroup_for_person_and_housemates(
                    person=person, to_send_abroad=to_send_abroad
                )
            elif activity == "commute":
                subgroup = self.travel.get_commute_subgroup(person=person)
            else:
                subgroup = self.get_personal_subgroup(person=person, activity=activity)
            if subgroup is not None:
                if subgroup.external:
                    person.busy = True
                    # this person goes to another MPI domain
                    return subgroup

                subgroup.append(person)
                return
        raise SimulatorError(
            "Attention! Some people do not have an activity in this timestep."
        )

    def send_and_receive_people_from_abroad(self, movable_people):
        """
        Deal with the MPI comms.
        """
        n_people_going_abroad = 0
        n_people_from_abroad = 0
        tick, tickw = perf_counter(), wall_clock()
        reqs = []

        for rank in range(mpi_size):

            if mpi_rank == rank:
                continue
            keys, data, n_this_rank = movable_people.serialise(rank)
            if n_this_rank:
                reqs.append(mpi_comm.isend(keys, dest=rank, tag=100))
                reqs.append(mpi_comm.isend(data, dest=rank, tag=200))
                n_people_going_abroad += n_this_rank
            else:
                reqs.append(mpi_comm.isend(None, dest=rank, tag=100))
                reqs.append(mpi_comm.isend(None, dest=rank, tag=200))

        # now it has all been sent, we can start the receiving.

        for rank in range(mpi_size):

            if rank == mpi_rank:
                continue
            keys = mpi_comm.recv(source=rank, tag=100)
            data = mpi_comm.recv(source=rank, tag=200)

            if keys is not None:
                movable_people.update(rank, keys, data)
                n_people_from_abroad += data.shape[0]

        for r in reqs:
            r.wait()

        tock, tockw = perf_counter(), wall_clock()
        logger.info(
            f"CMS: People COMS for rank {mpi_rank}/{mpi_size} - {tock - tick},{tockw - tickw} - {self.timer.date}"
        )
        mpi_logger.info(f"{self.timer.date},{mpi_rank},people_comms,{tock-tick}")
        return movable_people.skinny_in, n_people_from_abroad, n_people_going_abroad


from .activity_manager import ActivityManager, activity_hierarchy


# Code to reformat of school, household, and social matrices
# UK data into our code's input system

import pandas as pd
import numpy as np
import os
from shutil import copyfile


def read_df(
    DATA_DIR: str, filename: str, column_names: list, usecols: list, index: str
) -> pd.DataFrame:
    """Read dataframe and format

    Args:
        DATA_DIR: path to dataset folder (default should be output_area folder)
        filename:
        column_names: names of columns for output dataframe
        usecols: ids of columns to read
        index: index of output dataframe

    Returns:
        df: formatted df

    """

    df = pd.read_csv(
        os.path.join(DATA_DIR, filename), names=column_names, usecols=usecols, header=0
    )
    df.set_index(index, inplace=True)
    return df


def read_population_df(OUTPUT_AREA_DIR) -> pd.DataFrame:
    """Read population dataset downloaded from https://www.nomisweb.co.uk/census/2011/ks101ew

    Args:

    Returns:
        pandas dataframe with ratio of males and females per output area

    """
    # TODO: column names need to be more general for other datasets.
    population = "usual_resident_population.csv"
    population_column_names = ["output_area", "n_residents", "males", "females"]
    population_usecols = [
        "geography code",
        "Variable: All usual residents; measures: Value",
        "Variable: Males; measures: Value",
        "Variable: Females; measures: Value",
    ]
    population_df = pd.read_csv(
        os.path.join(OUTPUT_AREA_DIR, population), usecols=population_usecols
    )
    names_dict = dict(zip(population_usecols, population_column_names))
    population_df.rename(columns=names_dict, inplace=True)
    population_df.set_index("output_area", inplace=True)

    try:
        pd.testing.assert_series_equal(
            population_df["n_residents"],
            population_df["males"] + population_df["females"],
            check_names=False,
        )
    except AssertionError:
        print("males: ", len(population_df["males"]))
        print("females: ", len(population_df["females"]))
        raise AssertionError

    return population_df["n_residents"], population_df.drop(columns="n_residents")


def read_ages_df(OUTPUT_AREA_DIR: str, freq: bool = True) -> pd.DataFrame:
    """Read ages dataset downloaded from https://www.nomisweb.co.uk/census/2011/ks102ew

    Args:

    Returns:
        pandas dataframe with age profiles per output area

    """
    ages = "age_structure.csv"
    ages_names = [
        "output_area",
        "0-4",
        "5-7",
        "8-9",
        "10-14",
        "15",
        "16-17",
        "18-19",
        "20-24",
        "25-29",
        "30-44",
        "45-59",
        "60-64",
        "65-74",
        "75-84",
        "85-89",
        "90-XXX",
    ]

    ages_usecols = [2] + list(range(5, 21))

    ages_df = read_df(OUTPUT_AREA_DIR, ages, ages_names, ages_usecols, "output_area")
    return ages_df


def read_minimal_household_composition(OUTPUT_AREA_DIR):
    pass


def read_household_composition_people(OUTPUT_AREA_DIR, ages_df):
    """
    TableID: QS112EW
    https://www.nomisweb.co.uk/census/2011/qs112ew

    """
    household_people = "household_composition_people.csv"
    usecols = [
        2,
        6,
        7,
        9,
        11,
        12,
        13,
        14,
        16,
        17,
        18,
        19,
        21,
        22,
        23,
        24,
        26,
        27,
        28,
        30,
        31,
        32,
        33,
        34,
    ]
    column_names = [
        "output_area",
        "Person_old",
        "Person",
        "Old_Family",
        "Family_0k",
        "Family_1k",
        "Family_2k",
        "Family_adult_children",
        "SS_Family_0k",
        "SS_Family_1k",
        "SS_Family_2k",
        "SS_Family_adult_children",
        "Couple_Family_0k",
        "Couple_Family_1k",
        "Couple_Family_2k",
        "Couple_Family_adult_children",
        "Lone_1k",
        "Lone_2k",
        "Lone_adult_children",
        "Other_1k",
        "Other_2k",
        "Students",
        "Old_Unclassified",
        "Other",
    ]
    OLD_THRESHOLD = 12
    comp_people_df = read_df(
        OUTPUT_AREA_DIR, household_people, column_names, usecols, "output_area"
    )

    # Combine equivalent fields
    comp_people_df["Family_0k"] += (
        comp_people_df["SS_Family_0k"] + comp_people_df["Couple_Family_0k"]
    )
    comp_people_df["Family_1k"] += (
        comp_people_df["SS_Family_1k"]
        + comp_people_df["Couple_Family_1k"]
        + comp_people_df["Other_1k"]
    )
    comp_people_df["Family_2k"] += (
        comp_people_df["SS_Family_2k"]
        + comp_people_df["Couple_Family_2k"]
        + comp_people_df["Other_2k"]
    )
    comp_people_df["Family_adult_children"] += (
        comp_people_df["SS_Family_adult_children"]
        + comp_people_df["Couple_Family_adult_children"]
    )

    # Since other contains some old, give it some probability when there are old people in the area
    areas_with_old = ages_df[ages_df.columns[OLD_THRESHOLD:]].sum(axis=1) > 0
    areas_no_house_old = (
        comp_people_df["Person_old"]
        + comp_people_df["Old_Family"]
        + comp_people_df["Old_Unclassified"]
        == 0
    )

    comp_people_df["Family_0k"].loc[
        ~((areas_no_house_old) & (areas_with_old))
    ] += comp_people_df["Other"].loc[~((areas_no_house_old) & (areas_with_old))]

    comp_people_df["Old_Family"].loc[(areas_no_house_old) & (areas_with_old)] += (
        comp_people_df["Other"].loc[(areas_no_house_old) & (areas_with_old)]
        + 0.4 * comp_people_df["Other_1k"].loc[(areas_no_house_old) & (areas_with_old)]
    )

    comp_people_df = comp_people_df.drop(
        columns=[
            c
            for c in comp_people_df.columns
            if "SS" in c or "Couple" in c or "Other" in c
        ]
    )

    return comp_people_df


def read_household_df(OUTPUT_AREA_DIR: str) -> pd.DataFrame:
    """Read household dataset downloaded from https://www.nomisweb.co.uk/census/2011/ks105ew

    Args:

    Returns:
        pandas dataframe with number of households per output area

    """

    households = "household_composition.csv"
    households_names = ["output_area", "n_households"]
    households_usecols = [2, 4]

    households_df = read_df(
        OUTPUT_AREA_DIR, households, households_names, households_usecols, "output_area"
    )

    return households_df


def people_compositions2households(comp_people_df):

    households_df = pd.DataFrame()

    # SINGLES
    households_df["0 0 0 1"] = comp_people_df["Person_old"]
    households_df["0 0 1 0"] = comp_people_df["Person"]

    # COUPLES NO KIDS
    households_df["0 0 0 2"] = comp_people_df["Old_Family"] // 2
    households_df["0 0 2 0"] = comp_people_df["Family_0k"] // 2

    # COUPLES 1 DEPENDENT KID
    households_df["1 0 2 0"] = (
        comp_people_df["Family_1k"] // 3 - comp_people_df["Family_1k"] % 3
    ).apply(lambda x: max(x, 0))
    # i) Assumption: there can be only one independent child, and is a young adult
    households_df["1 1 2 0"] = comp_people_df["Family_1k"] % 3

    # COUPLES >2 DEPENDENT KIDS
    households_df["2 0 2 0"] = (
        comp_people_df["Family_2k"] // 4 - comp_people_df["Family_2k"] % 4
    ).apply(lambda x: max(x, 0))
    # ii) Assumption: the maximum number of children is 3, it could be a young adult or a kid
    households_df["3 0 2 0"] = 0.5 * (comp_people_df["Family_2k"] % 4)
    households_df["2 1 2 0"] = 0.5 * (comp_people_df["Family_2k"] % 4)

    # COUPLES WITH ONLY INDEPENDENT CHILDREN
    # iii) Assumption: either one or two children (no more than two)
    households_df["0 1 2 0"] = (
        comp_people_df["Family_adult_children"] // 3
        - comp_people_df["Family_adult_children"] % 3
    ).apply(lambda x: max(x, 0))
    households_df["0 2 2 0"] = comp_people_df["Family_adult_children"] % 3

    # LONE PARENTS 1 DEPENDENT KID
    households_df["1 0 1 0"] = (
        comp_people_df["Lone_1k"] // 2 - comp_people_df["Lone_1k"] % 2
    ).apply(lambda x: max(x, 0))
    # i) Assumption: there can be only one independent child, and is a young adult
    households_df["1 1 1 0"] = comp_people_df["Lone_1k"] % 2

    households_df["2 0 1 0"] = (
        comp_people_df["Lone_2k"] // 3 - comp_people_df["Lone_2k"] % 3
    ).apply(lambda x: max(x, 0))
    # ii) Assumption: the maximum number of children is 3, it could be a young adult or a kid
    households_df["3 0 1 0"] = 0.5 * (comp_people_df["Lone_2k"] % 3)
    households_df["2 1 1 0"] = 0.5 * (comp_people_df["Lone_2k"] % 3)

    # STUDENTS
    # iv) Students live in houses of 3 or 4
    households_df["0 3 0 0"] = (
        comp_people_df["Students"] // 3 - comp_people_df["Students"] % 3
    ).apply(lambda x: max(x, 0))

    households_df["0 4 0 0"] = comp_people_df["Students"] % 3

    # OLD OTHER
    # v) old other live in houses of 2 or 3
    households_df["0 0 0 2"] += (
        comp_people_df["Old_Unclassified"] // 2 - comp_people_df["Old_Unclassified"] % 2
    ).apply(lambda x: max(x, 0))
    households_df["0 0 0 3"] = comp_people_df["Old_Unclassified"] % 2

    return households_df


def read_school_census(DATA_DIR):
    """
    Reads school location and sizes, it initializes a KD tree on a sphere,
    to query the closest schools to a given location.
    """
    school_filename = os.path.join(DATA_DIR, "school_data", "uk_schools_data.csv")
    school_df = pd.read_csv(school_filename, index_col=0)
    school_df.dropna(inplace=True)
    school_df["age_min"].replace(to_replace=np.arange(0, 4), value=4, inplace=True)

    school_df["age_max"].replace(to_replace=np.arange(20, 50), value=19, inplace=True)

    assert school_df["age_min"].min() <= 4
    assert school_df["age_max"].max() < 20
    return school_df


def downsample_social_matrix(matrix):
    # low_res_matrix = pd.DataFrame()

    """
    print(matrix)

    low_res_matrix["0-4"] = matrix["0-4"]
    low_res_matrix["5-9"] = matrix["5-9"]
    low_res_matrix.loc["10-14"] = matrix.loc["10-12"] + matrix.loc["13-14"]
    low_res_matrix["10-14"] = matrix["10-12"] + matrix["13-14"]
    print(matrix.loc["10-12"])
    print(matrix.loc["13-14"])
    low_res_matrix["15-17"] = matrix["15-17"]
    low_res_matrix["18-19"] = matrix["18-19"]
    low_res_matrix["20-24"] = matrix["20-21"] + matrix["22-24"]
    low_res_matrix.loc["20-24"] = matrix.loc["20-21"] + matrix.loc["22-24"]
    low_res_matrix["25-29"] = matrix["25-29"]
    low_res_matrix["30-44"] = matrix["30-34"] + matrix["35-39"] + matrix["40-44"]
    low_res_matrix.loc["30-44"] = (
        matrix.loc["30-34"] + matrix.loc["35-39"] + matrix.loc["40-44"]
    )
    low_res_matrix["45-59"] = matrix["45-49"] + matrix["50-54"] + matrix["55-59"]
    low_res_matrix.loc["45-59"] = (
        matrix.loc["45-49"] + matrix.loc["50-54"] + matrix.loc["55-59"]
    )
    low_res_matrix["60-64"] = matrix["60-64"]
    low_res_matrix["65-74"] = matrix["65-69"] + matrix["70-74"]
    low_res_matrix.loc["65-74"] = matrix.loc["65-69"] + matrix.loc["70-74"]

    low_res_matrix.drop(
        [
            "10-12",
            "13-14",
            "20-21",
            "22-24",
            "30-34",
            "35-39",
            "40-44",
            "45-49",
            "50-54",
            "55-59",
            "65-69",
            "70-74",
        ], inplace=True
    )
    """

    return matrix


def reformat_social_matrices(raw_mixing_dir, processed_mixing_dir):
    social_matrices = ["all_school", "physical_school", "conversational_school"]

    for sm in social_matrices:
        matrix = pd.read_excel(
            os.path.join(
                raw_mixing_dir, "BBC_repriprocal_matrices_by_type_context.xls"
            ),
            sheet_name=sm,
            index_col=0,
        )
        matrix.fillna(0.0, inplace=True)
        low_res_matrix = downsample_social_matrix(matrix)
        low_res_matrix.to_csv(os.path.join(processed_mixing_dir, f"{sm}.csv"))


if __name__ == "__main__":

    region = "EnglandWales"
    RAW_DATA_DIR = os.path.join("..", "data", "census_data")
    RAW_OUTPUT_AREA_DIR = os.path.join(RAW_DATA_DIR, "output_area", region)

    residents, sex_df = read_population_df(RAW_OUTPUT_AREA_DIR)
    ages_df = read_ages_df(RAW_OUTPUT_AREA_DIR)
    comp_people_df = read_household_composition_people(RAW_OUTPUT_AREA_DIR, ages_df)
    households_df = people_compositions2households(comp_people_df)
    school_df = read_school_census(RAW_DATA_DIR)

    DATA_DIR = os.path.join("..", "data", "processed", "census_data")
    OUTPUT_AREA_DIR = os.path.join(DATA_DIR, "output_area", region)
    if not os.path.exists(OUTPUT_AREA_DIR):
        os.makedirs(OUTPUT_AREA_DIR)

    residents.to_csv(os.path.join(OUTPUT_AREA_DIR, "residents.csv"))
    sex_df.to_csv(os.path.join(OUTPUT_AREA_DIR, "sex.csv"))
    ages_df.to_csv(os.path.join(OUTPUT_AREA_DIR, "age_structure.csv"))
    households_df.to_csv(os.path.join(OUTPUT_AREA_DIR, "household_composition.csv"))

    SCHOOL_DIR = os.path.join(DATA_DIR, "school_data")
    if not os.path.exists(SCHOOL_DIR):
        os.makedirs(SCHOOL_DIR)

    school_df.to_csv(os.path.join(DATA_DIR, "school_data", "uk_schools_data.csv"))

    GEO_DIR = os.path.join("..", "data", "processed", "geographical_data")

    if not os.path.exists(GEO_DIR):
        os.makedirs(GEO_DIR)
    copyfile(
        os.path.join("..", "data", "geographical_data", "oa_coorindates.csv"),
        os.path.join(
            "..", "data", "processed", "geographical_data", "oa_coorindates.csv"
        ),
    )


import pandas as pd

from june import paths

raw_path = f"{paths.data_path}/census_data/output_area/"
processed_path = f"{paths.data_path}/processed/census_data/output_area/"

carehome_df = pd.read_csv(raw_path / "communal_people.csv")
carehome_df.set_index(carehome_df["geography"], inplace=True)

carehome_df = carehome_df[[col for col in carehome_df.columns if "Care home" in col]]
all_care_homes = carehome_df.sum(axis=1)
print(all_care_homes)
assert len(all_care_homes) == 181408
all_care_homes.to_csv(processed_path / "carehomes.csv")


import os

import numpy as np
import pandas as pd

from june import paths

raw_path = f"{paths.data_path}/census_data/output_area/"
processed_path = f"{paths.data_path}/processed/census_data/output_area/"


def filter_and_sum(df, in_column):
    if len(in_column) > 1:
        df = df.filter(regex="|".join(in_column))
    else:
        df = df.filter(regex=in_column[0])
    return df.sum(axis=1)


df = pd.read_csv(raw_path / "household_houses.csv", index_col=0)

# All England and Wales data
assert len(df) == 181408

df.set_index("geography", inplace=True)
df.drop(columns=["date", "geography code"], inplace=True)

df = df[
    [col for col in df.columns if "Total" not in col and "All categories" not in col]
]

encoding_households = pd.DataFrame()

encoding_households["0 0 0 0 1"] = filter_and_sum(df, ["One person household: Aged 65"])
encoding_households["0 0 0 1 0"] = filter_and_sum(df, ["One person household: Other"])
encoding_households["0 0 0 0 2"] = filter_and_sum(
    df, ["One family only: All aged 65 and over"]
)
encoding_households["0 0 0 2 0"] = filter_and_sum(df, ["No children"])

encoding_households["1 0 >=0 2 0"] = filter_and_sum(
    df,
    ["Married couple: One dependent child", "Cohabiting couple: One dependent child"],
)

encoding_households[">=2 0 >=0 2 0"] = filter_and_sum(
    df,
    [
        "Married couple: Two or more dependent",
        "Cohabiting couple: Two or more dependent",
    ],
)
encoding_households["0 0 >=1 2 0"] = filter_and_sum(
    df,
    [
        "Married couple: All children non-dependent",
        "Cohabiting couple: All children non-dependent",
    ],
)
encoding_households["1 0 >=0 1 0"] = filter_and_sum(
    df, ["Lone parent: One dependent child"]
)
encoding_households[">=2 0 >=0 1 0"] = filter_and_sum(
    df, ["Lone parent: Two or more dependent children"]
)
encoding_households["0 0 >=1 1 0"] = filter_and_sum(
    df, ["Lone parent: All children non-dependent"]
)

encoding_households["1 0 >=0 >=1 >=0"] = filter_and_sum(
    df, ["Other household types: With one dependent child"]
)
encoding_households[">=2 0 >=0 >=1 >=0"] = filter_and_sum(
    df, ["Other household types: With two"]
)
encoding_households["0 >=1 0 0 0"] = filter_and_sum(df, ["All full-time students"])
encoding_households["0 0 0 0 >=2"] = filter_and_sum(
    df, ["Other household types: All aged 65 and over"]
)
encoding_households["0 0 >=0 >=0 >=0"] = filter_and_sum(
    df, ["Other household types: Other"]
)

encoding_households.index.name = "output_area"

np.testing.assert_array_equal(
    encoding_households.sum(axis=1).values, df.sum(axis=1).values
)
# comunal establishments
comunal_df = pd.read_csv(os.path.join(raw_path, "communal_houses.csv"))
comunal_df.set_index(comunal_df["geography"], inplace=True)

all_comunal_df = comunal_df[
    [col for col in comunal_df.columns if "All categories" in col]
]
carehome_df = comunal_df[[col for col in comunal_df.columns if "Care home" in col]]
carehome_df = carehome_df.sum(axis=1)
comunal_not_carehome_df = all_comunal_df[all_comunal_df.columns[0]] - carehome_df

assert (
    comunal_not_carehome_df.sum() + carehome_df.sum()
    == all_comunal_df[all_comunal_df.columns[0]].sum()
)

encoding_households[">=0 >=0 >=0 >=0 >=0"] = comunal_not_carehome_df
encoding_households.to_csv(
    os.path.join(processed_path, "minimum_household_composition.csv")
)


import pandas as pd

from june import paths

raw_path = f"{paths.data_path}/census_data/output_area/"
processed_path = f"{paths.data_path}/processed/census_data/output_area/"

comunal = pd.read_csv(raw_path / "communal_people.csv")

comunal.set_index("geography", inplace=True)
all_comunal_df = comunal[[col for col in comunal.columns if "All categories" in col]]
carehome_df = comunal[[col for col in comunal.columns if "Care home" in col]]
carehome_df = carehome_df.sum(axis=1)
comunal = all_comunal_df[all_comunal_df.columns[0]] - carehome_df

assert (
    comunal.sum() + carehome_df.sum() == all_comunal_df[all_comunal_df.columns[0]].sum()
)

# comunal = comunal.rename(
#        {comunal.columns[0]: 'n_people_in_communal'},
#        axis=1
#        )

comunal.index.name = "output_area"

assert len(comunal) == 181408
comunal.to_csv(processed_path / "n_people_in_communal.csv")


import pandas as pd

from june import paths

raw_path = f"{paths.data_path}/census_data/output_area/"
processed_path = f"{paths.data_path}/processed/census_data/output_area/"

household_composition_people = pd.read_csv(
    raw_path / "household_composition_people.csv"
)

household_composition_people.set_index("geography", inplace=True)
household_composition_people = household_composition_people.filter(
    regex="All full-time students"
)

household_composition_people = household_composition_people.rename(
    {household_composition_people.columns[0]: "n_students"}, axis=1
)

household_composition_people.index.name = "output_area"

print(household_composition_people)
household_composition_people.to_csv(processed_path / "n_students.csv")


import googlemaps
import time

import requests


class APICall:
    """
    Handling API calls to the Google Maps API
    Interacts through url API calls directly, as well as using the Python client

    Note: This requires the Google Maps Place API for running and making calls
    """

    def __init__(self, key):
        self.key = key
        self.client = googlemaps.Client(self.key)
        self.raise_warning()

    def raise_warning(self):
        print(
            "WARNING: By running this class you will be making Google Maps API calls \n This will use API credits and may charge you money - please proceed with caution"
        )

    def get_request(self, url):
        try:
            response = requests.get(url)
            return response
        except Exception:
            raise Exception("Error: GET request failed")

    def process_results(self, results):
        locations = []
        names = []
        reviews = []
        ratings = []
        for i in results:
            locations.append(i["geometry"]["location"])
            names.append(i["name"])
            # ratings.append(i['rating'])
            # reviews.append(i['user_ratings_total'])

        return locations, names, reviews, ratings

    def process_pagetoken(self, resp_json_payload, out):
        locations, names, reviews, ratings = out
        try:
            next_page_token = resp_json_payload["next_page_token"]
            return [locations, names, reviews, ratings, next_page_token]
        except Exception:
            print("No more next page tokens")
            return [locations, names, reviews, ratings]

    def nearby_search(self, location, radius, location_type, return_pagetoken=False):
        """
        Searches nearby locations given a location and a radius for particular loction types

        :param location: (tuple of ints) location is a tuple of (latitude, longitude)
        :param radius: (int) meter radius search area around location coordinate
        :param location_type: (string) type of location being searched for
        :param return_pagetoken: (bool) if True and there is anoter page to be generated, then returns token for next page

        Note: location types can be found here: https://developers.google.com/places/supported_types#table1
        """

        lat, lon = location
        lat = str(lat)
        lon = str(lon)
        url = "https://maps.googleapis.com/maps/api/place/nearbysearch/json?location={},{}&radius={}&type={}&&key={}".format(
            lat, lon, radius, location_type, self.key
        )

        response = self.get_request(url)
        # convert to json
        resp_json_payload = response.json()

        results = resp_json_payload["results"]

        out = self.process_results(results)

        if return_pagetoken:
            out = self.process_pagetoken(resp_json_payload, out)
            return out
        else:
            return out

    def nearby_search_next_page(self, next_page_token, return_pagetoken=False):
        """
        After running nearby search with next_page token, call next page
        :param next_page_token: (string) output from self.nearby_search([...], return_pagetoken = True)
        :param return_pagetoken: (bool) if True and there is anoter page to be generated, then returns token for next page
        """

        url = "https://maps.googleapis.com/maps/api/place/nearbysearch/json?pagetoken={}&key={}".format(
            next_page_token, self.key
        )

        response = self.get_request(url)

        # convert to json
        resp_json_payload = response.json()

        results = resp_json_payload["results"]

        out = self.process_results(results)

        if return_pagetoken:
            out = self.process_pagetoken(resp_json_payload, out)
            return out
        else:
            return out

    def out_len_check(self, out):
        if len(out) == 5:
            locations, names, reviews, ratings, next_page_token = out
            return next_page_token
        else:
            locations, names, reviews, ratings = out
            return None

    def nearby_search_loop(self, location, radius, location_type):
        """
        In cases where there may be multple next pages (up to Google's max 3), run loop over all pages
        :param location: (tuple of ints) location is a tuple of (latitude, longitude)
        :param radius: (int) meter radius search area around location coordinate
        :param location_type: (string) type of location being searched for
        """
        print("Calling API")
        out = self.nearby_search(location, radius, location_type, return_pagetoken=True)
        token = self.out_len_check(out)
        outs = []
        outs.append(out)
        while token is not None:
            print("Calling API")
            time.sleep(2)
            out_token = self.nearby_search_next_page(token, return_pagetoken=True)
            outs.append(out_token)
            token_check = self.out_len_check(out_token)
            token = token_check

        return outs

    def places(self, query, location, radius, location_type=None):
        """
        Search places according to a certain query

        :param query: (string) e.g. 'restaurant'
        :param location: (tuple of ints) location is a tuple of (latitude, longitude)
        :param radius: (int) meter radius search area around location coordinate
        :param location_type: (string, optional) type of location being searched for
        """

        try:
            if location_type is not None:
                call = self.client.places(
                    query,
                    location=location,
                    radius=radius,
                    region="UK",
                    # language-"en-UK",
                    type=location_type,
                )
            else:
                call = self.client.places(
                    query,
                    location=location,
                    radius=radius,
                    region="UK",
                    # language-"en-UK"
                )

        except Exception:
            raise Exception("Error: GET request failed")

        results = call["results"]

        locations, names, reviews, ratings = self.process_results(results)

        return [locations, names, reviews, ratings]

    def distance(self, origin_location, destination_location, mode):
        """
        Determine distance between two locations according t the mode of transport

        :param origin_location: (tuple of ints) origin location is a tuple of (latitude, longitude)
        :param destination_location: (tuple of ints) destination location is a tuple of (latitude, longitude)
        :param mode: (string) mode of transport valid values are “driving”, “walking”, “transit” or “bicycling”
        """

        dist = self.client.distance_matrix(origin_location, destination_location, mode)

        # TODO finish this if needed

        return dist


import numpy as np
import pandas as pd
import argparse

from gmapi import APICall


class MSOASearch:
    """
    Functions for running Google Maps API by region at the MSOA level for a given type

    More information on Google types can be found here: https://developers.google.com/places/supported_types
    """

    def __init__(self):
        self.args = self.parse()

    def parse(self):
        """
        Parse input arguments
        """
        parser = argparse.ArgumentParser(
            description="Run the Google Maps API by region at the MSOA level for a given type"
        )
        parser.add_argument(
            "--apikey_file",
            dest="apikey_file",
            help="location of txt file containing apikey",
            type=str,
        )

        parser.add_argument(
            "--type",
            dest="location_type",
            help="Google maps type being selected (found on Google Cloud documentation)",
            type=str,
        )

        parser.add_argument(
            "--msoa_coord_dir",
            dest="msoa_coord",
            help="directory containing MSOA centroids - assume also where file will be saved to",
            type=str,
        )

        args = parser.parse_args()

        return args

    def get_msoas_type(self, apikey, msoas):
        """
        For a given type, call Google Maps API to search for type

        Note: Currently the radius is fixed at the average required to cover the whole of England and Wales
        """
        self.apikey = apikey
        apicall = APICall(self.apikey)

        coordinates = []
        for i in range(len(msoas)):
            coordinates.append((msoas["Y"][i], msoas["X"][i]))
        outs = []
        for i in range(len(coordinates)):
            out = apicall.nearby_search_loop(
                location=(coordinates[i][0], coordinates[i][1]),
                radius=4600,
                location_type=self.args.location_type,
            )
            outs.append(out)
            # print (out)

        return outs


if __name__ == "__main__":

    msoasearch = MSOASearch()

    with open(msoasearch.args.apikey_file, "r") as f:
        api = f.read()
    apikey = api.split("\n")[0]

    regions = [
        "East"
    ]  # , 'SouthEast', 'SouthWest', 'NorthEast', 'NorthWest', 'Yorkshire','London', 'Wales', 'EastMidlands', 'WestMidlands']

    for region in regions:
        print("Working on region: {}".format(region))
        msoas = pd.read_csv(
            "{}/msoa_coordinates_{}.csv".format(msoasearch.args.msoa_coord, region)
        )
        outs = msoasearch.get_msoas_type(apikey, msoas)
        np.save(
            "{}/outs_{}_{}.npy".format(
                msoasearch.args.msoa_coord, msoasearch.args.location_type, region
            ),
            outs,
            allow_pickle=True,
        )


import numpy as np
import pandas as pd
import argparse


def parse():
    """
    Parse input arguments
    """
    parser = argparse.ArgumentParser(
        description="Clean Google Maps API data pulled using msoa_search.py"
    )

    parser.add_argument(
        "--type",
        dest="location_type",
        help="Google maps type being selected (found on Google Cloud documentation)",
        type=str,
    )

    parser.add_argument(
        "--msoa_coord_dir",
        dest="msoa_coord",
        help="directory containing MSOA centroids - assume also where file will be saved to",
        type=str,
    )

    args = parser.parse_args()

    return args


def clean(region_file, msoa_file):

    msoa = []
    latitude = []
    longitude = []
    name = []
    for idx, i in enumerate(region_file):
        for j in i:
            for k in j[0]:
                latitude.append(k["lat"])
                longitude.append(k["lng"])
            for k in j[1]:
                name.append(k)
                msoa.append(list(msoa_file["MSOA11CD"])[idx])

    data = {"lat": latitude, "lon": longitude, "name": name, "msoa": msoa}
    df = pd.DataFrame(data)

    df_drop = df.drop_duplicates(subset=["lat", "lon"], keep="first")

    return df_drop


if __name__ == "__main__":

    args = parse()

    regions = [
        "Yorkshire",
        "London",
        "Wales",
        "EastMidlands",
        "WestMidlands",
        "SouthEast",
        "SouthWest",
        "NorthEast",
        "NorthWest",
        "East",
    ]
    for region in regions:
        print("Working on region: {}".format(region))
        region_file = np.load(
            "{}/outs_{}_{}.npy".format(args.msoa_coord, args.location_type, region),
            allow_pickle=True,
        )
        msoa_file = pd.read_csv(
            "{}/msoa_coordinates_{}.csv".format(args.msoa_coord, region)
        )
        df_clean = clean(region_file, msoa_file)
        df_clean.to_csv(
            "{}/outs_{}_{}_clean.csv".format(
                args.msoa_coord, args.location_type, region
            )
        )


import pandas as pd
from june import paths

raw_path = paths.data_path
processed_path = paths.data_path / "input/hospitals/"

hospitals_df = pd.read_csv(raw_path / "hospital_data/option1_trusts.csv")
area_translation_df = pd.read_csv(
    raw_path / "census_data/area_code_translations/areas_mapping.csv"
)
area_translation_df = area_translation_df[["postcode", "oa", "msoa"]]
area_translation_df.set_index("postcode", inplace=True)
postcodes_coord = pd.read_csv(
    raw_path / "geographical_data/ukpostcodes_coordinates.csv"
)
postcodes_coord.set_index("postcode", inplace=True)
hospitals_df = hospitals_df[
    ["Code", "Regular beds", "Intensive care beds (MV+ITU+IDU)", "Postcode"]
]
hospitals_df.set_index("Postcode", inplace=True)
hospitals_df.columns = ["code", "beds", "icu_beds"]

hospitals_df = hospitals_df.join(area_translation_df)
hospitals_df.columns = ["code", "beds", "icu_beds", "area", "super_area"]
hospitals_df = hospitals_df.join(postcodes_coord)
hospitals_df.set_index("super_area", inplace=True)
hospitals_df.to_csv(processed_path / "trusts.csv")


import pandas as pd
from june import paths

raw_path = paths.data_path / "seed/"
processed_path = paths.data_path / "processed/seed/"

seed_df = pd.read_csv(raw_path / "Seeding_March_10days.csv", index_col=0)

seed_df = seed_df.drop(columns=["Trust", "Code"])
seed_df = seed_df.rename(columns={"NHS England Region": "region"})
n_cases_region = seed_df.groupby("region").sum()
n_cases_region.loc["London"] = (
    n_cases_region.loc["London"] + n_cases_region.loc["London "]
)
n_cases_region = n_cases_region.drop("London ")

n_cases_region.to_csv(processed_path / "n_cases_region.csv")


import pandas as pd
from june import paths

raw_path = paths.data_path / "time_series/"
processed_path = paths.data_path / "processed/time_series/"

confirmed_cases_df = pd.read_csv(raw_path / "coronavirus-cases_latest.csv", index_col=0)
mask = confirmed_cases_df["Area type"] == "Region"
confirmed_cases_df = confirmed_cases_df[mask]
confirmed_cases_df = confirmed_cases_df[["Specimen date", "Daily lab-confirmed cases"]]
confirmed_cases_df.reset_index(inplace=True)
confirmed_cases_df = confirmed_cases_df.set_index("Specimen date")
confirmed_cases_df = confirmed_cases_df.pivot(
    columns="Area name", values="Daily lab-confirmed cases"
)
confirmed_cases_df.to_csv(processed_path / "n_confirmed_cases.csv")


import pandas as pd
from june import paths

raw_path = paths.data_path / "time_series/"
processed_path = paths.data_path / "processed/time_series/"

hosp_df = pd.read_csv(raw_path / "COVID_output_pivot_v3.csv", sep=",", skiprows=1)

hosp_df.set_index("ReportingPeriod", inplace=True)
hosp_df.index = pd.to_datetime(hosp_df.index)
hosp_df["covid_admissions"] = hosp_df[
    ["SIT008_Total", "SIT009_Total", "SIT009_Suspected"]
].sum(axis=1)
hosp_df = hosp_df.groupby([hosp_df.index, "Region_Name"]).sum()
hosp_df = hosp_df[["covid_admissions"]].reset_index()
hosp_df = hosp_df.pivot(
    index="ReportingPeriod", columns="Region_Name", values="covid_admissions"
)
hosp_df.to_csv(processed_path / "hospital_admissions_region.csv")


from typing import List, Dict, Optional

import numpy as np
import pandas as pd

from june import paths
from june.demography import Person
from june.geography import Geography
from june.utils import random_choice_numba

default_data_path = paths.data_path / "input/demography"

default_areas_map_path = paths.data_path / "input/geography/area_super_area_region.csv"

default_config_path = paths.configs_path


def parse_age_bin(age_bin: str):
    pairs = list(map(int, age_bin.split("-")))
    return pairs


class DemographyError(BaseException):
    pass


class AgeSexGenerator:
    def __init__(
        self,
        age_counts: list,
        sex_bins: list,
        female_fractions: list,
        ethnicity_age_bins: list = None,
        ethnicity_groups: list = None,
        ethnicity_structure: list = None,
        max_age=99,
    ):
        """
        age_counts is an array where the index in the array indicates the age,
        and the value indicates the number of counts in that age.
        sex_bins are the lower edges of each sex bin where we have a fraction of females from
        census data, and female_fractions are those fractions.
        ethnicity_age_bins are the lower edges of the age bins that ethnicity data is in
        ethnicity_groups are the labels of the ethnicities which we have data for.
        ethnicity_structure are (integer) ratios of the ethnicities, for each age bin. the sum
        of this strucutre need NOT be the total number of people returned by the generator.
        Example:
            age_counts = [1, 2, 3] means 1 person of age 0, 2 people of age 1 and 3 people of age 2.
            sex_bins = [1, 3] defines two bins: (0,1) and (3, infinity)
            female_fractions = [0.3, 0.5] means that between the ages 0 and 1 there are 30% females,
                                          and there are 50% females in the bin 3+ years
            ethnicity_age_bins - see sex_bins
            ethnicity_groups = ['A','B','C'] - there are three types of ethnicities that we are
                                          assigning here.
            ethnicity_structure = [[0,5,3],[2,3,0],...] in the first age bin, we assign people
                                          ethnicities A:B:C with probability 0:5:3, and so on.
        Given this information we initialize two generators for age and sex, that can be accessed
        through gen = AgeSexGenerator().age() and AgeSexGenerator().sex().

        Parameters
        ----------
        age_counts
            A list or array with the counts for each age.
        female_fractions
            A dictionary where keys are age intervals like "int-int" and the
            values are the fraction of females inside each age bin.
        """
        self.n_residents = np.sum(age_counts)
        ages = np.repeat(np.arange(0, len(age_counts)), age_counts)
        female_fraction_bins = np.digitize(ages, bins=list(map(int, sex_bins))) - 1
        sexes = (
            np.random.uniform(0, 1, size=self.n_residents)
            < np.array(female_fractions)[female_fraction_bins]
        ).astype(int)
        sexes = map(lambda x: ["m", "f"][x], sexes)
        self.age_iterator = iter(ages)
        self.sex_iterator = iter(sexes)
        self.max_age = max_age

        if ethnicity_age_bins is not None:
            ethnicity_age_counts, _ = np.histogram(
                ages, bins=list(map(int, ethnicity_age_bins)) + [100]
            )
            ethnicities = []
            for age_ind, age_count in enumerate(ethnicity_age_counts):
                ethnicities.extend(
                    np.random.choice(
                        np.repeat(ethnicity_groups, ethnicity_structure[age_ind]),
                        age_count,
                    )
                )
            self.ethnicity_iterator = iter(ethnicities)

    @classmethod
    def from_age_sex_bins(
        cls, men_age_dict: dict, women_age_dict: dict, exponential_decay: int = 2
    ):
        """
        Initializes age and sex generator (no ethnicity and socioecon_index for now) from
        a dictionary containing age bins and counts for man and woman. An example of the input is
        men_age_dict = {"0-2" : 10, "2-99": 50}. If the bin contains the 99 value at the end,
        the age will be sampled with an exponential decay of the form e^(-x/exponential_decay).
        """
        age_counts = np.zeros(99, dtype=np.int64)
        sex_bins = []
        female_fractions = []
        for (key_man, value_man), (_, value_woman) in zip(
            men_age_dict.items(), women_age_dict.items()
        ):
            age1, age2 = parse_age_bin(key_man)
            total_people = value_man + value_woman
            sex_bins.append(age1)
            if total_people == 0:
                female_fractions.append(0)
            else:
                female_fractions.append(value_woman / total_people)
            if age2 == 99:
                exp_values = np.exp(-np.arange(0, age2 - age1 + 1) / exponential_decay)
                p = exp_values / exp_values.sum()
                age_dist = np.random.choice(
                    np.arange(age1, age2 + 1), size=total_people, p=p
                )
                ages, counts = np.unique(age_dist, return_counts=True)
                age_counts[ages] += counts
            else:
                age_dist = np.random.choice(
                    np.arange(age1, age2 + 1), size=total_people
                )
                ages, counts = np.unique(age_dist, return_counts=True)
                age_counts[ages] += counts
        return cls(age_counts, sex_bins, female_fractions)

    def age(self) -> int:
        try:
            return min(next(self.age_iterator), self.max_age)
        except StopIteration:
            raise DemographyError("No more people living here!")

    def sex(self) -> str:
        try:
            return next(self.sex_iterator)
        except StopIteration:
            raise DemographyError("No more people living here!")

    def ethnicity(self) -> str:
        try:
            return next(self.ethnicity_iterator)
        except StopIteration:
            raise DemographyError("No more people living here!")


class Population:
    def __init__(self, people: Optional[List[Person]] = None):
        """
        A population of people.

        Behaves mostly like a list but also has the name of the area attached.

        Parameters
        ----------
        people
            A list of people generated to match census data for that area
        """
        if people is None:
            self.people_dict = {}
            self.people_ids = set()
            self.people = []
        else:
            self.people_dict = {person.id: person for person in people}
            self.people_ids = set(self.people_dict.keys())
            self.people = people

    def __len__(self):
        return len(self.people)

    def __iter__(self):
        return iter(self.people)

    def __getitem__(self, index):
        return self.people[index]

    def __add__(self, population: "Population"):
        self.people.extend(population.people)
        self.people_dict = {**self.people_dict, **population.people_dict}
        self.people_ids = set(self.people_dict.keys())
        return self

    def add(self, person):
        self.people_dict[person.id] = person
        self.people.append(person)
        self.people_ids.add(person.id)

    def remove(self, person):
        del self.people_dict[person.id]
        self.people.remove(person)
        self.people_ids.remove(person.id)

    def extend(self, people):
        for person in people:
            self.add(person)

    def get_from_id(self, id):
        return self.people_dict[id]

    @property
    def members(self):
        return self.people

    @property
    def total_people(self):
        return len(self.members)

    @property
    def infected(self):
        return [person for person in self.people if person.infected]

    @property
    def dead(self):
        return [person for person in self.people if person.dead]

    @property
    def vaccinated(self):
        return [person for person in self.people if person.vaccinated]


class Demography:
    def __init__(
        self,
        area_names,
        age_sex_generators: Dict[str, AgeSexGenerator],
        comorbidity_data=None,
    ):
        """
        Tool to generate population for a certain geographical regin.

        Parameters
        ----------
        age_sex_generators
            A dictionary mapping area identifiers to functions that generate
            age and sex for individuals.
        """
        self.area_names = area_names
        self.age_sex_generators = age_sex_generators
        self.comorbidity_data = comorbidity_data

    def populate(self, area_name: str, ethnicity=True, comorbidity=True) -> Population:
        """
        Generate a population for a given area. Age, sex and number of residents
        are all based on census data for that area.

        Parameters
        ----------
        area_name
            The name of an area a population should be generated for

        Returns
        -------
        A population of people
        """
        people = []
        age_and_sex_generator = self.age_sex_generators[area_name]
        if comorbidity:
            comorbidity_generator = ComorbidityGenerator(self.comorbidity_data)
        for _ in range(age_and_sex_generator.n_residents):
            if ethnicity:
                ethnicity_value = age_and_sex_generator.ethnicity()
            else:
                ethnicity_value = None
            person = Person.from_attributes(
                age=age_and_sex_generator.age(),
                sex=age_and_sex_generator.sex(),
                ethnicity=ethnicity_value,
            )
            if comorbidity:
                person.comorbidity = comorbidity_generator.get_comorbidity(person)
            people.append(person)  # add person to population
        return Population(people=people)

    @classmethod
    def for_geography(
        cls,
        geography: Geography,
        data_path: str = default_data_path,
        config: Optional[dict] = None,
    ) -> "Demography":
        """
        Initializes demography from an existing geography.

        Parameters
        ----------
        geography
            an instance of the geography class
        """
        if not geography.areas:
            raise DemographyError("Empty geography!")
        area_names = [area.name for area in geography.areas]
        return cls.for_areas(area_names, data_path, config)

    @classmethod
    def for_zone(
        cls,
        filter_key: Dict[str, list],
        data_path: str = default_data_path,
        areas_maps_path: str = default_areas_map_path,
        config: Optional[dict] = None,
    ) -> "Demography":
        """
        Initializes a geography for a specific list of zones. The zones are
        specified by the filter_dict dictionary where the key denotes the
        kind of zone, and the value is a list with the different zone names.

        Example
        -------
            filter_key = {"region" : "North East"}
            filter_key = {"super_area" : ["EXXXX", "EYYYY"]}
        """
        if len(filter_key.keys()) > 1:
            raise NotImplementedError("Only one type of area filtering is supported.")
        geo_hierarchy = pd.read_csv(areas_maps_path)
        zone_type, zone_list = filter_key.popitem()
        area_names = geo_hierarchy[geo_hierarchy[zone_type].isin(zone_list)]["area"]
        if not area_names.size:
            raise DemographyError("Region returned empty area list.")
        return cls.for_areas(area_names, data_path, config)

    @classmethod
    def for_areas(
        cls,
        area_names: List[str],
        data_path: str = default_data_path,
        config: Optional[dict] = None,
        config_path: str = default_config_path,
    ) -> "Demography":
        """
        Load data from files and construct classes capable of generating demographic
        data for individuals in the population.

        Parameters
        ----------
        area_names
            List of areas for which to create a demographic generator.
        data_path
            The path to the data directory
        config
            Optional configuration. At the moment this just gives an asymptomatic
            ratio.

        Returns
        -------
            A demography representing the super area
        """
        area_names = area_names
        age_structure_path = data_path / "age_structure_single_year.csv"
        female_fraction_path = data_path / "female_ratios_per_age_bin.csv"
        ethnicity_structure_path = data_path / "ethnicity_structure.csv"
        m_comorbidity_path = data_path / "uk_male_comorbidities.csv"
        f_comorbidity_path = data_path / "uk_female_comorbidities.csv"
        age_sex_generators = _load_age_and_sex_generators(
            age_structure_path,
            female_fraction_path,
            ethnicity_structure_path,
            area_names,
        )
        comorbidity_data = load_comorbidity_data(m_comorbidity_path, f_comorbidity_path)
        return Demography(
            age_sex_generators=age_sex_generators,
            area_names=area_names,
            comorbidity_data=comorbidity_data,
        )


def _load_age_and_sex_generators(
    age_structure_path: str,
    female_ratios_path: str,
    ethnicity_structure_path: str,
    area_names: List[str],
) -> Dict[str, AgeSexGenerator]:
    """
    A dictionary mapping area identifiers to a generator of age, sex, ethnicity.

    Returns
    -------
    ethnicity_structure_path
        File containing ethnicity nr. per Area.
        This approach chosen based on:
        Davis, J. A., & Smith, T. W. (1999); Chicago: National Opinion Research Center
    """
    age_structure_df = pd.read_csv(age_structure_path, index_col=0)
    age_structure_df = age_structure_df.loc[area_names]
    age_structure_df.sort_index(inplace=True)

    female_ratios_df = pd.read_csv(female_ratios_path, index_col=0)
    female_ratios_df = female_ratios_df.loc[area_names]
    female_ratios_df.sort_index(inplace=True)

    ethnicity_structure_df = pd.read_csv(
        ethnicity_structure_path, index_col=[0, 1]
    )  # pd MultiIndex!!!
    ethnicity_structure_df = ethnicity_structure_df.loc[pd.IndexSlice[area_names]]
    ethnicity_structure_df.sort_index(level=0, inplace=True)
    # "sort" is required as .loc slicing a multi_index df doesn't work as expected --
    # it preserves original order, and ignoring "repeat slices".
    # TODO fix this to use proper complete indexing.

    ret = {}
    for ((_, age_structure), (index, female_ratios), (_, ethnicity_df)) in zip(
        age_structure_df.iterrows(),
        female_ratios_df.iterrows(),
        ethnicity_structure_df.groupby(level=0),
    ):
        ethnicity_structure = [ethnicity_df[col].values for col in ethnicity_df.columns]
        ret[index] = AgeSexGenerator(
            age_structure.values,
            female_ratios.index.values,
            female_ratios.values,
            ethnicity_df.columns,
            ethnicity_df.index.get_level_values(1),
            ethnicity_structure,
        )

    return ret


def load_comorbidity_data(m_comorbidity_path=None, f_comorbidity_path=None):
    if m_comorbidity_path is not None and f_comorbidity_path is not None:
        male_co = pd.read_csv(m_comorbidity_path)
        female_co = pd.read_csv(f_comorbidity_path)

        male_co = male_co.set_index("comorbidity")
        female_co = female_co.set_index("comorbidity")

        for column in male_co.columns:
            m_nc = male_co[column].loc["no_condition"]
            m_norm_1 = 1 - m_nc
            m_norm_2 = np.sum(male_co[column]) - m_nc

            f_nc = female_co[column].loc["no_condition"]
            f_norm_1 = 1 - f_nc
            f_norm_2 = np.sum(female_co[column]) - f_nc

            for idx in list(male_co.index)[:-1]:
                male_co.loc[idx, column] = (
                    male_co.loc[idx, column] / m_norm_2 * m_norm_1
                )
                female_co.loc[idx, column] = (
                    female_co.loc[idx, column] / f_norm_2 * f_norm_1
                )

        return [male_co, female_co]

    else:
        return None


class ComorbidityGenerator:
    def __init__(self, comorbidity_data):
        self.male_comorbidities_probabilities = np.array(
            comorbidity_data[0].values.T, dtype=np.float64
        )
        self.female_comorbidities_probabilities = np.array(
            comorbidity_data[1].values.T, dtype=np.float64
        )
        self.ages = np.array(comorbidity_data[0].columns).astype(int)
        self.comorbidities = np.array(comorbidity_data[0].index).astype(str)
        self.comorbidities_idx = np.arange(0, len(self.comorbidities))

    def _get_age_index(self, person):
        column_index = 0
        for idx, i in enumerate(self.ages):
            if person.age <= i:
                break
            else:
                column_index = idx
        if column_index != 0:
            column_index += 1
        return column_index

    def get_comorbidity(self, person):
        age_index = self._get_age_index(person)
        if person.sex == "m":
            comorbidity_idx = random_choice_numba(
                self.comorbidities_idx, self.male_comorbidities_probabilities[age_index]
            )
        else:
            comorbidity_idx = random_choice_numba(
                self.comorbidities_idx,
                self.female_comorbidities_probabilities[age_index],
            )
        return self.comorbidities[comorbidity_idx]


def generate_comorbidity(person, comorbidity_data):
    if comorbidity_data is not None:

        male_co = comorbidity_data[0]
        female_co = comorbidity_data[1]
        ages = np.array(male_co.columns).astype(int)

        column_index = 0
        for idx, i in enumerate(ages):
            if person.age <= i:
                break
            else:
                column_index = idx
        if column_index != 0:
            column_index += 1

        if person.sex == "m":
            return random_choice_numba(
                male_co.index.values.astype(str),
                male_co[male_co.columns[column_index]].values,
            )

        elif person.sex == "f":
            return random_choice_numba(
                female_co.index.values.astype(str),
                female_co[female_co.columns[column_index]].values,
            )
    else:
        return None


def load_age_and_sex_generators_for_bins(
    age_sex_bins_filename: str, by="super_area"
) -> Dict[str, AgeSexGenerator]:
    """ """
    data = pd.read_csv(age_sex_bins_filename, index_col=0)
    area_names = data[by].values
    men = data.loc[:, data.columns.str.contains("M")].copy()
    rename_dict = {}
    for column in men.columns:
        rename_dict[column] = column.split(" ")[1]
    men.rename(columns=rename_dict, inplace=True)
    women = data.loc[:, data.columns.str.contains("F")].copy()
    rename_dict = {}
    for column in women.columns:
        rename_dict[column] = column.split(" ")[1]
    women.rename(columns=rename_dict, inplace=True)
    ret = {}
    i = 0
    for (area_name, men_row), (_, women_row) in zip(men.iterrows(), women.iterrows()):
        generator = AgeSexGenerator.from_age_sex_bins(
            men_row.to_dict(), women_row.to_dict()
        )
        ret[area_names[i]] = generator
        i += 1
    return ret


from itertools import count
from random import choice
from recordclass import dataobject

from june.epidemiology.infection import Infection, Immunity

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.geography.geography import Area
    from june.geography.geography import SuperArea
    from june.groups.travel.mode_of_transport import ModeOfTransport
    from june.policy.vaccine_policy import VaccineTrajectory


class Activities(dataobject):
    residence: None
    primary_activity: None
    medical_facility: None
    commute: None
    rail_travel: None
    leisure: None

    def iter(self):
        return [getattr(self, activity) for activity in self.__fields__]


person_ids = count()


class Person(dataobject):
    _id = count()
    id: int = 0
    sex: str = "f"
    age: int = 27
    ethnicity: str = None
    area: "Area" = None
    # work info
    work_super_area: "SuperArea" = None
    sector: str = None
    sub_sector: str = None
    lockdown_status: str = None
    # vaccine
    vaccine_trajectory: "VaccineTrajectory" = None
    vaccinated: int = None
    vaccine_type: str = None
    # comorbidity
    comorbidity: str = None
    # commute
    mode_of_transport: "ModeOfTransport" = None
    # activities
    busy: bool = False
    subgroups: Activities = None
    infection: Infection = None
    immunity: Immunity = None
    # infection
    dead: bool = False

    @classmethod
    def from_attributes(
        cls,
        sex="f",
        age=27,
        susceptibility_dict: dict = None,
        ethnicity=None,
        id=None,
        comorbidity=None,
    ):
        if id is None:
            id = next(Person._id)
        return Person(
            id=id,
            sex=sex,
            age=age,
            ethnicity=ethnicity,
            # IMPORTANT, these objects need to be recreated, otherwise the default
            # is always the same object !!!!
            immunity=Immunity(susceptibility_dict=susceptibility_dict),
            comorbidity=comorbidity,
            subgroups=Activities(None, None, None, None, None, None),
        )

    @property
    def infected(self):
        return self.infection is not None

    @property
    def residence(self):
        return self.subgroups.residence

    @property
    def primary_activity(self):
        return self.subgroups.primary_activity

    @property
    def medical_facility(self):
        return self.subgroups.medical_facility

    @property
    def commute(self):
        return self.subgroups.commute

    @property
    def rail_travel(self):
        return self.subgroups.rail_travel

    @property
    def leisure(self):
        return self.subgroups.leisure

    @property
    def hospitalised(self):
        try:
            return all(
                [
                    self.medical_facility.group.spec == "hospital",
                    self.medical_facility.subgroup_type
                    == self.medical_facility.group.SubgroupType.patients,
                ]
            )
        except AttributeError:
            return False

    @property
    def intensive_care(self):
        try:
            return all(
                [
                    self.medical_facility.group.spec == "hospital",
                    self.medical_facility.subgroup_type
                    == self.medical_facility.group.SubgroupType.icu_patients,
                ]
            )
        except AttributeError:
            return False

    @property
    def housemates(self):
        if self.residence.group.spec == "care_home":
            return []
        return self.residence.group.residents

    def find_guardian(self):
        possible_guardians = [person for person in self.housemates if person.age >= 18]
        if not possible_guardians:
            return None
        guardian = choice(possible_guardians)
        if (
            guardian.infection is not None and guardian.infection.should_be_in_hospital
        ) or guardian.dead:
            return None
        else:
            return guardian

    @property
    def symptoms(self):
        if self.infection is None:
            return None
        else:
            return self.infection.symptoms

    @property
    def super_area(self):
        try:
            return self.area.super_area
        except Exception:
            return None

    @property
    def region(self):
        try:
            return self.super_area.region
        except Exception:
            return None

    @property
    def home_city(self):
        return self.area.super_area.city

    @property
    def work_city(self):
        if self.work_super_area is None:
            return None
        return self.work_super_area.city

    @property
    def available(self):
        if (not self.dead) and (self.medical_facility is None) and (not self.busy):
            return True
        return False

    @property
    def socioeconomic_index(self):
        try:
            return self.area.socioeconomic_index
        except Exception:
            return


from .person import Person, Activities
from .demography import Demography, Population, AgeSexGenerator


import logging
import yaml
from random import shuffle, randint
from collections import OrderedDict, defaultdict

import numpy as np
import pandas as pd

from june import paths
from june.geography import Area, SuperAreas


logger = logging.getLogger("care_home_distributor")

care_homes_per_area_filename = paths.data_path / "input/care_homes/care_homes_ew.csv"

default_config_filename = paths.configs_path / "defaults/groups/care_home.yaml"
default_communal_men_by_super_area = (
    paths.data_path / "input/care_homes/communal_male_residents_by_super_area.csv"
)
default_communal_women_by_super_area = (
    paths.data_path / "input/care_homes/communal_female_residents_by_super_area.csv"
)


class CareHomeError(BaseException):
    pass


class CareHomeDistributor:
    def __init__(
        self,
        communal_men_by_super_area: dict,
        communal_women_by_super_area: dict,
        n_residents_per_worker: int = 10,
        workers_sector="Q",
    ):
        """
        Tool to distribute people from a certain area into a care home, if there is one.

        Parameters
        ----------
        min_age_in_care_home
            minimum age to put people in care home.
        """
        self.communal_men_by_super_area = communal_men_by_super_area
        self.communal_women_by_super_area = communal_women_by_super_area
        self.n_residents_per_worker = n_residents_per_worker
        self.workers_sector = workers_sector

    @classmethod
    def from_file(
        cls,
        communal_men_by_super_area_filename: str = default_communal_men_by_super_area,
        communal_women_by_super_area_filename: str = default_communal_women_by_super_area,
        config_filename: str = default_config_filename,
    ):
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        communal_men_df = pd.read_csv(communal_men_by_super_area_filename, index_col=0)
        communal_women_df = pd.read_csv(
            communal_women_by_super_area_filename, index_col=0
        )
        return cls(
            communal_men_by_super_area=communal_men_df.T.to_dict(),
            communal_women_by_super_area=communal_women_df.T.to_dict(),
            n_residents_per_worker=config["n_residents_per_worker"],
            workers_sector=config["workers_sector"],
        )

    def _create_people_dicts(self, area: Area):
        """
        Creates dictionaries with the men and women per age key living in the area.
        """
        men_by_age = defaultdict(list)
        women_by_age = defaultdict(list)
        for person in area.people:
            if person.sex == "m":
                men_by_age[person.age].append(person)
            else:
                women_by_age[person.age].append(person)
        return men_by_age, women_by_age

    def _find_person_in_age_range(self, people_by_age: dict, age_1, age_2):
        available_people = []
        for age in range(age_1, age_2 + 1):
            available_people += people_by_age[age]
        if not available_people:
            return None
        chosen_person_idx = randint(0, len(available_people) - 1)
        chosen_person = available_people[chosen_person_idx]
        people_by_age[chosen_person.age].remove(chosen_person)
        if not people_by_age[chosen_person.age]:
            del people_by_age[chosen_person.age]
        return chosen_person

    def _sort_dictionary_by_age_range_key(self, d: dict):
        """
        Sorts a dictionary by decreasing order of the age range in the keys.
        """
        ret = OrderedDict()
        ages = [age_range[0] for age_range in d.keys()]
        men_age_ranges_sorted = np.array(list(d.keys()))[np.argsort(ages)[::-1]]
        for key in men_age_ranges_sorted:
            ret[key] = d[key]
        return ret

    def populate_care_homes_in_super_areas(self, super_areas: SuperAreas):
        """
        Populates care homes in the super areas. For each super area, we look into the
        population that lives in communal establishments, from there we pick the oldest ones
        to live in care homes.
        """
        logger.info("Populating care homes")
        total_care_home_residents = 0
        for super_area in super_areas:
            men_communal_residents = self.communal_men_by_super_area[super_area.name]
            women_communal_residents = self.communal_women_by_super_area[
                super_area.name
            ]
            communal_men_sorted = self._sort_dictionary_by_age_range_key(
                men_communal_residents
            )
            communal_women_sorted = self._sort_dictionary_by_age_range_key(
                women_communal_residents
            )
            areas_with_care_homes = [
                area for area in super_area.areas if area.care_home is not None
            ]
            # now we need to choose from each area population which people go to the care home based on
            # the super area statistics. Check who goes first.
            shuffle(areas_with_care_homes)
            areas_dicts = [
                self._create_people_dicts(area) for area in areas_with_care_homes
            ]
            found_person = True
            assert communal_men_sorted.keys() == communal_women_sorted.keys()
            while found_person:
                found_person = False
                for i, area in enumerate(areas_with_care_homes):
                    care_home = area.care_home
                    if len(care_home.residents) < care_home.n_residents:
                        # look for men first
                        for age_range in communal_men_sorted:
                            age1, age2 = list(map(int, age_range.split("-")))
                            if communal_men_sorted[age_range] <= 0:
                                if communal_women_sorted[age_range] <= 0:
                                    continue
                                # find woman
                                person = self._find_person_in_age_range(
                                    areas_dicts[i][1], age1, age2
                                )
                                if person is None:
                                    continue
                                care_home.add(person, care_home.SubgroupType.residents)
                                communal_women_sorted[age_range] -= 1
                                total_care_home_residents += 1
                                found_person = True
                                break
                            person = self._find_person_in_age_range(
                                areas_dicts[i][0], age1, age2
                            )
                            if person is None:
                                # find woman
                                person = self._find_person_in_age_range(
                                    areas_dicts[i][1], age1, age2
                                )
                                if person is None:
                                    continue
                                care_home.add(person, care_home.SubgroupType.residents)
                                communal_women_sorted[age_range] -= 1
                                total_care_home_residents += 1
                                found_person = True
                                break
                            care_home.add(person, care_home.SubgroupType.residents)
                            communal_men_sorted[age_range] -= 1
                            total_care_home_residents += 1
                            found_person = True
                            break
        logger.info(
            f"This world has {total_care_home_residents} people living in care homes."
        )

    def distribute_workers_to_care_homes(self, super_areas: SuperAreas):
        for super_area in super_areas:
            care_homes = [
                area.care_home
                for area in super_area.areas
                if area.care_home is not None
            ]
            if not care_homes:
                continue
            carers = [
                person
                for person in super_area.workers
                if (
                    person.sector == "Q"
                    and person.primary_activity is None
                    and person.sub_sector is None
                )
            ]
            shuffle(carers)
            for care_home in care_homes:
                while len(care_home.workers) < care_home.n_workers:
                    try:
                        carer = carers.pop()
                    except Exception:
                        logger.info(
                            f"Care home in area {care_home.area.name} has not enough workers!"
                        )
                        break
                    care_home.add(
                        person=carer,
                        subgroup_type=care_home.SubgroupType.workers,
                        activity="primary_activity",
                    )
                    carer.lockdown_status = "key_worker"


from collections import defaultdict
import logging
import numpy as np
from random import randint


logger = logging.getLogger("company_distributor")

"""
This file contains routines to attribute people with different characteristics
according to census data.
"""


class CompanyDistributor:
    """
    Distributes workers that are not yet working in key company sectors
    (e.g. such as schools and hospitals) to companies. This assumes that
    the WorkerDistributor has already been run to allocate workers in
    a super_area
    """

    def __init__(self):
        """Get all companies within SuperArea"""

    def distribute_adults_to_companies_in_super_areas(self, super_areas):
        logger.info("Distributing workers to companies")
        for i, super_area in enumerate(super_areas):
            if i % 100 == 0:
                logger.info(
                    f"Distributed workers to companies in {i} of {len(super_areas)} super areas."
                )
            self.distribute_adults_to_companies_in_super_area(super_area)
        logger.info("Workers distributed to companies")

    def distribute_adults_to_companies_in_super_area(self, super_area):
        """
        Looks for all workers and companies in the super area and matches
        them
        """
        company_dict = defaultdict(list)
        full_idx = defaultdict(int)
        unallocated_workers = []
        for company in super_area.companies:
            company_dict[company.sector].append(company)
            full_idx[company.sector] = 0

        for worker in super_area.workers:
            if worker.primary_activity is not None:
                continue
            if company_dict[worker.sector]:
                if full_idx[worker.sector] >= len(company_dict[worker.sector]):
                    idx = randint(0, len(company_dict[worker.sector]) - 1)
                    company = company_dict[worker.sector][idx]
                    # company = np.random.choice(company_dict[worker.sector])
                else:
                    company = company_dict[worker.sector][0]
                    if company.n_workers >= company.n_workers_max:
                        full_idx[company.sector] += 1
                company.add(worker)
            else:
                unallocated_workers.append(worker)

        if unallocated_workers:
            companies_for_unallocated = np.random.choice(
                super_area.companies, len(unallocated_workers)
            )
            for worker, company in zip(unallocated_workers, companies_for_unallocated):
                company.add(worker)


import logging

import numpy as np
from random import shuffle
import yaml
from typing import List, Optional

from june import paths
from june.geography import SuperAreas, SuperArea
from june.groups import Hospitals

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.demography import Person
    from june.groups import Hospital

logger = logging.getLogger("hospital_distributor")

default_config_filename = (
    paths.configs_path / "defaults/distributors/hospital_distributor.yaml"
)


class HospitalDistributor:
    """
    Distributes people to work as health care workers in hospitals

        #TODO: sub sectors of doctors and nurses should be found
        Healthcares sector
        2211: Medical practitioners
        2217: Medical radiographers
        2231: Nurses
        2232: Midwives
    """

    def __init__(
        self,
        hospitals: Hospitals,
        medic_min_age: int,
        patients_per_medic: int,
        healthcare_sector_label: Optional[str] = None,
    ):
        """

        Parameters
        ----------
        hospitals:
            hospitals to populate with workers
        medic_min_age:
            minimum age to qualify as a worker
        patients_per_medic:
            ratio of patients per medic
        healthcare_sector_label:
            string that characterizes the helathcare workers
        """
        # check if this msoarea has hospitals
        self.hospitals = hospitals
        self.medic_min_age = medic_min_age
        self.patients_per_medic = patients_per_medic
        self.healthcare_sector_label = healthcare_sector_label

    @classmethod
    def from_file(cls, hospitals, config_filename=default_config_filename):
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        return HospitalDistributor(
            hospitals=hospitals,
            medic_min_age=config["medic_min_age"],
            patients_per_medic=config["patients_per_medic"],
            healthcare_sector_label=config["healthcare_sector_label"],
        )

    def distribute_medics_from_world(self, people: List["Person"]):
        """
        Randomly distribute people from the world to work as medics for hospitals,
        useful if we don't have data on where do people work. It will still
        match the patients to medic ratio and the minimum age to be a medic.

        Parameters
        ----------
        people:
            list of Persons in the world
        """
        medics = [person for person in people if person.age >= self.medic_min_age]
        shuffle(medics)
        for hospital in self.hospitals:
            max_capacity = hospital.n_beds + hospital.n_icu_beds
            if max_capacity == 0:
                continue
            n_medics = max(int(np.floor(max_capacity / self.patients_per_medic)), 1)
            for _ in range(n_medics):
                medic = medics.pop()
                hospital.add(medic, hospital.SubgroupType.workers)
                medic.lockdown_status = "key_worker"

    def distribute_medics_to_super_areas(self, super_areas: SuperAreas):
        """
        Distribute medics to super areas, flow data is necessary to find medics in the
        super area according to their sector.

        Parameters
        ----------
        super_areas:
            object containing all the super areas to distribute medics
        """
        logger.info("Distributing medics to hospitals")
        for super_area in super_areas:
            self.distribute_medics_to_hospitals(super_area)
        logger.info("Medics distributed to hospitals")

    def get_hospitals_in_super_area(self, super_area: SuperArea) -> List["Hospital"]:
        """
        From all hospitals, filter the ones placed in a given super_area

        Parameters
        ----------
        super_area:
            super area
        """
        hospitals_in_super_area = [
            hospital
            for hospital in self.hospitals.members
            if hospital.super_area.name == super_area.name
        ]
        return hospitals_in_super_area

    def distribute_medics_to_hospitals(self, super_area: SuperArea):
        """
        Distribute medics to hospitals within a super area
        Parameters
        ----------
        super_area:
            super area to distribute medics
        """
        hospitals_in_super_area = self.get_hospitals_in_super_area(super_area)
        if not hospitals_in_super_area:
            return
        medics = [
            person
            for idx, person in enumerate(super_area.workers)
            if person.sector == self.healthcare_sector_label
            and person.age > self.medic_min_age
            and person.primary_activity is None
        ]
        if not medics:
            logger.info(
                f"\n The SuperArea {super_area.name} has no people that work in it!"
            )
            return
        else:
            shuffle(medics)
            for hospital in hospitals_in_super_area:
                max_capacity = hospital.n_beds + hospital.n_icu_beds
                if max_capacity == 0:
                    continue
                n_medics = min(
                    max(int(np.floor(max_capacity / self.patients_per_medic)), 1),
                    len(medics),
                )
                for _ in range(n_medics):
                    medic = medics.pop()
                    hospital.add(medic, hospital.SubgroupType.workers)
                    medic.lockdown_status = "key_worker"

    def assign_closest_hospitals_to_super_areas(self, super_areas):
        if not self.hospitals.members:
            return
        for super_area in super_areas:
            super_area.closest_hospitals = self.hospitals.get_closest_hospitals(
                super_area.coordinates, self.hospitals.neighbour_hospitals
            )


from collections import OrderedDict
from collections import defaultdict
from typing import List
import logging

import numpy as np
import pandas as pd
import yaml
from scipy.stats import rv_discrete

from june import paths
from june.demography import Person
from june.geography import Area
from june.groups import Household, Households

logger = logging.getLogger("household_distributor")

default_config_filename = (
    paths.configs_path / "defaults/distributors/household_distributor.yaml"
)

default_household_composition_filename = (
    paths.data_path / "input/households/household_composition_ew.csv"
)

default_number_students_filename = (
    paths.data_path / "input/households/n_students_ew.csv"
)

default_number_communal_filename = (
    paths.data_path / "input/households/n_communal_ew.csv"
)

default_couples_age_difference_filename = (
    paths.data_path / "input/households/couples_age_difference.csv"
)

default_parent_kid_age_difference_filename = (
    paths.data_path / "input/households/parent_kid_age_difference.csv"
)

default_logging_config_filename = (
    paths.configs_path / "config_world_creation_logger.yaml"
)


"""
This file contains routines to distribute people to households
according to census data.
"""


class HouseholdError(BaseException):
    """class for throwing household related errors"""


def get_closest_element_in_array(array, value):
    min_idx = np.argmin(np.abs(value - array))
    return array[min_idx]


def count_items_in_dict(dictionary):
    counter = 0
    for age in dictionary:
        counter += len(dictionary[age])
    return counter


def count_remaining_people(dict1, dict2):
    return count_items_in_dict(dict1) + count_items_in_dict(dict2)


class HouseholdDistributor:
    def __init__(
        self,
        first_kid_parent_age_differences: dict,
        second_kid_parent_age_differences: dict,
        couples_age_differences: dict,
        number_of_random_numbers=int(1e3),
        kid_max_age=17,
        student_min_age=18,
        student_max_age=25,
        old_min_age=65,
        old_max_age=99,
        adult_min_age=18,
        adult_max_age=64,
        young_adult_min_age=18,
        young_adult_max_age=24,
        max_age_to_be_parent=64,
        max_household_size=8,
        allowed_household_compositions: dict = None,
        ignore_orphans: bool = False,
    ):
        """
        Tool to populate areas with households and fill them with the correct
        composition based on census data. The most important function is
        "distribute_people_to_households" which takes people in an area
        and fills them into households.

        Parameters
        ----------
        first_kid_parent_age_differences:
            dictionary where keys are the age differences between a mother and
            her FIRST kid. The values are the probabilities of each age difference.
        second_kid_parent_age_differences:
            dictionary where keys are the age differences between a mother and
            her SECOND kid. The values are the probabilities of
            each age difference.
        couples_age_differences:
            dictionary where keys are the age differences between a woman and
            a man at the time of marriage. A value of 20 means that the man
            is 20 years older than the woman. The values are the probabilities
            of each age difference.
        number_of_random_numbers:
            Number of random numbers required. This should be set to the
            number of people living in the area, minimum.
        """
        self.kid_max_age = kid_max_age
        self.student_min_age = student_min_age
        self.student_max_age = student_max_age
        self.old_min_age = old_min_age
        self.old_max_age = old_max_age
        self.adult_min_age = adult_min_age
        self.adult_max_age = adult_max_age
        self.young_adult_min_age = young_adult_min_age
        self.young_adult_max_age = young_adult_max_age
        self.max_age_to_be_parent = max_age_to_be_parent
        self.max_household_size = max_household_size
        self.ignore_orphans = ignore_orphans
        self.allowed_household_compositions = allowed_household_compositions
        if self.allowed_household_compositions is None:
            self.allowed_household_compositions = [
                "0 0 0 0 1",
                "0 0 0 1 0",
                "0 0 0 0 2",
                "0 0 0 2 0",
                "1 0 >=0 2 0",
                ">=2 0 >=0 2 0",
                "0 0 >=1 2 0",
                "1 0 >=0 1 0",
                ">=2 0 >=0 1 0",
                "0 0 >=1 1 0",
                "1 0 >=0 >=1 >=0",
                ">=2 0 >=0 >=1 >=0",
                "0 >=1 0 0 0",
                "0 0 0 0 >=2",
                "0 0 >=0 >=0 >=0",
                ">=0 >=0 >=0 >=0 >=0",
            ]

        self._first_kid_parent_age_diff_rv = rv_discrete(
            values=(
                list(first_kid_parent_age_differences.keys()),
                list(first_kid_parent_age_differences.values()),
            )
        )
        self._second_kid_parent_age_diff_rv = rv_discrete(
            values=(
                list(second_kid_parent_age_differences.keys()),
                list(second_kid_parent_age_differences.values()),
            )
        )
        self._couples_age_rv = rv_discrete(
            values=(
                list(couples_age_differences.keys()),
                list(couples_age_differences.values()),
            )
        )
        self._random_sex_rv = rv_discrete(values=((0, 1), (0.5, 0.5)))
        self._refresh_random_numbers_list(number_of_random_numbers)

    @classmethod
    def from_file(
        cls,
        husband_wife_filename: str = default_couples_age_difference_filename,
        parent_child_filename: str = default_parent_kid_age_difference_filename,
        config_filename: str = default_config_filename,
        number_of_random_numbers=int(1e3),
    ) -> "HouseholdDistributor":
        """
        Initializes a household distributor from file. If they are not specified they are assumed to be in the default
        location.

        Parameters
        ----------
        husband_wife_filename:
            Path of the CSV file containing in one column the age differences between
            wife and husband (relative to the wife) and in the second columns the
            probability of that age difference.
        parent_child_filename:
            Path of the CSV file containing in one column the age differences between a
            mother and her kids. The second and third columns must contain the probabilities
            for the first and second kid respectively.
        config_filename:
            Path of the config file defining the different age groups.
        number_of_random_numbers:
            Number of random numbers to initialize. This should be equal to the number of
            people in the area we want to put households in.
        """

        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        husband_wife_df = pd.read_csv(husband_wife_filename, index_col=0)
        parent_child_df = pd.read_csv(parent_child_filename, index_col=0)
        return cls.from_df(
            husband_wife_df,
            parent_child_df,
            number_of_random_numbers=number_of_random_numbers,
            **config,
        )

    @classmethod
    def from_df(
        cls, husband_wife_df: pd.DataFrame, parent_child_df: pd.DataFrame, **kwargs
    ) -> "HouseholdDistributor":
        """
        Initializes a household distributor from dataframes. If they are not specified they are assumed to be in the default
        location.

        Parameters
        ----------
        husband_wife_filename:
            Dataframe containing as index the age differences between wife and husband (relative to the wife)
            and one column with the probability of that age difference.
        parent_child_filename:
            Dataframe containing as index the age differences between a mother and her kids.
            The first and second columns must contain the probabilities for the first and second kid respectively.
        Keyword Arguments:
            Any extra argument that is taken by the __init__ of HouseholdDistributor.
        """
        kids_parents_age_diff_1 = parent_child_df["0"]
        kids_parents_age_diff_2 = parent_child_df["1"]
        couples_age_diff = husband_wife_df
        first_kid_parent_age_differences = kids_parents_age_diff_1.to_dict()
        second_kid_parent_age_differences = kids_parents_age_diff_2.to_dict()
        couples_age_differences = couples_age_diff.to_dict()["frequency"]
        return cls(
            first_kid_parent_age_differences,
            second_kid_parent_age_differences,
            couples_age_differences,
            **kwargs,
        )

    def _refresh_random_numbers_list(self, n=1000) -> None:
        """
        Samples one million age differences for couples and parents-kids. Sampling in batches makes the code much faster. They are converted to lists so they can be popped.
        """
        # create one million random age difference array to save time
        self._couples_age_differences_list = list(self._couples_age_rv.rvs(size=n))
        self._first_kid_parent_age_diff_list = list(
            self._first_kid_parent_age_diff_rv.rvs(size=n)
        )
        self._second_kid_parent_age_diff_list = list(
            self._second_kid_parent_age_diff_rv.rvs(size=n)
        )
        self._random_sex_list = list(self._random_sex_rv.rvs(size=2 * n))

    def _create_people_dicts(self, area: Area):
        """
        Creates dictionaries with the men and women per age key living in the area.
        """
        men_by_age = defaultdict(list)
        women_by_age = defaultdict(list)
        for person in area.people:
            if person.residence is not None:
                continue
            if person.sex == "m":
                men_by_age[person.age].append(person)
            else:
                women_by_age[person.age].append(person)
        return men_by_age, women_by_age

    def distribute_people_and_households_to_areas(
        self,
        areas: List[Area],
        number_households_per_composition_filename: str = default_household_composition_filename,
        n_students_filename: str = default_number_students_filename,
        n_people_in_communal_filename: str = default_number_communal_filename,
    ):
        """
        Distributes households and people into households for the given areas list.
        The households are stored in area.households.

        Parameters
        ----------
        areas
            list of instances of Area
        number_households_per_composition_filename
            path to the data file containing the number of households per household composition per area
        n_students_filename
            path to file containing the number of students per area
        n_people_in_communal_filename
            path to file containing the number of people living in communal establishments per area
        """
        logger.info("Distributing people to households")
        area_names = [area.name for area in areas]
        household_numbers_df = pd.read_csv(
            number_households_per_composition_filename, index_col=0
        ).loc[area_names]
        n_students_df = pd.read_csv(n_students_filename, index_col=0).loc[area_names]
        n_communal_df = pd.read_csv(n_people_in_communal_filename, index_col=0).loc[
            area_names
        ]
        households_total = []
        counter = 0
        for area, (_, number_households), (_, n_students), (_, n_communal) in zip(
            areas,
            household_numbers_df.iterrows(),
            n_students_df.iterrows(),
            n_communal_df.iterrows(),
        ):
            men_by_age, women_by_age = self._create_people_dicts(area)
            area.households = self.distribute_people_to_households(
                men_by_age,
                women_by_age,
                area,
                number_households.to_dict(),
                n_students.values[0],
                n_communal.values[0],
            )
            households_total += area.households
            counter += 1
            if counter % 5000 == 0:
                logger.info(f"filled {counter} areas of {len(area_names)}")
        logger.info(
            f"People assigned to households. There are {len(households_total)} households in this world."
        )
        return Households(households_total)

    def distribute_people_to_households(
        self,
        men_by_age,
        women_by_age,
        area: Area,
        number_households_per_composition: list,
        n_students: int,
        n_people_in_communal: int,
    ) -> Households:
        """
        Given a populated output area, it distributes the people to households.
        The instance of the Area class, area, should have two dictionary attributes,
        ``men_by_age`` and ``women_by_age``. The keys of the dictionaries are the ages
        and the values are the Person instances. The process of creating these dictionaries
        is done in people_distributor.py.
        The ``number_households_per_composition`` argument is a dictionary containing the
        number of households per each composition. We obtain this from the nomis dataset and
        should be read by the inputs class in the world init.

        Parameters
        ----------
        area:
            area from which to take people and distribute households.
        number_households_per_composition:
            dictionary containing the different possible household compositions and the number of
            households with that composition as key.
            Example:
            The area E00062207 has this configuration:
            number_households_per_composition = {
            "0 0 0 0 1"           :   15
            "0 0 0 1 0"           :   20
            "0 0 0 0 2"           :   11
            "0 0 0 2 0"           :   24
            "1 0 >=0 2 0"         :   12
            ">=2 0 >=0 2 0"       :    9
            "0 0 >=1 2 0"         :    6
            "1 0 >=0 1 0"         :    5
            ">=2 0 >=0 1 0"       :    3
            "0 0 >=1 1 0"         :    7
            "1 0 >=0 >=1 >=0"     :    0
            ">=2 0 >=0 >=1 >=0"   :    1
            "0 >=1 0 0 0"         :    0
            "0 0 0 0 >=2"         :    0
            "0 0 >=0 >=0 >=0"     :    1
            ">=0 >=0 >=0 >=0 >=0" :    0
            }
            The encoding follows the rule "1 2 3 4 5" = 1 kid, 2 students (that live in student households), 3 young adults, 4 adults, and 5 old people.
        n_students:
            the number of students living this area.
        n_people_in_communal:
            the number of people living in communal establishments in this area.
        """
        # We use these lists to store households that can accomodate different age groups
        # They will be useful to distribute remaining people at the end.
        households_with_extra_adults = []
        households_with_extra_oldpeople = []
        households_with_extra_kids = []
        households_with_extra_youngadults = []
        households_with_kids = []
        all_households = []
        total_people = count_remaining_people(men_by_age, women_by_age)
        self._refresh_random_numbers_list(total_people)
        # import time
        # time.sleep(0.01)

        if not men_by_age and not women_by_age:
            raise HouseholdError("No people in Area!")
        total_number_of_households = 0
        for key in number_households_per_composition:
            total_number_of_households += number_households_per_composition[key]
            if key not in self.allowed_household_compositions:
                raise HouseholdError(f"Household composition {key} not supported")

        # student households
        key = "0 >=1 0 0 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_all_student_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    area=area,
                    n_students=n_students,
                    student_houses_number=house_number,
                    composition_type=key,
                )

        # single person old
        key = "0 0 0 0 1"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_oldpeople_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    people_per_household=1,
                    n_households=house_number,
                    max_household_size=1,
                    extra_people_lists=(
                        households_with_extra_adults,
                        households_with_extra_oldpeople,
                    ),
                    area=area,
                    composition_type=key,
                )
        # couples old
        key = "0 0 0 0 2"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_oldpeople_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    people_per_household=2,
                    n_households=house_number,
                    max_household_size=2,
                    extra_people_lists=(
                        households_with_extra_adults,
                        households_with_extra_oldpeople,
                    ),
                    area=area,
                    composition_type=key,
                )

        # old people houses with possibly more old people
        key = "0 0 0 0 >=2"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_oldpeople_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    people_per_household=2,
                    n_households=house_number,
                    area=area,
                    extra_people_lists=(households_with_extra_oldpeople,),
                    composition_type=key,
                )

        # possible multigenerational, one kid and one adult minimum.
        # even though the number of old people is >=0, we put one old person
        # always if possible.
        key = "1 0 >=0 >=1 >=0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_families_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_households=house_number,
                    kids_per_house=1,
                    parents_per_house=1,
                    old_per_house=1,
                    extra_people_lists=(
                        households_with_kids,
                        households_with_extra_youngadults,
                        households_with_extra_adults,
                    ),
                    area=area,
                    composition_type=key,
                )
        # same as the previous one but with 2 kids minimum.
        key = ">=2 0 >=0 >=1 >=0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_families_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_households=house_number,
                    kids_per_house=2,
                    parents_per_house=1,
                    old_per_house=1,
                    area=area,
                    extra_people_lists=(
                        households_with_extra_kids,
                        households_with_kids,
                        households_with_extra_youngadults,
                        households_with_extra_adults,
                    ),
                    composition_type=key,
                )

        # one kid and one parent for sure, possibly extra young adults.
        key = "1 0 >=0 1 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_families_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_households=house_number,
                    kids_per_house=1,
                    parents_per_house=1,
                    old_per_house=0,
                    area=area,
                    extra_people_lists=(
                        households_with_kids,
                        households_with_extra_youngadults,
                    ),
                    composition_type=key,
                )
        # same as above with two kids instead.
        key = ">=2 0 >=0 1 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_families_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_households=house_number,
                    kids_per_house=2,
                    parents_per_house=1,
                    old_per_house=0,
                    area=area,
                    extra_people_lists=(
                        households_with_extra_kids,
                        households_with_kids,
                        households_with_extra_youngadults,
                    ),
                    composition_type=key,
                )
        # 1 kid and two parents with possibly young adults.
        key = "1 0 >=0 2 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_families_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_households=house_number,
                    kids_per_house=1,
                    parents_per_house=2,
                    old_per_house=0,
                    area=area,
                    extra_people_lists=(
                        households_with_kids,
                        households_with_extra_youngadults,
                    ),
                    composition_type=key,
                )
        # same as above but two kids.
        key = ">=2 0 >=0 2 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_families_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_households=house_number,
                    kids_per_house=2,
                    parents_per_house=2,
                    old_per_house=0,
                    area=area,
                    extra_people_lists=(
                        households_with_kids,
                        households_with_extra_kids,
                        households_with_extra_youngadults,
                    ),
                    composition_type=key,
                )
        # couple adult, it's possible to have a person < 65 with one > 65
        key = "0 0 0 2 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_nokids_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    adults_per_household=2,
                    n_households=house_number,
                    max_household_size=2,
                    area=area,
                    extra_people_lists=(
                        households_with_extra_adults,
                        households_with_extra_oldpeople,
                    ),
                    composition_type=key,
                )
        # one adult (parent) and one young adult (non-dependable child)
        key = "0 0 >=1 1 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_youngadult_with_parents_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    adults_per_household=1,
                    n_households=house_number,
                    area=area,
                    extra_people_lists=(households_with_extra_youngadults,),
                    composition_type=key,
                )

        # same as above but two adults
        key = "0 0 >=1 2 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_youngadult_with_parents_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    adults_per_household=2,
                    n_households=house_number,
                    area=area,
                    extra_people_lists=(households_with_extra_youngadults,),
                    composition_type=key,
                )

        # single person adult
        key = "0 0 0 1 0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                all_households += self.fill_nokids_households(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    adults_per_household=1,
                    n_households=house_number,
                    max_household_size=1,
                    area=area,
                    composition_type=key,
                )

        # other to be filled with remaining young adults, adults, and old people
        key = "0 0 >=0 >=0 >=0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            if house_number > 0:
                for _ in range(house_number):
                    household = self._create_household(
                        area=area, type="other", composition_type=key
                    )
                    households_with_extra_youngadults.append(household)
                    households_with_extra_adults.append(household)
                    households_with_extra_oldpeople.append(household)
                    all_households.append(household)

        # we have so far filled the minimum household configurations.
        # If the area has communal establishments, we fill those next.
        # The remaining people are then assigned to the existing households
        # trying to fit their household composition as much as possible

        remaining_people = count_remaining_people(men_by_age, women_by_age)
        communal_houses = 0  # this is used to count houses later
        key = ">=0 >=0 >=0 >=0 >=0"
        if key in number_households_per_composition:
            house_number = number_households_per_composition[key]
            communal_houses = house_number
            if n_people_in_communal >= 0 and house_number > 0:
                to_fill_in_communal = min(n_people_in_communal, remaining_people)
                all_households += self.fill_all_communal_establishments(
                    men_by_age=men_by_age,
                    women_by_age=women_by_age,
                    n_establishments=house_number,
                    n_people_in_communal=to_fill_in_communal,
                    area=area,
                    composition_type=key,
                )

        # remaining people
        self.fill_random_people_to_existing_households(
            men_by_age,
            women_by_age,
            households_with_extra_kids,
            households_with_kids,
            households_with_extra_youngadults,
            households_with_extra_adults,
            households_with_extra_oldpeople,
            all_households,
        )

        # make sure we have the correct number of households
        if not (
            total_number_of_households - communal_houses
            <= len(all_households)
            <= total_number_of_households
        ):
            raise HouseholdError("Number of households does not match.")
        people_in_households = 0
        # convert permanent residents list to tuples
        for household in all_households:
            people_in_households += len(household.people)
        assert total_people == people_in_households
        return all_households

    def _create_household(
        self, area: Area, composition_type, type=None, max_household_size: int = np.inf
    ) -> Household:
        """Creates household in the area.

        Parameters
        ----------
        area:
            Area in which to create the household.
        communal:
            Whether it is a communal establishment (True) or not (False).
        max_household_size:
            Maximum number of people allowed in the household.

        """
        household = Household(
            type=type,
            max_size=max_household_size,
            area=area,
            composition_type=composition_type,
        )
        return household

    def _add_to_household(
        self, household: Household, person: Person, subgroup=None
    ) -> None:
        """
        Adds person to household and assigns them the correct subgroup.
        """
        if subgroup == "kids":
            household.add(person, household.SubgroupType.kids)
        elif subgroup == "young_adults":
            household.add(person, household.SubgroupType.young_adults)
        elif subgroup == "adults":
            household.add(person, household.SubgroupType.adults)
        elif subgroup == "old":
            household.add(person, household.SubgroupType.old_adults)
        elif subgroup == "default":
            household.add(person, household.SubgroupType.adults)
        else:
            raise HouseholdError(f"Subgroup {subgroup} not recognized")

    def _check_if_age_dict_is_empty(self, people_dict: dict, age: int) -> bool:
        """
        Given a people_dict that contains a list of people for each age, it deletes the
        age key if the number of people of that age is 0.

        Parameters
        ----------
        people_dict:
            dictionary with age as keys and list of people of that age as values.
        age:
            age to check if empty.
        """
        if not people_dict[age]:
            del people_dict[age]
            return True
        return False

    def _check_if_oldpeople_left(self, men_by_age: dict, women_by_age: dict) -> bool:
        """
        Checks whether there are still old people without an allocated household.

        Parameters
        ----------
            area:
                the area to check.
        """
        ret = False
        for age in range(65, 100):
            if age in women_by_age or age in men_by_age:
                ret = True
                break
        return ret

    def _get_closest_person_of_age(
        self, first_dict: dict, second_dict: dict, age: int, min_age=0, max_age=100
    ) -> Person:
        """
        Tries to find the person with the closest age in first dict inside the min_age and max_age.
        If it fails, it looks into the second_dict. If it fails again it returns None.

        Parameters
        ----------
        first_dict:
            dictionary with lists of people by age as keys. This is the first dictionary to look for a suitable person.
        second_dict:
            dictionary with lists of people by age as keys. This is the second dictionary to look for a suitable person.
        age:
            the target age of the person.
        min_age:
            minimum age the person should have.
        max_age:
            maximum age the person should have.
        """
        if age < min_age or age > max_age:
            return

        compatible_ages = np.array(list(first_dict.keys()))
        compatible_ages = compatible_ages[
            (min_age <= compatible_ages) & (compatible_ages <= max_age)
        ]
        if not compatible_ages.size:
            compatible_ages = np.array(list(second_dict.keys()))
            compatible_ages = compatible_ages[
                (min_age <= compatible_ages) & (compatible_ages <= max_age)
            ]
            if not compatible_ages.size:
                return
            first_dict = second_dict
        closest_age = get_closest_element_in_array(compatible_ages, age)
        person = first_dict[closest_age].pop()
        self._check_if_age_dict_is_empty(first_dict, closest_age)
        return person

    def _get_random_person_in_age_bracket(
        self, men_by_age: dict, women_by_age: dict, min_age=0, max_age=100
    ) -> Person:
        """
        Returns a random person of random sex within the specified age bracket (inclusive).

        Parameters
        ----------
        men_by_age
            men left to allocate by age key
        women_by_age
            women left to allocate by age key
        area:
            The area to look at.
        min_age:
            The minimum age the person should have.
        max_age:
            The maximum age the person should have.
        """
        sex = self._random_sex_list.pop()
        age = np.random.randint(min_age, max_age + 1)
        if sex == 0:
            person = self._get_closest_person_of_age(
                men_by_age, women_by_age, age, min_age, max_age
            )
        else:
            person = self._get_closest_person_of_age(
                women_by_age, men_by_age, age, min_age, max_age
            )
        return person

    def _get_matching_partner(
        self, person: Person, men_by_age, women_by_age, under_65=False, over_65=False
    ) -> Person:
        """
        Given a person, it finds a suitable partner with similar age and opposite sex.
        The age difference is sampled from an observed distribution of age differences
        in couples in the US and the UK, and it read by __init__. We first try to look
        for a female parent, as it is more common to have a single mother than a single
        father.

        Parameters
        ----------
        person:
            the person instance to find a partner for.
        men_by_age
            men left to allocate by age key
        women_by_age
            women left to allocate by age key
        area:
            the area where to look for a partner.
        under_65:
            whether to restrict the search for a partner under 65 years old.
        over_65:
            whether to restrict the search for a partner over 65 years old.
        """
        sex = int(not person.sex)  # get opposite sex
        sampled_age_difference = self._couples_age_differences_list.pop()
        if under_65:
            target_age = min(person.age - abs(sampled_age_difference), 64)
        else:
            target_age = person.age + sampled_age_difference
        if over_65:
            target_age = max(65, target_age)
        target_age = max(min(self.old_max_age, target_age), 18)
        if sex == 0:
            partner = self._get_closest_person_of_age(
                men_by_age, women_by_age, target_age, min_age=self.adult_min_age
            )
            return partner
        else:
            partner = self._get_closest_person_of_age(
                women_by_age, men_by_age, target_age, min_age=self.adult_min_age
            )
            return partner

    def _get_matching_parent(
        self, kid: Person, men_by_age: dict, women_by_age: dict
    ) -> Person:
        """
        Given a person under 18 years old (strictly), it finds a matching mother with an age
        difference sampled for the known mother-firstkid age distribution read in the
        __init__ function.

        Parameters
        ----------
        kid:
            The person to look a parent for.
        men_by_age
            men left to allocate by age key
        women_by_age
            women left to allocate by age key
        area:
            The area in which to look for a parent.
        """
        sampled_age_difference = self._first_kid_parent_age_diff_list.pop()
        target_age = max(
            min(kid.age + sampled_age_difference, self.max_age_to_be_parent),
            self.adult_min_age,
        )
        parent = self._get_closest_person_of_age(
            women_by_age,
            men_by_age,
            target_age,
            min_age=self.adult_min_age,
            max_age=self.max_age_to_be_parent,
        )
        return parent

    def _get_matching_second_kid(
        self, parent: Person, men_by_age: dict, women_by_age: dict
    ) -> Person:
        """
        Given a parent, it finds a person under 18 years old with an age difference matching
        the distribution of age difference between a mother and their second kid.

        Parameters
        ----------
        parent:
            the parent (usually mother) to match with her second kid.
        men_by_age
            men left to allocate by age key
        women_by_age
            women left to allocate by age key
        area:
            area in which to look for the kid.
        """
        sampled_age_difference = self._second_kid_parent_age_diff_list.pop()
        target_age = min(max(parent.age - sampled_age_difference, 0), self.kid_max_age)
        if not men_by_age:
            closest_male = np.inf
        else:
            closest_male = (
                get_closest_element_in_array(
                    np.array(list(men_by_age.keys())), target_age
                )
                - target_age
            )
        if not women_by_age:
            closest_female = np.inf
        else:
            closest_female = (
                get_closest_element_in_array(
                    np.array(list(women_by_age.keys())), target_age
                )
                - target_age
            )
        if closest_male < closest_female:
            kid = self._get_closest_person_of_age(
                men_by_age,
                women_by_age,
                target_age,
                min_age=0,
                max_age=self.kid_max_age,
            )
        else:
            kid = self._get_closest_person_of_age(
                women_by_age,
                men_by_age,
                target_age,
                min_age=0,
                max_age=self.kid_max_age,
            )
        return kid

    def fill_all_student_households(
        self,
        men_by_age: dict,
        women_by_age: dict,
        area: Area,
        n_students: int,
        student_houses_number: int,
        composition_type,
    ) -> List[Household]:
        """
        Creates and fills all student households with people in the appropriate age bin (18-25 by default).

        Parameters
        ----------
        area:
            The area in which to create and fill the households.
        n_students:
            Number of students in this area. Found in the NOMIS data.
        student_houses_number:
            Number of student houses in this area.
        """
        if n_students == 0:
            return
        # students per household
        ratio = max(int(n_students / student_houses_number), 1)
        # get all people in the students age
        # fill students to households
        students_left = n_students
        student_houses = []
        for _ in range(0, student_houses_number):
            household = self._create_household(
                area=area, type="student", composition_type=composition_type
            )
            student_houses.append(household)
            for _ in range(0, ratio):
                student = self._get_random_person_in_age_bracket(
                    men_by_age,
                    women_by_age,
                    min_age=self.student_min_age,
                    max_age=self.student_max_age,
                )
                if student is None:
                    student = self._get_random_person_in_age_bracket(
                        men_by_age,
                        women_by_age,
                        min_age=self.student_min_age,
                        max_age=self.student_max_age + 10,
                    )
                self._add_to_household(household, student, subgroup="young_adults")
                students_left -= 1
        assert students_left >= 0
        index = 0
        while students_left:
            household = student_houses[index]
            student = self._get_random_person_in_age_bracket(
                men_by_age,
                women_by_age,
                min_age=self.student_min_age,
                max_age=self.student_max_age,
            )
            if student is None:
                student = self._get_random_person_in_age_bracket(
                    men_by_age,
                    women_by_age,
                    min_age=self.student_min_age,
                    max_age=self.student_max_age + 10,
                )
            self._add_to_household(household, student, subgroup="young_adults")
            students_left -= 1
            index += 1
            index = index % len(student_houses)
        return student_houses

    def fill_oldpeople_households(
        self,
        men_by_age,
        women_by_age,
        people_per_household: int,
        n_households: int,
        area: Area,
        composition_type,
        extra_people_lists=(),
        max_household_size=np.inf,
    ) -> List[Household]:
        """
        Creates and fills households with old people.

        Parameters
        ----------
        area:
            The area in which to create and fill the households.
        n_households:
            Number of households.
        extra_people_lists:
            Tuple of lists where the created households will be added to be used
            later to allocate unallocated people.
        max_household_size:
            The maximum size of the created households.
        """
        households = []
        for i in range(0, n_households):
            household = self._create_household(
                area=area,
                max_household_size=max_household_size,
                type="old",
                composition_type=composition_type,
            )
            households.append(household)
            person = self._get_random_person_in_age_bracket(
                men_by_age,
                women_by_age,
                min_age=self.old_min_age,
                max_age=self.old_max_age,
            )
            if person is None:
                # no old people left, leave the house and the rest empty and adults can come here later.
                for array in extra_people_lists:
                    array.append(household)
                for _ in range(i + 1, n_households):
                    household = self._create_household(
                        area=area,
                        max_household_size=max_household_size,
                        type="old",
                        composition_type=composition_type,
                    )
                    households.append(household)
                    for array in extra_people_lists:
                        array.append(household)
                return households
            self._add_to_household(household, person, subgroup="old")
            if people_per_household > 1 and person is not None:
                partner = self._get_matching_partner(
                    person, men_by_age, women_by_age, over_65=True
                )
                if partner is not None:
                    self._add_to_household(household, partner, subgroup="old")
            if household.size < household.max_size:
                for array in extra_people_lists:
                    array.append(household)
        return households

    def fill_families_households(
        self,
        men_by_age: dict,
        women_by_age: dict,
        n_households: int,
        kids_per_house: int,
        parents_per_house: int,
        old_per_house: int,
        area: Area,
        composition_type,
        max_household_size=np.inf,
        extra_people_lists=(),
    ) -> List[Household]:
        """
        Creates and fills households with families. The strategy is the following:
            - Put the first kid in the household.
            - Find a parent for the kid based on the age difference between parents and their first kid.
            - Find a partner for the first parent, based on age differences at time of marriage.
            - Add a second kid using the age difference with the mother.
            - Fill an extra old person if necessary for multigenerational families.
        Parameters
        ----------
        area:
            The area in which to create and fill the households.
        n_households:
            Number of households.
        kids_per_house:
            Number of kids (<18) in the house.
        old_per_house:
            Number of old people in the house.
        extra_people_lists:
            Tuple of lists where the created households will be added to be used
            later to allocate unallocated people.
        max_household_size:
            The maximum size of the created households.
        """
        households = []
        for i in range(0, n_households):
            household = self._create_household(
                area=area,
                max_household_size=max_household_size,
                type="family",
                composition_type=composition_type,
            )
            households.append(household)
            first_kid = self._get_random_person_in_age_bracket(
                men_by_age, women_by_age, min_age=0, max_age=self.kid_max_age
            )
            if first_kid is not None:
                self._add_to_household(household, first_kid, subgroup="kids")
            else:
                # fill with young adult instead
                first_kid = self._get_random_person_in_age_bracket(
                    men_by_age,
                    women_by_age,
                    min_age=self.young_adult_min_age,
                    max_age=self.young_adult_max_age,
                )
                if first_kid is not None:
                    self._add_to_household(
                        household, first_kid, subgroup="young_adults"
                    )
                else:
                    for array in extra_people_lists:
                        array.append(household)
                    for _ in range(i + 1, n_households):
                        household = self._create_household(
                            area=area,
                            max_household_size=max_household_size,
                            type="family",
                            composition_type=composition_type,
                        )
                        households.append(household)
                        for array in extra_people_lists:
                            array.append(household)
                    return households
            first_parent = self._get_matching_parent(
                first_kid, men_by_age, women_by_age
            )
            if first_parent is None and not self.ignore_orphans:
                #                import matplotlib.pyplot as plt
                #                ages = [person.age for person in area.people]
                #                plt.hist(ages, bins=np.arange(0,100))
                #                plt.show()
                #
                raise HouseholdError(
                    "Orphan kid. Check household configuration and population."
                )
            if first_parent is not None:
                self._add_to_household(household, first_parent, subgroup="adults")
            else:
                for array in extra_people_lists:
                    array.append(household)
                for _ in range(i + 1, n_households):
                    household = self._create_household(
                        area=area,
                        max_household_size=max_household_size,
                        type="family",
                        composition_type=composition_type,
                    )
                    households.append(household)
                    for array in extra_people_lists:
                        array.append(household)
                return households

            for array in extra_people_lists:
                array.append(household)
            if old_per_house > 0:
                for _ in range(old_per_house):
                    random_old = self._get_random_person_in_age_bracket(
                        men_by_age,
                        women_by_age,
                        min_age=self.old_min_age,
                        max_age=self.old_max_age,
                    )
                    if random_old is None:
                        break
                    self._add_to_household(household, random_old, subgroup="old")

            if parents_per_house == 2 and first_parent is not None:
                second_parent = self._get_matching_partner(
                    first_parent, men_by_age, women_by_age
                )
                if second_parent is not None:
                    self._add_to_household(household, second_parent, subgroup="adults")

            if kids_per_house == 2:
                second_kid = self._get_matching_second_kid(
                    first_parent, men_by_age, women_by_age
                )
                if second_kid is not None:
                    self._add_to_household(household, second_kid, subgroup="kids")
                else:
                    second_kid = self._get_random_person_in_age_bracket(
                        men_by_age,
                        women_by_age,
                        min_age=self.young_adult_min_age,
                        max_age=self.young_adult_max_age,
                    )
                    if second_kid is not None:
                        self._add_to_household(
                            household, second_kid, subgroup="young_adults"
                        )
        return households

    def fill_nokids_households(
        self,
        men_by_age,
        women_by_age,
        adults_per_household: int,
        n_households: int,
        area: Area,
        composition_type,
        extra_people_lists=(),
        max_household_size=np.inf,
    ) -> List[Household]:
        """
        Fils households with one or two adults.

        Parameters
        ----------
        adults_per_household:
            number of adults to fill in the household can be one or two.
        n_households:
            number of households with this configuration
        area:
            the area in which to put the household
        extra_people_lists:
            whether to include the created households in a list for extra people to be put in.
        max_household_size:
            maximum size of the created households.
        """
        households = []
        for _ in range(0, n_households):
            household = self._create_household(
                area=area,
                max_household_size=max_household_size,
                type="nokids",
                composition_type=composition_type,
            )
            households.append(household)
            if self._check_if_oldpeople_left(men_by_age, women_by_age):
                # if there are old people left, then put them here together with another adult.
                first_adult = self._get_random_person_in_age_bracket(
                    men_by_age,
                    women_by_age,
                    min_age=self.old_min_age,
                    max_age=self.old_max_age,
                )
                if first_adult is None:
                    raise HouseholdError("But you said there were old people left!")
            else:
                first_adult = self._get_random_person_in_age_bracket(
                    men_by_age,
                    women_by_age,
                    min_age=self.adult_min_age,
                    max_age=self.adult_max_age,
                )
            if first_adult is not None:
                self._add_to_household(household, first_adult, subgroup="adults")
            if adults_per_household == 1:
                if household.size < household.max_size:
                    for array in extra_people_lists:
                        array.append(household)
                continue
            # second_adult = self._get_matching_partner(first_adult, area, under_65=True)
            if first_adult is not None:
                second_adult = self._get_matching_partner(
                    first_adult, men_by_age, women_by_age
                )
                if second_adult is not None:
                    self._add_to_household(household, second_adult, subgroup="adults")
            if household.size < household.max_size:
                for array in extra_people_lists:
                    array.append(household)
        return households

    def fill_youngadult_households(
        self,
        men_by_age: dict,
        women_by_age: dict,
        youngadults_per_household: int,
        n_households: int,
        area: Area,
        composition_type,
        extra_people_lists=(),
    ) -> List[Household]:
        """
        Fils households with young adults (18 to 35) years old.

        Parameters
        ----------
        youngadults_per_household:
            number of adults to fill in the household. Can be any positive number.
        n_households:
            number of households with this configuration
        area:
            the area in which to put the household
        extra_people_lists:
            whether to include the created households in a list for extra people to be put in.
        """
        households = []
        for _ in range(0, n_households):
            household = self._create_household(
                area=area, type="youngadults", composition_type=composition_type
            )
            households.append(household)
            for _ in range(youngadults_per_household):
                person = self._get_random_person_in_age_bracket(
                    men_by_age,
                    women_by_age,
                    min_age=self.young_adult_min_age,
                    max_age=self.young_adult_max_age,
                )
                if person is not None:
                    self._add_to_household(household, person, subgroup="young_adults")
            for array in extra_people_lists:
                array.append(household)
        return households

    def fill_youngadult_with_parents_households(
        self,
        men_by_age: dict,
        women_by_age: dict,
        adults_per_household: int,
        n_households: int,
        area: Area,
        composition_type,
        extra_people_lists=(),
    ) -> List[Household]:
        """
        Fils households with one young adult (18 to 35) and one or two adults.

        Parameters
        ----------
        youngadults_per_household:
            number of adults to fill in the household. Can be one or two.
        n_households:
            number of households with this configuration
        area:
            the area in which to put the household
        extra_people_lists:
            whether to include the created households in a list for extra people to be put in.
        """
        households = []
        for _ in range(0, n_households):
            household = self._create_household(
                area=area, type="ya_parents", composition_type=composition_type
            )
            households.append(household)
            for array in extra_people_lists:
                array.append(household)
            youngadult = self._get_random_person_in_age_bracket(
                men_by_age,
                women_by_age,
                min_age=self.young_adult_min_age,
                max_age=self.young_adult_max_age,
            )
            if youngadult is not None:
                self._add_to_household(household, youngadult, subgroup="young_adults")
            for _ in range(adults_per_household):
                if youngadult is not None:
                    adult = self._get_random_person_in_age_bracket(
                        men_by_age,
                        women_by_age,
                        min_age=youngadult.age + 18,
                        max_age=self.adult_max_age,
                    )
                else:
                    adult = self._get_random_person_in_age_bracket(
                        men_by_age,
                        women_by_age,
                        min_age=self.adult_min_age,
                        max_age=self.adult_max_age,
                    )
                if adult is not None:
                    self._add_to_household(household, adult, subgroup="adults")
        return households

    def fill_all_communal_establishments(
        self,
        men_by_age,
        women_by_age,
        n_establishments: int,
        n_people_in_communal: int,
        area: Area,
        composition_type,
    ) -> List[Household]:
        """
        Fils all comunnal establishments with the remaining people that have not been allocated somewhere else.

        Parameters
        ----------
        n_establishments:
            number of communal establishments.
        n_people_in_communal:
            number of people in each communal establishment
        area:
            the area in which to put the household
        """
        ratio = max(int(n_people_in_communal / n_establishments), 1)
        people_left = n_people_in_communal
        communal_houses = []
        no_adults = False
        for _ in range(0, n_establishments):
            for i in range(ratio):
                if i == 0:
                    person = self._get_random_person_in_age_bracket(
                        men_by_age, women_by_age, min_age=18
                    )
                    if person is None:
                        no_adults = True
                        break
                    household = self._create_household(
                        area=area, type="communal", composition_type=composition_type
                    )
                    communal_houses.append(household)
                    self._add_to_household(household, person, subgroup="default")
                    people_left -= 1
                else:
                    person = self._get_random_person_in_age_bracket(
                        men_by_age, women_by_age
                    )
                    self._add_to_household(household, person, subgroup="default")
                    people_left -= 1
            if no_adults:
                break

        index = 0
        while people_left > 0:
            if not communal_houses:
                # this extreme case happens in area E00174453 (only case in England!!!)
                person = self._get_random_person_in_age_bracket(
                    men_by_age, women_by_age, min_age=15
                )
                household = self._create_household(
                    area=area, type="communal", composition_type=composition_type
                )
                communal_houses.append(household)
                self._add_to_household(household, person, subgroup="default")
                people_left -= 1
                continue
            person = self._get_random_person_in_age_bracket(men_by_age, women_by_age)
            household = communal_houses[index]
            self._add_to_household(household, person, subgroup="default")
            people_left -= 1
            index += 1
            index = index % len(communal_houses)
        return communal_houses

    def _remove_household_from_all_lists(self, household, lists: list) -> None:
        """
        Removes the given households from all the lists in lists.

        Parameters
        ----------

        household
            an instance of Household.
        lists
            list of lists of households.
        """
        for lis in lists:
            try:
                lis.remove(household)
            except ValueError:
                pass

    def _check_if_household_is_full(self, household: Household):
        """
        Checks if a household is full or has the maximum household size allowed by the Distributor.

        Parameters
        ----------
        household:
            the household to check.
        """
        size = household.size
        condition = (size >= household.max_size) or (size >= self.max_household_size)
        return condition

    def fill_random_people_to_existing_households(
        self,
        men_by_age,
        women_by_age,
        households_with_extra_kids: list,
        households_with_kids: list,
        households_with_extra_youngadults: list,
        households_with_extra_adults: list,
        households_with_extra_oldpeople: list,
        all_households: list,
    ) -> None:
        """
        The people that have not been associated a household yet are distributed in the following way.
        Given the lists in the arguments, we assign each age group according to this preferences:
        Kids -> households_with_extra_kids, households_with_kids, any
        Young adults -> households_with_extra_youngadults, households_with_adults, any
        Adults -> households_with_extra_adults, any
        Old people -> households_with_extra_oldpeople, any.

        When we allocate someone to any house, we prioritize the houses that have a small
        number of people (less than the max_household_size parameter defined in the __init__)

        Parameters
        ----------
        men_by_age
            dictionary with men left to allocate by age key.
        women_by_age
            dictionary with women left to allocate by age key.
        number_to_fill
            number of remaining people to distribute into spare households.
        households_with_extra_kids
            list of households that take extra kids.
        households_with_kids
            list of households that already have kids.
        households_with_extra_youngadults
            list of households that take extra young adults.
        households_with_extra_oldpeople
            list of households that take extra old people
        area
            area where households are.
        """
        households_with_space = [
            household
            for household in all_households
            if household.size < household.max_size
        ]
        all_households_no_space_restrictions = households_with_space.copy()
        available_lists = [
            all_households_no_space_restrictions,
            households_with_space,
            households_with_extra_kids,
            households_with_kids,
            households_with_extra_youngadults,
            households_with_extra_adults,
            households_with_extra_oldpeople,
        ]
        available_ages = list(men_by_age.keys()) + list(women_by_age.keys())
        available_ages = np.sort(np.unique(available_ages))
        people_left_dict = OrderedDict({})
        for age in available_ages:
            people_left_dict[age] = []
            if age in men_by_age:
                people_left_dict[age] += men_by_age[age]
            if age in women_by_age:
                people_left_dict[age] += women_by_age[age]
            np.random.shuffle(people_left_dict[age])  # mix men and women

        # fill old people first
        for age in range(self.old_min_age, self.old_max_age + 1):
            if age in people_left_dict:
                for person in people_left_dict[age]:
                    # old with old,
                    # otherwise random
                    household = self._find_household_for_nonkid(
                        [households_with_extra_oldpeople]
                    )
                    if household is None:
                        household = self._find_household_for_nonkid(
                            [
                                households_with_space,
                                all_households_no_space_restrictions,
                            ]
                        )
                    if household is None:
                        household = np.random.choice(all_households)
                    self._add_to_household(household, person, subgroup="old")
                    if self._check_if_household_is_full(household):
                        self._remove_household_from_all_lists(
                            household, available_lists
                        )

        # now young adults
        for age in range(self.young_adult_min_age, self.young_adult_max_age + 1):
            if age in people_left_dict:
                for person in people_left_dict[age]:
                    household = self._find_household_for_nonkid(
                        [households_with_extra_youngadults]
                    )
                    if household is None:
                        household = self._find_household_for_nonkid(
                            [
                                households_with_space,
                                all_households_no_space_restrictions,
                            ]
                        )
                    if household is None:
                        household = np.random.choice(all_households)
                    self._add_to_household(household, person, subgroup="young_adults")
                    if self._check_if_household_is_full(household):
                        self._remove_household_from_all_lists(
                            household, available_lists
                        )
        # now adults
        for age in range(self.young_adult_max_age + 1, self.adult_max_age + 1):
            if age in people_left_dict:
                for person in people_left_dict[age]:
                    household = self._find_household_for_nonkid(
                        [households_with_extra_adults]
                    )
                    if household is None:
                        household = self._find_household_for_nonkid(
                            [
                                households_with_space,
                                all_households_no_space_restrictions,
                            ]
                        )
                    if household is None:
                        household = np.random.choice(all_households)
                    self._add_to_household(household, person, subgroup="adults")
                    if self._check_if_household_is_full(household):
                        self._remove_household_from_all_lists(
                            household, available_lists
                        )

        # and lastly, kids
        for age in range(0, self.kid_max_age + 1):
            if age in people_left_dict:
                for person in people_left_dict[age]:
                    household = self._find_household_for_kid(
                        [households_with_extra_kids]
                    )
                    if household is None:
                        household = self._find_household_for_nonkid(
                            [
                                households_with_kids,
                                # households_with_space,
                                # all_households,
                            ]
                        )
                    if household is None:
                        household = self._find_household_for_nonkid(
                            [
                                households_with_space,
                                all_households_no_space_restrictions,
                            ]
                        )
                    if household is None:
                        household = np.random.choice(all_households)
                    self._add_to_household(household, person, subgroup="kids")
                    if self._check_if_household_is_full(household):
                        self._remove_household_from_all_lists(
                            household, available_lists
                        )

    def _find_household_for_kid(self, priority_lists):
        """
        Finds a suitable household for a kid. It first tries to search for a place in priority_lists[0],
        then 1, etc.

        Parameters
        ----------
        priority_lists:
            list of lists of households. The list should be sorted according to priority allocation.
        """
        for lis in priority_lists:
            list2 = [household for household in lis if household.size > 0]
            if not list2:
                continue
            household = np.random.choice(list2)
            return household

    def _find_household_for_nonkid(self, priority_lists):
        """
        Finds a suitable household for a person over 18 years old (who can live alone).
        It first tries to search for a place in priority_lists[0], then 1, etc.

        Parameters
        ----------
        priority_lists:
            list of lists of households. The list should be sorted according to priority allocation.
        """
        for lis in priority_lists:
            if not lis:
                continue
            household = np.random.choice(lis)
            return household


import logging
from typing import List, Tuple

import numpy as np
import yaml

from june import paths
from june.geography import Area, SuperArea, Geography
from june.groups.school import Schools

default_config_filename = (
    paths.configs_path / "defaults/distributors/school_distributor.yaml"
)

logger = logging.getLogger("school_distributor")

EARTH_RADIUS = 6371  # km

default_decoder = {2314: "secondary", 2315: "primary", 2316: "special_needs"}


class SchoolDistributor:
    """
    Distributes students in an area to different schools
    """

    def __init__(
        self,
        schools: Schools,
        education_sector_label="P",
        neighbour_schools: int = 35,
        age_range: Tuple[int, int] = (0, 19),
        mandatory_age_range: Tuple[int, int] = (5, 18),
        teacher_student_ratio_primary=21,
        teacher_student_ratio_secondary=16,
        teacher_min_age=21,
        max_classroom_size=40,
    ):
        """
        Get closest schools to this output area, per age group
        (different schools admit pupils with different age ranges)

        Parameters
        ----------
        schools:
            instance of Schools, with information on all schools in world.
        area:
            instance of Area.
        config:
            config dictionary.
        """
        self.schools = schools
        self.neighbour_schools = neighbour_schools
        self.school_age_range = age_range
        self.mandatory_school_age_range = mandatory_age_range
        self.education_sector_label = education_sector_label
        self.teacher_min_age = teacher_min_age
        self.teacher_student_ratio_primary = teacher_student_ratio_primary
        self.teacher_student_ratio_secondary = teacher_student_ratio_secondary
        self.max_classroom_size = max_classroom_size

    @classmethod
    def from_file(
        cls,
        schools: "Schools",
        config_filename: str = default_config_filename,
        # mandatory_age_range: Tuple[int, int] = (5, 18),#part of config ?
    ) -> "SchoolDistributor":
        """
        Initialize SchoolDistributor from path to its config file

        Parameters
        ----------
        schools:
            instance of Schools, with information on all schools in world.
        area:
            instance of Area.
        config:
            path to config dictionary

        Returns
        -------
        SchoolDistributor instance
        """
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        education_sector_label = SchoolDistributor.find_jobs(config)
        return SchoolDistributor(
            schools,
            education_sector_label,
            config["neighbour_schools"],
            config["age_range"],
            config["mandatory_age_range"],
            config["teacher_min_age"],
            config["max_classroom_size"],
        )

    @classmethod
    def from_geography(
        cls, geography: Geography, config_filename: str = default_config_filename
    ):
        return cls.from_file(geography.schools, config_filename)

    @staticmethod
    def find_jobs(config: dict):
        education_sector_label = []
        for value1 in config.values():
            if isinstance(value1, dict):
                for value2 in value1.values():
                    education_sector_label.append(value2["sector_id"])
        return education_sector_label

    def distribute_kids_to_school(self, areas: List[Area]):
        """
        Function to distribute kids to schools according to distance
        """
        logger.info("Distributing kids to schools")
        for i, area in enumerate(areas):
            if i % 4000 == 0:
                logger.info(f"Distributed kids in {i} of {len(areas)} areas.")
            closest_schools_by_age = {}
            is_school_full = {}
            for agegroup in self.schools.school_trees:
                closest_schools = []
                closest_schools_idx = self.schools.get_closest_schools(
                    agegroup, area.coordinates, self.neighbour_schools
                )
                for idx in closest_schools_idx:
                    real_idx = self.schools.school_agegroup_to_global_indices[agegroup][
                        idx
                    ]
                    closest_schools.append(self.schools.members[real_idx])
                closest_schools_by_age[agegroup] = closest_schools
                is_school_full[agegroup] = False
            self.distribute_mandatory_kids_to_school(
                area, is_school_full, closest_schools_by_age
            )
            self.distribute_non_mandatory_kids_to_school(
                area, is_school_full, closest_schools_by_age
            )
        logger.info("Kids distributed to schools")

    def distribute_mandatory_kids_to_school(
        self, area: Area, is_school_full: dict, closest_schools_by_age: dict
    ):
        """
        Send kids to the nearest school among the self.neighbour_schools,
        that has vacancies. If none of them has vacancies, pick one of them
        at random (making it larger than it should be)
        """
        for person in area.people:
            if (
                person.age <= self.mandatory_school_age_range[1]
                and person.age >= self.mandatory_school_age_range[0]
            ):
                if person.age not in is_school_full:
                    continue
                if is_school_full[person.age]:
                    random_number = np.random.randint(
                        0,
                        min(
                            len(closest_schools_by_age[person.age]),
                            self.neighbour_schools,
                        ),
                    )
                    school = closest_schools_by_age[person.age][random_number]
                else:
                    schools_full = 0
                    for i in range(self.neighbour_schools):  # look for non full school
                        if i >= len(closest_schools_by_age[person.age]):
                            break
                        school = closest_schools_by_age[person.age][i]
                        if school.n_pupils >= school.n_pupils_max:
                            schools_full += 1
                        else:
                            break

                        is_school_full[person.age] = True
                        random_number = np.random.randint(
                            0,
                            min(
                                len(closest_schools_by_age[person.age]),
                                self.neighbour_schools,
                            ),
                        )
                        school = closest_schools_by_age[person.age][random_number]
                    else:  # just keep the school saved in the previous for loop
                        pass
                # remove from working population
                if person.work_super_area is not None:
                    person.work_super_area.remove_worker(person)
                school.add(person)

    def distribute_non_mandatory_kids_to_school(
        self, area: Area, is_school_full: dict, closest_schools_by_age: dict
    ):
        """
        For kids in age ranges that might go to school, but it is not mandatory
        send them to the closest school that has vacancies among the self.max_schools closests.
        If none of them has vacancies do not send them to school
        """
        for person in area.people:
            if (
                self.school_age_range[0]
                < person.age
                < self.mandatory_school_age_range[0]
                or self.mandatory_school_age_range[1]
                < person.age
                < self.school_age_range[1]
            ):
                if person.age not in is_school_full or is_school_full[person.age]:
                    continue
                else:
                    find_school = False
                    for i in range(self.neighbour_schools):  # look for non full school
                        if i >= len(closest_schools_by_age[person.age]):
                            # TEST THIS
                            break
                        school = closest_schools_by_age[person.age][i]
                        # check number of students in that age group
                        yearindex = person.age - school.age_min + 1
                        n_pupils_age = len(school.subgroups[yearindex].people)
                        if (school.n_pupils < school.n_pupils_max) and (
                            n_pupils_age
                            < (school.n_pupils_max / (school.age_max - school.age_min))
                        ):
                            find_school = True
                            break
                    if find_school:
                        if person.work_super_area is not None:
                            person.work_super_area.remove_worker(person)
                        school.add(person)

    def distribute_teachers_to_schools_in_super_areas(
        self, super_areas: List[SuperArea]
    ):
        for super_area in super_areas:
            self.distribute_teachers_to_school(super_area)

    def distribute_teachers_to_school(self, super_area: SuperArea):
        """
        Assigns teachers to super area. The strategy is the following:
        we loop over the schools to divide them into two subgroups,
        primary schools and secondary schools. If a school is both, then
        we assign it randomly to one of the two.
        Then we loop over the workers in the super area to find the teachers,
        which we also divide into two subgroups analogously to the schools.
        We assign the teachers to the schools following a fix student to teacher ratio.
        We put a lower age limit to teachers at the age of 21.
        """
        # separate schools in primary and secondary
        primary_schools = []
        secondary_schools = []
        for area in super_area.areas:
            for school in area.schools:
                if school.n_pupils == 0:
                    continue
                # note one school can be primary and secondary.
                if type(school.sector) != str:
                    idx = np.random.randint(0, 2)
                    if idx == 0:
                        primary_schools.append(school)
                    else:
                        secondary_schools.append(school)
                else:
                    if "primary" in school.sector:
                        if "secondary" in school.sector:
                            idx = np.random.randint(0, 2)
                            if idx == 0:
                                primary_schools.append(school)
                            else:
                                secondary_schools.append(school)
                        else:
                            primary_schools.append(school)
                    elif "secondary" in school.sector:
                        secondary_schools.append(school)
                    else:
                        idx = np.random.randint(0, 2)
                        if idx == 0:
                            primary_schools.append(school)
                        else:
                            secondary_schools.append(school)
        # assign teacher to student ratios in schools
        for school in primary_schools:
            school.n_teachers_max = int(
                np.round(
                    school.n_pupils
                    / np.random.poisson(self.teacher_student_ratio_primary)
                )
            )
        for school in secondary_schools:
            school.n_teachers_max = int(
                np.round(
                    school.n_pupils
                    / np.random.poisson(self.teacher_student_ratio_secondary)
                )
            )

        np.random.shuffle(primary_schools)
        np.random.shuffle(secondary_schools)
        all_teachers = [
            person
            for person in super_area.workers
            if person.sector == self.education_sector_label
            and person.age > self.teacher_min_age
            and person.primary_activity is None
        ]
        primary_teachers = []
        secondary_teachers = []
        extra_teachers = []
        for teacher in all_teachers:
            if teacher.sub_sector == "teacher_primary":
                primary_teachers.append(teacher)
            elif teacher.sub_sector == "teacher_secondary":
                secondary_teachers.append(teacher)
            else:
                extra_teachers.append(teacher)
        np.random.shuffle(primary_teachers)
        np.random.shuffle(secondary_teachers)
        np.random.shuffle(extra_teachers)
        while primary_teachers:
            all_filled = True
            for primary_school in primary_schools:
                if primary_school.n_pupils == 0:
                    continue
                if primary_school.n_teachers < primary_school.n_teachers_max:
                    all_filled = False
                    teacher = primary_teachers.pop()
                    if not primary_teachers:
                        all_filled = True
                        break
                    primary_school.add(teacher)
                    teacher.lockdown_status = "key_worker"
            if all_filled:
                break

        while secondary_teachers:
            all_filled = True
            for secondary_school in secondary_schools:
                if secondary_school.n_pupils == 0:
                    continue
                if secondary_school.n_teachers < secondary_school.n_teachers_max:
                    all_filled = False
                    teacher = secondary_teachers.pop()
                    if not secondary_teachers:
                        all_filled = True
                        break
                    secondary_school.add(teacher)
                    teacher.lockdown_status = "key_worker"
            if all_filled:
                break

        remaining_teachers = primary_teachers + secondary_teachers + extra_teachers
        empty_schools = [
            school
            for school in primary_schools + secondary_schools
            if school.n_pupils > 0 and school.n_teachers == 0
        ]
        for school in empty_schools:
            if not remaining_teachers:
                break
            teacher = remaining_teachers.pop()
            school.add(teacher)
            teacher.lockdown_status = "key_worker"

        while remaining_teachers:
            all_filled = True
            for school in primary_schools + secondary_schools:
                if school.n_pupils == 0:
                    continue
                if school.n_teachers < school.n_teachers_max:
                    all_filled = False
                    teacher = remaining_teachers.pop()
                    if not remaining_teachers:
                        all_filled = True
                        break
                    school.add(teacher)
                    teacher.lockdown_status = "key_worker"
            if all_filled:
                break

    def limit_classroom_sizes(
        self,
    ):
        """
        Limit subgroup sizes that represent class rooms to a maximum number of students.
        If maximum number is exceeded create new subgroups to distribute students homogeneously
        """
        for school in self.schools:
            school.limit_classroom_sizes(self.max_classroom_size)


from typing import List
from collections import defaultdict
import logging
import numpy as np

from june.groups import University
from june.geography import Areas
from june.demography import Population

logger = logging.getLogger("university_distributor")


class UniversityDistributor:
    def __init__(self, universities: List[University]):
        """
        For each university it searches in the nearby areas for students living
        in student households. Once it has enough to fill the university, it stops
        searching and fills the university.

        Parameters
        ----------
        universities
            a list of universities to fill
        max_number_of_areas
            maximum number of neighbour areas to look for students
        """
        self.universities = universities
        self.min_student_age = 19
        self.max_student_age = 24

    def find_students_in_areas(
        self, students_dict: dict, areas: Areas, university: University
    ):
        for area in areas:
            for household in area.households:
                if household.type == "student":
                    for student in household.residents:
                        if self.min_student_age <= student.age <= self.max_student_age:
                            if student.primary_activity is None:
                                students_dict[university.ukprn]["student"].append(
                                    student.id
                                )
                elif household.type == "communal":
                    for person in household.residents:
                        if self.min_student_age <= person.age <= self.max_student_age:
                            if person.primary_activity is None:
                                students_dict[university.ukprn]["communal"].append(
                                    person.id
                                )
                else:
                    for person in household.residents:
                        if self.min_student_age <= person.age <= self.max_student_age:
                            if person.primary_activity is None:
                                students_dict[university.ukprn]["other"].append(
                                    person.id
                                )

    def distribute_students_to_universities(self, areas: Areas, people: Population):
        """
        For each university, search for students in nearby areas and allocate them to
        the university.
        """
        logger.info("Distributing students to universities")
        need_more_students = True
        distance_increment = 10
        distance = 5
        while need_more_students and distance < 45:
            students_dict = self._build_student_dict(areas=areas, distance=distance)
            self._assign_students_to_unis(students_dict=students_dict, people=people)
            distance += distance_increment
            need_more_students = False
            for university in self.universities:
                if university.n_students < university.n_students_max:
                    need_more_students = True
                    break
        uni_info_dict = {
            university.ukprn: university.n_students for university in self.universities
        }
        for key, value in uni_info_dict.items():
            logger.info(f"University {key} has {value} students.")

    def _build_student_dict(self, areas, distance):
        students_dict = defaultdict(lambda: defaultdict(list))
        # get students in areas
        for university in self.universities:
            close_areas, distances = areas.get_closest_areas(
                coordinates=university.coordinates,
                k=min(len(areas), 1000),
                return_distance=True,
            )
            close_areas = np.array(close_areas)[distances < distance]
            self.find_students_in_areas(
                students_dict=students_dict, areas=close_areas, university=university
            )
        return students_dict

    def _assign_students_to_unis(self, students_dict, people):
        for key in ["student", "communal", "other"]:
            keep_key = True
            while keep_key:
                keep_key = False
                for university in self.universities:
                    student_candidates = students_dict[university.ukprn][key]
                    if student_candidates and not university.is_full:
                        student_id = student_candidates.pop()
                        university.add(
                            people.get_from_id(student_id), subgroup="student"
                        )
                        keep_key = True


import logging
from itertools import count
from typing import List, Dict, Optional

import numpy as np
from random import randint
import pandas as pd
import yaml
from scipy.stats import rv_discrete

from june import paths
from june.demography import Person, Population
from june.geography import Geography, Areas, SuperAreas
from june.utils import random_choice_numba

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.groups.company import CompanyError

default_workflow_file = paths.data_path / "input/work/work_flow.csv"
default_sex_per_sector_per_superarea_file = (
    paths.data_path / "input/work/industry_by_sex_ew.csv"
)
default_areas_map_path = paths.data_path / "input/geography/area_super_area_region.csv"
default_config_file = (
    paths.configs_path / "defaults/distributors/worker_distributor.yaml"
)
default_policy_config_file = paths.configs_path / "defaults/policy/company_closure.yaml"

logger = logging.getLogger("worker_distributor")


class WorkerDistributor:
    """
    This class distributes people to their work. Work is understood as the main
    activity any individuum pursues during the week, e.g. for pupils it is
    learning in schools and for adults it is their work in companies and
    key sectors for which data was provided.
    """

    def __init__(
        self,
        workflow_df: pd.DataFrame,
        sex_per_sector_df: pd.DataFrame,
        company_closure: dict,
        age_range: List[int],
        sub_sector_ratio: dict,
        sub_sector_distr: dict,
        non_geographical_work_location: dict,
    ):
        """
        Parameters
        ----------
        workflow_df
            DataFrame that contains information about where man and woman go to work
            with respect to their SuperArea of residence.
        sector_by_sex_df
            DataFrame that contains information on the nr. of man and woman working
            in different sectors per Area (note that it is thus not provided for the
            SuperArea).
        sub_sector_ratio
            For each region containing the ratio of man and woman respectively that
            work in any key sector type. (e.g. for healthcare, how many man work
            in the key occupations, such as nurses within that sector)
        sub_sector_distr
            For each region containing how many of man and woman respectively
            work in any key sector jobs, such as primary teachers or medical
            practitioners.
        non_geographical_work_location:
            Special work place locations in dataset that do not correspond to a
            SuperArea name but to special cases such as:
            "home", "oversea", "offshore", ...
            They are the key of the dictionary. The value carries the action
            on what should be done with these workers. Currently they are:
            "home": let them work from home
            "bind": randomly select a SuperArea to send the worker to work in
        company_closure:
            Proportion of each company sector who will be defined as a key worker,
            become furloughed of will randomly assigned to go to work during a lockdown
        """
        self.workflow_df = workflow_df
        self.sex_per_sector_df = sex_per_sector_df
        self.age_range = age_range
        self.sub_sector_ratio = sub_sector_ratio
        self.sub_sector_distr = sub_sector_distr
        self.non_geographical_work_location = non_geographical_work_location
        self.company_closure = company_closure
        self._boundary_workers_counter = count()
        self.n_boundary_workers = 0

    def distribute(self, areas: Areas, super_areas: SuperAreas, population: Population):
        """
        Assign any person within the eligible working age range a location
        (SuperArea) of their work, and the sector (e.g. "P"=education) of
        their work.

        Parameters
        ----------
        """
        self.areas = areas
        self.super_areas = super_areas
        lockdown_tags = np.array(["key_worker", "random", "furlough"])
        lockdown_tags_idx = np.arange(0, len(lockdown_tags))
        lockdown_tags_probabilities_by_sector = (
            self._parse_closure_probabilities_by_sector(
                company_closure=self.company_closure, lockdown_tags=lockdown_tags
            )
        )
        logger.info("Distributing workers to super areas...")
        for i, area in enumerate(iter(self.areas)):
            wf_area_df = self.workflow_df.loc[(area.super_area.name,)]
            self._work_place_lottery(area.name, wf_area_df, len(area.people))
            self._lockdown_status_lottery(len(area.people))
            for idx, person in enumerate(area.people):
                if self.age_range[0] <= person.age <= self.age_range[1]:
                    self._assign_work_location(idx, person, wf_area_df)
                    self._assign_work_sector(idx, person)
                    self._assign_lockdown_status(
                        lockdown_tags_probabilities_by_sector,
                        lockdown_tags,
                        lockdown_tags_idx,
                        person,
                    )
            if i % 5000 == 0 and i != 0:
                logger.info(f"Distributed workers in {i} areas of {len(self.areas)}")
        logger.info("Workers distributed.")

    def _work_place_lottery(
        self, area_name: str, wf_area_df: pd.DataFrame, n_workers: int
    ):
        """
        Create lottery that randomly assigns people a sector and location
        of work.
        """
        # work msoa area/flow data
        work_msoa_man_rv = rv_discrete(
            values=(
                np.arange(0, len(wf_area_df.index.values)),
                wf_area_df["n_man"].values,
            )
        )
        self.work_msoa_man_rnd = work_msoa_man_rv.rvs(size=n_workers)
        work_msoa_woman_rv = rv_discrete(
            values=(
                np.arange(0, len(wf_area_df.index.values)),
                wf_area_df["n_woman"].values,
            )
        )
        self.work_msoa_woman_rnd = work_msoa_woman_rv.rvs(size=n_workers)
        # companies data
        numbers = np.arange(1, 22)
        m_col = [col for col in self.sex_per_sector_df.columns.values if "m " in col]

        f_col = [col for col in self.sex_per_sector_df.columns.values if "f " in col]
        self.sector_dict = {
            (idx + 1): col.split(" ")[-1] for idx, col in enumerate(m_col)
        }
        try:
            # fails if no female work in this Area
            distribution_female = (
                self.sex_per_sector_df.loc[area_name][f_col].fillna(0).values
            )
            self.sector_distribution_female = rv_discrete(
                values=(numbers, distribution_female)
            )
            self.sector_female_rnd = self.sector_distribution_female.rvs(size=n_workers)
        except Exception:
            logger.info(f"The Area {area_name} has no woman working in it.")
        try:
            # fails if no male work in this Area
            distribution_male = (
                self.sex_per_sector_df.loc[area_name][m_col].fillna(0).values
            )
            self.sector_distribution_male = rv_discrete(
                values=(numbers, distribution_male)
            )
            self.sector_male_rnd = self.sector_distribution_male.rvs(size=n_workers)
        except Exception:
            logger.info(f"The Area {area_name} has no man working in it.")

    def _assign_work_location(self, i: int, person: Person, wf_area_df: pd.DataFrame):
        """
        Employ people in any given sector.
        """
        if person.sex == "f":
            work_location = wf_area_df.index.values[self.work_msoa_woman_rnd[i]]
        else:
            work_location = wf_area_df.index.values[self.work_msoa_man_rnd[i]]
        try:
            super_area = self.super_areas.members_by_name[work_location]
            super_area.add_worker(person)
        except KeyError:
            if work_location in list(self.non_geographical_work_location):
                if self.non_geographical_work_location[work_location] == "home":
                    person.work_super_area = None
                elif self.non_geographical_work_location[work_location] == "bind":
                    self._select_rnd_superarea(person)
                else:
                    raise KeyError(
                        f"Work location {work_location} not found in world's geogeraphy"
                    )
            else:
                self._select_rnd_superarea(person)

    def _select_rnd_superarea(self, person: Person):
        """
        Selects random SuperArea to send a worker to work in
        """
        idx = randint(0, len(self.super_areas) - 1)
        self.super_areas.members[idx].add_worker(person)

    def _assign_work_sector(self, i: int, person: Person):
        """
        Employ people in a given SuperArea.
        """
        if person.sex == "f":
            sector_idx = self.sector_female_rnd[i]
        else:
            sector_idx = self.sector_male_rnd[i]
        person.sector = self.sector_dict[sector_idx]

        if person.sector in list(self.sub_sector_ratio.keys()):
            self._assign_sub_sector(person)

    def _assign_sub_sector(self, person):
        """
        Assign sub-sector job as defined in config
        """
        MC_random = np.random.uniform()
        ratio = self.sub_sector_ratio[person.sector][person.sex]
        distr = self.sub_sector_distr[person.sector][person.sex]
        if MC_random < ratio:
            sub_sector_idx = rv_discrete(values=(np.arange(len(distr)), distr)).rvs()
            person.sub_sector = self.sub_sector_distr[person.sector]["label"][
                sub_sector_idx
            ]

    def _lockdown_status_lottery(self, n_workers):
        """
        Creates run-once random list for each person in an area for assigning to a lockdown status
        """

        self.lockdown_status_random = np.random.choice(2, n_workers, p=[4 / 5, 1 / 5])

    def _parse_closure_probabilities_by_sector(
        self, company_closure: dict, lockdown_tags: List
    ):
        """
        parses config file of closure probabilities
        """
        ret = {}
        for sector in company_closure:
            ret[sector] = np.array(
                [
                    self.company_closure[sector][lockdown_tags[0]],
                    self.company_closure[sector][lockdown_tags[1]],
                    self.company_closure[sector][lockdown_tags[2]],
                ]
            )
        return ret

    def _assign_lockdown_status(
        self,
        probabilities_by_sector: dict,
        lockdown_tags: List[str],
        lockdown_tags_idx: List[int],
        person: Person,
    ):
        """
        Assign lockdown_status in proportion to definitions in the policy config
        """
        # value = np.random.choice(values, 1, p=probs)[0]
        idx = random_choice_numba(
            lockdown_tags_idx, probabilities_by_sector[person.sector]
        )

        # Currently all people definitely not furloughed or key are assigned a 'random' tag which allows for
        # them to dynamically be sent to work. For now we fix this so that the same 1/5 people go to work once a week
        # rather than a 1/5 chance that a person with a 'random' tag goes to work.
        # If commented out then people will be correctly assigned random tag for going to work randomly
        # if value == "random" and self.lockdown_status_random[idx] == 0:
        #    value = "furlough"

        person.lockdown_status = lockdown_tags[idx]

    @classmethod
    def for_geography(
        cls,
        geography: Geography,
        workflow_file: str = default_workflow_file,
        sex_per_sector_file: str = default_sex_per_sector_per_superarea_file,
        config_file: str = default_config_file,
        policy_config_file: str = default_policy_config_file,
    ) -> "WorkerDistributor":
        """
        Parameters
        ----------
        geography
            an instance of the geography class
        """
        area_names = [super_area.name for super_area in geography.super_areas]
        if not area_names:
            raise CompanyError("Empty geography!")
        return cls.for_super_areas(
            area_names,
            workflow_file,
            sex_per_sector_file,
            config_file,
            policy_config_file,
        )

    @classmethod
    def for_zone(
        cls,
        filter_key: Dict[str, list],
        areas_maps_path: str = default_areas_map_path,
        workflow_file: str = default_workflow_file,
        sex_per_sector_file: str = default_sex_per_sector_per_superarea_file,
        config_file: str = default_config_file,
        policy_config_file: str = default_policy_config_file,
    ) -> "WorkerDistributor":
        """

        Example
        -------
            filter_key = {"region" : "North East"}
            filter_key = {"super_area" : ["EXXXX", "EYYYY"]}
        """
        if len(filter_key.keys()) > 1:
            raise NotImplementedError("Only one type of area filtering is supported.")
        if "area" in len(filter_key.keys()):
            raise NotImplementedError(
                "Company data only for the SuperArea (MSOA) and above."
            )
        geo_hierarchy = pd.read_csv(areas_maps_path)
        zone_type, zone_list = filter_key.popitem()
        area_names = geo_hierarchy[geo_hierarchy[zone_type].isin(zone_list)][
            "super_area"
        ]
        if not area_names:
            raise CompanyError("Region returned empty area list.")
        return cls.for_super_areas(
            area_names,
            workflow_file,
            sex_per_sector_file,
            config_file,
            policy_config_file,
        )

    @classmethod
    def for_super_areas(
        cls,
        area_names: List[str],
        workflow_file: str = default_workflow_file,
        sex_per_sector_file: str = default_sex_per_sector_per_superarea_file,
        config_file: str = default_config_file,
        policy_config_file: str = default_policy_config_file,
    ) -> "WorkerDistributor":
        """ """
        return cls.from_file(
            area_names,
            workflow_file,
            sex_per_sector_file,
            config_file,
            policy_config_file,
        )

    @classmethod
    def from_file(
        cls,
        area_names: Optional[List[str]] = None,
        workflow_file: str = default_workflow_file,
        sex_per_sector_file: str = default_sex_per_sector_per_superarea_file,
        config_file: str = default_config_file,
        policy_config_file: str = default_policy_config_file,
    ) -> "WorkerDistributor":
        """
        Parameters
        ----------
        area_names
            List of SuperArea names for which to initiate WorkerDistributor
        workflow_file
            Filename to data containing information about where man and woman
            go to work with respect to their SuperArea of residence.
        sex_per_sector_file
        education_sector_file
        healthcare_sector_file
        """
        if area_names is None:
            area_names = []
        workflow_df = load_workflow_df(workflow_file, area_names)
        sex_per_sector_df = load_sex_per_sector(sex_per_sector_file, area_names)
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        with open(policy_config_file) as f:
            policy_config = yaml.load(f, Loader=yaml.FullLoader)
        return WorkerDistributor(
            workflow_df,
            sex_per_sector_df,
            policy_config["company_closure"]["sectors"],
            **config,
        )


def load_workflow_df(
    workflow_file: str = default_workflow_file, area_names: Optional[List[str]] = None
) -> pd.DataFrame:
    wf_df = pd.read_csv(
        workflow_file,
        delimiter=",",
        skiprows=1,
        usecols=[0, 1, 3, 4],
        names=["super_area", "work_super_area", "n_man", "n_woman"],
    )
    if area_names:
        wf_df = wf_df[wf_df["super_area"].isin(area_names)]
    # convert into ratios
    wf_df = wf_df.groupby(["super_area", "work_super_area"]).agg(
        {"n_man": "sum", "n_woman": "sum"}
    )
    wf_df["n_man"] = (
        wf_df.groupby(level=0)["n_man"].apply(lambda x: x / float(x.sum(axis=0))).values
    )
    wf_df["n_woman"] = (
        wf_df.groupby(level=0)["n_woman"]
        .apply(lambda x: x / float(x.sum(axis=0)))
        .values
    )
    return wf_df


def load_sex_per_sector(
    sector_by_sex_file: str = default_sex_per_sector_per_superarea_file,
    area_names: Optional[List[str]] = None,
) -> pd.DataFrame:
    sector_by_sex_df = pd.read_csv(sector_by_sex_file, index_col=0)
    # define all columns in csv file relateing to males
    m_columns = [col for col in sector_by_sex_df.columns.values if "m " in col]
    m_columns.remove("m all")
    m_columns.remove("m R S T U")
    f_columns = [col for col in sector_by_sex_df.columns.values if "f " in col]
    f_columns.remove("f all")
    f_columns.remove("f R S T U")

    uni_columns = [col for col in sector_by_sex_df.columns.values if "all " in col]
    sector_by_sex_df = sector_by_sex_df.drop(
        uni_columns + ["m all", "m R S T U", "f all", "f R S T U"], axis=1
    )

    if area_names:
        geo_hierarchy = pd.read_csv(default_areas_map_path)
        area_names = geo_hierarchy[geo_hierarchy["super_area"].isin(area_names)]["area"]
        sector_by_sex_df = sector_by_sex_df.loc[area_names]
        if (np.sum(sector_by_sex_df["m Q"]) == 0) and (
            np.sum(sector_by_sex_df["f Q"]) == 0
        ):
            logger.info("There exists no Healthcare sector in this geography.")
        if (np.sum(sector_by_sex_df["m P"]) == 0) and (
            np.sum(sector_by_sex_df["f P"]) == 0
        ):
            logger.info("There exists no Education sector in this geography.")

    # convert counts to ratios
    # Convert columns to float to avoid dtype incompatibility warnings
    sector_by_sex_df[m_columns] = sector_by_sex_df[m_columns].astype(float)
    sector_by_sex_df[f_columns] = sector_by_sex_df[f_columns].astype(float)
    
    sector_by_sex_df.loc[:, m_columns] = sector_by_sex_df.loc[:, m_columns].div(
        sector_by_sex_df[m_columns].sum(axis=1), axis=0
    )
    sector_by_sex_df.loc[:, f_columns] = sector_by_sex_df.loc[:, f_columns].div(
        sector_by_sex_df[f_columns].sum(axis=1), axis=0
    )
    return sector_by_sex_df


from .worker_distributor import WorkerDistributor, load_sex_per_sector, load_workflow_df
from .care_home_distributor import CareHomeDistributor
from .company_distributor import CompanyDistributor
from .household_distributor import HouseholdDistributor
from .hospital_distributor import HospitalDistributor
from .school_distributor import SchoolDistributor
from .university_distributor import UniversityDistributor


from itertools import count

from june.hdf5_savers import generate_domain_from_hdf5


class Domain:
    """
    The idea is that the world is divided in domains, which are just collections of super areas with
    people living/working/doing leisure in them.

    If we think as domains as sets, then world is the union of all domains, and each domain can have
    a non-zero intersection with other domains (some people can work and live in different domains).

    Domains are sent to MPI core to perfom calculation, and communcation between the processes is
    required to transfer the infection status of people.
    """

    _id = count()

    def __init__(self, id: int = None):
        if id is None:
            self.id = next(self._id)
        self.id = id

    def __iter__(self):
        return iter(self.super_areas)

    @classmethod
    def from_hdf5(
        cls,
        domain_id,
        super_areas_to_domain_dict: dict,
        hdf5_file_path: str,
        interaction_config: str = None,
    ):
        domain = generate_domain_from_hdf5(
            domain_id=domain_id,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
            file_path=hdf5_file_path,
            interaction_config=interaction_config,
        )
        domain.id = domain_id
        return domain


import logging
import json
import pandas as pd
import numpy as np
from score_clustering import Point, ScoreClustering

from june import paths
from june.hdf5_savers import load_data_for_domain_decomposition

default_super_area_adjaceny_graph_path = (
    paths.data_path / "input/geography/super_area_adjacency_graph.json"
)
default_super_area_centroids_path = (
    paths.data_path / "input/geography/super_area_centroids.csv"
)


logger = logging.getLogger("domain")

default_weights = {"population": 5.0, "workers": 1.0, "commuters": 1.0}


class DomainSplitter:
    """
    Class used to split the world into ``n`` domains containing an equal number
    of super areas continuous to each other.
    """

    def __init__(
        self,
        number_of_domains: int,
        super_area_data: dict,
        super_area_centroids_path: str = default_super_area_centroids_path,
        super_area_adjacency_graph_path: str = default_super_area_adjaceny_graph_path,
        weights=default_weights,
    ):
        """
        Parameters
        ----------
        number_of_domains
            how many domains to split for
        super_area_data
            dictionary specifying the number of people, workers, pupils and commmuters
            per super area
        """
        self.number_of_domains = number_of_domains
        with open(super_area_adjacency_graph_path, "r") as f:
            self.adjacency_graph = json.load(f)
        self.super_area_data = super_area_data
        self.super_area_df = pd.read_csv(super_area_centroids_path, index_col=0)
        self.super_area_df = self.super_area_df.loc[super_area_data.keys()]
        super_area_scores = list(
            map(lambda x: self.get_score(x, weights=weights), self.super_area_df.index)
        )
        self.super_area_df.loc[:, "score"] = super_area_scores

    @classmethod
    def generate_world_split(
        cls,
        number_of_domains: int,
        world_path: str,
        weights=default_weights,
        super_area_centroids_path: str = default_super_area_centroids_path,
        super_area_adjacency_graph_path: str = default_super_area_adjaceny_graph_path,
        maxiter=100,
    ):
        super_area_data = load_data_for_domain_decomposition(world_path)
        ds = cls(
            number_of_domains=number_of_domains,
            super_area_data=super_area_data,
            super_area_centroids_path=super_area_centroids_path,
            super_area_adjacency_graph_path=super_area_adjacency_graph_path,
            weights=weights,
        )
        return ds.generate_domain_split(maxiter=maxiter)

    def get_score(self, super_area, weights=default_weights):
        data = self.super_area_data[super_area]
        return (
            weights["population"] * data["n_people"]
            + weights["workers"] * (data["n_workers"] + data["n_pupils"])
            + weights["commuters"] * data["n_commuters"]
        )

    def generate_domain_split(self, maxiter=100):
        points = list(
            self.super_area_df.apply(
                lambda row: Point(row["X"], row["Y"], row["score"], row.name), axis=1
            ).values
        )
        for point in points:
            point.neighbors = [
                points[i]
                for i in np.where(self.adjacency_graph[point.name])[0]
                if i < len(points)
            ]
        sc = ScoreClustering(n_clusters=self.number_of_domains)
        clusters = sc.fit(points, maxiter=maxiter)
        super_areas_per_domain = {}
        score_per_domain = {}
        for (i, cluster) in enumerate(clusters):
            super_areas_per_domain[i] = [point.name for point in cluster.points]
            score_per_domain[i] = cluster.score
        print(f"Score is {sc.calculate_score_unbalance(clusters)}")
        return super_areas_per_domain, score_per_domain


from .domain import Domain
from .domain_decomposition import DomainSplitter


import numpy as np
from time import perf_counter
from time import time as wall_clock
import logging

from .infection import InfectionSelectors, ImmunitySetter
from june.demography import Activities
from june.policy import MedicalCarePolicies
from june.epidemiology.vaccines import VaccinationCampaigns
from june.mpi_setup import mpi_comm, mpi_size, mpi_rank, move_info
from june.groups import MedicalFacilities
from june.records import Record
from june.world import World
from june.time import Timer

from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
    from june.demography.person import Person
    from june.epidemiology.infection_seed.infection_seed import InfectionSeeds

logger = logging.getLogger("epidemiology")
mpi_logger = logging.getLogger("mpi")

if mpi_rank > 0:
    logger.propagate = False


def _get_medical_facilities(world, activity_manager):
    medical_facilities = []
    for group_name in activity_manager.all_super_groups:
        if "visits" in group_name:
            continue

        grouptype = getattr(world, group_name)
        if grouptype is not None:
            if isinstance(grouptype, MedicalFacilities):
                medical_facilities.append(grouptype)
    return medical_facilities


class Epidemiology:
    """
    This class boxes all the functionality related to epidemics,
    namely the infections, infection seeds, infection selectors,
    and susceptibility setter.
    """

    def __init__(
        self,
        infection_selectors: Optional[InfectionSelectors] = None,
        infection_seeds: Optional["InfectionSeeds"] = None,
        immunity_setter: Optional[ImmunitySetter] = None,
        medical_care_policies: Optional[MedicalCarePolicies] = None,
        medical_facilities: Optional[MedicalFacilities] = None,
        vaccination_campaigns: Optional[VaccinationCampaigns] = None,
    ):
        self.infection_selectors = infection_selectors
        self.infection_seeds = infection_seeds
        self.immunity_setter = immunity_setter
        self.medical_care_policies = medical_care_policies
        self.medical_facilities = medical_facilities
        self.vaccination_campaigns = vaccination_campaigns
        self.current_date = None

    def set_immunity(self, world):
        if self.immunity_setter:
            self.immunity_setter.set_immunity(world)

    def set_past_vaccinations(self, people, date, record=None):
        if self.vaccination_campaigns is not None:
            self.vaccination_campaigns.apply_past_campaigns(
                people=people, date=date, record=record
            )

    def set_effective_multipliers(self, population):
        if self.effective_multiplier_setter:
            self.effective_multiplier_setter.set_multipliers(population)

    def set_medical_care(self, world, activity_manager):
        self.medical_facilities = _get_medical_facilities(
            world=world, activity_manager=activity_manager
        )
        if activity_manager.policies:
            self.medical_care_policies = activity_manager.policies.medical_care_policies

    def infection_seeds_timestep(self, timer, record: Record = None):
        if self.infection_seeds:
            self.infection_seeds.unleash_virus_per_day(
                date=timer.date, record=record, time=timer.now
            )

    def do_timestep(
        self,
        world: World,
        timer: Timer,
        record: Record = None,
        infected_ids: list = None,
        infection_ids: list = None,
        people_from_abroad_dict: dict = None,
    ):
        if self.vaccination_campaigns is not None and (
            self.current_date is None or timer.date.date() != self.current_date.date()
        ):
            self.current_date = timer.date
            vaccinate = True
        else:
            vaccinate = False

        # infect the people that got exposed
        if self.infection_selectors:
            infect_in_domains = self.infect_people(
                world=world,
                time=timer.now,
                infected_ids=infected_ids,
                infection_ids=infection_ids,
                people_from_abroad_dict=people_from_abroad_dict,
            )
            self.tell_domains_to_infect(
                world=world, timer=timer, infect_in_domains=infect_in_domains
            )

        # update the health status of the population
        self.update_health_status(
            world=world,
            time=timer.now,
            date=timer.date,
            duration=timer.duration,
            record=record,
            vaccinate=vaccinate,
        )
        if record:
            record.summarise_time_step(timestamp=timer.date, world=world)
            record.time_step(timestamp=timer.date)

    @staticmethod
    def bury_the_dead(world: World, person: "Person", record: Record = None):
        """
        When someone dies, send them to cemetery.
        ZOMBIE ALERT!!

        Parameters
        ----------
        time
        person:
            person to send to cemetery
        """
        if record is not None:
            if person.medical_facility is not None:
                death_location = person.medical_facility.group
            else:
                death_location = person.residence.group
            record.accumulate(
                table_name="deaths",
                location_spec=death_location.spec,
                location_id=death_location.id,
                dead_person_id=person.id,
            )
        person.dead = True
        person.infection = None
        cemetery = world.cemeteries.get_nearest(person)
        cemetery.add(person)
        if person.residence.group.spec == "household":
            household = person.residence.group
            person.residence.residents = tuple(
                mate for mate in household.residents if mate != person
            )
        person.subgroups = Activities(None, None, None, None, None, None)

    @staticmethod
    def recover(person: "Person", record: Record = None):
        """
        When someone recovers, erase the health information they carry and change their susceptibility.

        Parameters
        ----------
        person:
            person to recover
        time:
            time (in days), at which the person recovers
        """
        if record:
            record.accumulate(
                table_name="recoveries",
                recovered_person_id=person.id,
                infection_id=person.infection.infection_id(),
            )
        person.infection = None

    def update_health_status(
        self,
        world: World,
        time: float,
        duration: float,
        date=None,
        record: Record = None,
        vaccinate: bool = False,
    ):
        """
        Update symptoms and health status of infected people.
        Send them to hospital if necessary, or bury them if they
        have died.

        Parameters
        ----------
        time:
            time now
        duration:
            duration of time step
        """
        for person in world.people:
            if person.infected:
                previous_tag = person.infection.tag
                new_status = person.infection.update_health_status(time, duration)
                if record is not None:
                    if previous_tag != person.infection.tag:
                        record.accumulate(
                            table_name="symptoms",
                            infected_id=person.id,
                            symptoms=person.infection.tag.value,
                            infection_id=person.infection.infection_id(),
                        )
                # Take actions on new symptoms
                if self.medical_care_policies:
                    self.medical_care_policies.apply(
                        person=person,
                        medical_facilities=self.medical_facilities,
                        days_from_start=time,
                        record=record,
                    )
                if new_status == "recovered":
                    self.recover(person, record=record)
                elif new_status == "dead":
                    self.bury_the_dead(world, person, record=record)
            if person.dead:
                continue
            if vaccinate:
                self.vaccination_campaigns.apply(
                    person=person, date=date, record=record
                )
                if person.vaccine_trajectory is not None:
                    person.vaccine_trajectory.update_vaccine_effect(
                        person=person, date=date, record=record
                    )

    def infect_people(
        self, world, time, infected_ids, infection_ids, people_from_abroad_dict
    ):
        """
        Given a list of infected ids, it initialises an infection object for them
        and sets it to person.infection. For the people who do not live in this domain
        a dictionary with their ids and domains is prepared to be sent through MPI.
        """
        foreign_ids = []
        foreign_infection_ids = []
        for person_id, infection_id in zip(infected_ids, infection_ids):
            if person_id in world.people.people_ids:
                person = world.people.get_from_id(person_id)
                self.infection_selectors.infect_person_at_time(
                    person=person, time=time, infection_id=infection_id
                )
            else:
                foreign_ids.append(person_id)
                foreign_infection_ids.append(infection_id)

        infect_in_domains = {}
        if foreign_ids:
            people_ids = []
            people_domains = []
            for spec in people_from_abroad_dict:
                for group in people_from_abroad_dict[spec]:
                    for subgroup in people_from_abroad_dict[spec][group]:
                        p_ids = list(
                            people_from_abroad_dict[spec][group][subgroup].keys()
                        )
                        people_ids += p_ids
                        for id in p_ids:
                            people_domains.append(
                                people_from_abroad_dict[spec][group][subgroup][id][
                                    "dom"
                                ]
                            )
            infection_counter = 0
            for id, domain in zip(people_ids, people_domains):
                if id in foreign_ids:
                    if domain not in infect_in_domains:
                        infect_in_domains[domain] = {}
                        infect_in_domains[domain]["id"] = []
                        infect_in_domains[domain]["inf_id"] = []
                    infect_in_domains[domain]["id"].append(id)
                    infect_in_domains[domain]["inf_id"].append(
                        foreign_infection_ids[infection_counter]
                    )
                    infection_counter += 1
        return infect_in_domains

    def tell_domains_to_infect(self, world, timer, infect_in_domains):
        """
        Sends information about the people who got infected in this domain to the other domains.
        """
        mpi_comm.Barrier()
        tick, tickw = perf_counter(), wall_clock()

        invalid_id = 4294967295  # largest possible uint32
        empty = np.array([invalid_id], dtype=np.uint32)

        # we want to make sure we transfer something for every domain.
        # (we have an np.concatenate which doesn't work on empty arrays)

        people_ids = [empty for x in range(mpi_size)]
        infection_ids = [empty for x in range(mpi_size)]

        # FIXME: domain id should not be floats! Origin is well upstream!
        for x in infect_in_domains:
            people_ids[int(x)] = np.array(infect_in_domains[x]["id"], dtype=np.uint32)
            infection_ids[int(x)] = np.array(
                infect_in_domains[x]["inf_id"], dtype=np.uint32
            )

        people_to_infect, n_sending, n_receiving = move_info(people_ids)
        infection_to_infect, n_sending, n_receiving = move_info(infection_ids)

        tock, tockw = perf_counter(), wall_clock()
        logger.info(
            f"CMS: Infection COMS-v2 for rank {mpi_rank}/{mpi_size}({n_sending+n_receiving})"
            f"{tock-tick},{tockw-tickw} - {timer.date}"
        )
        mpi_logger.info(f"{timer.date},{mpi_rank},infection,{tock-tick}")

        for person_id, infection_id in zip(people_to_infect, infection_to_infect):
            try:
                person = world.people.get_from_id(person_id)
                self.infection_selectors.infect_person_at_time(
                    person=person, time=timer.now, infection_id=infection_id
                )
            except Exception:
                if person_id == invalid_id:
                    continue
                raise


# from .epidemiology import Epidemiology
from .vaccines import Vaccine, Vaccines, VaccinationCampaigns


class Immunity:
    """
    This class stores the "medical record" of the person,
    indicating which infections the person has recovered from.
    """

    __slots__ = "susceptibility_dict", "effective_multiplier_dict"

    def __init__(
        self, susceptibility_dict: dict = None, effective_multiplier_dict: dict = None
    ):
        if susceptibility_dict:
            self.susceptibility_dict = susceptibility_dict
        else:
            self.susceptibility_dict = {}
        if effective_multiplier_dict:
            self.effective_multiplier_dict = effective_multiplier_dict
        else:
            self.effective_multiplier_dict = {}

    def add_immunity(self, infection_ids):
        for infection_id in infection_ids:
            self.susceptibility_dict[infection_id] = 0.0

    def add_multiplier(self, infection_id, multiplier):
        self.effective_multiplier_dict[infection_id] = multiplier

    def get_susceptibility(self, infection_id):
        return self.susceptibility_dict.get(infection_id, 1.0)

    def get_effective_multiplier(self, infection_id):
        return self.effective_multiplier_dict.get(infection_id, 1.0)

    def serialize(self):
        return (
            list(self.susceptibility_dict.keys()),
            list(self.susceptibility_dict.values()),
        )

    def is_immune(self, infection_id):
        return self.susceptibility_dict.get(infection_id, 1.0) == 0.0


from typing import Optional
from collections import Counter
import numpy as np
import yaml
from random import random

from june.utils import (
    parse_age_probabilities,
    parse_prevalence_comorbidities_in_reference_population,
    read_comorbidity_csv,
    convert_comorbidities_prevalence_to_dict,
)

from . import Covid19, B117, B16172

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.records.records_writer import Record


default_susceptibility_dict = {
    Covid19.infection_id(): {"0-13": 0.5, "13-100": 1.0},
    B117.infection_id(): {"0-13": 0.5, "13-100": 1.0},
    B16172.infection_id(): {"0-13": 0.5, "13-100": 1.0},
}
default_multiplier_dict = {
    Covid19.infection_id(): 1.0,
    B117.infection_id(): 1.5,
    B16172.infection_id(): 1.5,
}


class ImmunitySetter:
    def __init__(
        self,
        susceptibility_dict: dict = default_susceptibility_dict,
        multiplier_dict: dict = default_multiplier_dict,
        vaccination_dict: dict = None,
        previous_infections_dict=None,
        multiplier_by_comorbidity: Optional[dict] = None,
        comorbidity_prevalence_reference_population: Optional[dict] = None,
        susceptibility_mode="average",
        previous_infections_distribution="uniform",
        record: "Record" = None,
    ):
        """
        Sets immnuity parameters to different viruses.

        Parameters
        ----------
        susceptibility_dict:
           A dictionary mapping infection_id -> susceptibility by age.
           Example:
            susceptibility_dict = {"123" : {"0-50" : 0.5, "50-100" : 0.2}}
        multiplier_dict:
           A dictionary mapping infection_id -> symptoms reduction by age.
           Example:
            multiplier_dict = {"123" : {"0-50" : 0.5, "50-100" : 0.2}}
        vaccination_dict:
            A dictionary specifying the starting vaccination status of the population.
            Example:
                vaccination_dict = {
                    "pfizer": {
                        "percentage_vaccinated": {"0-50": 0.7, "50-100": 1.0},
                        "infections": {
                            Covid19.infection_id(): {
                                "sterilisation_efficacy": {"0-100": 0.5},
                                "symptomatic_efficacy": {"0-100": 0.5},
                            },
                        },
                    },
                    "sputnik": {
                        "percentage_vaccinated": {"0-30": 0.3, "30-100": 0.0},
                        "infections": {
                            B117.infection_id(): {
                                "sterilisation_efficacy": {"0-100": 0.8},
                                "symptomatic_efficacy": {"0-100": 0.8},
                            },
                        },
                    },
                }
            previous_infections_dict:
                A dictionary specifying the current seroprevalence per region and age.
                Example:
                    previous_infections_dict = {
                        "infections": {
                            Covid19.infection_id(): {
                                "sterilisation_efficacy": 0.5,
                                "symptomatic_efficacy": 0.6,
                            },
                            B117.infection_id(): {
                                "sterilisation_efficacy": 0.2,
                                "symptomatic_efficacy": 0.3,
                            },
                        },
                        "ratios": {
                            "London": {"0-50": 0.5, "50-100": 0.2},
                            "North East": {"0-70": 0.3, "70-100": 0.8},
                        },
                    }
        """
        self.susceptibility_dict = self._read_susceptibility_dict(susceptibility_dict)
        if multiplier_dict is None:
            self.multiplier_dict = {}
        else:
            self.multiplier_dict = multiplier_dict
        self.vaccination_dict = self._read_vaccination_dict(vaccination_dict)
        self.previous_infections_dict = self._read_previous_infections_dict(
            previous_infections_dict
        )
        self.multiplier_by_comorbidity = multiplier_by_comorbidity
        if comorbidity_prevalence_reference_population is not None:
            self.comorbidity_prevalence_reference_population = (
                parse_prevalence_comorbidities_in_reference_population(
                    comorbidity_prevalence_reference_population
                )
            )
        else:
            self.comorbidity_prevalence_reference_population = None
        self.susceptibility_mode = susceptibility_mode
        self.previous_infections_distribution = previous_infections_distribution
        self.record = record

    @classmethod
    def from_file_with_comorbidities(
        cls,
        susceptibility_dict: dict = default_susceptibility_dict,
        multiplier_dict: dict = default_multiplier_dict,
        vaccination_dict: dict = None,
        previous_infections_dict: dict = None,
        comorbidity_multipliers_path: Optional[str] = None,
        male_comorbidity_reference_prevalence_path: Optional[str] = None,
        female_comorbidity_reference_prevalence_path: Optional[str] = None,
        susceptibility_mode="average",
        record: "Record" = None,
    ) -> "ImmunitySetter":
        if comorbidity_multipliers_path is not None:
            with open(comorbidity_multipliers_path) as f:
                comorbidity_multipliers = yaml.load(f, Loader=yaml.FullLoader)
            female_prevalence = read_comorbidity_csv(
                female_comorbidity_reference_prevalence_path
            )
            male_prevalence = read_comorbidity_csv(
                male_comorbidity_reference_prevalence_path
            )
            comorbidity_prevalence_reference_population = (
                convert_comorbidities_prevalence_to_dict(
                    female_prevalence, male_prevalence
                )
            )
        else:
            comorbidity_multipliers = None
            comorbidity_prevalence_reference_population = None
        return ImmunitySetter(
            susceptibility_dict=susceptibility_dict,
            multiplier_dict=multiplier_dict,
            vaccination_dict=vaccination_dict,
            previous_infections_dict=previous_infections_dict,
            multiplier_by_comorbidity=comorbidity_multipliers,
            comorbidity_prevalence_reference_population=comorbidity_prevalence_reference_population,
            susceptibility_mode=susceptibility_mode,
            record=record,
        )

    def set_immunity(self, world):
        if self.multiplier_dict:
            self.set_multipliers(world.people)
        if self.susceptibility_dict:
            self.set_susceptibilities(world.people)
        if self.previous_infections_dict:
            self.set_previous_infections(world)
        if self.vaccination_dict:
            self.set_vaccinations(world.people)

    def get_multiplier_from_reference_prevalence(self, age, sex):
        """
        Compute mean comorbidity multiplier given the prevalence of the different comorbidities
        in the reference population (for example the UK). It will be used to remove effect of
        comorbidities in the reference population
        Parameters
        ----------
        age:
            age group to compute average multiplier
        sex:
            sex group to compute average multiplier
        Returns
        -------
            weighted_multiplier:
                weighted mean of the multipliers given prevalence
        """
        weighted_multiplier = 0.0
        for comorbidity in self.comorbidity_prevalence_reference_population.keys():
            weighted_multiplier += (
                self.multiplier_by_comorbidity[comorbidity]
                * self.comorbidity_prevalence_reference_population[comorbidity][sex][
                    age
                ]
            )
        return weighted_multiplier

    def get_weighted_multipliers_by_age_sex(
        self,
    ):
        reference_multipliers = {"m": [], "f": []}
        for sex in ("m", "f"):
            for age in range(100):
                reference_multipliers[sex].append(
                    self.get_multiplier_from_reference_prevalence(age=age, sex=sex)
                )
        return reference_multipliers

    def set_multipliers(self, population):
        if (
            self.multiplier_by_comorbidity is not None
            and self.comorbidity_prevalence_reference_population is not None
        ):
            set_comorbidity_multipliers = True
            reference_weighted_multipliers = self.get_weighted_multipliers_by_age_sex()
        else:
            set_comorbidity_multipliers = False
        for person in population:
            for inf_id in self.multiplier_dict:
                person.immunity.effective_multiplier_dict[
                    inf_id
                ] = self.multiplier_dict[inf_id]
                if set_comorbidity_multipliers:
                    multiplier = self.multiplier_by_comorbidity.get(
                        person.comorbidity, 1.0
                    )
                    reference_multiplier = reference_weighted_multipliers[person.sex][
                        person.age
                    ]
                    person.immunity.effective_multiplier_dict[inf_id] += (
                        multiplier / reference_multiplier
                    ) - 1.0

    def _read_susceptibility_dict(self, susceptibility_dict):
        if susceptibility_dict is None:
            return {}
        ret = {}
        for inf_id in susceptibility_dict:
            ret[inf_id] = parse_age_probabilities(
                susceptibility_dict[inf_id], fill_value=1.0
            )
        return ret

    def _read_vaccination_dict(self, vaccination_dict):
        if vaccination_dict is None:
            return {}
        ret = {}
        for vaccine, vdata in vaccination_dict.items():
            ret[vaccine] = {}
            ret[vaccine]["percentage_vaccinated"] = parse_age_probabilities(
                vdata["percentage_vaccinated"]
            )
            ret[vaccine]["infections"] = {}
            for inf_id in vdata["infections"]:
                ret[vaccine]["infections"][inf_id] = {}
                for key in vdata["infections"][inf_id]:
                    ret[vaccine]["infections"][inf_id][key] = parse_age_probabilities(
                        vdata["infections"][inf_id][key], fill_value=0.0
                    )
        return ret

    def _read_previous_infections_dict(self, previous_infections_dict):
        if previous_infections_dict is None:
            return {}
        ret = {}
        ret["infections"] = previous_infections_dict["infections"]
        ret["ratios"] = {}
        for region, region_ratios in previous_infections_dict["ratios"].items():
            ret["ratios"][region] = parse_age_probabilities(region_ratios)
        return ret

    def set_susceptibilities(self, population):
        if self.susceptibility_mode == "average":
            self._set_susceptibilities_avg(population)
        elif self.susceptibility_mode == "individual":
            self._set_susceptibilities_individual(population)
        else:
            raise NotImplementedError()

    def _set_susceptibilities_avg(self, population):
        for person in population:
            for inf_id in self.susceptibility_dict:
                if person.age >= len(self.susceptibility_dict[inf_id]):
                    continue
                person.immunity.susceptibility_dict[inf_id] = self.susceptibility_dict[
                    inf_id
                ][person.age]

    def _set_susceptibilities_individual(self, population):
        for person in population:
            for inf_id in self.susceptibility_dict:
                if person.age >= len(self.susceptibility_dict[inf_id]):
                    continue
                fraction = self.susceptibility_dict[inf_id][person.age]
                if random() > fraction:
                    person.immunity.susceptibility_dict[inf_id] = 0.0

    def set_vaccinations(self, population):
        """
        Sets previous vaccination on the starting population.
        """
        vaccine_type = []
        susccesfully_vaccinated = np.zeros(len(population), dtype=int)
        if not self.vaccination_dict:
            return
        vaccines = list(self.vaccination_dict.keys())
        for i, person in enumerate(population):
            if person.age > 99:
                age = 99
            else:
                age = person.age
            vaccination_rates = np.array(
                [
                    self.vaccination_dict[vaccine]["percentage_vaccinated"][age]
                    for vaccine in vaccines
                ]
            )
            total_vacc_rate = np.sum(vaccination_rates)
            if random() < total_vacc_rate:
                vaccination_rates /= total_vacc_rate
                vaccine = np.random.choice(vaccines, p=vaccination_rates)
                vdata = self.vaccination_dict[vaccine]
                for inf_id, inf_data in vdata["infections"].items():
                    person.immunity.add_multiplier(
                        inf_id, 1.0 - inf_data["symptomatic_efficacy"][age]
                    )
                    person.immunity.susceptibility_dict[inf_id] = (
                        1.0 - inf_data["sterilisation_efficacy"][age]
                    )
                    susccesfully_vaccinated[i] = 1
                person.vaccinated = True
                vaccine_type.append(vaccine)
            else:
                vaccine_type.append("none")
        if self.record is not None:
            self.record.statics["people"].extra_str_data["vaccine_type"] = vaccine_type
            self.record.statics["people"].extra_int_data[
                "susccesfully_vaccinated"
            ] = susccesfully_vaccinated

    def set_previous_infections(self, world):
        if self.previous_infections_distribution == "uniform":
            self.set_previous_infections_uniform(world.people)
        elif self.previous_infections_distribution == "clustered":
            self.set_previous_infections_clustered(world)
        else:
            raise ValueError(
                f"Previous infection distr. {self.previous_infections_distribution} not recognized"
            )

    def set_previous_infections_uniform(self, population):
        """
        Sets previous infections on the starting population in a uniform way.
        """
        for i, person in enumerate(population):
            if person.region.name not in self.previous_infections_dict["ratios"]:
                continue
            ratio = self.previous_infections_dict["ratios"][person.region.name][
                person.age
            ]
            if random() < ratio:
                for inf_id, inf_data in self.previous_infections_dict[
                    "infections"
                ].items():
                    person.immunity.add_multiplier(
                        inf_id, 1.0 - inf_data["symptomatic_efficacy"]
                    )
                    person.immunity.susceptibility_dict[inf_id] = (
                        1.0 - inf_data["sterilisation_efficacy"]
                    )

    def _get_people_to_infect_by_age(self, people, seroprev_by_age):
        """
        Returns total people to infect according to the serorev age profile
        """
        people_by_age = Counter([person.age for person in people])
        people_to_infect = {
            age: people_by_age[age] * seroprev_by_age[age] for age in people_by_age
        }
        return people_to_infect

    def _get_household_score(self, household, age_distribution):
        if len(household.residents) == 0:
            return 0
        ret = 0
        for resident in household.residents:
            ret += age_distribution[resident.age]
        return ret / np.sqrt(len(household.residents))

    def set_previous_infections_clustered(self, world):
        """
        Sets previous infections on the starting population by clustering households.
        """
        infection_ids = list(self.previous_infections_dict["infections"].keys())
        infection_data = list(self.previous_infections_dict["infections"].values())
        for region in world.regions:
            seroprev_by_age = self.previous_infections_dict["ratios"][region.name]
            people = region.people
            to_infect_by_age = self._get_people_to_infect_by_age(
                people=people, seroprev_by_age=seroprev_by_age
            )
            total_to_infect = sum(to_infect_by_age.values())
            age_distribution = {
                age: to_infect_by_age[age] / total_to_infect for age in to_infect_by_age
            }
            households = np.array(region.households)
            scores = [
                self._get_household_score(h, age_distribution) for h in households
            ]
            cum_scores = np.cumsum(scores)
            prev_inf_households = set()
            while total_to_infect > 0:
                num = random() * cum_scores[-1]
                idx = np.searchsorted(cum_scores, num)
                household = households[idx]
                if household.id in prev_inf_households:
                    continue
                for person in household.residents:
                    for inf_id, inf_data in zip(infection_ids, infection_data):
                        target_symptom_mult = 1.0 - inf_data["symptomatic_efficacy"]
                        target_susceptibility = 1.0 - inf_data["sterilisation_efficacy"]
                        current_multiplier = person.immunity.get_effective_multiplier(
                            inf_id
                        )
                        current_susc = person.immunity.get_susceptibility(inf_id)
                        person.immunity.effective_multiplier_dict[inf_id] = min(
                            current_multiplier, target_symptom_mult
                        )
                        person.immunity.susceptibility_dict[inf_id] = min(
                            current_susc, target_susceptibility
                        )
                    total_to_infect -= 1
                    if total_to_infect < 1:
                        return
                    prev_inf_households.add(household.id)


from zlib import adler32

from .symptoms import Symptoms, SymptomTag

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.epidemiology.infection.transmission import Transmission


class Infection:
    """
    The infection class combines the transmission (infectiousness profile) of the infected
    person, and their symptoms trajectory. We also keep track of how many people someone has
    infected, which is useful to compute R0. The infection probability is updated at every
    time step, according to an infectivity profile.
    """

    __slots__ = ("start_time", "transmission", "symptoms", "time_of_testing")
    _infection_id = None

    def __init__(
        self, transmission: "Transmission", symptoms: "Symptoms", start_time: float = -1
    ):
        """
        Parameters
        ----------
        transmission:
            instance of the class that controls the infectiousness profile
        symptoms:
            instance of the class that controls the symptoms' evolution
        start_time:
            time at which the person is infected
        """
        self.start_time = start_time
        self.transmission = transmission
        self.symptoms = symptoms
        self.time_of_testing = None

    @classmethod  # this could be a property but it is complicated (needs meta classes)
    def infection_id(cls):
        # this creates a unique id for each inherited class
        if not cls._infection_id:
            cls._infection_id = adler32(cls.__name__.encode("ascii"))
        return cls._infection_id

    @classmethod
    def immunity_ids(cls):
        """
        Ids of the infections that upon recovery this infection gives immunity to.
        """
        return (cls.infection_id(),)

    def update_health_status(self, time, delta_time):
        """
        Updates the infection probability and symptoms of the person's infection
        given the simulation time. Returns the new status of the person.

        Parameters:
        -----------
        time: float
            total time since the beginning of the simulation (in days)
        delta_time: float
            duration of the time step.

        Returns:
        --------
        status: str
            new status of the person. one of ``['recovered', 'dead', 'infected']``
        """
        self.update_symptoms_and_transmission(time + delta_time)
        if self.symptoms.recovered:
            status = "recovered"
        elif self.symptoms.dead:
            status = "dead"
        else:
            status = "infected"
        return status

    def update_symptoms_and_transmission(self, time: float):
        """
        Updates the infection's symptoms and transmission probability.
        Parameters
        ----------
        time:
            time elapsed (in days) from time of infection
        """
        time_from_infection = time - self.start_time
        self.transmission.update_infection_probability(
            time_from_infection=time_from_infection
        )
        self.symptoms.update_trajectory_stage(time_from_infection=time_from_infection)

    def length_of_infection(self, time):
        return time - self.time_of_infection

    @property
    def tag(self):
        return self.symptoms.tag

    @property
    def max_tag(self):
        return self.symptoms.max_tag

    @property
    def time_of_infection(self):
        return self.start_time

    @property
    def should_be_in_hospital(self) -> bool:
        return self.tag in (SymptomTag.hospitalised, SymptomTag.intensive_care)

    @property
    def infected_at_home(self) -> bool:
        return self.infected and not (self.dead or self.should_be_in_hospital)

    @property
    def dead(self) -> bool:
        return self.symptoms.dead

    @property
    def time_of_symptoms_onset(self):
        return self.symptoms.time_of_symptoms_onset

    @property
    def infection_probability(self):
        return self.transmission.probability


class Covid19(Infection):
    @classmethod
    def immunity_ids(cls):
        return (cls.infection_id(), B117.infection_id())


class B117(Infection):
    @classmethod
    def immunity_ids(cls):
        return (cls.infection_id(), Covid19.infection_id())


class B16172(Infection):
    @classmethod
    def immunity_ids(cls):
        return (
            cls.infection_id(),
            Covid19.infection_id(),
            B117.infection_id(),
            Omicron.infection_id(),
        )


class Delta(Infection):
    @classmethod
    def immunity_ids(cls):
        return (
            cls.infection_id(),
            Covid19.infection_id(),
            B117.infection_id(),
            Omicron.infection_id(),
        )


class Omicron(Infection):
    @classmethod
    def immunity_ids(cls):
        return (
            cls.infection_id(),
            Covid19.infection_id(),
            B117.infection_id(),
            Delta.infection_id(),
        )


import yaml

from june import paths
from .health_index.health_index import HealthIndexGenerator
from . import Infection, Covid19
from .symptoms import Symptoms
from .trajectory_maker import TrajectoryMakers
from .transmission import TransmissionConstant, TransmissionGamma
from .transmission_xnexp import TransmissionXNExp
from .trajectory_maker import CompletionTime

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.demography import Person
    from .transmission import Transmission
    from june.epidemiology.infection.symptom_tag import SymptomTag

default_transmission_config_path = (
    paths.configs_path / "defaults/epidemiology/infection/transmission/covid19.yaml"
)
default_trajectories_config_path = (
    paths.configs_path / "defaults/epidemiology/infection/symptoms/trajectories.yaml"
)
default_rates_file = paths.data_path / "input/health_index/infection_outcome_rates.csv"


class InfectionSelector:
    def __init__(
        self,
        transmission_config_path: str = default_transmission_config_path,
        infection_class: Infection = Covid19,
        trajectory_maker=TrajectoryMakers.from_file(default_trajectories_config_path),
        health_index_generator: HealthIndexGenerator = None,
    ):
        """
        Selects the type of infection a person is given

        Parameters
        ----------
        transmission_config_path:
            path to transmission config file
        """
        self.infection_class = infection_class
        self.transmission_config_path = transmission_config_path
        self.trajectory_maker = trajectory_maker
        self.health_index_generator = health_index_generator
        self._load_transmission()

    @classmethod
    def from_file(
        cls,
        infection_class: Infection = Covid19,
        transmission_config_path: str = default_transmission_config_path,
        trajectories_config_path: str = default_trajectories_config_path,
        rates_file: str = default_rates_file,
    ) -> "InfectionSelector":
        """
        Generate infection selector from default config file

        Parameters
        ----------
        transmission_config_path:
            path to transmission config file
        trajectories_config_path:
            path to trajectories config file
        health_index_generator:
            health index generator
        """
        health_index_generator = HealthIndexGenerator.from_file(rates_file=rates_file)
        trajectory_maker = TrajectoryMakers.from_file(trajectories_config_path)
        return InfectionSelector(
            infection_class=infection_class,
            transmission_config_path=transmission_config_path,
            trajectory_maker=trajectory_maker,
            health_index_generator=health_index_generator,
        )

    @property
    def infection_id(self):
        return self.infection_class.infection_id()

    def infect_person_at_time(self, person: "Person", time: float):
        """
        Infects a person at a given time.

        Parameters
        ----------
        person:
            person that will be infected
        time:
            time at which infection happens
        """
        person.infection = self._make_infection(person, time)
        person.immunity.add_immunity(person.infection.immunity_ids())

    def _make_infection(self, person: "Person", time: float):
        """
        Generates the symptoms and infectiousness of the person being infected

        Parameters
        ----------
        person:
            person that will be infected
        time:
            time at which infection happens
        """
        symptoms = self._select_symptoms(person)
        time_to_symptoms_onset = symptoms.time_exposed
        transmission = self._select_transmission(
            time_to_symptoms_onset=time_to_symptoms_onset,
            max_symptoms_tag=symptoms.max_tag.name,
        )
        return self.infection_class(
            transmission=transmission, symptoms=symptoms, start_time=time
        )

    def _load_transmission(self):
        """
        Load transmission config file, and store objects that will generate random realisations
        """
        with open(self.transmission_config_path) as f:
            transmission_config = yaml.safe_load(f)
        self.transmission_type = transmission_config["type"]
        if self.transmission_type == "xnexp":
            self._load_transmission_xnexp(transmission_config)
        elif self.transmission_type == "gamma":
            self._load_transmission_gamma(transmission_config)
        elif self.transmission_type == "constant":
            self._load_transmission_constant(transmission_config)
        else:
            raise NotImplementedError("This transmission type has not been implemented")

    def _load_transmission_xnexp(self, transmission_config: dict):
        """
        Given transmission config dictionary, load parameter generators from which
        transmission xnexp parameters will be sampled

        Parameters
        ----------
        transmission_config:
            dictionary of transmission config parameters
        """
        self.smearing_time_first_infectious = CompletionTime.from_dict(
            transmission_config["smearing_time_first_infectious"]
        )
        self.smearing_peak_position = CompletionTime.from_dict(
            transmission_config["smearing_peak_position"]
        )
        self.alpha = CompletionTime.from_dict(transmission_config["alpha"])
        self.max_probability = CompletionTime.from_dict(
            transmission_config["max_probability"]
        )
        self.norm_time = CompletionTime.from_dict(transmission_config["norm_time"])
        self.asymptomatic_infectious_factor = CompletionTime.from_dict(
            transmission_config["asymptomatic_infectious_factor"]
        )
        self.mild_infectious_factor = CompletionTime.from_dict(
            transmission_config["mild_infectious_factor"]
        )

    def _load_transmission_gamma(self, transmission_config: dict):
        """
        Given transmission config dictionary, load parameter generators from which
        transmission gamma parameters will be sampled

        Parameters
        ----------
        transmission_config:
            dictionary of transmission config parameters
        """
        self.max_infectiousness = CompletionTime.from_dict(
            transmission_config["max_infectiousness"]
        )
        self.shape = CompletionTime.from_dict(transmission_config["shape"])
        self.rate = CompletionTime.from_dict(transmission_config["rate"])
        self.shift = CompletionTime.from_dict(transmission_config["shift"])
        self.asymptomatic_infectious_factor = CompletionTime.from_dict(
            transmission_config["asymptomatic_infectious_factor"]
        )
        self.mild_infectious_factor = CompletionTime.from_dict(
            transmission_config["mild_infectious_factor"]
        )

    def _load_transmission_constant(self, transmission_config: dict):
        """
        Given transmission config dictionary, load parameter generators from which
        transmission constant parameters will be sampled

        Parameters
        ----------
        transmission_config:
            dictionary of transmission config parameters
        """
        self.probability = CompletionTime.from_dict(transmission_config["probability"])

    def _select_transmission(
        self, time_to_symptoms_onset: float, max_symptoms_tag: "SymptomTag"
    ) -> "Transmission":
        """
        Selects the transmission type specified by the user in the init,
        and links its parameters to the symptom onset for the person (incubation
        period)

        Parameters
        ----------
        person:
            person that will be infected
        time_to_symptoms_onset:
            time of symptoms onset for person
        """
        if self.transmission_type == "xnexp":
            time_first_infectious = (
                self.smearing_time_first_infectious() + time_to_symptoms_onset
            )
            peak_position = (
                time_to_symptoms_onset
                - time_first_infectious
                + self.smearing_peak_position()
            )
            return TransmissionXNExp(
                max_probability=self.max_probability(),
                time_first_infectious=time_first_infectious,
                norm_time=self.norm_time(),
                n=peak_position / self.alpha(),
                alpha=self.alpha(),
                max_symptoms=max_symptoms_tag,
                asymptomatic_infectious_factor=self.asymptomatic_infectious_factor(),
                mild_infectious_factor=self.mild_infectious_factor(),
            )
        elif self.transmission_type == "gamma":
            return TransmissionGamma(
                max_infectiousness=self.max_infectiousness(),
                shape=self.shape(),
                rate=self.rate(),
                shift=self.shift() + time_to_symptoms_onset,
                max_symptoms=max_symptoms_tag,
                asymptomatic_infectious_factor=self.asymptomatic_infectious_factor(),
                mild_infectious_factor=self.mild_infectious_factor(),
            )
        elif self.transmission_type == "constant":
            return TransmissionConstant(probability=self.probability())
        else:
            raise NotImplementedError("This transmission type has not been implemented")

    def _select_symptoms(self, person: "Person") -> "Symptoms":
        """
        Select the symptoms that a given person has, and how they will evolve
        in the future

        Parameters
        ----------
        person:
            person that will be infected
        infection_id:
            infection id
        """
        health_index = self.health_index_generator(
            person, infection_id=self.infection_id
        )
        return Symptoms(health_index=health_index)


class InfectionSelectors:
    def __init__(self, infection_selectors: list = None):
        self._infection_selectors = infection_selectors
        self.infection_id_to_selector = self.make_dict()

    def make_dict(self):
        """
        Makes two dicts:
        infection_type_id -> infection_class (needed for easier MPI comms)
        infection_class -> infection_selector (needed to map infection to
                            the class that creates infections)
        """
        if not self._infection_selectors:
            return {Covid19.infection_id(): InfectionSelector.from_file()}
        ret = {}
        for i, selector in enumerate(self._infection_selectors):
            ret[selector.infection_class.infection_id()] = selector
        return ret

    def infect_person_at_time(
        self, person: "Person", time: float, infection_id: int = Covid19.infection_id()
    ):
        """
        Infects a person at a given time with the given infection_class.

        Parameters
        ----------
        infection_class:
            type of infection to create
        person:
            person that will be infected
        time:
            time at which infection happens
        """
        selector = self.infection_id_to_selector[infection_id]
        selector.infect_person_at_time(person=person, time=time)

    def __iter__(self):
        return iter(self._infection_selectors)

    def __getitem__(self, item):
        return self._infection_selectors[item]


from random import random

import numpy as np
from .symptom_tag import SymptomTag
from .trajectory_maker import TrajectoryMakers

dead_tags = (SymptomTag.dead_home, SymptomTag.dead_hospital, SymptomTag.dead_icu)


class Symptoms:
    __slots__ = (
        "tag",
        "max_tag",
        "max_severity",
        "trajectory",
        "stage",
        "time_of_symptoms_onset",
    )
    """
    Class to represent the symptoms of a person. The symptoms class composes the
    ``Infection`` class alongside with the ``Transmission`` class. Once infected,
    a person is assigned a symptoms trajectory according to a health index generated
    by the ``HealthIndexGenerator``. A trajectory is a collection of symptom tags with
    characteristic timings.
    """

    def __init__(self, health_index=None):
        self.max_tag = None
        self.tag = SymptomTag.exposed
        self.max_severity = random()
        self.trajectory = self._make_symptom_trajectory(
            health_index
        )  # this also sets max_tag
        self.stage = 0
        self.time_of_symptoms_onset = self._compute_time_from_infection_to_symptoms()

    def _compute_time_from_infection_to_symptoms(self):
        symptoms_onset = 0
        for completion_time, tag in self.trajectory:
            symptoms_onset += completion_time
            if tag == SymptomTag.mild:
                break
            elif tag == SymptomTag.asymptomatic:
                return None
        return symptoms_onset

    def _make_symptom_trajectory(self, health_index):
        if health_index is None:
            return [(0, SymptomTag(0))]
        trajectory_maker = TrajectoryMakers.from_file()
        index_max_symptoms_tag = np.searchsorted(health_index, self.max_severity)
        self.max_tag = SymptomTag(index_max_symptoms_tag)
        return trajectory_maker[self.max_tag]

    def update_trajectory_stage(self, time_from_infection):
        """
        Updates the current symptom tag from the symptoms trajectory,
        given how much time has passed since the person was infected.

        Parameters
        ----------
        time_from_infection: float
            Time in days since the person got infected.
        """
        if time_from_infection > self.trajectory[self.stage + 1][0]:
            self.stage += 1
            self.tag = self.trajectory[self.stage][1]

    @property
    def time_exposed(self):
        return self.trajectory[1][0]

    @property
    def recovered(self):
        return self.tag == SymptomTag.recovered

    @property
    def dead(self):
        return self.tag in dead_tags


from enum import IntEnum


class SymptomTag(IntEnum):
    """
    A tag for the symptoms exhibited by a person.

    Higher numbers are more severe.
    0 - 5 correspond to indices in the health index array.
    """

    recovered = -3
    healthy = -2
    exposed = -1
    asymptomatic = 0
    mild = 1
    severe = 2
    hospitalised = 3
    intensive_care = 4
    dead_home = 5
    dead_hospital = 6
    dead_icu = 7

    @classmethod
    def from_string(cls, string: str) -> "SymptomTag":
        for item in SymptomTag:
            if item.name == string:
                return item
        raise AssertionError(f"{string} is not the name of a SymptomTag")


from abc import ABC, abstractmethod
from typing import List, Tuple

import yaml
from scipy.stats import beta, lognorm, norm, expon, exponweib

from june import paths
from .symptom_tag import SymptomTag

default_config_path = (
    paths.configs_path / "defaults/epidemiology/infection/symptoms/trajectories.yaml"
)


class CompletionTime(ABC):
    @abstractmethod
    def __call__(self) -> float:
        """
        Compute the time a given stage should take to complete
        """

    @staticmethod
    def class_for_type(type_string: str) -> type:
        """
        Get a CompletionTime class from a string in configuration

        Parameters
        ----------
        type_string
            The type of CompletionTime
            e.g. constant/exponential/beta

        Returns
        -------
        The corresponding class

        Raises
        ------
        AssertionError
            If the type string is not recognised
        """
        if type_string == "constant":
            return ConstantCompletionTime
        elif type_string == "exponential":
            return ExponentialCompletionTime
        elif type_string == "beta":
            return BetaCompletionTime
        elif type_string == "lognormal":
            return LognormalCompletionTime
        elif type_string == "normal":
            return NormalCompletionTime
        elif type_string == "exponweib":
            return ExponweibCompletionTime
        raise AssertionError(f"Unrecognised variation type {type_string}")

    @classmethod
    def from_dict(cls, variation_type_dict):
        type_string = variation_type_dict.pop("type")
        return CompletionTime.class_for_type(type_string)(**variation_type_dict)


class ConstantCompletionTime(CompletionTime):
    def __init__(self, value: float):
        self.value = value

    def __call__(self):
        return self.value


class DistributionCompletionTime(CompletionTime, ABC):
    def __init__(self, distribution, *args, **kwargs):
        self._distribution = distribution
        self.args = args
        self.kwargs = kwargs

    def __call__(self):
        # Note that we are using:
        #     self.distribution.rvs(*args, **kwargs)
        # rather than:
        #     self.distribution(*args, **kwargs).rvs()
        # or:
        #     self.distribution(*some_args, **some_kwargs).rvs(
        #         *remaining_args, **remaining_kwargs)
        # because the second and third cases are "frozen" distributions,
        # and frequent freezing of dists can become very time consuming.
        # See for example: https://github.com/scipy/scipy/issues/9394.
        return self._distribution.rvs(*self.args, **self.kwargs)

    @property
    def distribution(self):
        return self._distribution(*self.args, **self.kwargs)


class ExponentialCompletionTime(DistributionCompletionTime):
    def __init__(self, loc: float, scale):
        super().__init__(expon, loc=loc, scale=scale)
        # self.loc = loc
        # self.scale = scale


class BetaCompletionTime(DistributionCompletionTime):
    def __init__(self, a, b, loc=0.0, scale=1.0):
        super().__init__(beta, a, b, loc=loc, scale=scale)
        # self.a = a
        # self.b = b
        # self.loc = loc
        # self.scale = scale


class LognormalCompletionTime(DistributionCompletionTime):
    def __init__(self, s, loc=0.0, scale=1.0):
        super().__init__(lognorm, s, loc=loc, scale=scale)
        # self.s = s
        # self.loc = loc
        # self.scale = scale


class NormalCompletionTime(DistributionCompletionTime):
    def __init__(self, loc, scale):
        super().__init__(norm, loc=loc, scale=scale)
        # self.loc = loc
        # self.scale = scale


class ExponweibCompletionTime(DistributionCompletionTime):
    def __init__(self, a, c, loc=0.0, scale=1.0):
        super().__init__(exponweib, a, c, loc=loc, scale=scale)
        # self.a = a
        # self.c = c
        # self.loc = loc
        # self.scale = scale


class Stage:
    def __init__(
        self,
        *,
        symptoms_tag: SymptomTag,
        completion_time: CompletionTime = ConstantCompletionTime,
    ):
        """
        A stage on an illness,

        Parameters
        ----------
        symptoms_tag
            What symptoms does the person have at this stage?
        completion_time
            Function that returns value for how long this stage takes
            to complete.
        """
        self.symptoms_tag = symptoms_tag
        self.completion_time = completion_time

    @classmethod
    def from_dict(cls, stage_dict):
        completion_time = CompletionTime.from_dict(stage_dict["completion_time"])
        symptom_tag = SymptomTag.from_string(stage_dict["symptom_tag"])
        return Stage(symptoms_tag=symptom_tag, completion_time=completion_time)


class TrajectoryMaker:
    def __init__(self, *stages):
        """
        Generate trajectories of a particular kind.

        This defines how a given person moves through a series of symptoms.

        Parameters
        ----------
        stages
            A list of stages through which the person progresses
        """
        self.stages = stages

    @property
    def _symptoms_tags(self):
        return [stage.symptoms_tag for stage in self.stages]

    @property
    def most_severe_symptoms(self) -> SymptomTag:
        """
        The most severe symptoms experienced at any stage in this trajectory
        """
        return max(self._symptoms_tags)

    def generate_trajectory(self) -> List[Tuple[float, SymptomTag]]:
        """
        Generate a trajectory for a person. This is a list of tuples
        describing what symptoms the person should display at a given
        time.
        """
        trajectory = []
        cumulative = 0.0
        for stage in self.stages:
            time = stage.completion_time()
            trajectory.append((cumulative, stage.symptoms_tag))
            cumulative += time
        return trajectory

    @classmethod
    def from_dict(cls, trajectory_dict):
        return TrajectoryMaker(*map(Stage.from_dict, trajectory_dict["stages"]))


class TrajectoryMakers:
    """
    The various trajectories should depend on external data, and may depend on age &
    gender of the patient.  This would lead to a table of tons of trajectories, with
    lots of mean values/deviations and an instruction on how to vary them.
    For this first simple implementation I will choose everything to be fixed (constant)

    The trajectories will count "backwards" with zero time being the moment of
    infection.
    """

    __instance = None
    __path = None

    def __init__(self, trajectories: List[TrajectoryMaker]):
        """
        Trajectories and their stages should be parsed from configuration. I've
        removed params for now as they weren't being used but it will be trivial
        to reintroduce them when we are ready for configurable trajectories.
        """
        self.trajectories = {
            trajectory.most_severe_symptoms: trajectory for trajectory in trajectories
        }

    @classmethod
    def from_file(cls, config_path: str = default_config_path) -> "TrajectoryMakers":
        """
        Currently this doesn't do what it says it does.

        By setting an instance on the class we can make the trajectory maker
        something like a singleton. However, if it were being loaded from
        configurations we'd need to be careful as this could give unexpected
        effects.
        """
        if cls.__instance is None or cls.__path != config_path:
            with open(config_path) as f:
                cls.__instance = TrajectoryMakers.from_list(
                    yaml.safe_load(f)["trajectories"]
                )
                cls.__path = config_path
        return cls.__instance

    def __getitem__(self, tag: SymptomTag) -> List[Tuple[float, SymptomTag]]:
        """
        Generate a trajectory from a tag.

        It might be better to have this return the Trajectory class
        rather than generating the trajectory itself. I feel the getitem
        syntax disguises the fact that something new is being created.

        I've removed the person (patient) argument because it was not
        being used. It can be passed to the generate_trajectory class.

        Parameters
        ----------
        tag
            A tag describing the symptoms being experienced by a
            patient.

        Returns
        -------
        A list describing the symptoms experienced by the patient
        at given times.
        """
        return self.trajectories[tag].generate_trajectory()

    @classmethod
    def from_list(cls, trajectory_dicts):
        return TrajectoryMakers(
            trajectories=list(map(TrajectoryMaker.from_dict, trajectory_dicts))
        )


import yaml
import numpy as np
import numba as nb
from typing import Optional
from math import gamma

from .trajectory_maker import CompletionTime
from june import paths

default_config_path = (
    paths.configs_path
    / "defaults/epidemiology/infection/transmission/TransmissionConstant.yaml"
)
default_gamma_config_path = (
    paths.configs_path / "defaults/epidemiology/infection/transmission/nature.yaml"
)


class Transmission:
    __slots__ = "probability"

    def __init__(self):
        self.probability = 0.0

    def update_infection_probability(self, time_from_infection):
        raise NotImplementedError()


class TransmissionConstant(Transmission):
    def __init__(self, probability=0.3):
        super().__init__()
        self.probability = probability

    @classmethod
    def from_file(
        cls, config_path: str = default_config_path
    ) -> "TransmissionConstant":
        with open(config_path) as f:
            config = yaml.safe_load(f)
        probability = CompletionTime.from_dict(config["probability"])()
        return TransmissionConstant(probability=probability)

    def update_infection_probability(self, time_from_infection):
        pass


@nb.jit(nopython=True)
def gamma_pdf(x: float, a: float, loc: float, scale: float) -> float:
    """
    Implementation of gamma PDF in numba

    Parameters
    ----------
    x:
        x variable
    a:
        shape factor
    loc:
        denominator in exponential
    scale:


    Returns
    -------
        evaluation fo gamma pdf
    """
    if x < loc:
        return 0.0
    return (
        1.0
        / gamma(a)
        * ((x - loc) / scale) ** (a - 1)
        * np.exp(-(x - loc) / scale)
        / scale
    )


@nb.jit(nopython=True)
def gamma_pdf_vectorized(x: float, a: float, loc: float, scale: float) -> float:
    """
    Implementation of gamma PDF in numba

    Parameters
    ----------
    x:
        x variable
    a:
        shape factor
    loc:
        denominator in exponential
    scale:


    Returns
    -------
        evaluation fo gamma pdf
    """
    return np.where(
        x < loc,
        0.0,
        1.0
        / gamma(a)
        * ((x - loc) / scale) ** (a - 1)
        * np.exp(-(x - loc) / scale)
        / scale,
    )


class TransmissionGamma(Transmission):
    """
    Module to simulate the infectiousness profiles found in :
        - https://www.nature.com/articles/s41591-020-0869-5
        - https://arxiv.org/pdf/2007.06602.pdf
    """

    __slots__ = ("shape", "shift", "scale", "norm", "probability")

    def __init__(
        self,
        max_infectiousness: float = 1.0,
        shape: float = 2.0,
        rate: float = 3.0,
        shift: float = -2.0,
        max_symptoms: Optional[str] = None,
        asymptomatic_infectious_factor: Optional[float] = None,
        mild_infectious_factor: Optional[float] = None,
    ):
        """
        Parameters
        ----------
        max_infectiousness:
            value of the infectiousness at its peak
        shape:
            shape parameter of the gamma distribution (a for scipy stats)
        rate:
            rate parameter of the gamma distribution (1/rate = scale for scipy stats)
        shift:
            location parameter of the gamma distribution
        max_symptoms:
            maximum symptoms the individual will develop, used to reduce the infectiousness
            of asymptomatic and mild individuals if wanted
        asymptomatic_infectious_factor:
            factor to reduce the infectiousness of asymptomatic individuals
        mild_infectious_factor:
            factor to reduce the infectiousness of mild individuals
        """
        self.shape = shape
        self.shift = shift
        self.scale = 1.0 / rate
        self.norm = max_infectiousness
        if (
            asymptomatic_infectious_factor is not None
            and mild_infectious_factor is not None
        ):
            self.norm *= self._modify_infectiousness_for_symptoms(
                max_symptoms=max_symptoms,
                asymptomatic_infectious_factor=asymptomatic_infectious_factor,
                mild_infectious_factor=mild_infectious_factor,
            )
        self.probability = 0.0

    @classmethod
    def from_file(
        cls, max_symptoms: str = None, config_path: str = default_gamma_config_path
    ) -> "TransmissionGamma":
        """
        Generate transmission class reading parameters from config file

        Parameters
        ----------
        max_symptoms:
            maximum symptoms the individual will develop, used to reduce the infectiousness
            of asymptomatic and mild individuals if wanted
        config_path:
            path to config parameters

        Returns
        -------
            TransmissionGamma instance

        """
        with open(config_path) as f:
            config = yaml.safe_load(f)
        max_infectiousness = CompletionTime.from_dict(config["max_infectiousness"])()
        shape = CompletionTime.from_dict(config["shape"])()
        rate = CompletionTime.from_dict(config["rate"])()
        shift = CompletionTime.from_dict(config["shift"])()
        asymptomatic_infectious_factor = CompletionTime.from_dict(
            config["asymptomatic_infectious_factor"]
        )()
        mild_infectious_factor = CompletionTime.from_dict(
            config["mild_infectious_factor"]
        )()

        return cls(
            max_infectiousness=max_infectiousness,
            shape=shape,
            rate=rate,
            shift=shift,
            max_symptoms=max_symptoms,
            asymptomatic_infectious_factor=asymptomatic_infectious_factor,
            mild_infectious_factor=mild_infectious_factor,
        )

    @classmethod
    def from_file_linked_symptoms(
        cls,
        time_to_symptoms_onset: float,
        max_symptoms: str = None,
        config_path: str = default_gamma_config_path,
    ) -> "TransmissionGamma":
        """
        Generate transmission class reading parameters from config file, linked to
        the time of symptoms onset

        Parameters
        ----------
        time_to_symptoms_onset:
            time (from infection) at which the person becomes symptomatic
        max_symptoms:
            maximum symptoms the individual will develop, used to reduce the infectiousness
            of asymptomatic and mild individuals if wanted
        config_path:
            path to config parameters

        Returns
        -------
            TransmissionGamma instance

        """

        with open(config_path) as f:
            config = yaml.safe_load(f)
        max_infectiousness = CompletionTime.from_dict(config["max_infectiousness"])()
        shape = CompletionTime.from_dict(config["shape"])()
        rate = CompletionTime.from_dict(config["rate"])()
        shift = CompletionTime.from_dict(config["shift"])() + time_to_symptoms_onset
        asymptomatic_infectious_factor = CompletionTime.from_dict(
            config["asymptomatic_infectious_factor"]
        )()
        mild_infectious_factor = CompletionTime.from_dict(
            config["mild_infectious_factor"]
        )()

        return cls(
            max_infectiousness=max_infectiousness,
            shape=shape,
            rate=rate,
            shift=shift,
            max_symptoms=max_symptoms,
            asymptomatic_infectious_factor=asymptomatic_infectious_factor,
            mild_infectious_factor=mild_infectious_factor,
        )

    def update_infection_probability(self, time_from_infection: float):
        """
        Performs a probability update given time from infection

        Parameters
        ----------
        time_from_infection:
            time elapsed since person became infected
        """
        self.probability = self.norm * gamma_pdf(
            x=time_from_infection, a=self.shape, loc=self.shift, scale=self.scale
        )

    @property
    def time_at_maximum_infectivity(self) -> float:
        """
        Computes the time at which the individual is maximally infectious (in this case for
        a gamma distribution

        Returns
        -------
        t_max:
            time at maximal infectiousness
        """
        return (self.shape - 1) * self.scale + self.shift

    def _modify_infectiousness_for_symptoms(
        self,
        max_symptoms: str,
        asymptomatic_infectious_factor=None,
        mild_infectious_factor=None,
    ):
        """
        Lowers the infectiousness of asymptomatic and mild cases, by modifying
        the norm of the distribution

        Parameters
        ----------
        max_symptoms:
            maximum symptom severity the person will ever have
        asymptomatic_infectious_factor:
            factor to reduce the infectiousness of asymptomatic individuals
        mild_infectious_factor:
            factor to reduce the infectiousness of mild individuals
        """
        if (
            asymptomatic_infectious_factor is not None
            and max_symptoms == "asymptomatic"
        ):
            return asymptomatic_infectious_factor
        elif mild_infectious_factor is not None and max_symptoms == "mild":
            return mild_infectious_factor
        return 1.0


from .transmission import Transmission
from .trajectory_maker import CompletionTime
from .symptom_tag import SymptomTag
from june import paths
import yaml
import numpy as np
import numba as nb

default_config_path = (
    paths.configs_path / "defaults/epidemiology/infection/transmission/XNExp.yaml"
)


@nb.jit(nopython=True)
def xnexp(x: float, n: float, alpha: float) -> float:
    """
    Implementation of x^n exp(-x/alpha)

    Parameters
    ----------
    x:
        x variable
    n:
        exponent of x
    alpha:
        denominator in exponential

    Returns
    -------
        evaluation fo xnexp function
    """
    return x**n * np.exp(-x / alpha)


@nb.jit(nopython=True)
def update_probability(
    time_from_infection: float,
    time_first_infectious: float,
    norm: float,
    norm_time: float,
    alpha: float,
    n: float,
) -> float:
    """
    Determines how the infectiousness profile is updated over time

    Parameters
    ----------
    time_from_infection:
        time from infection
    time_first_infectious:
        time from infection at which the person becomes infectious
    norm:
        multiplier to the infectiousness profile
    norm_time:
        controls the definition of tau
    alpha:
        demominator in exponential for xnexp function
    n:
        exponent of x in xnexp

    Returns
    -------
        Value of infectiousness at time
    """

    if time_from_infection > time_first_infectious:
        delta_tau = (time_from_infection - time_first_infectious) / norm_time
        return norm * xnexp(x=delta_tau, n=n, alpha=alpha)
    else:
        return 0.0


class TransmissionXNExp(Transmission):
    __slots__ = (
        "time_first_infectious",
        "norm_time",
        "n",
        "alpha",
        "norm",
        "probability",
    )

    def __init__(
        self,
        max_probability: float = 1.0,
        time_first_infectious: float = 2.6,
        norm_time: float = 1.0,
        n: float = 1.0,
        alpha: float = 5.0,
        max_symptoms: str = None,
        asymptomatic_infectious_factor: float = None,
        mild_infectious_factor: float = None,
    ):
        """
        Class that defines the time profile of the infectiousness to be of the form x^n exp(-x/alpha)

        Parameters
        ----------
        max_probability:
            value of the infectiousness at its peak. Used to control the number of super spreaders
        time_first_infectious:
            time at which the person becomes infectious
        norm_time:
            controls the definition of x, x = (time_from_infection - time-first_infectious)/norm_time
        n:
            exponent of x in the x^n exp(-x/alpha) function
        alpha:
            denominator in exponential
        max_symptoms:
            maximum symptoms that the person will ever have, used to lower the infectiousness of
            asymptomatic and mild cases
        asymptomatic_infectious_factor:
            multiplier that lowers the infectiousness of asymptomatic cases
        mild_infectious_factor:
            multiplier that lowers the infectiousness of mild cases

        """
        self.time_first_infectious = time_first_infectious
        self.norm_time = norm_time
        self.n = n
        self.alpha = alpha
        max_delta_time = self.n * self.alpha * self.norm_time
        max_tau = max_delta_time / self.norm_time
        self.norm = max_probability / xnexp(max_tau, self.n, self.alpha)
        self._modify_infectiousness_for_symptoms(
            max_symptoms=max_symptoms,
            asymptomatic_infectious_factor=asymptomatic_infectious_factor,
            mild_infectious_factor=mild_infectious_factor,
        )
        self.probability = 0.0

    @classmethod
    def from_file(
        cls,
        time_first_infectious: float,
        n: float,
        alpha: float,
        max_symptoms: "SymptomTag" = None,
        config_path: str = default_config_path,
    ) -> "TransmissionXNExp":
        """
        Generates transmission class from config file

        Parameters
        ----------
        time_first_infectious:
            time at which the person becomes infectious
        n:
            exponent of x in the x^n exp(-x/alpha) function
        alpha:
            denominator in exponential
        max_symptoms:
            maximum symptoms that the person will ever have, used to lower the infectiousness of
            asymptomatic and mild cases


        Returns
        -------
            class instance
        """
        with open(config_path) as f:
            config = yaml.safe_load(f)
        max_probability = CompletionTime.from_dict(config["max_probability"])()
        norm_time = CompletionTime.from_dict(config["norm_time"])()
        asymptomatic_infectious_factor = CompletionTime.from_dict(
            config["asymptomatic_infectious_factor"]
        )()
        mild_infectious_factor = CompletionTime.from_dict(
            config["mild_infectious_factor"]
        )()
        return TransmissionXNExp(
            max_probability=max_probability,
            time_first_infectious=time_first_infectious,
            norm_time=norm_time,
            n=n,
            alpha=alpha,
            max_symptoms=max_symptoms,
            asymptomatic_infectious_factor=asymptomatic_infectious_factor,
            mild_infectious_factor=mild_infectious_factor,
        )

    @classmethod
    def from_file_linked_symptoms(
        cls,
        time_to_symptoms_onset: float,
        max_symptoms: "SymptomTag" = None,
        config_path: str = default_config_path,
    ) -> "TransmissionXNExp":
        """
        Generates transmission class from config file

        Parameters
        ----------
        time_first_infectious:
            time at which the person becomes infectious
        n:
            exponent of x in the x^n exp(-x/alpha) function
        alpha:
            denominator in exponential
        max_symptoms:
            maximum symptoms that the person will ever have, used to lower the infectiousness of
            asymptomatic and mild cases


        Returns
        -------
            class instance
        """
        with open(config_path) as f:
            config = yaml.safe_load(f)
        smearing_time_first_infectious = CompletionTime.from_dict(
            config["smearing_time_first_infectious"]
        )()
        time_first_infectious = time_to_symptoms_onset + smearing_time_first_infectious
        smearing_peak_position = CompletionTime.from_dict(
            config["smearing_peak_position"]
        )()
        alpha = CompletionTime.from_dict(config["alpha"])()
        peak_position = (
            time_to_symptoms_onset - time_first_infectious + smearing_peak_position
        )
        n = peak_position / alpha
        max_probability = CompletionTime.from_dict(config["max_probability"])()
        norm_time = CompletionTime.from_dict(config["norm_time"])()
        asymptomatic_infectious_factor = CompletionTime.from_dict(
            config["asymptomatic_infectious_factor"]
        )()
        mild_infectious_factor = CompletionTime.from_dict(
            config["mild_infectious_factor"]
        )()
        return TransmissionXNExp(
            max_probability=max_probability,
            time_first_infectious=time_first_infectious,
            norm_time=norm_time,
            n=n,
            alpha=alpha,
            max_symptoms=max_symptoms,
            asymptomatic_infectious_factor=asymptomatic_infectious_factor,
            mild_infectious_factor=mild_infectious_factor,
        )

    def update_infection_probability(self, time_from_infection: float):
        """
        Performs a probability update given time from infection

        Parameters
        ----------
        time_from_infection:
            time elapsed since person became infected (in days).
        """
        self.probability = update_probability(
            time_from_infection,
            self.time_first_infectious,
            self.norm,
            self.norm_time,
            self.alpha,
            self.n,
        )

    def _modify_infectiousness_for_symptoms(
        self, max_symptoms: str, asymptomatic_infectious_factor, mild_infectious_factor
    ):
        """
        Lowers the infectiousness of asymptomatic and mild cases, by modifying
        self.norm

        Parameters
        ----------
        max_symptoms:
            maximum symptom severity the person will ever have

        """
        if (
            asymptomatic_infectious_factor is not None
            and max_symptoms == "asymptomatic"
        ):
            self.norm *= asymptomatic_infectious_factor
        elif mild_infectious_factor is not None and max_symptoms == "mild":
            self.norm *= mild_infectious_factor


from .infection import Infection, Covid19, B117, B16172
from .immunity import Immunity
from .infection_selector import InfectionSelector, InfectionSelectors
from .trajectory_maker import TrajectoryMakers
from .health_index.health_index import HealthIndexGenerator
from .health_index.data_to_rates import Data2Rates
from .symptom_tag import SymptomTag
from .symptoms import Symptoms
from .transmission import Transmission, TransmissionConstant, TransmissionGamma
from .transmission_xnexp import TransmissionXNExp
from .immunity_setter import ImmunitySetter


import logging
import pandas as pd
import numpy as np
from typing import List, Union, Optional
from june import paths
import yaml


# ch = care home, gp = general population (so everyone not in a care home)

hi_data = paths.data_path / "input/health_index"
default_seroprevalence_file = hi_data / "seroprevalence_by_age.csv"
default_care_home_seroprevalence_file = hi_data / "care_home_seroprevalence_by_age.csv"

default_population_file = hi_data / "population_by_age_sex_2020_england.csv"
default_care_home_population_file = hi_data / "care_home_residents_by_age_sex_june.csv"

default_all_deaths_file = hi_data / "all_deaths_by_age_sex.csv"
default_care_home_deaths_file = hi_data / "care_home_deaths_by_age_sex.csv"
default_all_hospital_deaths_file = hi_data / "hospital_deaths_by_age_sex.csv"
default_all_hospital_admissions_file = hi_data / "hospital_admissions_by_age_sex.csv"
default_gp_admissions_file = hi_data / "cocin_gp_hospital_admissions_by_age_sex.csv"
default_ch_admissions_file = hi_data / "chess_ch_hospital_admissions_by_age_sex.csv"
default_gp_hospital_deaths_file = hi_data / "cocin_gp_hospital_deaths_by_age_sex.csv"
default_ch_hospital_deaths_file = hi_data / "chess_ch_hospital_deaths_by_age_sex.csv"
default_icu_hosp_rate_file = hi_data / "icu_hosp_rate.csv"
default_deathsicu_deathshosp_rate_file = hi_data / "dicu_dhosp_rate.csv"
default_asymptomatic_rate_file = hi_data / "asymptomatic_rates_by_age_sex.csv"
default_mild_rate_file = hi_data / "mild_rates_by_age_sex.csv"

logger = logging.getLogger("rates")


def convert_to_intervals(ages: List[str], is_interval=False) -> pd.IntervalIndex:
    idx = []
    for age in ages:
        if is_interval:
            age = age.strip("[]").split(",")
            idx.append(pd.Interval(left=int(age[0]), right=int(age[1]), closed="both"))
        else:
            idx.append(
                pd.Interval(
                    left=int(age.split("-")[0]),
                    right=int(age.split("-")[1]),
                    closed="both",
                )
            )
    return pd.IntervalIndex(idx)


def check_age_intervals(df: pd.DataFrame):
    age_intervals = list(df.index)
    lower_age = age_intervals[0].left
    upper_age = age_intervals[-1].right
    if lower_age != 0:
        logger.warning(
            f"Your age intervals do not contain values smaller than {lower_age}."
            f"We will presume ages from 0 to {lower_age} all have the same value."
        )
        age_intervals[0] = pd.Interval(
            left=0, right=age_intervals[0].right, closed="both"
        )
    if upper_age < 99:
        logger.warning(
            f"Your age intervals do not contain values larger than {upper_age}."
            f"We will presume ages {upper_age} all have the same value."
        )
        age_intervals[-1] = pd.Interval(
            left=age_intervals[-1].left, right=99, closed="both"
        )
    elif upper_age > 99:
        logger.warning(
            "Your age intervals contain values larger than 99."
            "Setting that to the be the uper limit"
        )
        age_intervals[-1] = pd.Interval(
            left=age_intervals[-1].left, right=99, closed="both"
        )
    df.index = age_intervals
    return df


def weighted_interpolation(value, weights):
    weights = np.array(weights)
    return weights * value / weights.sum()


def read_comorbidity_csv(filename: str):
    comorbidity_df = pd.read_csv(filename, index_col=0)
    column_names = [f"0-{comorbidity_df.columns[0]}"]
    for i in range(len(comorbidity_df.columns) - 1):
        column_names.append(
            f"{comorbidity_df.columns[i]}-{comorbidity_df.columns[i+1]}"
        )
    comorbidity_df.columns = column_names
    for column in comorbidity_df.columns:
        no_comorbidity = comorbidity_df[column].loc["no_condition"]
        should_have_comorbidity = 1 - no_comorbidity
        has_comorbidity = np.sum(comorbidity_df[column]) - no_comorbidity
        comorbidity_df[column].iloc[:-1] *= should_have_comorbidity / has_comorbidity

    return comorbidity_df.T


def convert_comorbidities_prevalence_to_dict(prevalence_female, prevalence_male):
    prevalence_reference_population = {}
    for comorbidity in prevalence_female.columns:
        prevalence_reference_population[comorbidity] = {
            "f": prevalence_female[comorbidity].to_dict(),
            "m": prevalence_male[comorbidity].to_dict(),
        }
    return prevalence_reference_population


class Data2Rates:
    def __init__(
        self,
        seroprevalence_df: pd.DataFrame,
        population_by_age_sex_df: pd.DataFrame,
        care_home_population_by_age_sex_df: pd.DataFrame,
        all_deaths_by_age_sex_df: pd.DataFrame,
        hospital_all_deaths_by_age_sex_df: pd.DataFrame,
        hospital_all_admissions_by_age_sex_df: pd.DataFrame,
        hospital_gp_deaths_by_age_sex_df: pd.DataFrame,
        hospital_ch_deaths_by_age_sex_df: pd.DataFrame,
        hospital_gp_admissions_by_age_sex_df: pd.DataFrame,
        hospital_ch_admissions_by_age_sex_df: pd.DataFrame,
        care_home_deaths_by_age_sex_df: pd.DataFrame = None,
        care_home_seroprevalence_by_age_df: pd.DataFrame = None,
        icu_hosp_rate_by_age_sex_df: pd.DataFrame = None,
        deathsicu_deathshosp_rate_by_age_df: pd.DataFrame = None,
        comorbidity_multipliers: Optional[dict] = None,
        comorbidity_prevalence_reference_population: Optional[dict] = None,
        asymptomatic_rates_by_age_sex_df: pd.DataFrame = None,
        mild_rates_by_age_sex_df: pd.DataFrame = None,
    ):
        # seroprev
        self.seroprevalence_df = self._process_df(seroprevalence_df, converters=True)
        self.care_home_seroprevalence_by_age_df = self._process_df(
            care_home_seroprevalence_by_age_df, converters=True
        )

        # populations
        self.population_by_age_sex_df = self._process_df(
            population_by_age_sex_df, converters=False
        )
        self.care_home_population_by_age_sex_df = self._process_df(
            care_home_population_by_age_sex_df, converters=False
        )
        self.all_deaths_by_age_sex_df = self._process_df(
            all_deaths_by_age_sex_df, converters=True
        )
        self.care_home_deaths_by_age_sex_df = self._process_df(
            care_home_deaths_by_age_sex_df, converters=True
        )
        self.all_hospital_deaths_by_age_sex = self._process_df(
            hospital_all_deaths_by_age_sex_df, converters=True
        )
        self.all_hospital_admissions_by_age_sex = self._process_df(
            hospital_all_admissions_by_age_sex_df, converters=True
        )
        self.hospital_gp_deaths_by_age_sex_df = self._process_df(
            hospital_gp_deaths_by_age_sex_df, converters=True
        )
        self.hospital_ch_deaths_by_age_sex_df = self._process_df(
            hospital_ch_deaths_by_age_sex_df, converters=True
        )
        self.hospital_gp_admissions_by_age_sex_df = self._process_df(
            hospital_gp_admissions_by_age_sex_df, converters=True
        )
        self.hospital_ch_admissions_by_age_sex_df = self._process_df(
            hospital_ch_admissions_by_age_sex_df, converters=True
        )
        self.icu_hosp_rate_by_age_sex_df = self._process_df(
            icu_hosp_rate_by_age_sex_df, converters=False
        )
        self.deathsicu_deathshosp_rate_by_age_df = self._process_df(
            deathsicu_deathshosp_rate_by_age_df, converters=False
        )
        self.comorbidity_multipliers = comorbidity_multipliers
        self.comorbidity_prevalence_reference_population = (
            comorbidity_prevalence_reference_population
        )
        self.mild_rates_by_age_sex_df = self._process_df(
            mild_rates_by_age_sex_df, converters=True
        )
        self.asymptomatic_rates_by_age_sex_df = self._process_df(
            asymptomatic_rates_by_age_sex_df, converters=True
        )
        self._init_mappers()

    @classmethod
    def from_file(
        cls,
        seroprevalence_file: str = default_seroprevalence_file,
        care_home_seroprevalence_by_age_file: str = default_care_home_seroprevalence_file,
        population_file: str = default_population_file,
        care_home_population_file: str = default_care_home_population_file,
        all_deaths_file: str = default_all_deaths_file,
        all_hospital_deaths_file: str = default_all_hospital_deaths_file,
        all_hospital_admissions_file: str = default_all_hospital_admissions_file,
        hospital_gp_deaths_file: str = default_gp_hospital_deaths_file,
        hospital_ch_deaths_file: str = default_ch_hospital_deaths_file,
        hospital_gp_admissions_file: str = default_gp_admissions_file,
        hospital_ch_admissions_file: str = default_ch_admissions_file,
        icu_hosp_rate_file: str = default_icu_hosp_rate_file,
        deathsicu_deathshosp_rate_file: str = default_deathsicu_deathshosp_rate_file,
        care_home_deaths_file: str = default_care_home_deaths_file,
        asymptomatic_rates_file: str = default_asymptomatic_rate_file,
        mild_rates_file: str = default_mild_rate_file,
        comorbidity_multipliers_file: Optional[str] = None,
        comorbidity_prevalence_female_file: Optional[str] = None,
        comorbidity_prevalence_male_file: Optional[str] = None,
    ) -> "Data2Rates":

        seroprevalence_df = cls._read_csv(seroprevalence_file)
        population_df = cls._read_csv(population_file)
        all_deaths_df = cls._read_csv(all_deaths_file)
        hospital_gp_deaths_df = cls._read_csv(hospital_gp_deaths_file)
        hospital_ch_deaths_df = cls._read_csv(hospital_ch_deaths_file)
        hospital_all_deaths_df = cls._read_csv(all_hospital_deaths_file)
        hospital_all_admissions_df = cls._read_csv(all_hospital_admissions_file)
        hospital_gp_admissions_df = cls._read_csv(hospital_gp_admissions_file)
        hospital_ch_admissions_df = cls._read_csv(hospital_ch_admissions_file)
        mild_rates_df = cls._read_csv(mild_rates_file)
        asymptomatic_rates_df = cls._read_csv(asymptomatic_rates_file)
        icu_hosp_rate_df = cls._read_csv(icu_hosp_rate_file)
        deathsicu_deathshosp_rate_df = cls._read_csv(deathsicu_deathshosp_rate_file)

        if care_home_deaths_file is None:
            care_home_deaths_df = None
        else:
            care_home_deaths_df = cls._read_csv(care_home_deaths_file)
        if care_home_population_file is None:
            care_home_population_df = None
        else:
            care_home_population_df = cls._read_csv(care_home_population_file)
        if care_home_seroprevalence_by_age_file is None:
            care_home_seroprevalence_by_age_df = None
        else:
            care_home_seroprevalence_by_age_df = cls._read_csv(
                care_home_seroprevalence_by_age_file
            )
        if comorbidity_multipliers_file is not None:
            with open(comorbidity_multipliers_file) as f:
                comorbidity_multipliers = yaml.load(f, Loader=yaml.FullLoader)
        else:
            comorbidity_multipliers = None
        if (
            comorbidity_prevalence_female_file is not None
            and comorbidity_prevalence_male_file is not None
        ):
            comorbidity_female_prevalence = read_comorbidity_csv(
                comorbidity_prevalence_female_file
            )
            comorbidity_male_prevalence = read_comorbidity_csv(
                comorbidity_prevalence_male_file
            )
            prevalence_reference_population = convert_comorbidities_prevalence_to_dict(
                comorbidity_female_prevalence, comorbidity_male_prevalence
            )
        else:
            prevalence_reference_population = None
        return cls(
            seroprevalence_df=seroprevalence_df,
            population_by_age_sex_df=population_df,
            care_home_population_by_age_sex_df=care_home_population_df,
            all_deaths_by_age_sex_df=all_deaths_df,
            hospital_all_deaths_by_age_sex_df=hospital_all_deaths_df,
            hospital_all_admissions_by_age_sex_df=hospital_all_admissions_df,
            hospital_gp_deaths_by_age_sex_df=hospital_gp_deaths_df,
            hospital_ch_deaths_by_age_sex_df=hospital_ch_deaths_df,
            hospital_gp_admissions_by_age_sex_df=hospital_gp_admissions_df,
            hospital_ch_admissions_by_age_sex_df=hospital_ch_admissions_df,
            icu_hosp_rate_by_age_sex_df=icu_hosp_rate_df,
            deathsicu_deathshosp_rate_by_age_df=deathsicu_deathshosp_rate_df,
            care_home_deaths_by_age_sex_df=care_home_deaths_df,
            care_home_seroprevalence_by_age_df=care_home_seroprevalence_by_age_df,
            asymptomatic_rates_by_age_sex_df=asymptomatic_rates_df,
            mild_rates_by_age_sex_df=mild_rates_df,
            comorbidity_multipliers=comorbidity_multipliers,
            comorbidity_prevalence_reference_population=prevalence_reference_population,
        )

    @classmethod
    def _read_csv(cls, filename):
        df = pd.read_csv(filename)
        df.set_index("age", inplace=True)
        return df

    def _process_df(self, df, converters=True):
        if converters:
            new_index = convert_to_intervals(df.index)
            df.index = new_index
            df = check_age_intervals(df=df)
        df = df.sort_index()
        return df

    def _init_mappers(self):
        """
        These mappers (age, sex) -> float are used to weight bins.
        """
        self.gp_mapper = (
            lambda age, sex: self.population_by_age_sex_df.loc[age, sex]
            - self.care_home_population_by_age_sex_df.loc[age, sex]
        )
        self.ch_mapper = lambda age, sex: self.care_home_population_by_age_sex_df.loc[
            age, sex
        ]
        self.all_mapper = lambda age, sex: self.population_by_age_sex_df.loc[age, sex]
        self.gp_deaths_mapper = lambda age, sex: self.get_n_deaths(
            age=age, sex=sex, is_care_home=False
        )
        self.ch_deaths_mapper = lambda age, sex: self.get_n_deaths(
            age=age, sex=sex, is_care_home=True
        )
        self.all_deaths_mapper = lambda age, sex: self.gp_deaths_mapper(
            age, sex
        ) + self.ch_deaths_mapper(age, sex)

    def _get_interpolated_value(self, df, age, sex, weight_mapper=None):
        """
         Interpolates bins to single years by weighting each year by its population times
         the death rate

         Parameters
         ----------
         df
             dataframe with the structure
             age | male | female
             0-5 | 2    | 3
             etc.
        weight_mapper
             function mapping (age,sex) -> weight
             if not provided uses population weight.
        """
        if weight_mapper is None:
            weight_mapper = lambda age, sex: 1
        age_bin = df.loc[age].name
        data_bin = df.loc[age, sex]
        bin_weight = sum(
            [
                weight_mapper(age_i, sex)
                for age_i in range(age_bin.left, age_bin.right + 1)
            ]
        )
        if bin_weight == 0:
            return 0
        value_weight = weight_mapper(age, sex)
        return value_weight * data_bin / bin_weight

    def get_n_care_home(self, age: int, sex: str):
        return self.care_home_population_by_age_sex_df.loc[age, sex]

    def get_n_cases(self, age: int, sex: str, is_care_home: bool = False) -> float:
        if is_care_home:
            sero_prevalence = self.care_home_seroprevalence_by_age_df.loc[
                age, "seroprevalence"
            ]
            n_people = self.get_n_care_home(age, sex)
        else:
            sero_prevalence = self.seroprevalence_df.loc[age, "seroprevalence"]
            n_people = self.population_by_age_sex_df.loc[age, sex]
            if self.care_home_population_by_age_sex_df is not None:
                n_people -= self.get_n_care_home(age=age, sex=sex)
        # including death correction
        n_deaths = self.get_n_deaths(age=age, sex=sex, is_care_home=is_care_home)
        n_cases = (n_people - n_deaths) * sero_prevalence + n_deaths
        return n_cases

    def get_care_home_deaths(self, age: int, sex: str):
        return self._get_interpolated_value(
            df=self.care_home_deaths_by_age_sex_df,
            age=age,
            sex=sex,
            weight_mapper=self.ch_mapper,
        )

    def get_all_deaths(self, age: int, sex: str):
        return self._get_interpolated_value(
            df=self.all_deaths_by_age_sex_df,
            age=age,
            sex=sex,
            weight_mapper=self.all_mapper,
        )

    def get_n_deaths(self, age: int, sex: str, is_care_home: bool = False) -> int:
        if is_care_home:
            return self.get_care_home_deaths(age=age, sex=sex)
        else:
            deaths_total = self.get_all_deaths(age=age, sex=sex)
            if self.care_home_deaths_by_age_sex_df is None:
                return deaths_total
            else:
                deaths_care_home = self.get_care_home_deaths(age=age, sex=sex)
                return deaths_total - deaths_care_home

    # ### hospital ###
    def get_all_hospital_deaths(self, age: int, sex: str):
        return self._get_interpolated_value(
            df=self.all_hospital_deaths_by_age_sex,
            age=age,
            sex=sex,
            weight_mapper=self.all_deaths_mapper,
        )

    def get_icu_hospital_rate(self, age: int, sex: str):
        return self.icu_hosp_rate_by_age_sex_df.loc[age, sex]

    def get_deathsicu_deathshosp_rate(self, age: int, sex: str):
        return self.deathsicu_deathshosp_rate_by_age_df.loc[age, sex]

    def get_gp_hospital_deaths(self, age: int, sex: str):
        return self.get_all_hospital_deaths(
            age=age, sex=sex
        ) - self.get_care_home_hospital_deaths(age=age, sex=sex)

    def get_care_home_hospital_deaths(self, age: int, sex: str):
        return self._get_interpolated_value(
            df=self.hospital_ch_deaths_by_age_sex_df,
            age=age,
            sex=sex,
            weight_mapper=self.ch_deaths_mapper,
        )

    def get_n_hospital_deaths(
        self, age: int, sex: str, is_care_home: bool = False
    ) -> int:
        if is_care_home:
            return self.get_care_home_hospital_deaths(age=age, sex=sex)
        else:
            return self.get_gp_hospital_deaths(age=age, sex=sex)

    def get_all_hospital_admissions(self, age: int, sex: str):
        return self._get_interpolated_value(
            df=self.all_hospital_admissions_by_age_sex,
            age=age,
            sex=sex,
            weight_mapper=self.all_deaths_mapper,
        )

    def get_gp_hospital_admissions(self, age: int, sex: str):
        return self.get_all_hospital_admissions(
            age=age, sex=sex
        ) - self.get_care_home_hospital_admissions(age=age, sex=sex)

    def get_care_home_hospital_admissions(self, age: int, sex: str):
        return self._get_interpolated_value(
            df=self.hospital_ch_admissions_by_age_sex_df,
            age=age,
            sex=sex,
            weight_mapper=self.ch_deaths_mapper,
        )

    def get_n_hospital_admissions(
        self, age: int, sex: str, is_care_home: bool = False
    ) -> int:
        if is_care_home:
            return self.get_care_home_hospital_admissions(age=age, sex=sex)
        else:
            return self.get_gp_hospital_admissions(age=age, sex=sex)

    def get_n_icu_admissions(
        self, age: int, sex: str, is_care_home: bool = False
    ) -> int:

        if is_care_home:
            return self.get_care_home_hospital_admissions(
                age=age, sex=sex
            ) * self.get_icu_hospital_rate(age=age, sex=sex)
        else:
            return self.get_gp_hospital_admissions(
                age=age, sex=sex
            ) * self.get_icu_hospital_rate(age=age, sex=sex)

    def get_n_icu_deaths(self, age: int, sex: str, is_care_home: bool = False) -> int:

        return self.get_n_hospital_deaths(
            age=age, sex=sex, is_care_home=is_care_home
        ) * self.get_deathsicu_deathshosp_rate(age=age, sex=sex)

    def get_hospital_death_rate(
        self, age: int, sex: str, is_care_home: bool = False
    ) -> int:
        return self.get_n_hospital_deaths(
            age=age, sex=sex, is_care_home=is_care_home
        ) / self.get_n_hospital_admissions(age=age, sex=sex, is_care_home=is_care_home)

    def get_icu_death_rate(self, age: int, sex: str, is_care_home: bool = False) -> int:
        return self.get_n_icu_deaths(
            age=age, sex=sex, is_care_home=is_care_home
        ) / self.get_n_icu_admissions(age=age, sex=sex, is_care_home=is_care_home)

    # ## home ###
    def get_care_home_home_deaths(self, age: int, sex: str):
        return self.get_n_deaths(
            age=age, sex=sex, is_care_home=True
        ) - self.get_n_hospital_deaths(age=age, sex=sex, is_care_home=True)

    def get_n_home_deaths(self, age: int, sex: str, is_care_home: bool = False):
        if is_care_home:
            return self.get_care_home_home_deaths(age=age, sex=sex)
        else:
            return self.get_n_deaths(
                age=age, sex=sex, is_care_home=False
            ) - self.get_n_hospital_deaths(age=age, sex=sex, is_care_home=False)

    # ### IFRS ###
    def _get_ifr(
        self,
        function,
        age: Union[int, pd.Interval],
        sex: str,
        is_care_home: bool = False,
    ):
        if isinstance(age, pd.Interval):
            if sex == "all":
                function_values = sum(
                    function(age=agep, sex="male", is_care_home=is_care_home)
                    + function(age=agep, sex="female", is_care_home=is_care_home)
                    for agep in range(age.left, age.right + 1)
                )
                n_cases = sum(
                    self.get_n_cases(age=agep, sex="male", is_care_home=is_care_home)
                    + self.get_n_cases(
                        age=agep, sex="female", is_care_home=is_care_home
                    )
                    for agep in range(age.left, age.right + 1)
                )
            else:
                function_values = sum(
                    function(age=agep, sex=sex, is_care_home=is_care_home)
                    for agep in range(age.left, age.right + 1)
                )
                n_cases = sum(
                    self.get_n_cases(age=agep, sex=sex, is_care_home=is_care_home)
                    for agep in range(age.left, age.right + 1)
                )
        else:
            if sex == "all":
                function_values = function(
                    age=age, sex="male", is_care_home=is_care_home
                ) + function(age=age, sex="female", is_care_home=is_care_home)
                n_cases = self.get_n_cases(
                    age=age, sex="male", is_care_home=is_care_home
                ) + self.get_n_cases(age=age, sex="female", is_care_home=is_care_home)
            else:
                function_values = function(age=age, sex=sex, is_care_home=is_care_home)
                n_cases = self.get_n_cases(age=age, sex=sex, is_care_home=is_care_home)
        if n_cases * function_values == 0:
            return 0
        return max(function_values / n_cases, 0)

    def get_infection_fatality_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home: bool = False
    ) -> float:
        return self._get_ifr(
            function=self.get_n_deaths, age=age, sex=sex, is_care_home=is_care_home
        )

    def get_hospital_infection_fatality_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home: bool = False
    ) -> int:
        return self._get_ifr(
            function=self.get_n_hospital_deaths,
            age=age,
            sex=sex,
            is_care_home=is_care_home,
        )

    def get_icu_infection_fatality_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home: bool = False
    ) -> int:
        return self._get_ifr(
            function=self.get_n_icu_deaths, age=age, sex=sex, is_care_home=is_care_home
        )

    def get_hospital_infection_admission_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home: bool = False
    ) -> int:
        return self._get_ifr(
            function=self.get_n_hospital_admissions,
            age=age,
            sex=sex,
            is_care_home=is_care_home,
        )

    def get_icu_infection_admission_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home: bool = False
    ) -> int:
        return self._get_ifr(
            function=self.get_n_icu_admissions,
            age=age,
            sex=sex,
            is_care_home=is_care_home,
        )

    def get_home_infection_fatality_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home: bool = False
    ):
        return self._get_ifr(
            function=self.get_n_home_deaths, age=age, sex=sex, is_care_home=is_care_home
        )

    def get_mild_rate(self, age: Union[int, pd.Interval], sex: str, is_care_home):
        if isinstance(age, pd.Interval):
            return self.mild_rates_by_age_sex_df.loc[age.left : age.right, sex].mean()
        else:
            return self.mild_rates_by_age_sex_df.loc[age, sex]

    def get_asymptomatic_rate(
        self, age: Union[int, pd.Interval], sex: str, is_care_home
    ):
        if isinstance(age, pd.Interval):
            return self.asymptomatic_rates_by_age_sex_df.loc[
                age.left : age.right, sex
            ].mean()
        else:
            return self.mild_rates_by_age_sex_df.loc[age, sex]


def get_outputs_df(rates, age_bins):
    outputs = pd.DataFrame(index=age_bins)
    for pop in ["gp", "ch"]:
        for sex in ["male", "female"]:
            for fname, function in zip(
                [
                    "asymptomatic",
                    "mild",
                    "ifr",
                    "hospital_ifr",
                    "icu_ifr",
                    "hospital",
                    "icu",
                    "home_ifr",
                ],
                [
                    rates.get_asymptomatic_rate,
                    rates.get_mild_rate,
                    rates.get_infection_fatality_rate,
                    rates.get_hospital_infection_fatality_rate,
                    rates.get_icu_infection_fatality_rate,
                    rates.get_hospital_infection_admission_rate,
                    rates.get_icu_infection_admission_rate,
                    rates.get_home_infection_fatality_rate,
                ],
            ):
                colname = f"{pop}_{fname}_{sex}"
                for age_bin in age_bins:
                    outputs.loc[age_bin, colname] = function(
                        age=age_bin, sex=sex, is_care_home=pop == "ch"
                    )
    return outputs


import numpy as np
import pandas as pd

from june import paths

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.demography.person import Person

_sex_short_to_long = {"m": "male", "f": "female"}
index_to_maximum_symptoms_tag = {
    0: "asymptomatic",
    1: "mild",
    2: "severe",
    3: "hospitalised",
    4: "intensive_care",
    5: "dead_home",
    6: "dead_hospital",
    7: "dead_icu",
}

default_rates_file = paths.data_path / "input/health_index/infection_outcome_rates.csv"


def _parse_interval(interval):
    age1, age2 = interval.split(",")
    age1 = int(age1.split("[")[-1])
    age2 = int(age2.split("]")[0])
    return pd.Interval(left=age1, right=age2, closed="both")


class HealthIndexGenerator:
    def __init__(
        self,
        rates_df: pd.DataFrame,
        care_home_min_age: int = 50,
        max_age=99,
        m_exp_baseline=79.4,
        f_exp_baseline=83.1,
        m_exp=79.4,
        f_exp=83.1,
        cutoff_age=16,
    ):
        """
        A Generator to determine the final outcome of an infection.

        Parameters
        ----------
        rates_df
            a dataframe containing all the different outcome rates,
            check the default file for a reference
        care_home_min_age
            the age from which a care home resident follows the health index
            for care homes.
        """
        self.care_home_min_age = care_home_min_age
        self.rates_df = rates_df
        self.age_bins = self.rates_df.index
        self.probabilities = self._get_probabilities(max_age)
        self.max_mild_symptom_tag = {
            value: key for key, value in index_to_maximum_symptoms_tag.items()
        }["severe"]

        self.m_exp_baseline = m_exp_baseline
        self.f_exp_baseline = f_exp_baseline
        self.m_exp = m_exp
        self.f_exp = f_exp
        self.cutoff_age = cutoff_age
        if self.m_exp_baseline == self.m_exp and self.f_exp_baseline == self.f_exp:
            self.use_physiological_age = False
        else:
            self.use_physiological_age = True

    @classmethod
    def from_file(
        cls,
        rates_file: str = default_rates_file,
        care_home_min_age=50,
        m_exp_baseline=79.4,
        f_exp_baseline=83.1,
        m_exp=79.4,
        f_exp=83.1,
        cutoff_age=16,
    ):
        ifrs = pd.read_csv(rates_file, index_col=0)
        ifrs = ifrs.rename(_parse_interval)
        return cls(
            rates_df=ifrs,
            care_home_min_age=care_home_min_age,
            m_exp_baseline=m_exp_baseline,
            f_exp_baseline=f_exp_baseline,
            m_exp=m_exp,
            f_exp=f_exp,
            cutoff_age=cutoff_age,
        )

    def physiological_age(self, person_age, sex):
        if sex == "f":
            exp_baseline_age = self.f_exp_baseline
            exp_age = self.f_exp
        elif sex == "m":
            exp_baseline_age = self.m_exp_baseline
            exp_age = self.m_exp

        if person_age > self.cutoff_age:
            if exp_age == self.cutoff_age:
                return 99
            m = (exp_baseline_age - self.cutoff_age) / (exp_age - self.cutoff_age)
            c = self.cutoff_age * (1 - m)
            scaled_age = person_age * m + c
        else:
            scaled_age = person_age

        if scaled_age > 99.0:
            scaled_age = 99.0
        return int(round(scaled_age))

    def __call__(self, person: "Person", infection_id: int):
        """
        Computes the probability of having all 8 posible outcomes for all ages between 0 and 100,
             self.max_mild_symptom_tag = [
                tag.value for tag in SymptomTag if tag.name == "severe"
            ][0]       for male and female
        Given the person and the id of the infection responsible for the symptoms
        """
        if (
            person.residence is not None
            and person.residence.group.spec == "care_home"
            and person.age >= self.care_home_min_age
        ):
            population = "ch"
        else:
            population = "gp"
        if self.use_physiological_age:
            physiological_age = self.physiological_age(int(person.age), person.sex)
        else:
            physiological_age = int(person.age)
        probabilities = self.probabilities[population][person.sex][physiological_age]
        if infection_id is not None:
            effective_multiplier = person.immunity.get_effective_multiplier(
                infection_id
            )
            if effective_multiplier != 1.0:
                probabilities = self.apply_effective_multiplier(
                    probabilities, effective_multiplier
                )
        return np.cumsum(probabilities)

    def apply_effective_multiplier(self, probabilities, effective_multiplier):
        modified_probabilities = np.zeros_like(probabilities)
        probability_mild = probabilities[: self.max_mild_symptom_tag].sum()
        probability_severe = probabilities[self.max_mild_symptom_tag :].sum() + (
            1 - probabilities.sum()
        )
        modified_probability_severe = probability_severe * effective_multiplier
        modified_probability_mild = 1.0 - modified_probability_severe
        modified_probabilities[: self.max_mild_symptom_tag] = (
            probabilities[: self.max_mild_symptom_tag]
            * modified_probability_mild
            / probability_mild
        )
        modified_probabilities[self.max_mild_symptom_tag :] = (
            probabilities[self.max_mild_symptom_tag :]
            * modified_probability_severe
            / probability_severe
        )
        return modified_probabilities

    def _set_probability_per_age_bin(self, p, age_bin, sex, population):
        _sex = _sex_short_to_long[sex]
        asymptomatic_rate = self.rates_df.loc[
            age_bin, f"{population}_asymptomatic_{_sex}"
        ]
        mild_rate = self.rates_df.loc[age_bin, f"{population}_mild_{_sex}"]
        hospital_rate = self.rates_df.loc[age_bin, f"{population}_hospital_{_sex}"]
        icu_rate = self.rates_df.loc[age_bin, f"{population}_icu_{_sex}"]
        home_dead_rate = self.rates_df.loc[age_bin, f"{population}_home_ifr_{_sex}"]
        hospital_dead_rate = self.rates_df.loc[
            age_bin, f"{population}_hospital_ifr_{_sex}"
        ]
        icu_dead_rate = self.rates_df.loc[age_bin, f"{population}_icu_ifr_{_sex}"]
        severe_rate = max(
            0, 1 - (hospital_rate + home_dead_rate + asymptomatic_rate + mild_rate)
        )
        # fill each age in bin
        for age in range(age_bin.left, age_bin.right + 1):
            p[population][sex][age][0] = asymptomatic_rate  # recovers as asymptomatic
            p[population][sex][age][1] = mild_rate  # recovers as mild
            p[population][sex][age][2] = severe_rate  # recovers as severe
            p[population][sex][age][3] = (
                hospital_rate - hospital_dead_rate
            )  # recovers in the ward
            p[population][sex][age][4] = max(
                icu_rate - icu_dead_rate, 0
            )  # recovers in the icu
            p[population][sex][age][5] = max(home_dead_rate, 0)  # dies at home
            p[population][sex][age][6] = max(
                hospital_dead_rate - icu_dead_rate, 0
            )  # dies in the ward
            p[population][sex][age][7] = icu_dead_rate
            # renormalise all but death rates (since those are the most certain ones)
            to_keep_sum = p[population][sex][age][5:].sum()
            to_adjust_sum = p[population][sex][age][:5].sum()
            target_adjust_sum = max(1 - to_keep_sum, 0)
            p[population][sex][age][:5] *= target_adjust_sum / to_adjust_sum

    def _get_probabilities(self, max_age=99):
        n_outcomes = 8
        probabilities = {
            "ch": {
                "m": np.zeros((max_age + 1, n_outcomes)),
                "f": np.zeros((max_age + 1, n_outcomes)),
            },
            "gp": {
                "m": np.zeros((max_age + 1, n_outcomes)),
                "f": np.zeros((max_age + 1, n_outcomes)),
            },
        }
        for population in ("ch", "gp"):
            for sex in ["m", "f"]:
                # values are constant at each bin
                for age_bin in self.age_bins:
                    self._set_probability_per_age_bin(
                        p=probabilities, age_bin=age_bin, sex=sex, population=population
                    )
        return probabilities


from .data_to_rates import Data2Rates
from .health_index import HealthIndexGenerator


import pandas as pd
import numpy as np

from june import paths

default_super_area_to_region_file = (
    paths.data_path / "input/geography/area_super_area_region.csv"
)
default_residents_per_super_area_file = (
    paths.data_path / "input/demography/residents_per_super_area.csv"
)


def get_super_area_population_weights_by_region(
    super_area_to_region: pd.DataFrame, residents_per_super_area: pd.DataFrame
) -> pd.DataFrame:
    """
    Compute the weight in population that a super area has over its whole region, used
    to convert regional cases to cases by super area by population density

    Returns
    -------
    data frame indexed by super area, with weights and region
    """
    people_per_super_area_and_region = pd.merge(
        residents_per_super_area, super_area_to_region, on="super_area"
    )
    people_per_region = people_per_super_area_and_region.groupby("region").sum()[
        "n_residents"
    ]
    people_per_super_area_and_region[
        "weights"
    ] = people_per_super_area_and_region.apply(
        lambda x: x.n_residents / people_per_region.loc[x.region], axis=1
    )
    ret = people_per_super_area_and_region.loc[:, ["super_area", "weights"]]
    ret = ret.set_index("super_area")
    return ret


def get_super_area_population_weights(
    residents_per_super_area: pd.DataFrame,
) -> pd.DataFrame:
    """
    Compute the weight in population that a super area has over its whole region, used
    to convert regional cases to cases by super area by population density

    Returns
    -------
    data frame indexed by super area, with weights and region
    """
    residents_per_super_area.set_index("super_area", inplace=True)

    percent = residents_per_super_area / residents_per_super_area["n_residents"].sum()
    return percent


class CasesDistributor:
    """
    Class to distribute cases to super areas from different
    geographic and demographic granularities.
    """

    def __init__(self, cases_per_super_area):
        cases_per_super_area.index = pd.to_datetime(cases_per_super_area.index)
        self.cases_per_super_area = cases_per_super_area

    @classmethod
    def from_regional_cases(
        cls,
        cases_per_day_region: pd.DataFrame,
        super_area_to_region: pd.DataFrame,
        residents_per_super_area: pd.DataFrame,
    ):
        """
        Creates cases per super area from specifying the number of cases per region.

        Parameters
        ----------
        cases_per_day_region
            A Pandas df with date as index, regions as columns, and cases as values.
        super_area_to_region
            A df containing two columns ['super_area', 'region']
        residents_per_super_area
            A df with the number of residents per super area (index).
        """
        residents_per_super_area.set_index("super_area", inplace=True)
        ret = pd.DataFrame(index=cases_per_day_region.index)
        weights_per_super_area = get_super_area_population_weights_by_region(
            super_area_to_region=super_area_to_region,
            residents_per_super_area=residents_per_super_area,
        )
        for region in cases_per_day_region.columns:
            region_cases = cases_per_day_region.loc[:, region]
            region_super_areas = super_area_to_region.loc[
                super_area_to_region.region == region, "super_area"
            ]
            ret.loc[:, region_super_areas] = 0
            for date, n_cases in region_cases.iteritems():
                weights = weights_per_super_area.loc[
                    region_super_areas
                ].values.flatten()
                cases_distributed = np.random.choice(
                    region_super_areas, size=n_cases, p=weights, replace=True
                )
                super_areas, cases = np.unique(cases_distributed, return_counts=True)
                ret.loc[date, super_areas] = cases
        return cls(ret)

    @classmethod
    def from_regional_cases_file(
        cls,
        cases_per_day_region_file: str,
        super_area_to_region_file: str = default_super_area_to_region_file,
        residents_per_super_area_file: str = default_residents_per_super_area_file,
    ):
        cases_per_day_region = pd.read_csv(cases_per_day_region_file, index_col=0)
        super_area_to_region = pd.read_csv(super_area_to_region_file)
        super_area_to_region = super_area_to_region.loc[
            :, ["super_area", "region"]
        ].drop_duplicates()
        residents_per_super_area = pd.read_csv(residents_per_super_area_file)
        return cls.from_regional_cases(
            cases_per_day_region=cases_per_day_region,
            super_area_to_region=super_area_to_region,
            residents_per_super_area=residents_per_super_area,
        )

    @classmethod
    def from_national_cases(
        cls,
        cases_per_day: pd.DataFrame,
        super_area_to_region: pd.DataFrame,
        residents_per_super_area: pd.DataFrame,
    ):
        ret = pd.DataFrame(index=cases_per_day.index)
        weights_per_super_area = get_super_area_population_weights(
            residents_per_super_area=residents_per_super_area
        )
        for date, n_cases in cases_per_day.iterrows():
            weights = weights_per_super_area.values.flatten()
            cases_distributed = np.random.choice(
                list(weights_per_super_area.index),
                size=n_cases.values[0],
                p=weights,
                replace=True,
            )
            super_areas, cases = np.unique(cases_distributed, return_counts=True)
            ret.loc[date, super_areas] = cases
        return cls(ret)

    @classmethod
    def from_national_cases_file(
        cls,
        cases_per_day_file,
        super_area_to_region_file: str = default_super_area_to_region_file,
        residents_per_super_area_file: str = default_residents_per_super_area_file,
    ):
        cases_per_day = pd.read_csv(cases_per_day_file, index_col=0)
        residents_per_super_area = pd.read_csv(residents_per_super_area_file)
        super_area_to_region = pd.read_csv(super_area_to_region_file)
        super_area_to_region = super_area_to_region.loc[
            :, ["super_area", "region"]
        ].drop_duplicates()

        return cls.from_national_cases(
            cases_per_day=cases_per_day,
            super_area_to_region=super_area_to_region,
            residents_per_super_area=residents_per_super_area,
        )


import pandas as pd
import numpy as np
from random import random
from collections import defaultdict

from .infection_seed import InfectionSeed
from june.epidemiology.infection import InfectionSelector

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.world import World


class ClusteredInfectionSeed(InfectionSeed):
    def __init__(
        self,
        world: "World",
        infection_selector: InfectionSelector,
        daily_cases_per_capita_per_age_per_region: pd.DataFrame,
        seed_past_infections: bool = True,
        seed_strength=1.0,
        account_secondary_infections=False,
    ):
        super().__init__(
            world=world,
            infection_selector=infection_selector,
            daily_cases_per_capita_per_age_per_region=daily_cases_per_capita_per_age_per_region,
            seed_past_infections=seed_past_infections,
            seed_strength=seed_strength,
            account_secondary_infections=account_secondary_infections,
        )

    def get_total_people_to_infect(self, people, cases_per_capita_per_age):
        people_by_age = defaultdict(int)
        for person in people:
            people_by_age[person.age] += 1
        total = sum(
            [
                people_by_age[age] * cases_per_capita_per_age.loc[age]
                for age in people_by_age
            ]
        )
        ret = int(total)
        ret += int(random() < (total - ret))
        return ret

    def get_household_score(self, household, age_distribution):
        if len(household.residents) == 0:
            return 0
        ret = 0
        for resident in household.residents:
            ret += age_distribution.loc[resident.age]
        return ret / np.sqrt(len(household.residents))

    def infect_super_area(
        self, super_area, cases_per_capita_per_age, time, record=None
    ):

        infection_id = self.infection_selector.infection_class.infection_id()
        people = super_area.people
        total_to_infect = self.get_total_people_to_infect(
            people=people, cases_per_capita_per_age=cases_per_capita_per_age
        )
        age_distribution = cases_per_capita_per_age / cases_per_capita_per_age.sum()
        households = np.array(super_area.households)
        scores = [self.get_household_score(h, age_distribution) for h in households]

        cum_scores = np.cumsum(scores)
        seeded_households = set()
        while total_to_infect > 0:
            num = random() * cum_scores[-1]
            idx = np.searchsorted(cum_scores, num)
            household = households[idx]
            if household.id in seeded_households:
                continue
            for person in household.residents:
                if person.immunity.get_susceptibility(infection_id) > 0:
                    self.infect_person(person=person, time=time, record=record)
                    if time < 0:
                        self.bring_infection_up_to_date(
                            person=person, time_from_infection=-time, record=record
                        )
                    total_to_infect -= 1
                    if total_to_infect < 1:
                        return
                    seeded_households.add(household.id)


import numpy as np
import pandas as pd
import random
import datetime
import logging
from collections import defaultdict
from typing import List, Optional

from june.records import Record
from june.epidemiology.infection import InfectionSelector

from .infection_seed import InfectionSeed
from june.epidemiology.infection import InfectionSelector

from june.world import World
from june.geography import Region, Area, SuperArea

seed_logger = logging.getLogger("seed")


class ExactNumInfectionSeed(InfectionSeed):
    def __init__(
        self,
        world: "World",
        infection_selector: InfectionSelector,
        daily_cases_per_capita_per_age_per_region: pd.DataFrame,
        seed_past_infections: bool = True,
        seed_strength=1.0,
        # account_secondary_infections=False,
    ):
        super().__init__(
            world=world,
            infection_selector=infection_selector,
            daily_cases_per_capita_per_age_per_region=daily_cases_per_capita_per_age_per_region,
            seed_past_infections=seed_past_infections,
            seed_strength=seed_strength,
            account_secondary_infections=False,
        )
        # use age bin in exact number mode. No need to expanding individual ages.
        self.daily_cases_per_capita_per_age_per_region = (
            daily_cases_per_capita_per_age_per_region * seed_strength
        )

        self.iter_type_set = set()
        if "all" not in daily_cases_per_capita_per_age_per_region.columns:
            # generate list of existing regions, superareas, areas
            regions = [region.name for region in self.world.regions]
            super_areas = [super_area.name for super_area in self.world.super_areas]
            areas = [area.name for area in self.world.areas]

            # check if seeding locations are existing in curent world
            for loc_name in self.daily_cases_per_capita_per_age_per_region.columns:
                if loc_name in regions:
                    self.iter_type_set.add(self.world.regions)
                elif loc_name in super_areas:
                    self.iter_type_set.add(self.world.super_areas)
                elif loc_name in areas:
                    self.iter_type_set.add(self.world.areas)
                else:
                    raise TypeError(
                        "invalid seeding location (column) name: " + loc_name
                    )

    def infect_super_area(
        self, super_area, cases_per_capita_per_age, time, record=None
    ):
        people = super_area.people
        infection_id = self.infection_selector.infection_class.infection_id()

        age_ranges = []
        for age in cases_per_capita_per_age.index:
            agemin, agemax = age.split("-")
            age_ranges.append([int(agemin), int(agemax)])

        N_seeded = np.zeros(len(age_ranges), dtype="int")
        random.seed()
        for person in random.sample(list(people), len(people)):
            in_seed_age_range = False
            for j in range(len(age_ranges)):
                if (
                    person.age >= age_ranges[j][0]
                    and person.age < age_ranges[j][1]
                    and N_seeded[j] < cases_per_capita_per_age[j]
                ):
                    in_seed_age_range = True
                    break
            if (
                in_seed_age_range
                and person.immunity.get_susceptibility(infection_id) > 0
            ):
                self.infect_person(person=person, time=time, record=record)
                self.current_seeded_cases[person.region.name] += 1
                if time < 0:
                    self.bring_infection_up_to_date(
                        person=person, time_from_infection=-time, record=record
                    )

                N_seeded[j] += 1
                if np.all(N_seeded == np.array(cases_per_capita_per_age)):
                    break

    def infect_super_areas(
        self,
        cases_per_capita_per_age_per_region: pd.DataFrame,
        time: float,
        date: datetime.datetime,
        record: Optional[Record] = None,
    ):
        """
        Infect world/region/super_area/area with number of cases given by data frame
        Not only super area, but still keep the old function name for now.

        Parameters
        ----------
        n_cases_per_super_area:
            data frame containig the number of cases per world/region/super_area/area
        time:
            Time where infections start (could be negative if they started before the simulation)
        """
        if "all" in cases_per_capita_per_age_per_region.columns:
            self.infect_super_area(
                super_area=self.world,
                cases_per_capita_per_age=cases_per_capita_per_age_per_region["all"],
                time=time,
                record=record,
            )
        else:
            num_locations_to_seed = len(cases_per_capita_per_age_per_region.columns)
            for geo_type in self.iter_type_set:
                for this_loc in geo_type:
                    try:
                        cases_per_capita_per_age = cases_per_capita_per_age_per_region[
                            this_loc.name
                        ]
                    except KeyError:
                        continue

                    """ 
                    ### 
                    # TO DO: rewite self._adjust_seed_accounting_secondary_infections to work for superarea/area
                    ###
                    if self._need_to_seed_accounting_secondary_infections(date=date):
                        cases_per_capita_per_age = (
                            self._adjust_seed_accounting_secondary_infections(
                                cases_per_capita_per_age=cases_per_capita_per_age,
                                region=this_loc,
                                date=date,
                                time=time,
                            )
                        )
                    """
                    self.infect_super_area(
                        super_area=this_loc,
                        cases_per_capita_per_age=cases_per_capita_per_age,
                        time=time,
                        record=record,
                    )
                    num_locations_to_seed -= 1

            # check if all columns are seeded
            assert (
                num_locations_to_seed < 1
            ), "something wrong in location (column) name !!!"


class ExactNumClusteredInfectionSeed(ExactNumInfectionSeed):
    def __init__(
        self,
        world: "World",
        infection_selector: InfectionSelector,
        daily_cases_per_capita_per_age_per_region: pd.DataFrame,
        seed_past_infections: bool = True,
        seed_strength=1.0,
        # account_secondary_infections=False,
    ):
        super().__init__(
            world=world,
            infection_selector=infection_selector,
            daily_cases_per_capita_per_age_per_region=daily_cases_per_capita_per_age_per_region,
            seed_past_infections=seed_past_infections,
            seed_strength=seed_strength,
            # account_secondary_infections=account_secondary_infections,
        )

    def get_household_score(self, household, age_distribution):
        if len(household.residents) == 0:
            return 0
        age_ranges = []
        for age in age_distribution.index:
            agemin, agemax = age.split("-")
            age_ranges.append([int(agemin), int(agemax)])
        ret = 0
        for resident in household.residents:
            for ii, age_bin in enumerate(age_ranges):
                if resident.age >= age_bin[0] and resident.age < age_bin[1]:
                    ret += age_distribution[ii]
                    break
        return ret / np.sqrt(len(household.residents))

    def infect_super_area(
        self, super_area, cases_per_capita_per_age, time, record=None
    ):
        households = []
        if isinstance(super_area, World):
            for r in super_area.regions:
                for sa in r.super_areas:
                    for area in sa.areas:
                        households += area.households
        elif isinstance(super_area, Region):
            for sa in super_area.super_areas:
                for area in sa.areas:
                    households += area.households
        elif isinstance(super_area, SuperArea):
            for area in super_area.areas:
                households += area.households
        elif isinstance(super_area, Area):
            households += super_area.households
        else:
            raise TypeError(
                "invalid seeding location type: " + type(super_area).__name__
            )

        age_distribution = cases_per_capita_per_age / cases_per_capita_per_age.sum()
        scores = [self.get_household_score(h, age_distribution) for h in households]
        cum_scores = np.cumsum(scores)

        infection_id = self.infection_selector.infection_class.infection_id()
        total_to_infect = cases_per_capita_per_age.sum()

        seeded_households = set()
        while total_to_infect > 0:
            num = random.random() * cum_scores[-1]
            idx = np.searchsorted(cum_scores, num)
            household = households[idx]
            if household.id in seeded_households:
                continue
            for person in household.residents:
                if person.immunity.get_susceptibility(infection_id) > 0:
                    self.infect_person(person=person, time=time, record=record)
                    if time < 0:
                        self.bring_infection_up_to_date(
                            person=person,
                            time_from_infection=-time,
                            record=record,
                        )
                    total_to_infect -= 1
                    if total_to_infect < 1:
                        return
                    seeded_households.add(household.id)


import numpy as np
import pandas as pd
from random import random
import datetime
import logging
from collections import defaultdict
from typing import List, Optional

from june.records import Record
from june.epidemiology.infection import InfectionSelector
from june.epidemiology.epidemiology import Epidemiology
from june.utils import parse_age_probabilities

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.world import World

seed_logger = logging.getLogger("seed")


class InfectionSeed:
    """
    The infection seed takes a dataframe of cases to seed per capita, per age, and per region.
    There are multiple ways to construct the dataframe, from deaths, tests, etc. Each infection seed
    is associated to one infection selector, so if we run multiple infection types, there could be multiple infection
    seeds for each infection type.
    """

    def __init__(
        self,
        world: "World",
        infection_selector: InfectionSelector,
        daily_cases_per_capita_per_age_per_region: pd.DataFrame,
        seed_past_infections: bool = True,
        seed_strength=1.0,
        account_secondary_infections=False,
    ):
        """
        Class that generates the seed for the infection.

        Parameters
        ----------
        world:
            world to infect
        infection_selector:
            selector to generate infections
        daily_cases_per_capita_per_region:
            Double indexed dataframe. First index: date, second index: age in brackets "0-100",
            columns: region names, use "all" as placeholder for whole England.
            Example:
                date,age,North East,London
                2020-07-01,0-100,0.05,0.1
        seed_past_infections:
            whether to seed infections that started past the initial simulation point.
        """
        self.world = world
        self.infection_selector = infection_selector
        self.daily_cases_per_capita_per_age_per_region = self._parse_input_dataframe(
            df=daily_cases_per_capita_per_age_per_region, seed_strength=seed_strength
        )
        self.min_date = (
            self.daily_cases_per_capita_per_age_per_region.index.get_level_values(
                "date"
            ).min()
        )
        self.max_date = (
            self.daily_cases_per_capita_per_age_per_region.index.get_level_values(
                "date"
            ).max()
        )
        self.dates_seeded = set()
        self.past_infections_seeded = not (seed_past_infections)
        self.seed_past_infections = seed_past_infections
        self.seed_strength = seed_strength
        self.account_secondary_infections = account_secondary_infections
        self.last_seeded_cases = defaultdict(int)
        self.current_seeded_cases = defaultdict(int)

    def _parse_input_dataframe(self, df, seed_strength=1.0):
        """
        Parses ages by expanding the intervals.
        """
        multi_index = pd.MultiIndex.from_product(
            [df.index.get_level_values("date").unique(), range(0, 100)],
            names=["date", "age"],
        )
        ret = pd.DataFrame(index=multi_index, columns=df.columns, dtype=float)
        for date in df.index.get_level_values("date"):
            for region in df.loc[date].columns:
                cases_per_age = parse_age_probabilities(
                    df.loc[date, region].to_dict(), fill_value=0.0
                )
                ret.loc[date, region] = np.array(cases_per_age)
        ret *= seed_strength
        return ret

    @classmethod
    def from_global_age_profile(
        cls,
        world: "World",
        infection_selector: InfectionSelector,
        daily_cases_per_region: pd.DataFrame,
        seed_past_infections: bool,
        seed_strength: float = 1.0,
        age_profile: Optional[dict] = None,
        account_secondary_infections=False,
    ):
        """
        seed_strength:
            float that controls the strength of the seed
        age_profile:
            dictionary with weight on age groups. Example:
            age_profile = {'0-20': 0., '21-50':1, '51-100':0.}
            would only infect people aged between 21 and 50
        """
        if age_profile is None:
            age_profile = {"0-100": 1.0}
        multi_index = pd.MultiIndex.from_product(
            [daily_cases_per_region.index.values, age_profile.keys()],
            names=["date", "age"],
        )
        df = pd.DataFrame(
            index=multi_index, columns=daily_cases_per_region.columns, dtype=float
        )
        for region in daily_cases_per_region.columns:
            for age_key, age_value in age_profile.items():
                df.loc[(daily_cases_per_region.index, age_key), region] = (
                    age_value * daily_cases_per_region[region].values
                )
        return cls(
            world=world,
            infection_selector=infection_selector,
            daily_cases_per_capita_per_age_per_region=df,
            seed_past_infections=seed_past_infections,
            seed_strength=seed_strength,
            account_secondary_infections=account_secondary_infections,
        )

    @classmethod
    def from_uniform_cases(
        cls,
        world: "World",
        infection_selector: InfectionSelector,
        cases_per_capita: float,
        date: str,
        seed_past_infections,
        seed_strength=1.0,
        account_secondary_infections=False,
    ):
        date = pd.to_datetime(date)
        mi = pd.MultiIndex.from_product([[date], ["0-100"]], names=["date", "age"])
        df = pd.DataFrame(index=mi, columns=["all"])
        df[:] = cases_per_capita
        return cls(
            world=world,
            infection_selector=infection_selector,
            daily_cases_per_capita_per_age_per_region=df,
            seed_past_infections=seed_past_infections,
            seed_strength=seed_strength,
            account_secondary_infections=account_secondary_infections,
        )

    def infect_person(self, person, time, record):
        self.infection_selector.infect_person_at_time(person=person, time=time)
        if record:
            record.accumulate(
                table_name="infections",
                location_spec="infection_seed",
                region_name=person.super_area.region.name,
                location_id=0,
                infected_ids=[person.id],
                infector_ids=[person.id],
                infection_ids=[person.infection.infection_id()],
            )

    def infect_super_area(
        self, super_area, cases_per_capita_per_age, time, record=None
    ):
        people = super_area.people
        infection_id = self.infection_selector.infection_class.infection_id()
        n_people_by_age = defaultdict(int)
        susceptible_people_by_age = defaultdict(list)
        for person in people:
            n_people_by_age[person.age] += 1
            if person.immunity.get_susceptibility(infection_id) > 0:
                susceptible_people_by_age[person.age].append(person)
        for age, susceptible in susceptible_people_by_age.items():
            # Need to rescale to number of susceptible people in the simulation.
            rescaling = n_people_by_age[age] / len(susceptible_people_by_age[age])
            for person in susceptible:
                prob = cases_per_capita_per_age.loc[age] * rescaling
                if random() < prob:
                    self.infect_person(person=person, time=time, record=record)
                    self.current_seeded_cases[super_area.region.name] += 1
                    if time < 0:
                        self.bring_infection_up_to_date(
                            person=person, time_from_infection=-time, record=record
                        )

    def bring_infection_up_to_date(self, person, time_from_infection, record):
        # Update transmission probability
        person.infection.transmission.update_infection_probability(
            time_from_infection=time_from_infection
        )
        # Need to update trajectories to current stage
        symptoms = person.symptoms
        while time_from_infection > symptoms.trajectory[symptoms.stage + 1][0]:
            symptoms.stage += 1
            symptoms.tag = symptoms.trajectory[symptoms.stage][1]
            if symptoms.stage == len(symptoms.trajectory) - 1:
                break
        # Need to check if the person has already recovered or died
        if "dead" in symptoms.tag.name:
            Epidemiology.bury_the_dead(world=self.world, person=person, record=record)
        elif "recovered" == symptoms.tag.name:
            Epidemiology.recover(person=person, record=record)

    def infect_super_areas(
        self,
        cases_per_capita_per_age_per_region: pd.DataFrame,
        time: float,
        date: datetime.datetime,
        record: Optional[Record] = None,
    ):
        """
        Infect super areas with numer of cases given by data frame

        Parameters
        ----------
        n_cases_per_super_area:
            data frame containig the number of cases per super area
        time:
            Time where infections start (could be negative if they started before the simulation)
        """
        for region in self.world.regions:
            # Check if secondary infections already provide seeding.
            if "all" in cases_per_capita_per_age_per_region.columns:
                cases_per_capita_per_age = cases_per_capita_per_age_per_region["all"]
            else:
                cases_per_capita_per_age = cases_per_capita_per_age_per_region[
                    region.name
                ]
            if self._need_to_seed_accounting_secondary_infections(date=date):
                cases_per_capita_per_age = (
                    self._adjust_seed_accounting_secondary_infections(
                        cases_per_capita_per_age=cases_per_capita_per_age,
                        region=region,
                        date=date,
                        time=time,
                    )
                )
            for super_area in region.super_areas:
                self.infect_super_area(
                    super_area=super_area,
                    cases_per_capita_per_age=cases_per_capita_per_age,
                    time=time,
                    record=record,
                )

    def unleash_virus_per_day(
        self, date: datetime, time, record: Optional[Record] = None
    ):
        """
        Infect super areas at a given ```date```

        Parameters
        ----------
        date:
            current date
        time:
            time since start of the simulation
        record:
            Record object to record infections
        """
        if (not self.past_infections_seeded) and self.seed_past_infections:
            self._seed_past_infections(date=date, time=time, record=record)
            self.past_infections_seeded = True
        is_seeding_date = self.max_date >= date >= self.min_date
        date_str = date.date().strftime("%Y-%m-%d")
        not_yet_seeded_date = (
            date_str not in self.dates_seeded
            and date_str
            in self.daily_cases_per_capita_per_age_per_region.index.get_level_values(
                "date"
            )
        )
        if is_seeding_date and not_yet_seeded_date:
            seed_logger.info(
                f"Seeding {self.infection_selector.infection_class.__name__} infections at date {date.date()}"
            )
            cases_per_capita_per_age_per_region = (
                self.daily_cases_per_capita_per_age_per_region.loc[date]
            )
            self.infect_super_areas(
                cases_per_capita_per_age_per_region=cases_per_capita_per_age_per_region,
                time=time,
                record=record,
                date=date,
            )
            self.dates_seeded.add(date_str)
            self.last_seeded_cases = self.current_seeded_cases.copy()
            self.current_seeded_cases = defaultdict(int)

    def _seed_past_infections(self, date, time, record):
        past_dates = []
        for (
            past_date
        ) in self.daily_cases_per_capita_per_age_per_region.index.get_level_values(
            "date"
        ).unique():
            if past_date.date() < date.date():
                past_dates.append(past_date)
        for past_date in past_dates:
            seed_logger.info(f"Seeding past infections at {past_date.date()}")
            past_time = (past_date.date() - date.date()).days
            past_date_str = past_date.date().strftime("%Y-%m-%d")
            self.dates_seeded.add(past_date_str)
            self.infect_super_areas(
                cases_per_capita_per_age_per_region=self.daily_cases_per_capita_per_age_per_region.loc[
                    past_date
                ],
                time=past_time,
                record=record,
                date=past_date,
            )
            self.last_seeded_cases = self.current_seeded_cases.copy()
            self.current_seeded_cases = defaultdict(int)
            if record:
                # record past infections and deaths.
                record.time_step(timestamp=past_date)

    def _need_to_seed_accounting_secondary_infections(self, date):
        if self.account_secondary_infections:
            yesterday = date - datetime.timedelta(days=1)
            if yesterday not in self.daily_cases_per_capita_per_age_per_region.index:
                return False
            return True
        return False

    def _adjust_seed_accounting_secondary_infections(
        self, cases_per_capita_per_age, region, date, time
    ):
        people_by_age = defaultdict(int)
        for person in region.people:
            people_by_age[person.age] += 1
        yesterday_seeded_cases = self.last_seeded_cases[region.name]
        today_df = self.daily_cases_per_capita_per_age_per_region.loc[date]
        today_seeded_cases = sum(
            [
                today_df.loc[age, region.name] * people_by_age[age]
                for age in people_by_age
            ]
        )
        yesterday_total_cases = len(
            [
                p
                for p in region.people
                if p.infected
                and (time - p.infection.start_time)
                <= 1  # infection starting time less than one day ago
                and p.infection.__class__.__name__
                == self.infection_selector.infection_class.__name__
            ]
        )
        secondary_infs = yesterday_total_cases - yesterday_seeded_cases
        toseed = max(0, today_seeded_cases - secondary_infs)
        previous = sum(
            [
                cases_per_capita_per_age.loc[age] * people_by_age[age]
                for age in people_by_age
            ]
        )
        cases_per_capita_per_age = cases_per_capita_per_age * toseed / previous
        return cases_per_capita_per_age


class InfectionSeeds:
    """
    Groups infection seeds and applies them sequentially.
    """

    def __init__(self, infection_seeds: List[InfectionSeed]):
        self.infection_seeds = infection_seeds

    def unleash_virus_per_day(
        self, date: datetime, time, record: Optional[Record] = None
    ):
        for seed in self.infection_seeds:
            seed.unleash_virus_per_day(date=date, record=record, time=time)

    def __iter__(self):
        return iter(self.infection_seeds)

    def __getitem__(self, item):
        return self.infection_seeds[item]


import pandas as pd
import yaml
import numpy as np
from datetime import timedelta
from typing import List, Optional, Tuple
from collections import defaultdict, Counter
from scipy.ndimage import gaussian_filter1d

from june import paths
from june.demography import Person
from june.epidemiology.infection.symptom_tag import SymptomTag
from june.epidemiology.infection.trajectory_maker import TrajectoryMaker

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.epidemiology.infection.health_index.health_index import (
        HealthIndexGenerator,
    )
    from june.epidemiology.infection.trajectory_maker import Stage

default_trajectories_path = (
    paths.configs_path / "defaults/epidemiology/infection/symptoms/trajectories.yaml"
)
default_area_super_region_path = (
    paths.data_path / "input/geography/area_super_area_region.csv"
)
default_observed_deaths_path = (
    paths.data_path / "input/infection_seed/hospital_deaths_per_region_per_date.csv"
)
default_age_per_area_path = (
    paths.data_path / "input/demography/age_structure_single_year.csv"
)
default_female_fraction_per_area_path = (
    paths.data_path / "input/demography/female_ratios_per_age_bin.csv"
)


class Observed2Cases:
    def __init__(
        self,
        age_per_area_df: pd.DataFrame,
        female_fraction_per_area_df: pd.DataFrame,
        regional_infections_per_hundred_thousand=100,
        health_index_generator: "HealthIndexGenerator" = None,
        symptoms_trajectories: Optional["TrajectoryMaker"] = None,
        n_observed_deaths: Optional[pd.DataFrame] = None,
        area_super_region_df: Optional[pd.DataFrame] = None,
        smoothing=False,
    ):
        """
        Class to convert observed deaths over time into predicted number of latent cases
        over time, use for the seed of the infection.
        It reads population data, to compute average death rates for a particular region,
        timings from config file estimate the median time it takes for someone infected to
        die in hospital, and the health index to obtain the death rate as a function of
        age and sex.

        Parameters
        ----------
        age_per_area_df:
            data frame with the age distribution per area, to compute the weighted death rate
        female_fraction_per_area_df:
            data frame with the fraction of females per area as a function of age to compute
            the weighted death rate
        health_index_generator:
            generator of the health index to compute death_rate(age,sex)
        symptoms_trajectories:
            used to read the trajectory config file and compute the
            median time it takes to die in hospital
        n_observed_deaths:
            time series with the number of observed deaths per region
        area_super_region_df:
            df with area, super_area, region mapping
        smoothing:
            whether to smooth the observed deaths time series before computing
            the expected number of cases (therefore the estimates becomes less
            dependent on spikes in the data)
        """
        self.regional_infections_per_hundred_thousand = (
            regional_infections_per_hundred_thousand
        )
        self.area_super_region_df = area_super_region_df
        self.age_per_area_df = age_per_area_df
        (
            self.females_per_age_region_df,
            self.males_per_age_region_df,
        ) = self.aggregate_age_sex_dfs_by_region(
            age_per_area_df=age_per_area_df,
            female_fraction_per_area_df=female_fraction_per_area_df,
        )
        self.symptoms_trajectories = symptoms_trajectories
        self.health_index_generator = health_index_generator
        self.regions = self.area_super_region_df["region"].unique()
        # TODO: this are particularities of England that should not be here.
        if (
            n_observed_deaths is not None
            and "East Of England" in n_observed_deaths.columns
        ):
            n_observed_deaths = n_observed_deaths.rename(
                columns={"East Of England": "East of England"}
            )
        if smoothing:
            n_observed_deaths = self._smooth_time_series(n_observed_deaths)
        self.n_observed_deaths = n_observed_deaths

    @classmethod
    def from_file(
        cls,
        health_index_generator,
        regional_infections_per_hundred_thousand=100,
        age_per_area_path: str = default_age_per_area_path,
        female_fraction_per_area_path: str = default_female_fraction_per_area_path,
        trajectories_path: str = default_trajectories_path,
        observed_deaths_path: str = default_observed_deaths_path,
        area_super_region_path: str = default_area_super_region_path,
        smoothing=False,
    ) -> "Observed2Cases":
        """
        Creates class from paths to data

        Parameters
        ----------
        health_index_generator:
            generator of the health index to compute death_rate(age,sex)
        age_per_area_path:
            path to data with number of people of a given age by area
        female_fraction_per_area_df:
            path to data with fraction of people that are female by area and age bin
        trajectories_path:
            path to config file with possible symptoms trajectories and their timings
        observed_deaths_path:
            path to time series of observed deaths over time
        area_super_region_path:
            path to data on area, super_area, region mapping
        smoothing:
            whether to smooth the observed deaths time series before computing
            the expected number of cases (therefore the estimates becomes less
            dependent on spikes in the data)

        Returns
        -------
        Instance of Observed2Cases
        """
        age_per_area_df = pd.read_csv(age_per_area_path, index_col=0)
        female_fraction_per_area_df = pd.read_csv(
            female_fraction_per_area_path, index_col=0
        )
        with open(trajectories_path) as f:
            symptoms_trajectories = yaml.safe_load(f)["trajectories"]
        symptoms_trajectories = [
            TrajectoryMaker.from_dict(trajectory)
            for trajectory in symptoms_trajectories
        ]
        n_observed_deaths = pd.read_csv(observed_deaths_path, index_col=0)
        n_observed_deaths.index = pd.to_datetime(n_observed_deaths.index)
        area_super_region_df = pd.read_csv(area_super_region_path, index_col=0)
        # Combine regions as in deaths dataset
        # TODO: do this outside here for generality
        area_super_region_df = area_super_region_df.replace(
            {
                "region": {
                    "West Midlands": "Midlands",
                    "East Midlands": "Midlands",
                    "North East": "North East And Yorkshire",
                    "Yorkshire and The Humber": "North East And Yorkshire",
                }
            }
        )
        return cls(
            regional_infections_per_hundred_thousand=regional_infections_per_hundred_thousand,
            age_per_area_df=age_per_area_df,
            female_fraction_per_area_df=female_fraction_per_area_df,
            health_index_generator=health_index_generator,
            symptoms_trajectories=symptoms_trajectories,
            n_observed_deaths=n_observed_deaths,
            area_super_region_df=area_super_region_df,
            smoothing=smoothing,
        )

    def aggregate_areas_by_region(self, df_per_area: pd.DataFrame) -> pd.DataFrame:
        """
        Aggregates an area dataframe into a region dataframe

        Parameters
        ----------
        df_per_area:
            data frame indexed by area

        Returns
        -------
        """
        return (
            pd.merge(
                df_per_area,
                self.area_super_region_df.drop(columns="super_area"),
                left_index=True,
                right_index=True,
            )
            .groupby("region")
            .sum()
        )

    def aggregate_age_sex_dfs_by_region(
        self, age_per_area_df: pd.DataFrame, female_fraction_per_area_df: pd.DataFrame
    ) -> Tuple[pd.DataFrame, pd.DataFrame]:
        """
        Combines the age per area dataframe and female fraction per area to
        create two data frames with numbers of females by age per region, and
        numbers of males by age per region

        Parameters
        ----------
        age_per_area_df:
            data frame with the number of people with a certain age per area
        female_fraction_per_area_df:
            fraction of those that are females per area and age bin

        Returns
        -------
        females_per_age_region_df:
            number of females as a function of age per region
        males_per_age_region_df:
            number of males as a function of age per region
        """
        sex_bins = list(map(int, female_fraction_per_area_df.columns))
        females_per_age_area_df = age_per_area_df.apply(
            lambda x: x
            * female_fraction_per_area_df[
                female_fraction_per_area_df.columns[
                    np.digitize(int(x.name), bins=sex_bins) - 1
                ]
            ]
        ).astype("int")
        males_per_age_area_df = age_per_area_df - females_per_age_area_df
        females_per_age_region_df = self.aggregate_areas_by_region(
            females_per_age_area_df
        )
        males_per_age_region_df = self.aggregate_areas_by_region(males_per_age_area_df)
        return females_per_age_region_df, males_per_age_region_df

    def get_symptoms_rates_per_age_sex(
        self,
    ) -> dict:
        """
        Computes the rates of ending up with certain SymptomTag for all
        ages and sex.

        Returns
        -------
        dictionary with rates of symptoms (fate) as a function of age and sex
        """
        symptoms_rates_dict = {"m": defaultdict(int), "f": defaultdict(int)}
        for sex in ("m", "f"):
            for age in np.arange(100):
                symptoms_rates_dict[sex][age] = np.diff(
                    self.health_index_generator(
                        Person(sex=sex, age=age), infection_id=None
                    ),
                    prepend=0.0,
                    append=1.0,
                )  # need np.diff because health index is cummulative
        return symptoms_rates_dict

    def weight_rates_by_age_sex_per_region(
        self, symptoms_rates_dict: dict, symptoms_tags: List["SymptomTag"]
    ) -> List[float]:
        """
        Get the weighted average by age and sex of symptoms rates for symptoms in symptoms_tags.
        For example to get the weighted average death rate per region,
        select symptoms_tags = ('dead_hospital', 'dead_icu', 'dead_home')

        Parameters
        ----------
        symtpoms_rates_dict:
            dictionary with rates for all the possible final symptoms, indexed by sex and age.
        symptoms_tags:
            final symptoms to keep
        Returns
        -------
        List of weighted rates for symptoms in ```symptoms_tags``` (ordered in the same way!!)
        """
        idx_symptoms_to_keep = [getattr(SymptomTag, tag) for tag in symptoms_tags]
        weighted_rates = {}
        for region in self.regions:
            weighted_rates_region = 0
            for age in np.arange(100):
                weighted_rates_region += (
                    symptoms_rates_dict["f"][age][idx_symptoms_to_keep]
                    * self.females_per_age_region_df.loc[region][str(age)]
                )
                weighted_rates_region += (
                    symptoms_rates_dict["m"][age][idx_symptoms_to_keep]
                    * self.males_per_age_region_df.loc[region][str(age)]
                )
            n_people_region = (
                self.females_per_age_region_df.loc[region].sum()
                + self.males_per_age_region_df.loc[region].sum()
            )
            weighted_rates[region] = weighted_rates_region / n_people_region
        return weighted_rates

    def get_latent_cases_from_observed(self, n_observed: int, avg_rates: List) -> int:
        """
        Given a number of observed cases, such as observed deaths or observed hospital
        admissions, this function converts it into number of latent cases necessary to
        produce such an observation.

        Parameters
        ----------
        n_observed:
            observed number of cases (such as deaths or hospital admissions)
        avg_rates:
            average rates to produce the observed cases, such as average death rate or average
            admission rates. It is a list, since we might want to look at, for instance,
            death rate, which is a combination of deat_home, deat_hospital, dead_icu rates.
        Returns
        -------
        Number of latent cases
        """
        avg_rate = sum(avg_rates)
        return round(n_observed / avg_rate)

    def get_latent_cases_per_region(
        self,
        n_observed_df: pd.DataFrame,
        time_to_get_symptoms: int,
        avg_rates_per_region: dict,
    ) -> pd.DataFrame:
        """
        Converts observed cases per region into latent cases per region.

        Parameters
        ----------
        n_observed_df:
            time series of the observed cases
        time_to_get_symptoms:
            days it takes form infection to the symptoms of interest (such as time to death)
        avg_rates_per_region:
            average probability to get those symptoms per region

        Returns
        -------
        n_cases_per_region_df:
            number of latent cases per region time series
        """
        n_cases_per_region_df = n_observed_df.apply(
            lambda x: self.get_latent_cases_from_observed(
                x, avg_rates_per_region[x.name]
            )
        )
        n_cases_per_region_df.index = n_observed_df.index - timedelta(
            days=time_to_get_symptoms
        )
        return n_cases_per_region_df

    def get_super_area_population_weights(
        self,
    ) -> pd.DataFrame:
        """
        Compute the weight in population that a super area has over its whole region, used
        to convert regional cases to cases by super area by population density

        Returns
        -------
        data frame indexed by super area, with weights and region
        """
        people_per_super_area = (
            pd.merge(
                self.age_per_area_df.sum(axis=1).to_frame("n_people"),
                self.area_super_region_df.drop(columns="region"),
                left_index=True,
                right_index=True,
            )
            .groupby("super_area")
            .sum()
        )
        people_per_super_aera_and_region = pd.merge(
            people_per_super_area,
            self.area_super_region_df.drop_duplicates().set_index("super_area"),
            left_index=True,
            right_index=True,
            how="left",
        )
        people_per_region = people_per_super_aera_and_region.groupby("region").sum()[
            "n_people"
        ]
        people_per_super_aera_and_region[
            "weights"
        ] = people_per_super_aera_and_region.apply(
            lambda x: x.n_people / people_per_region.loc[x.region], axis=1
        )
        return people_per_super_aera_and_region[["weights", "region"]]

    def limit_cases_per_region(self, n_cases_per_region_df, starting_date="2020-02-24"):
        people_per_region = self.females_per_age_region_df.sum(
            axis=1
        ) + self.males_per_age_region_df.sum(axis=1)
        n_cases_per_region_df = n_cases_per_region_df.loc[starting_date:]
        cummulative_infections_hundred_thousand = (
            n_cases_per_region_df.cumsum() / people_per_region * 100_000
        )
        regional_series = []
        for region in n_cases_per_region_df.columns:
            regional_index = np.searchsorted(
                cummulative_infections_hundred_thousand[region].values,
                self.regional_infections_per_hundred_thousand,
            )
            regional_cases_to_seed = n_cases_per_region_df[region].iloc[
                : regional_index + 1
            ]
            target_cases = (
                self.regional_infections_per_hundred_thousand
                * people_per_region.loc[region]
                / 100_000
            )
            remaining_cases = np.round(
                max(0, target_cases - regional_cases_to_seed.iloc[:-1].sum())
            )
            regional_cases_to_seed.iloc[-1] = remaining_cases
            regional_series.append(regional_cases_to_seed)
        return pd.concat(regional_series, axis=1).fillna(0.0)

    def convert_regional_cases_to_super_area(
        self, n_cases_per_region_df: pd.DataFrame, starting_date: str
    ) -> pd.DataFrame:
        """
        Converts regional cases to cases by super area by weighting each super area
        within the region according to its population

        Parameters
        ----------
        n_cases_per_region_df:
            data frame with the number of cases by region, indexed by date
        dates:
            dates to select (it can be a dictinary with different dates for different regions

        Returns
        -------
        data frame with the number of cases by super area, indexed by date
        """
        n_cases_per_region_df = self.limit_cases_per_region(
            n_cases_per_region_df=n_cases_per_region_df, starting_date=starting_date
        )
        n_cases_per_super_area_df = pd.DataFrame(
            0,
            index=n_cases_per_region_df.index,
            columns=self.area_super_region_df["super_area"].unique(),
        )
        super_area_weights = self.get_super_area_population_weights()
        for region in n_cases_per_region_df.columns:
            super_area_weights_for_region = super_area_weights[
                super_area_weights["region"] == region
            ]
            for date, n_cases in n_cases_per_region_df[region].iteritems():
                chosen_super_areas = np.random.choice(
                    list(super_area_weights_for_region.index),
                    replace=True,
                    size=round(n_cases),
                    p=super_area_weights_for_region["weights"],
                )
                n_cases_super_area = Counter(chosen_super_areas)
                n_cases_per_super_area_df.loc[
                    date, list(n_cases_super_area.keys())
                ] = n_cases_super_area.values()
        return n_cases_per_super_area_df

    def _smooth_time_series(self, time_series_df: pd.DataFrame) -> pd.DataFrame:
        """
        Smooth a time series by applying a gaussian filter in 1d

        Parameters
        ----------
        time_series_df:
            df with time as index

        Returns
        -------
        smoothed time series df
        """
        return time_series_df.apply(lambda x: gaussian_filter1d(x, sigma=2))

    def filter_symptoms_trajectories(
        self,
        symptoms_trajectories: List["TrajectoryMaker"],
        symptoms_to_keep: Tuple[str] = ("dead_hospital", "dead_icu"),
    ) -> List["TrajectoryMaker"]:
        """
        Filter all symptoms trajectories to obtain only the ones that contain given symtpoms
        in ```symptoms_to_keep```

        Parameters
        ----------
        symptoms_trajectories:
            list of all symptoms trajectories
        symptoms_to_keep:
            tuple of strings containing the desired symptoms for which to find trajectories

        Returns
        -------
        trajectories containing ```symptoms_to_keep```
        """
        filtered_trajectories = []
        for trajectory in symptoms_trajectories:
            symptom_tags = [stage.symptoms_tag.name for stage in trajectory.stages]
            if set(symptom_tags).intersection(set(symptoms_to_keep)):
                filtered_trajectories.append(trajectory)
        return filtered_trajectories

    def get_median_completion_time(self, stage: "Stage") -> float:
        """
        Get median completion time of a stage, from its distribution

        Parameters
        ----------
            stage:
                given stage in trajectory
        Returns
        -------
        Median time spent in stage
        """
        if hasattr(stage.completion_time, "distribution"):
            return stage.completion_time.distribution.median()
        else:
            return stage.completion_time.value

    def get_time_it_takes_to_symptoms(
        self, symptoms_trajectories: List["TrajectoryMaker"], symptoms_tags: List[str]
    ):
        """
        Compute the median time it takes to get certain symptoms in ```symptoms_tags```, such as death or hospital
        admission.

        Parameters
        ----------
        symptoms_trajectories:
            list of symptoms trajectories
        symptoms_tags:
            symptoms tags for the symptoms of interest
        """
        time_to_symptoms = []
        for trajectory in symptoms_trajectories:
            time = 0
            for stage in trajectory.stages:
                if stage.symptoms_tag.name in symptoms_tags:
                    break
                time += self.get_median_completion_time(stage)
            time_to_symptoms.append(time)
        return time_to_symptoms

    def get_weighted_time_to_symptoms(
        self,
        symptoms_trajectories: List["TrajectoryMaker"],
        avg_rate_for_symptoms: List["float"],
        symptoms_tags: List[str],
    ) -> float:
        """
        Get the time it takes to get certain symptoms weighted by population. For instance,
        when computing the death rate, more people die in hospital than in icu,
        therefore the median time to die in hospital gets weighted more than the median time
        to die in icu.

        Parameters
        ----------
        symptoms_trajectories:
            trajectories for symptoms that include the symptoms of interest
        avg_rate_for_symptoms:
            list containing the average rate for certain symptoms given in ```symptoms tags```.
            WARNING: should be in the same order
        symptoms_tags:
            tags of the symptoms for which we want to know the median time

        Returns
        -------
        Weighted median time to symptoms

        """
        times_to_symptoms = self.get_time_it_takes_to_symptoms(
            symptoms_trajectories, symptoms_tags=symptoms_tags
        )
        return sum(avg_rate_for_symptoms * times_to_symptoms) / sum(
            avg_rate_for_symptoms
        )

    def get_regional_latent_cases(
        self,
    ) -> pd.DataFrame:
        """
        Find regional latent cases from the observed one.

        Returns
        -------
        data frame with latent cases per region indexed by date
        """
        symptoms_tags = ("dead_hospital", "dead_icu")
        symtpoms_rates = self.get_symptoms_rates_per_age_sex()
        avg_hospital_death_rate = self.weight_rates_by_age_sex_per_region(
            symtpoms_rates, symptoms_tags=symptoms_tags
        )
        avg_death_rate_over_regions = np.mean(
            list(avg_hospital_death_rate.values()), axis=0
        )
        hospital_death_trajectories = self.filter_symptoms_trajectories(
            self.symptoms_trajectories, symptoms_to_keep=symptoms_tags
        )

        median_time_to_death = round(
            self.get_weighted_time_to_symptoms(
                hospital_death_trajectories,
                avg_death_rate_over_regions,
                symptoms_tags=symptoms_tags,
            )
        )
        return self.get_latent_cases_per_region(
            self.n_observed_deaths, median_time_to_death, avg_hospital_death_rate
        )


from .observed_to_cases import Observed2Cases
from .infection_seed import InfectionSeed, InfectionSeeds
from .clustered_infection_seed import ClusteredInfectionSeed
from .cases_distributor import CasesDistributor
from .exact_num_infection_seed import (
    ExactNumInfectionSeed,
    ExactNumClusteredInfectionSeed,
)


import operator
from typing import List, Optional, Tuple, Set, TYPE_CHECKING
from random import random
import numpy as np
import datetime
import yaml
import logging
from pathlib import Path

from june import paths
from june.utils import read_date
from .vaccines import Vaccine, Vaccines

logger = logging.getLogger("vaccination")


default_config_filename = (
    paths.configs_path / "defaults/epidemiology/vaccines/vaccination_campaigns.yaml"
)
default_vaccines_config_filename = (
    paths.configs_path / "defaults/epidemiology/vaccines/vaccines.yaml"
)


# TODO:
# iii) Vaccinate individually given age, region, n doses, and vaccine type (could be made of combinations)


if TYPE_CHECKING:
    from june.demography import Person
    from june.records import Record


class VaccinationCampaign:
    """
    Defines a campaign to vaccinate a group of people in
    a given time span and with a given vaccine
    """

    def __init__(
        self,
        vaccine: Vaccine,
        days_to_next_dose: List[int],
        dose_numbers: List[int] = [0, 1],
        start_time: str = "2100-01-01",
        end_time: str = "2100-01-02",
        group_by: str = "age",
        group_type: str = "50-100",
        group_coverage: float = 1.0,
        last_dose_type: Optional[str] = None,
    ):
        """__init__.

        Parameters
        ----------
        vaccine : Vaccine
            vaccine to give out
        days_to_next_dose : List[int]
            days to wait from the moment a person is vaccinated to
            their next dose. Should have same length as dose_numbers
        dose_numbers : List[int]
            what doses to give out.
            Example: dose_numbers = [0,1] would give out first
            and second dose, whereas dose_numbers = [2] would
            only give a third dose
        start_time : str
            date at which to start vaccinating people
        end_time : str
            date at which to stop vaccinating people
        group_by : str
            defines what group to vaccinate.
            Examples: 'age', 'sex', 'residence', 'primary_activity'
        group_type : str
            from the group defined by group_by, what people to vaccinate.
            Examples:
            if group_by = 'age' -> group_type = '20-40' would vaccinate
            people aged between 20 and 40
            if group_by = 'residence' -> group_type = 'carehome' would vaccinate
            people living in care homes.
        group_coverage : float
            percentage of the eligible group to vaccinate. Must be between 0. and 1.
        last_dose_type : Optional[str]
            if not starting with a first dose (dose_numbers[0] = 0), whether to
            vaccinate only people whose previous vaccines where of a certain type.
        """
        self.start_time = read_date(start_time)
        self.end_time = read_date(end_time)
        self.vaccine = vaccine
        self.days_to_next_dose = days_to_next_dose
        self.group_attribute, self.group_value = self.process_group_description(
            group_by, group_type
        )
        self.total_days = (self.end_time - self.start_time).days
        self.group_coverage = group_coverage
        if last_dose_type is None:
            self.last_dose_type = []
        else:
            self.last_dose_type = last_dose_type
        self.dose_numbers = dose_numbers
        self.vaccinated_ids = set()
        self.starting_dose = self.dose_numbers[0]
        self.days_from_administered_to_finished = (
            sum(self.days_to_next_dose)
            + sum(
                [
                    self.vaccine.days_administered_to_effective[dose]
                    for dose in self.dose_numbers
                ]
            )
            + sum(
                [
                    self.vaccine.days_effective_to_waning[dose]
                    for dose in self.dose_numbers
                ]
            )
            + sum([self.vaccine.days_waning[dose] for dose in self.dose_numbers])
        )

    def is_active(self, date: datetime.datetime) -> bool:
        """
        Returns true if the policy is active, false otherwise
        Parameters
        ----------
        date:
            date to check
        """
        return self.start_time <= date < self.end_time

    def process_group_description(self, group_by: str, group_type: str) -> Tuple[str]:
        """process_group_description.

        Parameters
        ----------
        group_by : str
            group_by
        group_type : str
            group_type

        Returns
        -------
        Tuple[str]

        """
        if group_by in ("residence", "primary_activity"):
            return f"{group_by}.group.spec", group_type
        else:
            return f"{group_by}", group_type

    def is_target_group(self, person: "Person") -> bool:
        """is_target_group.

        Parameters
        ----------
        person : "Person"
            person

        Returns
        -------
        bool

        """
        if self.group_attribute != "age":
            try:
                if (
                    operator.attrgetter(self.group_attribute)(person)
                    == self.group_value
                ):
                    return True
            except Exception:
                return False
        else:
            if (
                int(self.group_value.split("-")[0])
                <= getattr(person, self.group_attribute)
                < int(self.group_value.split("-")[1])
            ):
                return True
        return False

    def has_right_dosage(self, person: "Person") -> bool:
        """has_right_dosage.

        Parameters
        ----------
        person : "Person"
            person

        Returns
        -------
        bool

        """
        if person.vaccinated is not None and self.starting_dose == 0:
            return False
        if self.starting_dose > 0:
            if person.vaccinated is None or person.vaccinated != self.starting_dose - 1:
                return False
            if self.last_dose_type and person.vaccine_type not in self.last_dose_type:
                return False
        return True

    def should_be_vaccinated(self, person: "Person") -> bool:
        """should_be_vaccinated.

        Parameters
        ----------
        person : "Person"
            person

        Returns
        -------
        bool

        """
        return self.has_right_dosage(person) and self.is_target_group(person)

    def vaccinate(
        self,
        person: "Person",
        date: datetime.datetime,
        record: Optional["Record"] = None,
    ):
        """vaccinate.

        Parameters
        ----------
        person : "Person"
            person
        date : datetime.datetime
            date
        record : Optional["Record"]
            record
        """
        vaccine_trajectory = self.vaccine.generate_trajectory(
            person=person,
            dose_numbers=self.dose_numbers,
            days_to_next_dose=self.days_to_next_dose,
            date=date,
        )
        vaccine_trajectory.update_dosage(person=person, record=record)
        person.vaccine_trajectory = vaccine_trajectory
        self.vaccinated_ids.add(person.id)

    def daily_vaccination_probability(self, days_passed: int) -> float:
        """daily_vaccination_probability.

        Parameters
        ----------
        days_passed : int
            days_passed

        Returns
        -------
        float

        """
        return self.group_coverage * (
            1 / (self.total_days - days_passed * self.group_coverage)
        )


class VaccinationCampaigns:
    """VaccinationCampaigns."""

    def __init__(self, vaccination_campaigns: List[VaccinationCampaign]):
        """__init__.

        Parameters
        ----------
        vaccination_campaigns : List[VaccinationCampaign]
            vaccination_campaigns
        """
        self.vaccination_campaigns = vaccination_campaigns

    @classmethod
    def from_config(
        cls,
        config_file: Path = default_config_filename,
        vaccines_config_file: Path = default_vaccines_config_filename,
    ):
        """from_config.

        Parameters
        ----------
        config_file : Path
            config_file
        vaccines_config_file : Path
            vaccines_config_file
        """
        vaccines = Vaccines.from_config(vaccines_config_file)
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        vaccination_campaigns = []
        for key, params_dict in config.items():
            params_dict["vaccine"] = vaccines.get_by_name(params_dict["vaccine_type"])
            vaccination_campaigns.append(
                VaccinationCampaign(
                    **{k: v for k, v in params_dict.items() if k != "vaccine_type"}
                )
            )
        return cls(vaccination_campaigns=vaccination_campaigns)

    def __iter__(
        self,
    ):
        """__iter__."""
        return iter(self.vaccination_campaigns)

    def get_active(self, date: datetime) -> List[VaccinationCampaign]:
        """get_active.

        Parameters
        ----------
        date : datetime
            date

        Returns
        -------
        List[VaccinationCampaign]

        """
        return [vc for vc in self.vaccination_campaigns if vc.is_active(date)]

    def apply(
        self, person: "Person", date: datetime, record: Optional["Record"] = None
    ):
        """apply.

        Parameters
        ----------
        person : "Person"
            person
        date : datetime
            date
        record :
            record
        """
        active_campaigns = self.get_active(date=date)
        daily_probability, campaigns_to_chose_from = [], []
        for vc in active_campaigns:
            if vc.should_be_vaccinated(person=person):
                days_passed = (date - vc.start_time).days
                daily_probability.append(
                    vc.daily_vaccination_probability(days_passed=days_passed)
                )
                campaigns_to_chose_from.append(vc)
        daily_probability = np.array(daily_probability)
        norm = daily_probability.sum()
        if norm > 0.0:
            if random() < norm:
                daily_probability /= norm
                campaign = np.random.choice(
                    campaigns_to_chose_from, p=daily_probability
                )
                campaign.vaccinate(person=person, date=date, record=record)

    def collect_all_dates_in_past(
        self, current_date: datetime.datetime
    ) -> Set[datetime.datetime]:
        dates = set()
        for cv in self.vaccination_campaigns:
            start_time = cv.start_time
            if start_time < current_date:
                days_to_finished = cv.days_from_administered_to_finished
                end_time = min(
                    current_date,
                    cv.end_time + datetime.timedelta(days=days_to_finished),
                )
                delta = end_time - start_time
                for i in range(delta.days + 1):
                    date = start_time + datetime.timedelta(days=i)
                    dates.add(date)
        return sorted(list(dates))

    def apply_past_campaigns(
        self, people, date: datetime.datetime, record: Optional["Record"] = None
    ):
        dates_to_vaccinate = self.collect_all_dates_in_past(current_date=date)
        for date_to_vax in dates_to_vaccinate:
            logger.info(f"Vaccinating at date {date_to_vax.date()}")
            for person in people:
                self.apply(person=person, date=date_to_vax, record=record)
                if person.vaccine_trajectory is not None:
                    person.vaccine_trajectory.update_vaccine_effect(
                        person=person, date=date_to_vax, record=record
                    )
            if record is not None:
                record.time_step(timestamp=date_to_vax)


import yaml
import operator
from pathlib import Path
import datetime
from typing import List, Dict, Optional, TYPE_CHECKING
from dataclasses import dataclass
import numpy as np

from june import paths
from june.epidemiology.infection import infection as infection_module
from june.utils.parse_probabilities import parse_age_probabilities

default_config_filename = (
    paths.configs_path / "defaults/epidemiology/vaccines/vaccines.yaml"
)

# TODO: apply doses from file
# Start with (date, pseudo_id, region, age, dose, vaccine_type)
# (person id, dose number, vaccine type) -> append to trajectory.doses

if TYPE_CHECKING:
    from june.demography import Person


@dataclass
class Efficacy:
    """Efficacy types"""

    infection: Dict[int, float]
    symptoms: Dict[int, float]
    waning_factor: float

    def __call__(self, protection_type: str, infection_id: int):
        """__call__.

        Parameters
        ----------
        protection_type : str
            protection_type
        infection_id : int
            infection_id
        """
        return getattr(self, f"{protection_type}").get(infection_id)

    def __mul__(self, factor: float):
        """__mul__.

        Parameters
        ----------
        factor : float
            factor
        """
        return Efficacy(
            infection={k: v * factor for k, v in self.infection.items()},
            symptoms={k: v * factor for k, v in self.symptoms.items()},
            waning_factor=1.0,
        )


class Dose:
    """Dose."""

    def __init__(
        self,
        number: int,
        date_administered: datetime.datetime,
        days_administered_to_effective: int,
        days_effective_to_waning: int,
        days_waning: int,
        prior_efficacy: Efficacy,
        efficacy: Efficacy,
    ):
        """__init__.

        Parameters
        ----------
        number : int
            number
        date_administered : datetime.datetime
            date_administered
        days_administered_to_effective : int
            days_administered_to_effective
        days_effective_to_waning : int
            days_effective_to_waning
        days_waning : int
            days_waning
        prior_efficacy : Efficacy
            prior_efficacy
        efficacy : Efficacy
            efficacy
        """
        self.number = number
        self.days_administered_to_effective = days_administered_to_effective
        self.days_effective_to_waning = days_effective_to_waning
        self.days_waning = days_waning
        self.efficacy = efficacy
        self.prior_efficacy = prior_efficacy
        self.date_administered = date_administered
        self.date_effective = self.date_administered + datetime.timedelta(
            days=self.days_administered_to_effective
        )
        self.date_waning = self.date_administered + datetime.timedelta(
            days=self.days_effective_to_waning + self.days_administered_to_effective
        )
        self.date_finished = self.date_administered + datetime.timedelta(
            days=self.days_waning
            + self.days_effective_to_waning
            + self.days_administered_to_effective
        )

    def get_efficacy(
        self, date: datetime.datetime, infection_id: int, protection_type: str
    ):
        """get_efficacy.

        Parameters
        ----------
        date : datetime.datetime
            date
        infection_id : int
            infection_id
        protection_type : str
            protection_type
        """
        efficacy = self.efficacy(
            protection_type=protection_type, infection_id=infection_id
        )
        if date > self.date_finished:
            return self.efficacy.waning_factor * self.efficacy(
                protection_type=protection_type, infection_id=infection_id
            )

        elif date > self.date_waning:
            prior_efficacy = efficacy
            final_efficacy = self.efficacy.waning_factor * self.efficacy(
                protection_type=protection_type, infection_id=infection_id
            )
            prior_date = self.date_waning
            duration = self.days_waning
        elif date > self.date_effective:
            return efficacy
        elif date >= self.date_administered:
            prior_efficacy = self.prior_efficacy(
                protection_type=protection_type, infection_id=infection_id
            )
            final_efficacy = efficacy
            prior_date = self.date_administered
            duration = self.days_administered_to_effective
        n_days = (date - prior_date).days
        m = (final_efficacy - prior_efficacy) / duration
        n = prior_efficacy
        return m * n_days + n


class VaccineTrajectory:
    """VaccineTrajectory."""

    def __init__(self, doses: List[Dose], name: str, infection_ids: List[int]):
        """__init__.

        Parameters
        ----------
        doses : List[Dose]
            doses
        name : str
            name
        infection_ids : List[int]
            infection_ids
        """
        self.doses = sorted(doses, key=operator.attrgetter("date_administered"))
        self.name = name
        self.infection_ids = infection_ids
        self.first_dose_date = self.doses[0].date_administered
        self.dates_administered = [
            (dose.date_administered - self.first_dose_date).days for dose in self.doses
        ]
        (
            self.prior_susceptibility,
            self.prior_effective_multiplier,
        ) = self._get_immunity_prior_to_trajectory()
        self.stage = 0

    @property
    def current_dose(
        self,
    ):
        """current_dose."""
        return self.doses[self.stage].number

    def _get_immunity_prior_to_trajectory(
        self,
    ):
        """_get_immunity_prior_to_trajectory."""
        prior_efficacy = self.doses[0].prior_efficacy
        suscepbitility = {
            inf_id: 1 - value for inf_id, value in prior_efficacy.infection.items()
        }
        effective_multiplier = {
            inf_id: 1 - value for inf_id, value in prior_efficacy.symptoms.items()
        }
        return suscepbitility, effective_multiplier

    def get_dose_index(self, date: datetime.datetime):
        """get_dose_index.

        Parameters
        ----------
        date : datetime.datetime
            date
        """
        days_from_start = (date - self.first_dose_date).days
        return min(
            np.searchsorted(self.dates_administered, days_from_start, side="right") - 1,
            len(self.doses) - 1,
        )

    def get_dose_number(self, date: datetime.datetime):
        """get_dose_number.

        Parameters
        ----------
        date : datetime.datetime
            date
        """
        return self.doses[self.get_dose_index(date=date)].number

    def update_trajectory_stage(self, date: datetime.datetime):
        """update_trajectory_stage.

        Parameters
        ----------
        date : datetime.datetime
            date
        """
        if (
            self.stage < len(self.doses) - 1
            and date >= self.doses[self.stage + 1].date_administered
        ):
            self.stage += 1
            self.dose_number = self.doses[self.stage].number

    def get_efficacy(
        self, date: datetime.datetime, infection_id: int, protection_type: str
    ):
        """get_efficacy.

        Parameters
        ----------
        date : datetime.datetime
            date
        infection_id : int
            infection_id
        protection_type : str
            protection_type
        """
        return self.doses[self.stage].get_efficacy(
            date=date, infection_id=infection_id, protection_type=protection_type
        )

    def susceptibility(self, date: datetime.datetime, infection_id: int):
        """susceptibility.

        Parameters
        ----------
        date : datetime.datetime
            date
        infection_id : int
            infection_id
        """
        return 1.0 - self.get_efficacy(
            date=date, protection_type="infection", infection_id=infection_id
        )

    def effective_multiplier(self, date, infection_id: int):
        """effective_multiplier.

        Parameters
        ----------
        date :
            date
        infection_id : int
            infection_id
        """
        return 1.0 - self.get_efficacy(
            date=date, protection_type="symptoms", infection_id=infection_id
        )

    def is_finished(self, date):
        """is_finished.

        Parameters
        ----------
        date :
            date
        """
        if date > self.doses[-1].date_finished:
            return True
        return False

    def update_dosage(self, person, record=None):
        """update_dosage.

        Parameters
        ----------
        person :
            person
        record :
            record
        """
        dose_number = self.current_dose
        person.vaccinated = dose_number
        person.vaccine_type = self.name
        if record is not None:
            record.events["vaccines"].accumulate(person.id, self.name, dose_number)

    def update_vaccine_effect(
        self, person: "Person", date: datetime.datetime, record=None
    ):
        """update_vaccine_effect.

        Parameters
        ----------
        person : "Person"
            person
        date : datetime.datetime
            date
        record :
            record
        """
        if self.is_finished(date=date):
            person.vaccine_trajectory = None
            return
        immunity = person.immunity
        dose_number = self.current_dose
        # update person.vaccinated here and use record
        self.update_trajectory_stage(date=date)
        for infection_id in self.infection_ids:
            updated_susceptibility = self.susceptibility(
                date=date, infection_id=infection_id
            )
            updated_effective_multiplier = self.effective_multiplier(
                date=date, infection_id=infection_id
            )
            immunity.susceptibility_dict[infection_id] = min(
                self.prior_susceptibility.get(infection_id, 1.0), updated_susceptibility
            )
            immunity.effective_multiplier_dict[infection_id] = min(
                self.prior_effective_multiplier.get(infection_id, 1.0),
                updated_effective_multiplier,
            )
        if self.current_dose != dose_number:
            self.update_dosage(person=person, record=record)


class Vaccine:
    """Vaccine."""

    def __init__(
        self,
        name: str,
        days_administered_to_effective: List[int],
        days_effective_to_waning: List[int],
        days_waning: List[int],
        sterilisation_efficacies,
        symptomatic_efficacies,
        waning_factor: Optional[float] = 1.0,
    ):
        """
        Class defining a vaccine type and its effectiveness

        Parameters
        ----------
        name:
           vaccine name
        days_to_effective:
            number of days it takes for current dose to be fully effective
        sterilisation_efficacy
            final full efficacy against infection, by variant and age
        symptomatic_efficacy
            final full efficacy against symptoms, by variant and age
        """

        self.name = name
        self.days_administered_to_effective = days_administered_to_effective
        self.days_effective_to_waning = days_effective_to_waning
        self.days_waning = days_waning
        self.sterilisation_efficacies = self._parse_efficacies(sterilisation_efficacies)
        self.symptomatic_efficacies = self._parse_efficacies(symptomatic_efficacies)
        self.infection_ids = self._read_infection_ids(self.sterilisation_efficacies)
        self.waning_factor = waning_factor

    @classmethod
    def from_config_dict(cls, name: str, config: Dict):
        """from_config_dict.

        Parameters
        ----------
        name : str
            name
        config : Dict
            config
        """
        return cls(
            name=name,
            days_administered_to_effective=config["days_administered_to_effective"],
            days_effective_to_waning=config["days_effective_to_waning"],
            days_waning=config["days_waning"],
            sterilisation_efficacies=config["sterilisation_efficacies"],
            symptomatic_efficacies=config["symptomatic_efficacies"],
            waning_factor=config["waning_factor"],
        )

    @classmethod
    def from_config(
        cls, vaccine_type: str, config_file: Path = default_config_filename
    ):
        """from_config.

        Parameters
        ----------
        vaccine_type : str
            vaccine_type
        config_file : Path
            config_file
        """
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        config = config[vaccine_type]
        return cls.from_config_dict(name=vaccine_type, config=config)

    def _read_infection_ids(self, sterilisation_efficacies):
        """_read_infection_ids.

        Parameters
        ----------
        sterilisation_efficacies :
            sterilisation_efficacies
        """
        ids = set()
        for dd in sterilisation_efficacies:
            for key in dd:
                ids.add(key)
        return list(ids)

    def _parse_efficacies(self, efficacies):
        """_parse_efficacies.

        Parameters
        ----------
        efficacies :
            efficacies
        """
        ret = []
        for dd in efficacies:
            dd_id = {}
            for key in dd:
                infection_id = getattr(infection_module, key).infection_id()
                dd_id[infection_id] = parse_age_probabilities(dd[key])
            ret.append(dd_id)
        return ret

    def collect_prior_efficacy(self, person):
        """collect_prior_efficacy.

        Parameters
        ----------
        person :
            person
        """
        immunity = person.immunity
        return Efficacy(
            infection={
                inf_id: 1.0 - immunity.susceptibility_dict.get(inf_id, 1.0)
                for inf_id in self.infection_ids
            },
            symptoms={
                inf_id: 1.0 - immunity.effective_multiplier_dict.get(inf_id, 1.0)
                for inf_id in self.infection_ids
            },
            waning_factor=1.0,
        )

    def generate_trajectory(
        self,
        person: "Person",
        dose_numbers: List[int],
        days_to_next_dose: List[int],
        date: datetime.datetime,
    ) -> VaccineTrajectory:
        """generate_trajectory.

        Parameters
        ----------
        person : "Person"
            person
        dose_numbers : List[int]
            dose_numbers
        days_to_next_dose : List[int]
            days_to_next_dose
        date : datetime.datetime
            date

        Returns
        -------
        VaccineTrajectory

        """
        prior_efficacy = self.collect_prior_efficacy(person=person)
        doses = []
        for i, dose in enumerate(dose_numbers):
            date += datetime.timedelta(days=days_to_next_dose[i])
            efficacy = Efficacy(
                infection={
                    inf_id: self.sterilisation_efficacies[dose][inf_id][person.age]
                    for inf_id in self.infection_ids
                },
                symptoms={
                    inf_id: self.symptomatic_efficacies[dose][inf_id][person.age]
                    for inf_id in self.infection_ids
                },
                waning_factor=self.waning_factor,
            )
            doses.append(
                Dose(
                    number=dose,
                    date_administered=date,
                    days_administered_to_effective=self.days_administered_to_effective[
                        dose
                    ],
                    days_effective_to_waning=self.days_effective_to_waning[dose],
                    days_waning=self.days_waning[dose],
                    prior_efficacy=prior_efficacy,
                    efficacy=efficacy,
                )
            )
            prior_efficacy = efficacy * efficacy.waning_factor
        return VaccineTrajectory(
            doses=doses, name=self.name, infection_ids=self.infection_ids
        )


class Vaccines:
    """Vaccines."""

    def __init__(self, vaccines: List[Vaccine]):
        """__init__.

        Parameters
        ----------
        vaccines : List[Vaccine]
            vaccines
        """
        self.vaccines = vaccines
        self.vaccines_dict = {vaccine.name: vaccine for vaccine in vaccines}

    def get_by_name(self, vaccine_name: str):
        """get_by_name.

        Parameters
        ----------
        vaccine_name : str
            vaccine_name
        """
        if vaccine_name not in self.vaccines_dict:
            raise ValueError(f"{vaccine_name} does not exist")
        return self.vaccines_dict[vaccine_name]

    @classmethod
    def from_config_dict(cls, config: Dict):
        """from_config_dict.

        Parameters
        ----------
        config : Dict
            config
        """
        vaccines = []
        for key, values in config.items():
            vaccines.append(Vaccine(name=key, **values))
        return cls(vaccines=vaccines)

    @classmethod
    def from_config(cls, config_file: Path = default_config_filename):
        """from_config.

        Parameters
        ----------
        config_file : Path
            config_file
        """
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        return cls.from_config_dict(config=config)

    def __iter__(
        self,
    ):
        """__iter__."""
        return iter(self.vaccines)

    def get_max_effective_date(
        self,
    ):
        """get_max_effective_date."""
        return max([sum(vaccine.days_to_effective) for vaccine in self.vaccines])


from .vaccines import Vaccine, Vaccines
from .vaccination_campaign import VaccinationCampaign, VaccinationCampaigns


from typing import Dict, Union
from random import random, shuffle
import logging
import datetime

from .event import Event
from june.utils import parse_age_probabilities

logger = logging.getLogger("domestic_care")


class DomesticCare(Event):
    """
    This event models people taking care of their elderly who live
    alone or in couples. The logic is that at the beginning of each
    leisure time-step, people who have caring responsibilites go
    to their relatives household for the duration of their time-step.

    Parameters
    ----------
    start_time
        time from when the event is active (default is always)
    end_time
        time when the event ends (default is always)
    needs_care_probabilities
        dictionary mapping the probability of needing care per age.
        Example:
        needs_care_probabilities = {"0-65" : 0.0, "65-100" : 0.5}
    relative_frequency
        relative factor to scale the overall probabilities in needs_care_probabilities
        useful for when we want to change the caring frequency with lockdowns, etc.
    """

    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
        needs_care_probabilities: Dict[str, float],
        daily_going_probability=1.0,
    ):
        super().__init__(start_time=start_time, end_time=end_time)
        self.needs_care_probabilities = parse_age_probabilities(
            needs_care_probabilities
        )
        self.daily_going_probability = daily_going_probability

    def initialise(self, world):
        self._link_carers_to_households(world=world)

    def apply(self, world, activities, day_type, simulator=None):
        """
        When a household is reponsible for caring of another housheold,
        a random person is sent during leisure to take care of that household.
        We checked that the person is not at hospital when we send them.
        """
        if (
            "leisure" not in activities
            or day_type == "weekend"
            or "primary_activity" in activities
        ):
            return
        for household in world.households:
            if household.household_to_care is not None:
                household_to_care = household.household_to_care
                carers = list(household.residents)
                shuffle(carers)
                receives_care = False
                for person in carers:
                    if person.age > 18 and person.available:
                        household_to_care.add(person, activity="leisure")
                        receives_care = True
                        break
                if receives_care:
                    household_to_care.receiving_care = True
                    # make residents stay at home
                    for person in household_to_care.residents:
                        if person.available:
                            person.residence.append(person)

    def _link_carers_to_households(self, world):
        """
        Links old people households to other households that provide them with care aid.
        All linking is restricted to the super area level.
        """
        total_need_care = 0
        for super_area in world.super_areas:
            # get households that need care
            need_care = []
            can_provide_care = []
            for area in super_area.areas:
                for household in area.households:
                    if self._check_household_needs_care(household):
                        need_care.append(household)
                    if self._check_household_can_provide_care(household):
                        can_provide_care.append(household)
            shuffle(need_care)
            shuffle(can_provide_care)
            if len(need_care) > len(can_provide_care):
                logger.warning(
                    f"super area {super_area.id} does not" f"have enough carers"
                )
            for needer, provider in zip(need_care, can_provide_care):
                total_need_care += 1
                provider.household_to_care = needer

    def _check_household_needs_care(self, household):
        """
        Check if a household needs care. We take the oldest
        person in the household to be representative of the risk
        for needing care.
        """
        if household.type == "old":
            for person in household.residents:
                care_probability = self.needs_care_probabilities[person.age]
                if random() < care_probability:
                    return True
        return False

    def _check_household_can_provide_care(self, household):
        """
        We limit care providers to non-student households.
        """
        if household.type in ["student", "old"]:
            return False
        return True


from abc import ABC
import yaml
import datetime
import logging
from typing import Union, List

from june.utils import read_date, str_to_class
from june.paths import configs_path
from june.mpi_setup import mpi_rank

default_config_filename = configs_path / "defaults/event/events.yaml"
logger = logging.getLogger("events")
if mpi_rank > 0:
    logger.propagate = False


class Event(ABC):
    """
    This class represents an event. An event is a sequence of actions to the world,
    that can happen at the beginning of each time step during a defined period of time.
    """

    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
    ):
        self.start_time = read_date(start_time)
        self.end_time = read_date(end_time)

    def is_active(self, date: datetime.datetime):
        return self.start_time <= date < self.end_time

    def initialise(self, world):
        raise NotImplementedError

    def apply(self, world, simulator, activities, day_type):
        raise NotImplementedError


class Events:
    def __init__(self, events=None):
        self.events = events

    @classmethod
    def from_file(
        cls, config_file=default_config_filename, base_event_modules=("june.event",)
    ):
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader) or {}
        events = []
        for event, event_data in config.items():
            camel_case_key = "".join(x.capitalize() or "_" for x in event.split("_"))
            if "start_time" not in event_data:
                for event_i, event_data_i in event_data.items():
                    if (
                        "start_time" not in event_data_i.keys()
                        or "end_time" not in event_data_i.keys()
                    ):
                        raise ValueError("event config file not valid.")
                    events.append(
                        str_to_class(camel_case_key, base_event_modules)(**event_data_i)
                    )
            else:
                events.append(
                    str_to_class(camel_case_key, base_event_modules)(**event_data)
                )
        return cls(events)

    def init_events(self, world):
        logger.info("Initialising events...")
        for event in self.events:
            event.initialise(world=world)
            logger.info(f"Event {event.__class__.__name__} initialised")

    def apply(self, date, world, simulator, activities: List[str], day_type: bool):
        for event in self.events:
            if event.is_active(date=date):
                event.apply(
                    world=world,
                    simulator=simulator,
                    activities=activities,
                    day_type=day_type,
                )


from typing import Union, Dict
import datetime
from random import sample, choices

from .event import Event


class IncidenceSetter(Event):
    """
    This Event is used to set a specific incidence per region at some point in the code.
    It can be used to correct, based on data, the current epidemiological state of the code.
    The added infection types are sampled from the currrent ones.
    """

    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
        incidence_per_region: Dict[str, float],
    ):
        super().__init__(start_time=start_time, end_time=end_time)
        self.incidence_per_region = incidence_per_region

    def initialise(self, world):
        pass

    def apply(self, world, simulator, activities=None, day_type=None):
        selectors = simulator.epidemiology.infection_selectors
        for region in world.regions:
            if region.name in self.incidence_per_region:
                target_incidence = self.incidence_per_region[region.name]
                people = region.people
                infected_people = [person for person in people if person.infected]
                incidence = len(infected_people) / len(people)
                if incidence > target_incidence:
                    n_to_remove = int((incidence - target_incidence) * len(people))
                    to_cure = sample(infected_people, n_to_remove)
                    for person in to_cure:
                        person.infection = None
                elif incidence < target_incidence:
                    n_to_add = int((target_incidence - incidence) * len(people))
                    to_infect = sample(people, k=2 * n_to_add)
                    infected = choices(infected_people, k=2 * n_to_add)
                    counter = 0
                    for person, infected_ref in zip(to_infect, infected):
                        if person.infected:
                            continue
                        counter += 1
                        selectors.infect_person_at_time(
                            person,
                            simulator.timer.now,
                            infected_ref.infection.infection_id(),
                        )
                        if counter == n_to_add:
                            break


import datetime
from typing import Union, Dict
from random import random

from june.epidemiology.infection import B117
from .event import Event


class Mutation(Event):
    """
    This events aims to reproduce a mutation effect.
    It was originally implemented to model the new Covid19 variant
    detected in the UK around November 2020. The idea is that a percentage
    of the active infections (which can vary region to region) is converted
    to the new variant, with different epidemiological charactersitics.
    Note: currently we only change the infection transmission charactersitics,
    but leaving the symptoms trajectory intact.

    Parameters
    ----------
    start_time
        time from when the event is active (default is always)
    end_time
        time when the event ends (default is always)
    regional_probabilities
        fraction of current infections that will be transformed to the new variant
    mutation_id
        unique id of the new mutation. These are generated with an adler32 encoding on
        the name.
        Covid19: 170852960
        B117: 37224668
    """

    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
        regional_probabilities: Dict[str, float],
        mutation_id=B117.infection_id(),
    ):
        super().__init__(start_time=start_time, end_time=end_time)
        self.regional_probabilities = regional_probabilities
        self.mutation_id = mutation_id

    def initialise(self, world=None):
        pass

    def apply(self, world, simulator, activities=None, day_type=None):
        selector = simulator.epidemiology.infection_selectors.infection_id_to_selector[
            self.mutation_id
        ]
        for person in world.people:
            if person.infected:
                probability = self.regional_probabilities.get(person.region.name, 0)
                if random() < probability:
                    new_infection = selector._make_infection(
                        person, time=person.infection.start_time
                    )
                    new_infection.time_of_testing = person.infection.time_of_testing
                    new_infection.start_time = person.infection.start_time
                    new_infection.symptoms = person.infection.symptoms
                    person.infection = new_infection


from .event import Event, Events
from .domestic_care import DomesticCare
from .mutation import Mutation
from .incidence_setter import IncidenceSetter


import pandas as pd
from typing import List
import numpy as np
from random import randint
from sklearn.neighbors import BallTree
from itertools import count
import logging

from june.paths import data_path
from june.geography import SuperArea, Geography
from june.groups.group import Supergroup
from june.groups.group.external import ExternalGroup

default_cities_filename = data_path / "input/geography/cities_per_super_area_ew.csv"

earth_radius = 6371  # km

logger = logging.getLogger(__name__)


def _calculate_centroid(latitudes, longitudes):
    """
    Calculates the centroid of the city.
    WARNING: This currently takes the mean of the latitude and longitude, however this is not correct for some cases,
    eg, the mean angle between 1 and 359 should be 0, not 180, etc.
    """
    return [np.mean(latitudes), np.mean(longitudes)]


class City:
    """
    A city is a collection of areas, with some added methods for functionality,
    such as commuting or local lockdowns.
    """

    external = False

    _id = count()

    def __init__(
        self,
        super_areas: List[str] = None,
        super_area: SuperArea = None,
        name: str = None,
        coordinates=None,
    ):
        """
        Initializes a city. A city is defined by a collection of ``super_areas`` and is located at one particular ``super_area``.
        The location to one ``super_area`` is necessary for domain parallelisation.

        Parameters
        ----------
        super_areas:
            A list of super area names
        super_area:
            The ``SuperArea`` instance of where the city resides
        name
            The city name
        coordinates
            A tuple or array of floats indicating latitude and longitude of the city (in degrees).
        """
        self.id = next(self._id)
        self.super_area = super_area
        self.super_areas = super_areas
        self.name = name
        self.super_stations = None
        self.city_stations = []
        self.inter_city_stations = []
        self.coordinates = coordinates
        self.internal_commuter_ids = set()  # internal commuters in the city

    @classmethod
    def from_file(cls, name, city_super_areas_filename=default_cities_filename):
        city_super_areas_df = pd.read_csv(city_super_areas_filename)
        city_super_areas_df.set_index("city", inplace=True)
        return cls.from_df(name=name, city_super_areas_df=city_super_areas_df)

    @classmethod
    def from_df(cls, name, city_super_areas_df):
        city_super_areas = city_super_areas_df.loc[name].values
        return cls(super_areas=city_super_areas, name=name)

    def get_commute_subgroup(self, person):
        """
        Gets the commute subgroup of the person. We first check if
        the person is in the list of the internal city commuters. If not,
        we then check if the person is a commuter in their closest city station.
        If none of the above, then that person doesn't need commuting.
        """
        if not self.has_stations:
            return
        if person.id in self.internal_commuter_ids:
            internal_station = self.city_stations[
                randint(0, len(self.city_stations) - 1)
            ]
            return internal_station.get_commute_subgroup()
        else:
            closest_inter_city_station = (
                person.super_area.closest_inter_city_station_for_city[self.name]
            )
            if person.id in closest_inter_city_station.commuter_ids:
                return closest_inter_city_station.get_commute_subgroup()

    def get_closest_inter_city_station(self, coordinates):
        return self.inter_city_stations.get_closest_station(coordinates)

    @property
    def has_stations(self):
        return ((self.city_stations is not None) and (len(self.city_stations) > 0)) or (
            (self.inter_city_stations is not None) and len(self.inter_city_stations) > 0
        )


class Cities(Supergroup):
    """
    A collection of cities.
    """

    def __init__(self, cities: List[City], ball_tree=True):
        super().__init__(cities)
        self.members_by_name = {city.name: city for city in cities}
        if ball_tree:
            self._ball_tree = self._construct_ball_tree()

    @classmethod
    def for_super_areas(
        cls,
        super_areas: List[SuperArea],
        city_super_areas_filename=default_cities_filename,
    ):
        """
        Initializes the cities which are on the given super areas.
        """
        city_super_areas = pd.read_csv(city_super_areas_filename)
        city_super_areas = city_super_areas.loc[
            city_super_areas.super_area.isin(
                [super_area.name for super_area in super_areas]
            )
        ]
        city_super_areas.reset_index(inplace=True)
        city_super_areas.set_index("city", inplace=True)
        cities = []
        for city in city_super_areas.index.unique():
            super_area_names = city_super_areas.loc[city, "super_area"]
            if type(super_area_names) == str:
                super_area_names = [super_area_names]
            else:
                super_area_names = super_area_names.values.astype(str)
            city = City(name=city, super_areas=super_area_names)
            lats = []
            lons = []
            for super_area_name in super_area_names:
                super_area = super_areas.members_by_name[super_area_name]
                super_area.city = city
                lats.append(super_area.coordinates[0])
                lons.append(super_area.coordinates[1])
            city.coordinates = _calculate_centroid(lats, lons)
            city.super_area = super_areas.get_closest_super_area(city.coordinates)
            cities.append(city)
        return cls(cities)

    @classmethod
    def for_geography(
        cls, geography: Geography, city_super_areas_filename=default_cities_filename
    ):
        return cls.for_super_areas(
            super_areas=geography.super_areas,
            city_super_areas_filename=city_super_areas_filename,
        )

    def _construct_ball_tree(self):
        """
        Constructs a NN tree with the haversine metric for the cities.
        """
        coordinates = np.array([np.deg2rad(city.coordinates) for city in self])
        ball_tree = BallTree(coordinates, metric="haversine")
        return ball_tree

    def get_closest_cities(self, coordinates, k=1, return_distance=False):
        coordinates = np.array(coordinates)
        if self._ball_tree is None:
            raise ValueError("Cities initialized without a BallTree")
        if coordinates.shape == (2,):
            coordinates = coordinates.reshape(1, -1)
        if return_distance:
            distances, indcs = self.ball_tree.query(
                np.deg2rad(coordinates), return_distance=return_distance, k=k
            )
            if coordinates.shape == (1, 2):
                cities = [self[idx] for idx in indcs[0]]
                return cities, distances[0] * earth_radius
            else:
                cities = [self[idx] for idx in indcs[:, 0]]
                return cities, distances[:, 0] * earth_radius
        else:
            indcs = self._ball_tree.query(
                np.deg2rad(coordinates), return_distance=return_distance, k=k
            )
            cities = [self[idx] for idx in indcs[0]]
            return cities

    def get_by_name(self, city_name):
        return self.members_by_name[city_name]

    def get_closest_city(self, coordinates):
        return self.get_closest_cities(coordinates, k=1, return_distance=False)[0]

    def get_closest_commuting_city(self, coordinates):
        cities_by_distance = self.get_closest_cities(coordinates, k=len(self.members))
        for city in cities_by_distance:
            if city.stations.members:
                return city
        logger.warning("No commuting city in this world.")


class ExternalCity(ExternalGroup):
    """
    This a city that lives outside the simulated domain.
    """

    external = True

    def __init__(self, id, domain_id, coordinates=None, commuter_ids=None, name=None):
        super().__init__(spec="city", domain_id=domain_id, id=id)
        self.internal_commuter_ids = commuter_ids or set()
        self.city_stations = []
        self.inter_city_stations = []
        self.super_area = None
        self.coordinates = coordinates
        self.name = name

    @property
    def has_stations(self):
        return len(self.city_stations) > 0

    def get_commute_subgroup(self, person):
        """
        Gets the commute subgroup of the person. We first check if
        the person is in the list of the internal city commuters. If not,
        we then check if the person is a commuter in their closest city station.
        If none of the above, then that person doesn't need commuting.
        """
        if not self.has_stations:
            return
        if person.id in self.internal_commuter_ids:
            internal_station = self.city_stations[
                randint(0, len(self.city_stations) - 1)
            ]
            return internal_station.get_commute_subgroup()
        else:
            closest_inter_city_station = (
                person.super_area.closest_inter_city_station_for_city[self.name]
            )
            if person.id in closest_inter_city_station.commuter_ids:
                return closest_inter_city_station.get_commute_subgroup()


import logging
from itertools import count, chain
from typing import List, Dict, Tuple, Optional
import pandas as pd
import numpy as np
from sklearn.neighbors import BallTree

from june import paths
from june.demography.person import Person

default_hierarchy_filename = (
    paths.data_path / "input/geography/area_super_area_region.csv"
)
default_area_coord_filename = (
    paths.data_path / "input/geography/area_coordinates_sorted.csv"
)
default_superarea_coord_filename = (
    paths.data_path / "input/geography/super_area_coordinates_sorted.csv"
)
default_area_socioeconomic_index_filename = (
    paths.data_path / "input/geography/socioeconomic_index.csv"
)

logger = logging.getLogger(__name__)

earth_radius = 6371  # km


class GeographyError(BaseException):
    pass


class Area:
    """
    Fine geographical resolution.
    """

    __slots__ = (
        "people",
        "id",
        "name",
        "coordinates",
        "super_area",
        "care_home",
        "schools",
        "households",
        "social_venues",
        "socioeconomic_index",
    )
    _id = count()

    def __init__(
        self,
        name: str = None,
        super_area: "SuperArea" = None,
        coordinates: Tuple[float, float] = None,
        socioeconomic_index: float = None,
    ):
        """
        Coordinate is given in the format [Y, X] where X is longitude and Y is latitude.
        """
        self.id = next(self._id)
        self.name = name
        self.care_home = None
        self.coordinates = coordinates
        self.super_area = super_area
        self.socioeconomic_index = socioeconomic_index
        self.people = []
        self.schools = []
        self.households = []
        self.social_venues = {}

    def add(self, person: Person):
        self.people.append(person)
        person.area = self

    def populate(self, demography, ethnicity=True, comorbidity=True):
        for person in demography.populate(
            self.name, ethnicity=ethnicity, comorbidity=comorbidity
        ):
            self.add(person)

    @property
    def region(self):
        return self.super_area.region


class Areas:
    __slots__ = "members_by_id", "super_area", "ball_tree", "members_by_name"

    def __init__(self, areas: List[Area], super_area=None, ball_tree: bool = True):
        self.members_by_id = {area.id: area for area in areas}
        try:
            self.members_by_name = {area.name: area for area in areas}
        except AttributeError:
            self.members_by_name = None
        self.super_area = super_area
        if ball_tree:
            self.ball_tree = self.construct_ball_tree()
        else:
            self.ball_tree = None

    def __iter__(self):
        return iter(self.members)

    def __len__(self):
        return len(self.members)

    def __getitem__(self, index):
        return self.members[index]

    def get_from_id(self, id):
        return self.members_by_id[id]

    def get_from_name(self, name):
        return self.members_by_name[name]

    @property
    def members(self):
        return list(self.members_by_id.values())

    def construct_ball_tree(self):
        all_members = self.members
        coordinates = np.array([np.deg2rad(area.coordinates) for area in all_members])
        ball_tree = BallTree(coordinates, metric="haversine")
        return ball_tree

    def get_closest_areas(self, coordinates, k=1, return_distance=False):
        coordinates = np.array(coordinates)
        if self.ball_tree is None:
            raise GeographyError("Areas initialized without a BallTree")
        if coordinates.shape == (2,):
            coordinates = coordinates.reshape(1, -1)
        if return_distance:
            distances, indcs = self.ball_tree.query(
                np.deg2rad(coordinates), return_distance=return_distance, k=k
            )
            if coordinates.shape == (1, 2):
                all_areas = self.members
                areas = [all_areas[idx] for idx in indcs[0]]
                return areas, distances[0] * earth_radius
            else:
                all_areas = self.members
                areas = [all_areas[idx] for idx in indcs[:, 0]]
                return areas, distances[:, 0] * earth_radius
        else:
            indcs = self.ball_tree.query(
                np.deg2rad(coordinates), return_distance=return_distance, k=k
            )
            all_areas = self.members
            areas = [all_areas[idx] for idx in indcs.flatten()]
            return areas

    def get_closest_area(self, coordinates, return_distance=False):
        if return_distance:
            closest_areas, dists = self.get_closest_areas(
                coordinates, k=1, return_distance=return_distance
            )
            return closest_areas[0], dists[0]
        else:
            return self.get_closest_areas(
                coordinates, k=1, return_distance=return_distance
            )[0]


class SuperArea:
    """
    Coarse geographical resolution.
    """

    __slots__ = (
        "id",
        "name",
        "city",
        "coordinates",
        "closest_inter_city_station_for_city",
        "region",
        "workers",
        "areas",
        "companies",
        "closest_hospitals",
    )
    external = False
    _id = count()

    def __init__(
        self,
        name: Optional[str] = None,
        areas: List[Area] = None,
        coordinates: Tuple[float, float] = None,
        region: Optional[str] = None,
    ):
        self.id = next(self._id)
        self.name = name
        self.city = None
        self.closest_inter_city_station_for_city = {}
        self.coordinates = coordinates
        self.region = region
        self.areas = areas or []
        self.workers = []
        self.companies = []
        self.closest_hospitals = None

    def add_worker(self, person: Person):
        self.workers.append(person)
        person.work_super_area = self

    def remove_worker(self, person: Person):
        self.workers.remove(person)
        person.work_super_area = None

    @property
    def people(self):
        return list(chain.from_iterable(area.people for area in self.areas))

    @property
    def households(self):
        return list(chain.from_iterable(area.households for area in self.areas))

    def __eq__(self, other):
        return self.name == other.name


class SuperAreas:
    __slots__ = "members_by_id", "ball_tree", "members_by_name"

    def __init__(self, super_areas: List[SuperArea], ball_tree: bool = True):
        """
        Group to aggregate SuperArea objects.

        Parameters
        ----------
        super_areas
            list of super areas
        ball_tree
            whether to construct a NN tree for the super areas
        """
        self.members_by_id = {area.id: area for area in super_areas}
        try:
            self.members_by_name = {
                super_area.name: super_area for super_area in super_areas
            }
        except AttributeError:
            self.members_by_name = None
        if ball_tree:
            self.ball_tree = self.construct_ball_tree()
        else:
            self.ball_tree = None

    def __iter__(self):
        return iter(self.members)

    def __len__(self):
        return len(self.members)

    def __getitem__(self, index):
        return self.members[index]

    def get_from_id(self, id):
        return self.members_by_id[id]

    def get_from_name(self, name):
        return self.members_by_name[name]

    @property
    def members(self):
        return list(self.members_by_id.values())

    def construct_ball_tree(self):
        all_members = self.members
        coordinates = np.array(
            [np.deg2rad(super_area.coordinates) for super_area in all_members]
        )
        ball_tree = BallTree(coordinates, metric="haversine")
        return ball_tree

    def get_closest_super_areas(self, coordinates, k=1, return_distance=False):
        coordinates = np.array(coordinates)
        if self.ball_tree is None:
            raise GeographyError("Areas initialized without a BallTree")
        if coordinates.shape == (2,):
            coordinates = coordinates.reshape(1, -1)
        if return_distance:
            distances, indcs = self.ball_tree.query(
                np.deg2rad(coordinates),
                return_distance=return_distance,
                k=k,
                sort_results=True,
            )
            indcs = chain.from_iterable(indcs)
            all_super_areas = self.members
            super_areas = [all_super_areas[idx] for idx in indcs]
            distances = distances.flatten()
            return super_areas, distances * earth_radius
        else:
            indcs = self.ball_tree.query(
                np.deg2rad(coordinates),
                return_distance=return_distance,
                k=k,
                sort_results=True,
            )
            all_super_areas = self.members
            super_areas = [all_super_areas[idx] for idx in indcs.flatten()]
            return super_areas

    def get_closest_super_area(self, coordinates, return_distance=False):
        if return_distance:
            closest_areas, distances = self.get_closest_super_areas(
                coordinates, k=1, return_distance=return_distance
            )
            return closest_areas[0], distances[0]
        else:
            return self.get_closest_super_areas(
                coordinates, k=1, return_distance=return_distance
            )[0]


class ExternalSuperArea:
    """
    This a city that lives outside the simulated domain.
    """

    external = True
    __slots__ = "city", "spec", "id", "domain_id", "coordinates"

    def __init__(self, id, domain_id, coordinates):
        self.city = None
        self.spec = "super_area"
        self.id = id
        self.domain_id = domain_id
        self.coordinates = coordinates


class Region:
    """
    Coarsest geographical resolution
    """

    __slots__ = ("id", "name", "super_areas", "policy")
    _id = count()

    def __init__(
        self, name: Optional[str] = None, super_areas: List[SuperAreas] = None
    ):
        self.id = next(self._id)
        self.name = name
        self.super_areas = super_areas or []
        self.policy = {
            "regional_compliance": 1.0,
            "lockdown_tier": None,
            "local_closed_venues": set(),
            "global_closed_venues": set(),
        }

    @property
    def people(self):
        return list(
            chain.from_iterable(super_area.people for super_area in self.super_areas)
        )

    @property
    def regional_compliance(self):
        return self.policy["regional_compliance"]

    @regional_compliance.setter
    def regional_compliance(self, value):
        self.policy["regional_compliance"] = value

    @property
    def closed_venues(self):
        return self.policy["local_closed_venues"] | self.policy["global_closed_venues"]

    @property
    def households(self):
        return list(
            chain.from_iterable(
                super_area.households for super_area in self.super_areas
            )
        )


class Regions:
    __slots__ = "members_by_id", "members_by_name"

    def __init__(self, regions: List[Region]):
        self.members_by_id = {region.id: region for region in regions}
        try:
            self.members_by_name = {region.name: region for region in regions}
        except AttributeError:
            self.members_by_name = None

    def __iter__(self):
        return iter(self.members)

    def __len__(self):
        return len(self.members)

    def __getitem__(self, index):
        return self.members[index]

    def get_from_id(self, id):
        return self.members_by_id[id]

    def get_from_name(self, name):
        return self.members_by_name[name]

    @property
    def members(self):
        return list(self.members_by_id.values())


class Geography:
    def __init__(
        self, areas: List[Area], super_areas: List[SuperArea], regions: List[Region]
    ):
        """
        Generate hierachical devision of geography.

        Parameters
        ----------
        hierarchy
            The different geographical division units from which the
            hierachical structure will be constructed.
        area_coordinates

        Note: It would be nice to find a better way to handle coordinates.
        """
        self.areas = areas
        self.super_areas = super_areas
        self.regions = regions
        # possible buildings
        self.households = None
        self.schools = None
        self.hospitals = None
        self.companies = None
        self.care_homes = None
        self.pubs = None
        self.cinemas = None
        self.groceries = None
        self.cemeteries = None
        self.universities = None

    @classmethod
    def _create_areas(
        cls,
        area_coords: pd.DataFrame,
        super_area: pd.DataFrame,
        socioeconomic_indices: pd.Series,
    ) -> List[Area]:
        """
        Applies the _create_area function throught the area_coords dataframe.
        If area_coords is a series object, then it does not use the apply()
        function as it does not support the axis=1 parameter.

        Parameters
        ----------
        area_coords
            pandas Dataframe with the area name as index and the coordinates
            X, Y where X is longitude and Y is latitude.
        """
        # if a single area is given, then area_coords is a series
        # and we cannot do iterrows()
        if isinstance(area_coords, pd.Series):
            areas = [
                Area(
                    area_coords.name,
                    super_area,
                    area_coords.values,
                    socioeconomic_indices.loc[area_coords.name],
                )
            ]
        else:
            areas = []
            for name, coordinates in area_coords.iterrows():
                areas.append(
                    Area(
                        name,
                        super_area,
                        coordinates=np.array(
                            [coordinates.latitude, coordinates.longitude]
                        ),
                        socioeconomic_index=socioeconomic_indices.loc[name],
                    )
                )
        return areas

    @classmethod
    def _create_super_areas(
        cls,
        super_area_coords: pd.DataFrame,
        area_coords: pd.DataFrame,
        area_socioeconomic_indices: pd.Series,
        region: "Region",
        hierarchy: pd.DataFrame,
    ) -> List[Area]:
        """
        Applies the _create_super_area function throught the super_area_coords dataframe.
        If super_area_coords is a series object, then it does not use the apply()
        function as it does not support the axis=1 parameter.

        Parameters
        ----------
        super_area_coords
            pandas Dataframe with the super area name as index and the coordinates
            X, Y where X is longitude and Y is latitude.
        region
            region instance to what all the super areas belong to
        """
        # if a single area is given, then area_coords is a series
        # and we cannot do iterrows()
        area_hierarchy = hierarchy.reset_index()
        area_hierarchy.set_index("super_area", inplace=True)
        total_areas_list, super_areas_list = [], []
        if isinstance(super_area_coords, pd.Series):
            super_areas_list = [
                SuperArea(
                    super_area_coords.name,
                    areas=None,
                    region=region,
                    coordinates=np.array(
                        [super_area_coords.latitude, super_area_coords.longitude]
                    ),
                )
            ]
            areas_df = area_coords.loc[
                area_hierarchy.loc[super_area_coords.name, "area"]
            ]
            areas_list = cls._create_areas(
                areas_df, super_areas_list[0], area_socioeconomic_indices
            )
            super_areas_list[0].areas = areas_list
            total_areas_list += areas_list
        else:
            for super_area_name, row in super_area_coords.iterrows():
                super_area = SuperArea(
                    areas=None,
                    name=super_area_name,
                    coordinates=np.array([row.latitude, row.longitude]),
                    region=region,
                )
                areas_df = area_coords.loc[area_hierarchy.loc[super_area_name, "area"]]
                areas_list = cls._create_areas(
                    areas_df, super_area, area_socioeconomic_indices
                )
                super_area.areas = areas_list
                total_areas_list += list(areas_list)
                super_areas_list.append(super_area)
        return super_areas_list, total_areas_list

    @classmethod
    def create_geographical_units(
        cls,
        hierarchy: pd.DataFrame,
        area_coordinates: pd.DataFrame,
        super_area_coordinates: pd.DataFrame,
        area_socioeconomic_indices: pd.Series,
        sort_identifiers=True,
    ):
        """
        Create geo-graph of the used geographical units.

        """
        # this method ensure that super geo.super_areas, geo.areas, and so are ordered by identifier.
        region_hierarchy = hierarchy.reset_index().set_index("region")["super_area"]
        region_hierarchy = region_hierarchy.drop_duplicates()
        region_list = []
        total_areas_list, total_super_areas_list = [], []
        for region_name in region_hierarchy.index.unique():
            region = Region(name=region_name, super_areas=None)

            super_areas_df = super_area_coordinates.loc[
                region_hierarchy.loc[region_name]
            ]
            super_areas_list, areas_list = cls._create_super_areas(
                super_areas_df,
                area_coordinates,
                area_socioeconomic_indices,
                region,
                hierarchy=hierarchy,
            )
            region.super_areas = super_areas_list
            total_super_areas_list += list(super_areas_list)
            total_areas_list += list(areas_list)
            region_list.append(region)
        if sort_identifiers:
            total_areas_list = sort_geo_unit_by_identifier(total_areas_list)
            total_super_areas_list = sort_geo_unit_by_identifier(total_super_areas_list)

        areas = Areas(total_areas_list)
        super_areas = SuperAreas(total_super_areas_list)
        regions = Regions(region_list)
        logger.info(
            f"There are {len(areas)} areas and "
            + f"{len(super_areas)} super_areas "
            + f"and {len(regions)} regions in the world."
        )
        return areas, super_areas, regions

    @classmethod
    def from_file(
        cls,
        filter_key: Optional[Dict[str, list]] = None,
        hierarchy_filename: str = default_hierarchy_filename,
        area_coordinates_filename: str = default_area_coord_filename,
        super_area_coordinates_filename: str = default_superarea_coord_filename,
        area_socioeconomic_index_filename: str = default_area_socioeconomic_index_filename,
        sort_identifiers=True,
    ) -> "Geography":
        """
        Load data from files and construct classes capable of generating
        hierarchical structure of geographical areas.

        Example usage
        -------------
            ```
            geography = Geography.from_file(filter_key={"region" : "North East"})
            geography = Geography.from_file(filter_key={"super_area" : ["E02005728"]})
            ```
        Parameters
        ----------
        filter_key
            Filter out geo-units which should enter the world.
            At the moment this can only be one of [area, super_area, region]
        hierarchy_filename
            Pandas df file containing the relationships between the different
            geographical units.
        area_coordinates_filename:
            coordinates of the area units
        super_area_coordinates_filename
            coordinates of the super area units
        area_socioeconomic_index_filename
            socioeconomic index of each area
        logging_config_filename
            file path of the logger configuration
        """
        geo_hierarchy = pd.read_csv(hierarchy_filename)
        areas_coord = pd.read_csv(area_coordinates_filename)
        super_areas_coord = pd.read_csv(super_area_coordinates_filename)
        if filter_key is not None:
            geo_hierarchy = _filtering(geo_hierarchy, filter_key)
        areas_coord = areas_coord.loc[areas_coord.area.isin(geo_hierarchy.area)]
        super_areas_coord = super_areas_coord.loc[
            super_areas_coord.super_area.isin(geo_hierarchy.super_area)
        ].drop_duplicates()
        areas_coord.set_index("area", inplace=True)
        areas_coord = areas_coord[["latitude", "longitude"]]
        super_areas_coord.set_index("super_area", inplace=True)
        super_areas_coord = super_areas_coord[["latitude", "longitude"]]
        geo_hierarchy.set_index("super_area", inplace=True)
        if area_socioeconomic_index_filename:
            area_socioeconomic_df = pd.read_csv(area_socioeconomic_index_filename)
            area_socioeconomic_df = area_socioeconomic_df.loc[
                area_socioeconomic_df.area.isin(geo_hierarchy.area)
            ]
            area_socioeconomic_df.set_index("area", inplace=True)
            area_socioeconomic_index = area_socioeconomic_df["socioeconomic_centile"]
        else:
            area_socioeconomic_index = pd.Series(
                data=np.full(len(areas_coord), None),
                index=areas_coord.index,
                name="socioeconomic_centile",
            )
        areas, super_areas, regions = cls.create_geographical_units(
            geo_hierarchy,
            areas_coord,
            super_areas_coord,
            area_socioeconomic_index,
            sort_identifiers=sort_identifiers,
        )
        return cls(areas, super_areas, regions)


def _filtering(data: pd.DataFrame, filter_key: Dict[str, list]) -> pd.DataFrame:
    """
    Filter DataFrame for given geo-unit and it's listed names
    """
    return data[
        data[list(filter_key.keys())[0]].isin(list(filter_key.values())[0]).values
    ]


def sort_geo_unit_by_identifier(geo_units):
    geo_identifiers = [unit.name for unit in geo_units]
    sorted_idx = np.argsort(geo_identifiers)
    first_unit_id = geo_units[0].id
    units_sorted = [geo_units[idx] for idx in sorted_idx]
    # reassign ids
    for i, unit in enumerate(units_sorted):
        unit.id = first_unit_id + i
    return units_sorted


from typing import List
import numpy as np
import logging
from random import randint
from sklearn.neighbors import BallTree
from itertools import count

from june.paths import data_path
from june.geography import City, SuperAreas, SuperArea
from june.groups import Supergroup, ExternalGroup, ExternalSubgroup
from june.utils.distances import add_distance_to_lat_lon

default_super_stations_filename = (
    data_path / "input/geography/stations_per_super_area_ew.csv"
)

logger = logging.getLogger(__name__)


class Station:
    """
    This represents a general station.
    """

    external = False
    _id = count()

    def __init__(self, city: str = None, super_area: SuperArea = None):
        self.id = next(self._id)
        self.commuter_ids = set()
        self.city = city
        self.super_area = super_area

    @property
    def coordinates(self):
        return self.super_area.coordinates


class CityStation(Station):
    """
    This is a city station for internal commuting
    """

    def __init__(self, city: str = None, super_area: SuperArea = None):
        super().__init__(city=city, super_area=super_area)
        self.city_transports = []

    @property
    def n_city_transports(self):
        return len(self.city_transports)

    def get_commute_subgroup(self):
        return self.city_transports[randint(0, self.n_city_transports - 1)][0]

    @property
    def station_type(self):
        return "city"


class InterCityStation(Station):
    """
    This is an inter-city station for inter-city commuting
    """

    def __init__(self, city: str = None, super_area: SuperArea = None):
        super().__init__(city=city, super_area=super_area)
        self.inter_city_transports = []

    @property
    def n_inter_city_transports(self):
        return len(self.inter_city_transports)

    def get_commute_subgroup(self):
        return self.inter_city_transports[randint(0, self.n_inter_city_transports - 1)][
            0
        ]

    @property
    def station_type(self):
        return "inter_city"


class Stations(Supergroup):
    """
    A collection of stations belonging to a city.
    """

    def __init__(self, stations: List[Station]):
        super().__init__(stations)
        self._ball_tree = None

    @classmethod
    def from_city_center(
        cls,
        city: City,
        type: str,
        super_areas: SuperAreas,
        number_of_stations: int = 4,
        distance_to_city_center: int = 20,
    ):
        """
        Initialises ``number_of_stations`` radially around the city center.

        Parameters
        ----------
        super_areas
            The super_areas where to put the hubs on
        number_of_stations:
            How many stations to initialise
        distance_to_city_center
            The distance from the center to the each station
        """
        stations = []
        angle = 0
        delta_angle = 2 * np.pi / number_of_stations
        x = distance_to_city_center
        y = 0
        city_coordinates = city.coordinates
        for i in range(number_of_stations):
            station_position = add_distance_to_lat_lon(
                city_coordinates[0], city_coordinates[1], x=x, y=y
            )
            angle += delta_angle
            x = distance_to_city_center * np.cos(angle)
            y = distance_to_city_center * np.sin(angle)
            super_area = super_areas.get_closest_super_area(np.array(station_position))
            if type == "city_station":
                station = CityStation(city=city.name, super_area=super_area)
            elif type == "inter_city_station":
                station = InterCityStation(city=city.name, super_area=super_area)
            else:
                raise ValueError
            stations.append(station)
        return cls(stations)

    def _construct_ball_tree(self):
        coordinates = np.array([np.deg2rad(station.coordinates) for station in self])
        self._ball_tree = BallTree(coordinates, metric="haversine")

    def get_closest_station(self, coordinates):
        coordinates = np.array(coordinates)
        if self._ball_tree is None:
            raise ValueError("Stations initialized without a BallTree")
        if coordinates.shape == (2,):
            coordinates = coordinates.reshape(1, -1)
        indcs = self._ball_tree.query(
            np.deg2rad(coordinates), return_distance=False, k=1
        )
        super_areas = [self[idx] for idx in indcs[:, 0]]
        return super_areas[0]


class ExternalStation(ExternalGroup):
    external = True

    def __init__(self, id: int, domain_id: int, city: str = None):
        super().__init__(spec="station", domain_id=domain_id, id=id)
        self.commuter_ids = set()
        self.city = city

    @property
    def coordinates(self):
        return self.super_area.coordinates

    def get_commute_subgroup(self):
        raise NotImplementedError


class ExternalCityStation(ExternalStation):
    """
    This an external city station that lives outside the simulated domain.
    """

    def __init__(self, id: int, domain_id: int, city: str = None):
        super().__init__(id=id, domain_id=domain_id, city=city)
        self.city_transports = []

    @property
    def n_city_transports(self):
        return len(self.city_transports)

    def get_commute_subgroup(self):
        group = self.city_transports[randint(0, self.n_city_transports - 1)]
        return ExternalSubgroup(group=group, subgroup_type=0)


class ExternalInterCityStation(ExternalStation):
    """
    This an external city station that lives outside the simulated domain.
    """

    def __init__(self, id: int, domain_id: int, city: str = None):
        super().__init__(id=id, domain_id=domain_id, city=city)
        self.inter_city_transports = []

    @property
    def n_inter_city_transports(self):
        return len(self.inter_city_transports)

    def get_commute_subgroup(self):
        group = self.inter_city_transports[randint(0, self.n_inter_city_transports - 1)]
        return ExternalSubgroup(group=group, subgroup_type=0)


from .geography import (
    Area,
    SuperArea,
    Areas,
    SuperAreas,
    Geography,
    ExternalSuperArea,
    Region,
    Regions,
)
from .city import City, Cities, ExternalCity
from .station import (
    Station,
    Stations,
    CityStation,
    InterCityStation,
    ExternalCityStation,
    ExternalInterCityStation,
)


import logging

import numpy as np
import pandas as pd
from scipy.stats import rv_discrete

from june.demography.person import Person
from june.groups import Group

logger = logging.getLogger(__name__)


class BoundaryError(BaseException):
    """Class for throwing boundary related errors."""


class Boundary(Group):
    def __init__(self, world):
        super().__init__()
        self.world = world
        self.n_residents = 0
        self.missing_workforce_nr()

    def missing_workforce_nr(self):
        """
        Estimate missing workforce in simulated region.
        This will establish the number of workers recruited
        from the boundary.
        """

        self.ADULT_THRESHOLD = self.world.config["people"]["adult_threshold"]
        self.OLD_THRESHOLD = self.world.config["people"]["old_threshold"]
        self._init_frequencies()

        for company in self.world.companies.members:
            # nr. of missing workforce
            # TODO companies shouldn always be completely full
            n_residents = company.n_employees_max - company.n_employees

            (sex_rnd_arr, nomis_bin_rnd_arr, age_rnd_arr) = self.init_random_variables(
                n_residents, company.industry
            )

            for i in range(n_residents):
                # create new person
                person = Person(
                    self.world,
                    (i + self.n_residents),
                    "boundary",
                    company.msoa,
                    age_rnd_arr[i],
                    nomis_bin_rnd_arr[i],
                    sex_rnd_arr[i],
                    econ_index=0,
                    mode_of_transport=None,
                )
                person.industry = company.industry

                # Inform groups about new person
                self.people.append(person)
                self.world.people.members.append(person)
                company.people.append(person)
                idx = [
                    idx
                    for idx, msoa in enumerate(self.world.msoareas.members)
                    if msoa.name == company.msoa
                ][0]
                self.world.msoareas.members[idx].work_people.append(person)

            self.n_residents += n_residents

    def _init_frequencies(self):
        """
        Create the frequencies for different attributes of the whole
        simulated region.
        """
        # sex-frequencies per company sector
        f_col = [
            col
            for col in self.world.inputs.compsec_by_sex_df.columns.values
            if "f " in col
        ]
        f_nrs_per_compsec = self.world.inputs.compsec_by_sex_df[f_col].sum(axis="rows")

        m_col = [
            col
            for col in self.world.inputs.compsec_by_sex_df.columns.values
            if "m " in col
        ]
        m_nrs_per_compsec = self.world.inputs.compsec_by_sex_df[m_col].sum(axis="rows")

        sex_freq_per_compsec = pd.DataFrame(
            data=np.vstack((f_nrs_per_compsec.values, m_nrs_per_compsec.values)).T,
            index=[idx.split(" ")[-1] for idx in m_nrs_per_compsec.index.values],
            columns=["f", "m"],
        )
        self.sex_freq_per_compsec = sex_freq_per_compsec.div(
            sex_freq_per_compsec.sum(axis=1), axis=0
        )

        # age-frequencies of people at work, based on the whole simulated region
        nomis_and_age_list = [
            [person.nomis_bin, person.age] for person in self.world.people.members
        ]
        nomis_bin_arr, age_arr = np.array(nomis_and_age_list).T
        nomis_bin_unique, nomis_bin_counts = np.unique(
            nomis_bin_arr, return_counts=True
        )
        age_unique, age_counts = np.unique(age_arr, return_counts=True)

        nomis_bin_df = pd.DataFrame(
            data=np.vstack((nomis_bin_unique, nomis_bin_counts)).T,
            columns=["age", "freq"],
        )
        nomis_bin_df = nomis_bin_df[
            (nomis_bin_df["age"] >= self.ADULT_THRESHOLD)
            & (nomis_bin_df["age"] <= self.OLD_THRESHOLD)
        ]
        self.nomis_bins = nomis_bin_df.div(nomis_bin_df.sum(axis=0), axis=1)

        age_df = pd.DataFrame(
            data=np.vstack((age_unique, age_counts)).T, columns=["age", "freq"]
        )
        age_df = age_df[(age_df["age"] >= 20) & (age_df["age"] <= 65)]
        self.ages = age_df.div(age_df.sum(axis=0), axis=1)

    def init_random_variables(self, n_residents, compsec):
        """
        Create the random variables following the discrete distributions.
        for different attributes of the whole simulated region.
        """
        sex_rv = rv_discrete(
            values=(np.arange(0, 2), self.sex_freq_per_compsec.loc[compsec].values)
        )
        sex_rnd_arr = sex_rv.rvs(size=n_residents)

        nomis_bin_rv = rv_discrete(
            values=(
                np.arange(len(self.nomis_bins.freq.values)),
                self.nomis_bins.freq.values,
            )
        )
        nomis_bin_rnd_arr = nomis_bin_rv.rvs(size=n_residents)

        age_rv = rv_discrete(
            values=(np.arange(len(self.ages.freq.values)), self.ages.freq.values)
        )
        age_rnd_arr = age_rv.rvs(size=n_residents)

        return sex_rnd_arr, nomis_bin_rnd_arr, age_rnd_arr


import logging
import yaml
from enum import IntEnum
from typing import List
import numpy as np

import pandas as pd

from june import paths
from june.geography import Geography, Area
from june.groups import Group, Supergroup

default_data_filename = paths.data_path / "input/care_homes/care_homes_ew.csv"
default_areas_map_path = paths.data_path / "input/geography/area_super_area_region.csv"
default_config_filename = paths.configs_path / "defaults/groups/care_home.yaml"
logger = logging.getLogger("care_homes")


class CareHomeError(BaseException):
    pass


class CareHome(Group):
    """
    The Carehome class represents a carehome and contains information about
    its residents, workers and visitors.
    We assume three subgroups:
    0 - workers
    1 - residents
    2 - visitors
    """

    __slots__ = ("n_residents", "area", "n_workers", "quarantine_starting_date")

    # class SubgroupType(IntEnum):
    #     workers = 0
    #     residents = 1
    #     visitors = 2

    def __init__(
        self, area: Area = None, n_residents: int = None, n_workers: int = None
    ):
        super().__init__()
        self.n_residents = n_residents
        self.n_workers = n_workers
        self.area = area
        self.quarantine_starting_date = None

    def add(self, person, subgroup_type, activity: str = "residence"):
        if activity == "leisure":
            super().add(
                person, subgroup_type=self.SubgroupType.visitors, activity="leisure"
            )
        else:
            super().add(person, subgroup_type=subgroup_type, activity=activity)

    @property
    def workers(self):
        return self.subgroups[self.SubgroupType.workers]

    @property
    def residents(self):
        return self.subgroups[self.SubgroupType.residents]

    @property
    def visitors(self):
        return self.subgroups[self.SubgroupType.visitors]

    def quarantine(self, time, quarantine_days, household_compliance):
        return True

    @property
    def coordinates(self):
        return self.area.coordinates

    @property
    def super_area(self):
        if self.area is None:
            return None
        else:
            return self.area.super_area

    @property
    def households_to_visit(self):
        return None

    @property
    def care_homes_to_visit(self):
        return None

    def get_leisure_subgroup(self, person, subgroup_type, to_send_abroad):
        return self[self.SubgroupType.visitors]

    @property
    def type(self):
        return "care_home"


class CareHomes(Supergroup):
    venue_class = CareHome

    def __init__(self, care_homes: List[venue_class]):
        super().__init__(members=care_homes)

    @classmethod
    def for_geography(
        cls,
        geography: Geography,
        data_file: str = default_data_filename,
        config_file: str = default_config_filename,
    ) -> "CareHomes":
        """
        Initializes care homes from geography.
        """
        areas = geography.areas
        if not areas:
            raise CareHomeError("Empty geography!")
        return cls.for_areas(areas, data_file, config_file)

    @classmethod
    def for_areas(
        cls,
        areas: List[Area],
        data_file: str = default_data_filename,
        config_file: str = default_config_filename,
    ) -> "CareHomes":
        """
        Parameters
        ----------
        area_names
            list of areas for which to create populations
        data_path
            The path to the data directory
        config
        """
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        care_home_df = pd.read_csv(data_file, index_col=0)
        if areas:
            area_names = [area.name for area in areas]
            # filter out carehomes that are in the area of interest
            care_home_df = care_home_df.loc[area_names]
        care_homes = []
        logger.info(
            f"There are {len(care_home_df.loc[care_home_df.values!=0])} care_homes in this geography."
        )
        for area in areas:
            n_residents = care_home_df.loc[area.name].values[0]
            n_worker = max(
                int(np.ceil(n_residents / config["n_residents_per_worker"])), 1
            )
            if n_residents != 0:
                area.care_home = cls.venue_class(area, n_residents, n_worker)
                care_homes.append(area.care_home)
        return cls(care_homes)


from june.groups import Supergroup, Group


class Cemetery(Group):
    def add(self, person):
        self[0].people.append(person)


class Cemeteries(Supergroup):
    def __init__(self):
        super().__init__([Cemetery()])

    def get_nearest(self, person):
        return self.members[0]


import logging
from enum import IntEnum
from random import shuffle
from june import paths
from typing import List
import yaml

import numpy as np
import pandas as pd

from june.geography import Geography, SuperArea
from june.groups import Group, Supergroup
from june.groups.group.interactive import InteractiveGroup

default_size_nr_file = paths.data_path / "input/companies/company_size_2011.csv"
default_sector_nr_per_msoa_file = (
    paths.data_path / "input/companies/company_sector_2011.csv"
)
default_areas_map_path = paths.data_path / "input/geography/area_super_area_region.csv"
default_config_filename = paths.configs_path / "defaults/groups/companies.yaml"

logger = logging.getLogger(__name__)


def _get_size_brackets(sizegroup: str):
    """
    Given company size group calculates mean
    """
    # ensure that read_companysize_census() also returns number of companies
    # in each size category
    size_min, size_max = sizegroup.split("-")
    if size_max == "XXX" or size_max == "xxx":
        size_min = int(size_min)
        size_max = 1500
    else:
        size_min = int(size_min)
        size_max = int(size_max)
    return size_min, size_max


class CompanyError(BaseException):
    pass


class Company(Group):
    """
    The Company class represents a company that contains information about
    its workers which are not yet distributed to key company sectors
    (e.g. as schools and hospitals).

    Currently we treat the workforce of a company as one single sub-group
    and therefore we invoke the base class group with the default Ngroups = 1.
    We made this explicit here, although it is not necessary.
    """

    __slots__ = ("super_area", "sector", "n_workers_max")

    # class SubgroupType(IntEnum):
    #     workers = 0

    def __init__(self, super_area=None, n_workers_max=np.inf, sector=None):
        super().__init__()
        self.super_area = super_area
        self.sector = sector
        self.n_workers_max = n_workers_max

    def add(self, person):
        super().add(
            person,
            subgroup_type=self.get_index_subgroup(person),
            activity="primary_activity",
        )

    @property
    def n_workers(self):
        return len(self.people)

    # @property
    # def workers(self):
    #     return self.subgroups[self.SubgroupType.workers]

    @property
    def coordinates(self):
        return self.super_area.coordinates

    @property
    def area(self):
        return self.super_area.areas[0]

    def get_interactive_group(self, people_from_abroad=None):
        return InteractiveCompany(self, people_from_abroad=people_from_abroad)


class Companies(Supergroup):
    venue_class = Company

    def __init__(self, companies: List["Companies"]):
        """
        Create companies and provide functionality to allocate workers.

        Parameters
        ----------
        company_size_per_superarea_df: pd.DataFram
            Nr. of companies within a size-range per SuperArea.

        compsec_per_msoa_df: pd.DataFrame
            Nr. of companies per sector sector per SuperArea.
        """
        super().__init__(members=companies)

    @classmethod
    def for_geography(
        cls,
        geography: Geography,
        size_nr_file: str = default_size_nr_file,
        sector_nr_per_msoa_file: str = default_sector_nr_per_msoa_file,
        default_config_filename: str = default_config_filename,
    ) -> "Companies":
        """
        Creates companies for the specified geography, and saves them
        to the super_aresa they belong to
        Parameters
        ----------
        geography
            an instance of the geography class
        company_size_per_superarea_filename:
            Nr. of companies within a size-range per SuperArea.
        compsec_per_msoa_filename:
            Nr. of companies per sector sector per SuperArea.
        """
        if not geography.super_areas:
            raise CompanyError("Empty geography!")
        return cls.for_super_areas(
            geography.super_areas,
            size_nr_file,
            sector_nr_per_msoa_file,
            default_config_filename,
        )

    @classmethod
    def for_super_areas(
        cls,
        super_areas: List[SuperArea],
        size_nr_per_super_area_file: str = default_size_nr_file,
        sector_nr_per_super_area_file: str = default_sector_nr_per_msoa_file,
        default_config_filename: str = default_config_filename,
    ) -> "Companies":
        """Creates companies for the specified super_areas, and saves them
        to the super_aresa they belong to
        Parameters
        ----------
        super_areas
            list of super areas
        company_size_per_superarea_filename:
            Nr. of companies within a size-range per SuperArea.
        compsec_per_msoa_filename:
            Nr. of companies per industry sector per SuperArea.

        Parameters
        ----------
        """
        size_per_superarea_df = pd.read_csv(size_nr_per_super_area_file, index_col=0)
        sector_per_superarea_df = pd.read_csv(
            sector_nr_per_super_area_file, index_col=0
        )
        super_area_names = [super_area.name for super_area in super_areas]
        company_sizes_per_super_area = size_per_superarea_df.loc[super_area_names]
        company_sectors_per_super_area = sector_per_superarea_df.loc[super_area_names]
        assert len(company_sectors_per_super_area) == len(company_sizes_per_super_area)
        if len(company_sectors_per_super_area) == 1:
            super_area = super_areas[0]
            companies = cls.create_companies_in_super_area(
                super_area, company_sizes_per_super_area, company_sectors_per_super_area
            )
            super_area.companies = companies
        else:
            companies = []
            for super_area, (_, company_sizes), (_, company_sectors) in zip(
                super_areas,
                company_sizes_per_super_area.iterrows(),
                company_sectors_per_super_area.iterrows(),
            ):
                super_area.companies = cls.create_companies_in_super_area(
                    super_area, company_sizes, company_sectors
                )
                companies += super_area.companies
        return cls(companies)

    @classmethod
    def create_companies_in_super_area(
        cls, super_area: SuperArea, company_sizes, company_sectors
    ) -> list:
        """
        Crates companies in super area using the sizes and sectors distributions.
        """
        sizes = np.array([])
        for size_bracket, counts in company_sizes.items():
            size_min, size_max = _get_size_brackets(size_bracket)
            sizes = np.concatenate(
                (sizes, np.random.randint(max(size_min, 1), size_max, int(counts)))
            )
        np.random.shuffle(sizes)
        sectors = []
        for sector, counts in company_sectors.items():
            sectors += [sector] * int(counts)
        shuffle(sectors)
        companies = list(
            map(
                lambda company_size, company_sector: cls.create_company(
                    super_area, company_size, company_sector
                ),
                sizes,
                sectors,
            )
        )
        return companies

    @classmethod
    def create_company(cls, super_area, company_size, company_sector):
        company = cls.venue_class(super_area, company_size, company_sector)
        return company


def _read_sector_betas():
    with open(default_config_filename) as f:
        sector_betas = yaml.load(f, Loader=yaml.FullLoader) or {}
    return sector_betas


class InteractiveCompany(InteractiveGroup):
    sector_betas = _read_sector_betas()

    def __init__(self, group: "Group", people_from_abroad=None):
        super().__init__(group=group, people_from_abroad=people_from_abroad)
        self.sector = group.sector

    def get_processed_beta(self, betas, beta_reductions):
        beta_processed = super().get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        return beta_processed * self.sector_betas.get(self.sector, 1.0)


import yaml
import logging
from enum import IntEnum
from june import paths
from typing import List, Tuple, Optional
import numpy as np
import pandas as pd
from sklearn.neighbors import BallTree

from june.groups import Group, Supergroup, ExternalGroup, ExternalSubgroup
from june.exc import HospitalError
from june.groups.group.make_subgroups import SubgroupParams

logger = logging.getLogger("hospitals")

default_data_filename = paths.data_path / "input/hospitals/trusts.csv"
default_config_filename = paths.configs_path / "defaults/groups/hospitals.yaml"


class MedicalFacility:
    pass


class MedicalFacilities:
    pass


class AbstractHospital:
    """
    Hospital functionality common for all hospitals (internal to the domain and external).
    """

    def __init__(self):
        self.ward_ids = set()
        self.icu_ids = set()

    def add_to_ward(self, person):
        self.ward_ids.add(person.id)
        person.subgroups.medical_facility = self.ward

    def remove_from_ward(self, person):
        self.ward_ids.remove(person.id)
        person.subgroups.medical_facility = None

    def add_to_icu(self, person):
        self.icu_ids.add(person.id)
        person.subgroups.medical_facility = self.icu

    def remove_from_icu(self, person):
        self.icu_ids.remove(person.id)
        person.subgroups.medical_facility = None

    def allocate_patient(self, person):
        """
        Allocate a patient inside the hospital, in the ward, in the ICU, or transfer.
        To correctly log if the person has been just admitted, transfered, or released,
        we return a few flags:
        - "ward_admitted" : this person has been admitted to the ward.
        - "icu_admitted" : this person has been directly admitted to icu.
        - "ward_transferred" : this person has been transferred  to ward (from icu)
        - "icu_transferred" : this person has been transferred to icu (from ward)
        - "no_change" : no change respect to last time step.
        """
        if (
            person.medical_facility is None
            or person.medical_facility.spec != "hospital"
        ):
            if person.infection.tag.name == "hospitalised":
                self.add_to_ward(person)
                return "ward_admitted"
            elif person.infection.tag.name == "intensive_care":
                self.add_to_icu(person)
                return "icu_admitted"
            else:
                raise HospitalError(
                    f"Person with symptoms {person.infection.tag} trying to enter hospital."
                )
        else:
            # this person has already been allocated in a hospital (this one)
            if person.infection.tag.name == "hospitalised":
                if person.id in self.ward_ids:
                    return "no_change"
                else:
                    self.remove_from_icu(person)
                    self.add_to_ward(person)
                    return "ward_transferred"
            elif person.infection.tag.name == "intensive_care":
                if person.id in self.icu_ids:
                    return "no_change"
                else:
                    self.remove_from_ward(person)
                    self.add_to_icu(person)
                    return "icu_transferred"

    def release_patient(self, person):
        """
        Releases patient from hospital.
        """
        if person.id in self.ward_ids:
            self.remove_from_ward(person)
        elif person.id in self.icu_ids:
            self.remove_from_icu(person)
        else:
            raise HospitalError("Trying to release patient not located in icu or ward.")


class Hospital(Group, AbstractHospital, MedicalFacility):
    """
    The Hospital class represents a hospital and contains information about
    its patients and workers - the latter being the usual "people".

    We currently use three subgroups:
    0 - workers (i.e. nurses, doctors, etc.),
    1 - patients
    2 - ICU patients
    """

    # class SubgroupType(IntEnum):
    #     workers = 0
    #     patients = 1
    #     icu_patients = 2

    __slots__ = "id", "n_beds", "n_icu_beds", "coordinates", "area", "trust_code"

    def __init__(
        self,
        n_beds: int,
        n_icu_beds: int,
        area: str = None,
        coordinates: Optional[Tuple[float, float]] = None,
        trust_code: str = None,
    ):
        """
        Create a Hospital given its description.

        Parameters
        ----------
        n_beds:
            total number of regular beds in the hospital
        n_icu_beds:
            total number of ICU beds in the hospital
        area:
            name of the super area the hospital belongs to
        coordinates:
            latitude and longitude
        """
        Group.__init__(self)
        AbstractHospital.__init__(self)
        self.area = area
        self.coordinates = coordinates
        self.n_beds = n_beds
        self.n_icu_beds = n_icu_beds
        self.trust_code = trust_code

    @property
    def super_area(self):
        return self.area.super_area

    @property
    def region(self):
        return self.super_area.region

    @property
    def region_name(self):
        return self.region.name

    @property
    def full(self):
        """
        Check whether all regular beds are being used
        """
        return self[self.SubgroupType.patients].size >= self.n_beds

    @property
    def full_ICU(self):
        """
        Check whether all ICU beds are being used
        """
        return self[self.SubgroupType.icu_patients].size >= self.n_icu_beds

    def add(self, person, subgroup_type):
        if subgroup_type in [
            self.SubgroupType.patients,
            self.SubgroupType.icu_patients,
        ]:
            super().add(
                person, activity="medical_facility", subgroup_type=subgroup_type
            )
        else:
            super().add(
                person,
                activity="primary_activity",
                subgroup_type=self.SubgroupType.workers,
            )

    @property
    def ward(self):
        return self.subgroups[self.SubgroupType.patients]

    @property
    def icu(self):
        return self.subgroups[self.SubgroupType.icu_patients]


class Hospitals(Supergroup, MedicalFacilities):
    venue_class = Hospital

    def __init__(
        self, hospitals: List["Hospital"], neighbour_hospitals: int = 5, ball_tree=True
    ):
        """
        Create a group of hospitals, and provide functionality to locate patients
        to a nearby hospital. It will check in order the first ```neighbour_hospitals```,
        when one has space available the patient is allocated to it. If none of the closest
        ones has beds available it will pick one of them at random and that hospital will
        overflow

        Parameters
        ----------
        hospitals:
            list of hospitals to aggrupate
        neighbour_hospitals:
            number of closest hospitals to look for
        """
        super().__init__(members=hospitals)
        self.neighbour_hospitals = neighbour_hospitals
        if ball_tree and self.members:
            coordinates = np.array([hospital.coordinates for hospital in hospitals])
            self.init_trees(coordinates)

    @classmethod
    def from_file(
        cls,
        filename: str = default_data_filename,
        config_filename: str = default_config_filename,
    ) -> "Hospitals":
        """
        Initialize Hospitals from path to data frame, and path to config file.

        Parameters
        ----------
        filename:
            path to hospital dataframe
        config_filename:
            path to hospital config dictionary

        Returns
        -------
        Hospitals instance
        """

        hospital_df = pd.read_csv(filename)
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        neighbour_hospitals = config["neighbour_hospitals"]
        logger.info(f"There are {len(hospital_df)} hospitals in the world.")
        hospitals = cls.init_hospitals(cls, hospital_df)
        return Hospitals(hospitals, neighbour_hospitals)

    @classmethod
    def for_geography(
        cls,
        geography,
        filename: str = default_data_filename,
        config_filename: str = default_config_filename,
    ):
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        neighbour_hospitals = config["neighbour_hospitals"]
        hospital_df = pd.read_csv(filename, index_col=4)
        area_names = [area.name for area in geography.areas]
        hospital_df = hospital_df.loc[hospital_df.index.isin(area_names)]
        logger.info(f"There are {len(hospital_df)} hospitals in this geography.")
        total_hospitals = len(hospital_df)
        hospitals = []
        for area in geography.areas:
            if area.name in hospital_df.index:
                hospitals_in_area = hospital_df.loc[area.name]
                if isinstance(hospitals_in_area, pd.Series):
                    hospital = cls.create_hospital_from_df_row(area, hospitals_in_area)
                    hospitals.append(hospital)
                else:
                    for _, row in hospitals_in_area.iterrows():
                        hospital = cls.create_hospital_from_df_row(area, row)
                        hospitals.append(hospital)
                if len(hospitals) == total_hospitals:
                    break
        return cls(
            hospitals=hospitals, neighbour_hospitals=neighbour_hospitals, ball_tree=True
        )

    @classmethod
    def create_hospital_from_df_row(cls, area, row):
        coordinates = row[["latitude", "longitude"]].values.astype(np.float64)
        n_beds = row["beds"]
        n_icu_beds = row["icu_beds"]
        trust_code = row["code"]
        hospital = cls.venue_class(
            area=area,
            coordinates=coordinates,
            n_beds=n_beds,
            n_icu_beds=n_icu_beds,
            trust_code=trust_code,
        )
        return hospital

    def init_hospitals(self, hospital_df: pd.DataFrame) -> List["Hospital"]:
        """
        Create Hospital objects with the right characteristics,
        as given by dataframe.

        Parameters
        ----------
        hospital_df:
            dataframe with hospital characteristics data
        """
        hospitals = []
        for index, row in hospital_df.iterrows():
            n_beds = row["beds"]
            n_icu_beds = row["icu_beds"]
            trust_code = row["code"]
            coordinates = row[["latitude", "longitude"]].values.astype(np.float64)
            hospital = Hospital(
                coordinates=coordinates,
                n_beds=n_beds,
                n_icu_beds=n_icu_beds,
                trust_code=trust_code,
            )
            hospitals.append(hospital)
        return hospitals

    def init_trees(self, hospital_coordinates: np.array) -> BallTree:
        """
        Reads hospital location and sizes, it initializes a KD tree on a sphere,
        to query the closest hospital to a given location.

        Parameters
        ----------
        hospital_df:
            dataframe with hospital characteristics data

        Returns
        -------
        Tree to query nearby schools
        """
        self.hospital_trees = BallTree(
            np.deg2rad(hospital_coordinates), metric="haversine"
        )

    def get_closest_hospitals_idx(
        self, coordinates: Tuple[float, float], k: int
    ) -> Tuple[float, float]:
        """
        Get the k-th closest hospital to a given coordinate

        Parameters
        ---------
        coordinates:
            latitude and longitude
        k:
            k-th neighbour

        Returns
        -------
        ID of the k-th closest hospital

        """
        k = min(k, len(list(self.hospital_trees.data)))
        distances, neighbours = self.hospital_trees.query(
            np.deg2rad(coordinates.reshape(1, -1)), k=k, sort_results=True
        )
        return neighbours[0]

    def get_closest_hospitals(
        self, coordinates: Tuple[float, float], k: int
    ) -> Tuple[float, float]:
        """
        Get the k-th closest hospital to a given coordinate

        Parameters
        ---------
        coordinates:
            latitude and longitude
        k:
            k-th neighbour

        Returns
        -------
        ID of the k-th closest hospital

        """
        k = min(k, len(list(self.hospital_trees.data)))
        distances, neighbours = self.hospital_trees.query(
            np.deg2rad(coordinates.reshape(1, -1)), k=k, sort_results=True
        )
        return [self.members[index] for index in neighbours[0]]


class ExternalHospital(ExternalGroup, AbstractHospital, MedicalFacility):
    external = True
    __slots__ = "spec", "id", "domain_id", "region_name", "ward_ids", "icu_ids"

    class SubgroupType(IntEnum):
        workers = 0
        patients = 1
        icu_patients = 2

    def __init__(self, id, spec, domain_id, region_name):
        ExternalGroup.__init__(self, id=id, spec=spec, domain_id=domain_id)
        AbstractHospital.__init__(self)
        self.region_name = region_name

        self.ward = ExternalSubgroup(
            group=self, subgroup_type=self.SubgroupType.patients
        )
        self.icu = ExternalSubgroup(
            group=self, subgroup_type=self.SubgroupType.icu_patients
        )


from enum import IntEnum
from collections import defaultdict
import numpy as np
from random import random

from june.groups import Group, Supergroup
from june.groups.group.interactive import InteractiveGroup

from typing import List


class Household(Group):
    """
    The Household class represents a household and contains information about
    its residents.
    We assume four subgroups:
    0 - kids
    1 - young adults
    2 - adults
    3 - old adults
    """

    __slots__ = (
        "area",
        "type",
        "composition_type",
        "max_size",
        "residents",
        "quarantine_starting_date",
        "residences_to_visit",
        "being_visited",
        "household_to_care",
        "receiving_care",
    )

    # class SubgroupType(IntEnum):
    #     kids = 0
    #     young_adults = 1
    #     adults = 2
    #     old_adults = 3

    def __init__(self, type=None, area=None, max_size=np.inf, composition_type=None):
        """
        Type should be on of ["family", "student", "young_adults", "old", "other", "nokids", "ya_parents", "communal"].
        Relatives is a list of people that are related to the family living in the household
        """
        super().__init__()
        self.area = area
        self.type = type
        self.quarantine_starting_date = -99
        self.max_size = max_size
        self.residents = ()
        self.residences_to_visit = defaultdict(tuple)
        self.household_to_care = None
        self.being_visited = False  # this is True when people from other households have been added to the group
        self.receiving_care = False
        self.composition_type = composition_type

    def _get_leisure_subgroup_for_person(self, person):
        if person.age < 18:
            subgroup = self.SubgroupType.kids
        elif person.age <= 25:
            subgroup = self.SubgroupType.young_adults
        elif person.age < 65:
            subgroup = self.SubgroupType.adults
        else:
            subgroup = self.SubgroupType.old_adults
        return subgroup

    def add(self, person, subgroup_type=None, activity="residence"):
        if subgroup_type is None:
            subgroup_type = self.get_leisure_subgroup_type(person)

        if activity == "leisure":
            subgroup_type = self.get_leisure_subgroup_type(person)
            person.subgroups.leisure = self[subgroup_type]
            self[subgroup_type].append(person)
            self.being_visited = True
        elif activity == "residence":
            self[subgroup_type].append(person)
            self.residents = tuple((*self.residents, person))
            person.subgroups.residence = self[subgroup_type]
        else:
            raise NotImplementedError(f"Activity {activity} not supported in household")

    def get_leisure_subgroup_type(cls, person):
        """
        A person wants to come and visit this household. We need to assign the person
        to the relevant age subgroup, and make sure the residents welcome him and
        don't go do any other leisure activities.
        """
        if person.age < 18:
            return cls.SubgroupType.kids
        elif person.age <= 25:
            return cls.SubgroupType.young_adults
        elif person.age < 65:
            return cls.SubgroupType.adults
        else:
            return cls.SubgroupType.old_adults

    def make_household_residents_stay_home(self, to_send_abroad=None):
        """
        Forces the residents to stay home if they are away doing leisure.
        This is used to welcome visitors.
        """
        for mate in self.residents:
            if mate.busy:
                if (
                    mate.leisure is not None
                ):  # this person has already been assigned somewhere
                    if not mate.leisure.external:
                        if mate not in mate.leisure.people:
                            # person active somewhere else, let's not disturb them
                            continue
                        mate.leisure.remove(mate)
                    else:
                        ret = to_send_abroad.delete_person(mate, mate.leisure)
                        if ret:
                            # person active somewhere else, let's not disturb them
                            continue
                    mate.subgroups.leisure = mate.residence
                    mate.residence.append(mate)
            else:
                mate.subgroups.leisure = (
                    mate.residence  # person will be added later in the simulator.
                )

    # @property
    # def kids(self):
    #     return self.subgroups[self.SubgroupType.kids]

    # @property
    # def young_adults(self):
    #     return self.subgroups[self.SubgroupType.young_adults]

    # @property
    # def adults(self):
    #     return self.subgroups[self.SubgroupType.adults]

    # @property
    # def old_adults(self):
    #     return self.subgroups[self.SubgroupType.old_adults]

    @property
    def coordinates(self):
        return self.area.coordinates

    @property
    def n_residents(self):
        return len(self.residents)

    def quarantine(self, time, quarantine_days, household_compliance):
        if self.type == "communal":
            return False
        if self.quarantine_starting_date:
            if (
                self.quarantine_starting_date
                < time
                < self.quarantine_starting_date + quarantine_days
            ):
                return random() < household_compliance
        return False

    @property
    def super_area(self):
        try:
            return self.area.super_area
        except AttributeError:
            return None

    def clear(self):
        super().clear()
        self.being_visited = False
        self.receiving_care = False

    def get_interactive_group(self, people_from_abroad=None):
        return InteractiveHousehold(self, people_from_abroad=people_from_abroad)

    def get_leisure_subgroup(self, person, subgroup_type, to_send_abroad):
        self.being_visited = True
        self.make_household_residents_stay_home(to_send_abroad=to_send_abroad)
        return self[self._get_leisure_subgroup_for_person(person=person)]


class Households(Supergroup):
    """
    Contains all households for the given area, and information about them.
    """

    venue_class = Household

    def __init__(self, households: List[venue_class]):
        super().__init__(members=households)


class InteractiveHousehold(InteractiveGroup):
    def get_processed_beta(self, betas, beta_reductions):
        """
        In the case of households, we need to apply the beta reduction of household visits
        if the household has a visit, otherwise we apply the beta reduction for a normal
        household.
        """
        if self.group.receiving_care:
            # important than this goes first than being visited
            beta = betas["care_visits"]
            beta_reduction = beta_reductions.get("care_visits", 1.0)
        elif self.group.being_visited:
            beta = betas["household_visits"]
            beta_reduction = beta_reductions.get("household_visits", 1.0)
        else:
            beta = betas["household"]
            beta_reduction = beta_reductions.get(self.spec, 1.0)
        regional_compliance = self.super_area.region.regional_compliance
        return beta * (1 + regional_compliance * (beta_reduction - 1))


import logging
import numba as nb
import math
from enum import IntEnum
from copy import deepcopy
from june import paths
from typing import List, Tuple, Dict, Optional

import numpy as np
import pandas as pd
from sklearn.neighbors import BallTree

from june.geography import Geography, Areas, Area
from june.groups import Group, Subgroup, Supergroup
from june.groups.group.interactive import InteractiveGroup


default_data_filename = paths.data_path / "input/schools/england_schools.csv"
default_areas_map_path = paths.data_path / "input/geography/area_super_area_region.csv"
default_config_filename = paths.configs_path / "defaults/groups/schools.yaml"

logger = logging.getLogger("schools")


class SchoolError(BaseException):
    pass


class SchoolClass(Subgroup):
    def __init__(self, group, subgroup_type: int):
        super().__init__(group, subgroup_type)
        self.quarantine_starting_date = -np.inf


class School(Group):

    __slots__ = (
        "id",
        "coordinates",
        "n_pupils_max",
        "n_teachers_max",
        "age_min",
        "age_max",
        "age_structure",
        "sector",
        "years",
    )

    # class SubgroupType(IntEnum):
    #     teachers = 0
    #     students = 1

    def __init__(
        self,
        coordinates: Tuple[float, float] = None,
        n_pupils_max: int = None,
        age_min: int = 0,
        age_max: int = 18,
        sector: str = None,
        area: Area = None,
        n_classrooms: Optional[int] = None,
        years: Optional[int] = None,
    ):
        """
        Create a School given its description.

        Parameters
        ----------
        coordinates:
            latitude and longitude
        n_pupils_max:
            maximum number of pupils that can attend the school
        age_min:
            minimum age of the pupils
        age_max:
            maximum age of the pupils
        sector:
            whether it is a "primary", "secondary" or both "primary_secondary"
        area:
            area the school belongs to
        n_classrooms:
            number of classrooms in the school
        years:
            age group year per classroom

        number of SubGroups N = age_max-age_min year +1 (student years) + 1 (teachers):
        0 - teachers
        1 - year of lowest age (age_min)
        ...
        n - year of highest age (age_max)
        """
        super().__init__()
        self.subgroups = []
        # for i, _ in enumerate(range(age_min, age_max + 2)):
        if n_classrooms is None:
            n_classrooms = age_max - age_min

        self.subgroups = [SchoolClass(self, i) for i in range(n_classrooms + 2)]

        self.n_classrooms = n_classrooms
        self.coordinates = coordinates
        self.area = area
        self.n_pupils_max = n_pupils_max
        self.n_teachers_max = None
        self.age_min = age_min
        self.age_max = age_max
        self.sector = sector
        if years is None:
            self.years = tuple(range(age_min, age_max + 1))
        else:
            self.years = tuple(years)

    def get_interactive_group(self, people_from_abroad=None):
        return InteractiveSchool(self, people_from_abroad=people_from_abroad)

    def add(self, person):
        if person.age <= self.age_max:
            subgroup = self.subgroups[1 + person.age - self.age_min]
            subgroup.append(person)
            person.subgroups.primary_activity = subgroup
        else:  # teacher
            subgroup = self.subgroups[0]
            subgroup.append(person)
            person.subgroups.primary_activity = subgroup

    def limit_classroom_sizes(self, max_classroom_size: int):
        """
        Make all subgroups smaller than ```max_classroom_size```

        Parameters
        ----------
        max_classroom_size:
           maximum number of students per classroom (subgroup)
        """
        age_subgroups = self.subgroups.copy()
        year_age_group = deepcopy(self.years)
        self.subgroups = [age_subgroups[0]]  # keep teachers
        self.years = []
        counter = 1
        for idx, subgroup in enumerate(age_subgroups[1:]):
            if len(subgroup.people) > max_classroom_size:
                n_classrooms = math.ceil(len(subgroup.people) / max_classroom_size)
                self.years += [year_age_group[idx]] * n_classrooms
                pupils_in_classroom = np.array_split(subgroup.people, n_classrooms)
                for i in range(n_classrooms):
                    classroom = SchoolClass(self, counter)
                    for pupil in pupils_in_classroom[i]:
                        classroom.append(pupil)
                        pupil.subgroups.primary_activity = classroom
                    self.subgroups.append(classroom)
                    counter += 1
            else:
                subgroup.subgroup_type = counter
                self.subgroups.append(subgroup)
                counter += 1
                self.years.append(year_age_group[idx])
        self.years = tuple(self.years)
        self.n_classrooms = len(self.subgroups) - 1

    @property
    def is_full(self):
        if self.n_pupils >= self.n_pupils_max:
            return True
        return False

    @property
    def n_pupils(self):
        return len(self.students)

    @property
    def n_teachers(self):
        return len(self.teachers)

    @property
    def teachers(self):
        return self.subgroups[self.SubgroupType.teachers]

    @property
    def students(self):
        ret = []
        for subgroup in self.subgroups[1:]:
            ret += subgroup.people
        return ret

    @property
    def super_area(self):
        if self.area is None:
            return None
        return self.area.super_area


class Schools(Supergroup):
    venue_class = School

    def __init__(
        self,
        schools: List["venue_class"],
        school_trees: Optional[Dict[int, BallTree]] = None,
        agegroup_to_global_indices: dict = None,
    ):
        """
        Create a group of Schools, and provide functionality to access closest school

        Parameters
        ----------
        area_names
            list of areas for which to build schools
        schools:
            list of school instances
        school_tree:
            BallTree built on all schools coordinates
        agegroup_to_global_indices:
            dictionary to map the
        """
        super().__init__(members=schools)
        self.school_trees = school_trees
        self.school_agegroup_to_global_indices = agegroup_to_global_indices

    @classmethod
    def for_geography(
        cls,
        geography: Geography,
        data_file: str = default_data_filename,
        config_file: str = default_config_filename,
    ) -> "Schools":
        """
        Parameters
        ----------
        geography
            an instance of the geography class
        """
        return cls.for_areas(geography.areas, data_file, config_file)

    @classmethod
    def for_areas(
        cls,
        areas: Areas,
        data_file: str = default_data_filename,
        config_file: str = default_config_filename,
    ) -> "Schools":
        """
        Parameters
        ----------
        area_names
            list of areas for which to create populations
        data_path
            The path to the data directory
        config
        """
        return cls.from_file(areas, data_file, config_file)

    @classmethod
    def from_file(
        cls,
        areas: Areas,
        data_file: str = default_data_filename,
        config_file: str = default_config_filename,
    ) -> "Schools":
        """
        Initialize Schools from path to data frame, and path to config file

        Parameters
        ----------
        filename:
            path to school dataframe
        config_filename:
            path to school config dictionary

        Returns
        -------
        Schools instance
        """
        school_df = pd.read_csv(data_file, index_col=0)
        area_names = [area.name for area in areas]
        if area_names is not None:
            # filter out schools that are in the area of interest
            school_df = school_df[school_df["oa"].isin(area_names)]
        school_df.reset_index(drop=True, inplace=True)
        logger.info(f"There are {len(school_df)} schools in this geography.")
        return cls.build_schools_for_areas(areas, school_df)  # , **config,)

    @classmethod
    def build_schools_for_areas(
        cls,
        areas: Areas,
        school_df: pd.DataFrame,
        age_range: Tuple[int, int] = (0, 19),
        employee_per_clients: Dict[str, int] = None,
    ) -> "Schools":
        """
        Parameters
        ----------
        area
        Returns
        -------
            An infrastructure of schools
        """
        employee_per_clients = employee_per_clients or {"primary": 30, "secondary": 30}
        # build schools
        schools = []
        for school_name, row in school_df.iterrows():
            n_pupils_max = row["NOR"]
            school_type = row["sector"]
            if school_type is np.nan:
                school_type = list(employee_per_clients.keys())[0]
            coordinates = np.array(
                row[["latitude", "longitude"]].values, dtype=np.float64
            )
            area = areas.get_closest_area(coordinates)
            school = cls.venue_class(
                coordinates=coordinates,
                n_pupils_max=n_pupils_max,
                age_min=int(row["age_min"]),
                age_max=int(row["age_max"]),
                sector=school_type,
                area=area,
            )
            schools.append(school)
            area.schools.append(school)

        # link schools
        school_trees, agegroup_to_global_indices = Schools.init_trees(
            school_df, age_range
        )
        return Schools(
            schools,
            school_trees=school_trees,
            agegroup_to_global_indices=agegroup_to_global_indices,
        )

    @staticmethod
    def init_trees(school_df: pd.DataFrame, age_range: Tuple[int, int]) -> "Schools":
        """
        Create trees to easily find the closest school that
        accepts a pupil given their age

        Parameters
        ----------
        school_df:
            dataframe with school characteristics data
        """
        school_trees = {}
        school_agegroup_to_global_indices = {
            k: [] for k in range(int(age_range[0]), int(age_range[1]) + 1)
        }
        # have a tree per age
        for age in range(int(age_range[0]), int(age_range[1]) + 1):
            _school_df_agegroup = school_df[
                (school_df["age_min"] <= age) & (school_df["age_max"] >= age)
            ]
            schools_coords = _school_df_agegroup[["latitude", "longitude"]].values
            if not schools_coords.size:
                logger.info(f"No school for the age {age} in this world.")
                continue
            school_trees[age] = Schools._create_school_tree(schools_coords)
            school_agegroup_to_global_indices[age] = _school_df_agegroup.index.values
        return school_trees, school_agegroup_to_global_indices

    @staticmethod
    def _create_school_tree(schools_coordinates: np.ndarray) -> BallTree:
        """
        Reads school location and sizes, it initializes a KD tree on a sphere,
        to query the closest schools to a given location.

        Parameters
        ----------
        school_df:
            dataframe with school characteristics data

        Returns
        -------
        Tree to query nearby schools

        """
        school_tree = BallTree(np.deg2rad(schools_coordinates), metric="haversine")
        return school_tree

    def get_closest_schools(
        self, age: int, coordinates: Tuple[float, float], k: int
    ) -> int:
        """
        Get the k-th closest school to a given coordinate, that accepts pupils
        aged age

        Parameters
        ----------
        age:
            age of the pupil
        coordinates:
            latitude and longitude
        k:
            k-th neighbour

        Returns
        -------
        ID of the k-th closest school, within school trees for
        a given age group

        """
        school_tree = self.school_trees[age]
        coordinates_rad = np.deg2rad(coordinates).reshape(1, -1)
        k = min(k, school_tree.data.shape[0])
        distances, neighbours = school_tree.query(
            coordinates_rad, k=k, sort_results=True
        )
        return neighbours[0]

    @property
    def n_teachers(self):
        return sum([school.n_teachers for school in self.members])

    @property
    def n_pupils(self):
        return sum([school.n_pupils for school in self.members])


@nb.jit(nopython=True)
def _translate_school_subgroup(idx, school_years):
    if idx > 0:
        idx = school_years[idx - 1] + 1
    return idx


class InteractiveSchool(InteractiveGroup):
    def __init__(self, group: "Group", people_from_abroad=None):
        super().__init__(group=group, people_from_abroad=people_from_abroad)
        self.school_years = group.years
        self.sector = group.sector

    @classmethod
    def get_raw_contact_matrix(
        cls, contact_matrix, alpha_physical, proportion_physical, characteristic_time
    ):
        """
        Creates a global contact matrix for school, which is by default 20x20, to take into account
        all possible school years combinations. Each school will then use a slice of this matrix.
        We assume that the number of contacts between two different school years goes as
        $ xi**abs(age_difference_between_school_years) * contacts_between_students$
        Teacher contacts are left as specified in the config file.
        """
        xi = 0.3
        age_min = 0
        age_max = 30
        n_subgroups_max = (age_max - age_min) + 2  # adding teachers
        age_differences = np.subtract.outer(
            range(age_min, age_max + 1), range(age_min, age_max + 1)
        )
        processed_contact_matrix = np.zeros((n_subgroups_max, n_subgroups_max))
        processed_contact_matrix[0, 0] = contact_matrix[0][0]
        processed_contact_matrix[0, 1:] = contact_matrix[0][1]
        processed_contact_matrix[1:, 0] = contact_matrix[1][0]
        processed_contact_matrix[1:, 1:] = (
            xi ** abs(age_differences) * contact_matrix[1][1]
        )
        physical_ratios = np.zeros((n_subgroups_max, n_subgroups_max))
        physical_ratios[0, 0] = proportion_physical[0][0]
        physical_ratios[0, 1:] = proportion_physical[0][1]
        physical_ratios[1:, 0] = proportion_physical[1][0]
        physical_ratios[1:, 1:] = proportion_physical[1][1]
        # add physical contacts
        processed_contact_matrix = processed_contact_matrix * (
            1.0 + (alpha_physical - 1.0) * physical_ratios
        )
        processed_contact_matrix *= 24 / characteristic_time
        # If same age but different class room, reduce contacts
        return processed_contact_matrix

    def get_processed_contact_matrix(self, contact_matrix):
        n_school_years = len(self.school_years)
        n_subgroups = n_school_years + 1
        ret = np.zeros((n_subgroups, n_subgroups))
        for i in range(0, n_subgroups):
            for j in range(0, n_subgroups):
                if i == j:
                    if i != 0:
                        ret[i, j] = contact_matrix[1, 1]
                    else:
                        ret[0, 0] = contact_matrix[0, 0]
                else:
                    if i == 0:
                        ret[0, j] = contact_matrix[0][1] / n_school_years
                    elif j == 0:
                        ret[i, 0] = contact_matrix[1][0] / n_school_years
                    else:
                        year_idx_i = _translate_school_subgroup(i, self.school_years)
                        year_idx_j = _translate_school_subgroup(j, self.school_years)
                        if year_idx_i == year_idx_j:
                            ret[i, j] = contact_matrix[year_idx_i, year_idx_j] / 4
                        else:
                            ret[i, j] = contact_matrix[year_idx_i, year_idx_j]
        return ret

    def get_processed_beta(self, betas, beta_reductions):
        """
        Returns the processed contact intensity, by taking into account the policies
        beta reductions and regional compliance. This is a group method as different interactive
        groups may choose to treat this differently.
        """
        if self.sector is None:
            spec = "school"
        elif "secondary" in self.sector:
            spec = "secondary_school"
        else:
            spec = "primary_school"
        if spec in betas:
            beta = betas[spec]
        else:
            beta = betas["school"]
        if spec in beta_reductions:
            beta_reduction = beta_reductions[spec]
        else:
            beta_reduction = beta_reductions.get("school", 1.0)
        try:
            regional_compliance = self.super_area.region.regional_compliance
        except AttributeError:
            regional_compliance = 1
        try:
            lockdown_tier = self.super_area.region.policy["lockdown_tier"]
            if lockdown_tier is None:
                lockdown_tier = 1
        except Exception:
            lockdown_tier = 1
        if int(lockdown_tier) == 4:
            tier_reduction = 0.5
        else:
            tier_reduction = 1.0

        return beta * (1 + regional_compliance * tier_reduction * (beta_reduction - 1))


import numpy as np
import pandas as pd
from random import randint
from typing import List
import logging

from june.groups import Group, Subgroup, Supergroup
from june.geography import Areas, Geography
from june.paths import data_path

age_to_years = {19: 0, 20: 1, 21: 2, 22: 3, 23: 4}

default_universities_filename = data_path / "input/universities/uk_universities.csv"

logger = logging.getLogger("universities")


class University(Group):
    def __init__(
        self, n_students_max=None, n_years=5, ukprn=None, area=None, coordinates=None
    ):
        self.n_students_max = n_students_max
        self.n_years = n_years
        self.ukprn = ukprn
        self.area = area
        self.coordinates = coordinates
        super().__init__()
        self.subgroups = [Subgroup(self, i) for i in range(self.n_years)]

    @property
    def students(self):
        return [person for subgroup in self.subgroups[:] for person in subgroup]

    @property
    def n_students(self):
        return sum([self.subgroups[i].size for i in range(1, len(self.subgroups))])

    @property
    def super_area(self):
        return self.area.super_area

    def add(self, person, subgroup="student"):
        if subgroup == "student":
            if person.age not in age_to_years:
                year = randint(0, len(self.subgroups) - 1)
            else:
                year = age_to_years[person.age]
            self.subgroups[year].append(person)
            person.subgroups.primary_activity = self.subgroups[year]
            if person.work_super_area is not None:
                person.work_super_area.remove_worker(person)
        elif subgroup == "professors":
            # No professors in the modeling of the code!
            self.subgroups[0].append(person)
            person.subgroups.primary_activity = self.subgroups[0]

    @property
    def is_full(self):
        return self.n_students >= self.n_students_max


class Universities(Supergroup):
    venue_class = University

    def __init__(self, universities: List[venue_class]):
        super().__init__(members=universities)

    @classmethod
    def for_areas(
        cls,
        areas: Areas,
        universities_filename: str = default_universities_filename,
        max_distance_to_area=5,
    ):
        """
        Initializes universities from super areas. By looking at the coordinates
        of each university in the filename, we initialize those universities who
        are close to any of the super areas.

        Parameters
        ----------
        areas:
            an instance of Areas
        universities_filename:
            path to the university data
        """
        universities_df = pd.read_csv(universities_filename)
        longitudes = universities_df["longitude"].values
        latitudes = universities_df["latitude"].values
        coordinates = np.array(list(zip(latitudes, longitudes)))
        n_students = universities_df["n_students"].values
        ukprn_values = universities_df["UKPRN"].values
        universities = []
        for coord, n_stud, ukprn in zip(coordinates, n_students, ukprn_values):
            closest_area, distance = areas.get_closest_areas(
                coordinates=coord, return_distance=True, k=1
            )
            distance = distance[0]
            closest_area = closest_area[0]
            if distance > max_distance_to_area:
                continue
            university = cls.venue_class(
                area=closest_area, n_students_max=n_stud, ukprn=ukprn, coordinates=coord
            )
            universities.append(university)
        logger.info(f"There are {len(universities)} universities in this world.")
        return cls(universities)

    @classmethod
    def for_geography(
        cls,
        geography: Geography,
        universities_filename: str = default_universities_filename,
        max_distance_to_area: float = 20,
    ):
        return cls.for_areas(
            geography.areas,
            universities_filename=universities_filename,
            max_distance_to_area=max_distance_to_area,
        )

    # @property
    # def n_professors(self):
    #     return sum([uni.n_professors for uni in self.members])

    @property
    def n_students(self):
        return sum([uni.n_students for uni in self.members])


from .group.group import Group
from .group import AbstractGroup, Subgroup, Supergroup, ExternalSubgroup, ExternalGroup
from .boundary import Boundary
from .care_home import CareHome, CareHomes
from .cemetery import Cemetery, Cemeteries
from .company import Company, Companies, InteractiveCompany
from .hospital import (
    Hospital,
    Hospitals,
    MedicalFacility,
    MedicalFacilities,
    ExternalHospital,
)
from .household import Household, Households, InteractiveHousehold
from .school import School, Schools, InteractiveSchool
from .university import University, Universities
from .leisure import Pub, Pubs, Grocery, Groceries, Cinema, Cinemas, Leisure, Gym, Gyms


from abc import abstractmethod, ABC


class AbstractGroup(ABC):
    """
    Represents properties common to groups and subgroups.

    Both groups and subgroups comprise people in known states of health.
    """

    @property
    @abstractmethod
    def susceptible(self):
        pass

    @property
    @abstractmethod
    def infected(self):
        pass

    @property
    @abstractmethod
    def recovered(self):
        pass

    @property
    @abstractmethod
    def in_hospital(self):
        pass

    @property
    @abstractmethod
    def dead(self):
        pass

    @property
    def size(self):
        return len(self.people)

    @property
    def size_susceptible(self):
        return len(self.susceptible)

    @property
    def size_infected(self):
        return len(self.infected)

    @property
    def size_recovered(self):
        return len(self.recovered)


from typing import List, Tuple
from june.demography.person import Person


class ExternalGroup:
    external = True
    __slots__ = "spec", "id", "domain_id"

    def __init__(self, id, spec, domain_id):
        self.spec = spec
        self.id = id
        self.domain_id = domain_id

    def clear(self):
        pass

    def get_leisure_subgroup(self, person, subgroup_type, to_send_abroad):
        return ExternalSubgroup(group=self, subgroup_type=subgroup_type)


class ExternalSubgroup:
    external = True
    __slots__ = ("subgroup_type", "group")
    """
    This is a place holder group for groups that live in other domains.
    """

    def __init__(self, group, subgroup_type):
        self.group = group
        self.subgroup_type = subgroup_type

    @property
    def group_id(self):
        return self.group.id

    @property
    def domain_id(self):
        return self.group.domain_id

    def clear(self):
        pass

    @property
    def spec(self):
        return self.group.spec


import logging
import re
import numpy as np
from collections import defaultdict
from enum import IntEnum
from itertools import count
from typing import List, Tuple

from june.demography.person import Person
from .interactive import InteractiveGroup
from . import AbstractGroup
from . import Subgroup

from june.groups.group.make_subgroups import SubgroupParams

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.geography.geography import Region

logger = logging.getLogger(__name__)


class Group(AbstractGroup):
    """
    A group of people enjoying social interactions.  It contains three lists,
    all people in the group, the healthy ones and the infected ones (we may
    have to add the immune ones as well).

    This is very basic and we will have to specify derived classes with
    additional information - like household, work, commute - where some,
    like household groups are stable and others, like commute groups, are
    randomly assorted on a step-by-step base.

    The logic is that the group will enjoy one Interaction per time step,
    where the infection spreads, with a probablity driven by transmission
    probabilities and inteaction intensity, plus, possilby, individual
    susceptibility to become infected.

    TODO: we will have to decide in how far specific groups define behavioral
    patterns, which may be time-dependent.  So, far I have made a first pass at
    a list of group specifiers - we could promote it to a dicitonary with
    default intensities (maybe mean+width with a pre-described range?).
    """

    external = False
    subgroup_params = SubgroupParams.from_file()

    # @property
    # def SubgroupType(self):
    #     if self.get_spec() in self.subgroup_params.specs:
    #         return IntEnum("SubgroupType", self.subgroup_labels, start=0)
    #     else:
    #         self.subgroup_params.params = {
    #             self.get_spec(): {
    #                 'contacts': [[0]],
    #                 'proportion_physical': [[0]],
    #                 'characteristic_time': 0,
    #                 'type': 'Age',
    #                 'bins': [0,100]
    #             }
    #         }
    #         self.subgroup_params.specs = self.subgroup_params.params.keys()
    #         return IntEnum("SubgroupType", ["default"], start=0)

    __slots__ = ("id", "subgroups", "spec")

    __id_generators = defaultdict(count)

    @classmethod
    def _next_id(cls) -> int:
        """
        Iterate an id for this class. Each group class has its own id iterator
        starting at 0
        """
        return next(cls.__id_generators[cls])

    def __init__(self):
        """
        A group of people such as in a hospital or a school.

        If a spec attribute is not defined in the child class then it is generated
        by converting the class name into snakecase.
        """
        self.id = self._next_id()
        self.spec = self.get_spec()
        self.SubgroupType = IntEnum(
            "SubgroupType", self.subgroup_params.subgroup_labels(self.spec), start=0
        )
        # noinspection PyTypeChecker
        self.subgroups = [Subgroup(self, i) for i in range(len(self.SubgroupType))]

    @property
    def name(self) -> str:
        """
        The name is computed on the fly to reduce memory footprint. It combines
        the name fo the class with the id of the instance.
        """
        return f"{self.__class__.__name__}_{self.id:05d}"

    @property
    def region(self) -> "Region":
        try:
            return self.super_area.region
        except Exception:
            return None

    def get_spec(self) -> str:
        """
        Returns the speciailization of the group.
        """
        return re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()

    def remove_person(self, person: Person):
        """
        Remove a person from this group by removing them
        from the subgroup to which they belong

        Parameters
        ----------
        person
            A person
        """
        for grouping in self.subgroups:
            if person in grouping:
                grouping.remove(person)

    def __getitem__(self, item) -> "Subgroup":
        """
        A subgroup with a given index
        """
        return self.subgroups[item]

    def add(
        self, person: Person, activity: str, subgroup_type: None  # , dynamic=False
    ):
        """
        Add a person to a given subgroup. For example, in a school
        a student is added to the subgroup matching their age.

        Parameters
        ----------
        person
            A person
        group_type

        """
        # if not dynamic:
        if subgroup_type is None:
            subgroup_type = self.get_leisure_subgroup(person)

        self[subgroup_type].append(person)
        if activity is not None:
            setattr(person.subgroups, activity, self[subgroup_type])

    @property
    def people(self) -> Tuple[Person]:
        """
        All the people in this group
        """
        return tuple(
            person for subgroup in self.subgroups for person in subgroup.people
        )

    @property
    def contains_people(self) -> bool:
        """
        Does this group contain at least one person?
        """

        for grouping in self.subgroups:
            if grouping.contains_people:
                return True

        return False

    def _collate_from_subgroups(self, attribute: str) -> List[Person]:
        """
        Return a set of all of the people in the subgroups with a particular health status

        Parameters
        ----------
        attribute
            The name of the attribute in the subgroup, e.g. "in_hospital"

        Returns
        -------
        The union of all the sets with the given attribute name in all of the sub groups.
        """
        return [
            person
            for subgroup in self.subgroups
            for person in subgroup.people
            if getattr(person, attribute)
        ]

    @property
    def susceptible(self):
        return self._collate_from_subgroups("susceptible")

    @property
    def infected(self):
        return self._collate_from_subgroups("infected")

    @property
    def recovered(self):
        return self._collate_from_subgroups("recovered")

    @property
    def in_hospital(self):
        return self._collate_from_subgroups("in_hospital")

    @property
    def dead(self):
        return self._collate_from_subgroups("dead")

    @property
    def must_timestep(self):
        return self.size > 1 and self.size_infected > 0 and self.size_susceptible > 0

    @property
    def size_infected(self):
        return np.sum([subgroup.size_infected for subgroup in self.subgroups])

    @property
    def size_recovered(self):
        return np.sum([subgroup.size_recovered for subgroup in self.subgroups])

    @property
    def size_susceptible(self):
        return np.sum([subgroup.size_susceptible for subgroup in self.subgroups])

    def clear(self):
        for subgroup in self.subgroups:
            subgroup.clear()

    def get_interactive_group(self, people_from_abroad=None):
        return InteractiveGroup(self, people_from_abroad=people_from_abroad)

    def get_leisure_subgroup(self, person, subgroup_type=None, to_send_abroad=None):
        if self.subgroup_type == "Age":
            min_age = self.subgroup_bins[0]
            max_age = self.subgroup_bins[-1] - 1

            if person.age >= min_age and person.age <= max_age:
                subgroup_idx = (
                    np.searchsorted(self.subgroup_bins, person.age, side="right") - 1
                )
                return self.subgroups[subgroup_idx]
            else:
                return
        elif self.subgroup_type == "Discrete":
            if len(self.subgroups) == 1:
                return self.subgroups[0]
            else:
                return

    def get_index_subgroup(self, person, subgroup_type=None, to_send_abroad=None):
        if self.subgroup_type == "Age":
            min_age = self.subgroup_bins[0]
            max_age = self.subgroup_bins[-1] - 1

            if person.age >= min_age and person.age <= max_age:
                subgroup_idx = (
                    np.searchsorted(self.subgroup_bins, person.age, side="right") - 1
                )
                return subgroup_idx
            else:
                return
        elif self.subgroup_type == "Discrete":
            if len(self.subgroups) == 1:
                return 0
            else:
                return

    @property
    def subgroup_type(self):
        return self.subgroup_params.subgroup_type(self.get_spec())

    @property
    def subgroup_labels(self):
        return self.subgroup_params.subgroup_labels(self.get_spec())

    @property
    def subgroup_bins(self):
        return self.subgroup_params.subgroup_bins(self.get_spec())

    @property
    def kids(self):
        return [
            person
            for subgroup in self.subgroups
            for person in subgroup.people
            if person.age < self.subgroup_params.AgeYoungAdult
        ]

    # @property
    # def young_adults(self):
    #     return [
    #         person
    #         for subgroup in self.subgroups
    #         for person in subgroup.people
    #         if person.age >= self.subgroup_params.AgeYoungAdult and person.age < self.subgroup_params.AgeAdult
    #     ]

    @property
    def adults(self):
        return [
            person
            for subgroup in self.subgroups
            for person in subgroup.people
            if person.age >= self.subgroup_params.AgeAdult
        ]

    # @property
    # def old_adults(self):
    #     return [
    #         person
    #         for subgroup in self.subgroups
    #         for person in subgroup.people
    #         if person.age >= self.subgroup_params.AgeOldAdult
    #     ]

    @classmethod
    def get_leisure_subgroup_type(cls, person):
        """
        A person wants to come and visit this household. We need to assign the person
        to the relevant age subgroup, and make sure the residents welcome him and
        don't go do any other leisure activities.
        """
        if person.age < 18:
            return cls.SubgroupType.kids
        elif person.age <= 35:
            return cls.SubgroupType.young_adults
        elif person.age < 65:
            return cls.SubgroupType.adults
        else:
            return cls.SubgroupType.old_adults


from collections import defaultdict
import numba as nb

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.groups.group.group import Group


@nb.jit(nopython=True)
def _get_processed_contact_matrix(contact_matrix, alpha_physical, proportion_physical):
    """
    Computes the contact matrix used in the interaction,
    which boosts the physical contacts by a factor.

    Parameters
    ----------
    - contact_matrix : contact matrix
    - alpha_physical : relative weight of physical contacts respect to the normal ones.
    (1 = same as normal).
    - proportion_physical : proportion of physical contacts.
    """
    return contact_matrix * (1.0 + (alpha_physical - 1.0) * proportion_physical)


class InteractiveGroup:
    """
    Extracts the necessary information about a group to perform an interaction time
    step over it. This step is necessary, since all the information is stored in numpy
    arrays that allow for efficient computation.

    Parameters
    ----------
    - group : group that we want to prepare for interaction.
    """

    def __init__(self, group: "Group", people_from_abroad=None):
        """
        This function is very long to avoid function calls for performance reasons.
        InteractiveGroups are created millions of times. Given a group, we need to extract:
        - ids of the people that can infect (infector).
        - ids of the people that can be infected (susceptible).
        - probabilities of transmission of the infectors.
        - susceptibilities of the susceptible.
        - indices of the subgroups that contain infectors.
        - sizes of the subgroups that contain infectors.
        - indices of the subgroups that contain susceptible.
        - spec of the group
        - super area of the group (for geo attributes like local regional compliances)
        """
        people_from_abroad = people_from_abroad or {}
        self.group = group
        self.infectors_per_infection_per_subgroup = defaultdict(
            lambda: defaultdict(lambda: defaultdict(list))
        )  # maps virus variant -> subgroup -> infectors -> {infector ids, transmission probs}
        self.susceptibles_per_subgroup = defaultdict(
            dict
        )  # maps subgroup -> susceptible id -> {variant -> susceptibility}
        self.subgroup_sizes = {}
        group_size = 0

        for subgroup_index, subgroup in enumerate(group.subgroups):
            subgroup_size = len(subgroup.people)
            if subgroup.subgroup_type in people_from_abroad:
                people_abroad_data = people_from_abroad[subgroup.subgroup_type]
                people_abroad_ids = people_abroad_data.keys()
                subgroup_size += len(people_abroad_ids)
            else:
                people_abroad_data = None
                people_abroad_ids = []
            if subgroup_size == 0:
                continue
            self.subgroup_sizes[subgroup_index] = subgroup_size
            group_size += subgroup_size

            # Get susceptible people
            # local
            for person in subgroup:
                if not person.infected:
                    self.susceptibles_per_subgroup[subgroup_index][
                        person.id
                    ] = person.immunity.susceptibility_dict
            # from abroad
            for id in people_abroad_ids:
                if people_abroad_data[id]["susc"]:
                    dd = {
                        key: value
                        for key, value in zip(
                            people_abroad_data[id]["immunity_inf_ids"],
                            people_abroad_data[id]["immunity_suscs"],
                        )
                    }
                    self.susceptibles_per_subgroup[subgroup_index][id] = dd

            # Get infectors
            for person in subgroup:
                if person.infection is not None:
                    infection_id = person.infection.infection_id()
                    self.infectors_per_infection_per_subgroup[infection_id][
                        subgroup_index
                    ]["ids"].append(person.id)
                    self.infectors_per_infection_per_subgroup[infection_id][
                        subgroup_index
                    ]["trans_probs"].append(person.infection.transmission.probability)
            for id in people_abroad_ids:
                if people_abroad_data[id]["inf_id"] != 0:
                    infection_id = people_abroad_data[id]["inf_id"]
                    self.infectors_per_infection_per_subgroup[infection_id][
                        subgroup_index
                    ]["ids"].append(id)
                    self.infectors_per_infection_per_subgroup[infection_id][
                        subgroup_index
                    ]["trans_probs"].append(people_abroad_data[id]["inf_prob"])
        self.must_timestep = self.has_susceptible and self.has_infectors
        self.size = group_size

    @classmethod
    def get_raw_contact_matrix(
        cls, contact_matrix, alpha_physical, proportion_physical, characteristic_time
    ):
        """
        Returns the processed contact matrix, by default it returns the input,
        but children of this class will interact differently.
        """
        processed_contact_matrix = contact_matrix * (
            1.0 + (alpha_physical - 1.0) * proportion_physical
        )
        processed_contact_matrix *= 24 / characteristic_time
        return processed_contact_matrix

    def get_processed_beta(self, betas, beta_reductions):
        """
        Returns the processed contact intensity, by taking into account the policies
        beta reductions and regional compliance. This is a group method as different interactive
        groups may choose to treat this differently.
        """
        beta = betas[self.spec]
        beta_reduction = beta_reductions.get(self.spec, 1.0)
        try:
            regional_compliance = self.super_area.region.regional_compliance
        except AttributeError:
            regional_compliance = 1
        try:
            lockdown_tier = self.super_area.region.policy["lockdown_tier"]
            if lockdown_tier is None:
                lockdown_tier = 1
        except Exception:
            lockdown_tier = 1
        if int(lockdown_tier) == 4:
            tier_reduction = 0.5
        else:
            tier_reduction = 1.0

        return beta * (1 + regional_compliance * tier_reduction * (beta_reduction - 1))

    def get_processed_contact_matrix(self, contact_matrix):
        return contact_matrix

    @property
    def spec(self):
        return self.group.spec

    @property
    def super_area(self):
        return self.group.super_area

    @property
    def regional_compliance(self):
        return self.group.super_area.region.regional_compliance

    @property
    def has_susceptible(self):
        return bool(self.susceptibles_per_subgroup)

    @property
    def has_infectors(self):
        return bool(self.infectors_per_infection_per_subgroup)


import itertools
import string
import yaml
from june import paths
import numpy as np
import logging

default_config_filename = paths.configs_path / "defaults/interaction/interaction.yaml"

logger = logging.getLogger("subgroup maker")


def get_defaults(spec):
    if spec in [
        "pub",
        "grocery",
        "cinema",
        "city_transport",
        "inter_city_transport",
        "gym",
    ]:
        return [0, 100], "Age"

    elif spec in ["care_home"]:
        return ["workers", "residents", "visitors"], "Discrete"

    elif spec in ["university"]:
        return ["1", "2", "3", "4", "5"], "Discrete"
    elif spec in ["school"]:
        return ["teachers", "students"], "Discrete"
    elif spec in ["household"]:
        return ["kids", "young_adults", "adults", "old_adults"], "Discrete"
    elif spec in ["company"]:
        return ["workers"], "Discrete"

    # Cox defaults
    elif spec in [
        "communal",
        "distribution_center",
        "e_voucher",
        "female_communal",
        "isolation_unit",
        "n_f_distribution_center",
        "pump_latrine",
        "religious",
    ]:
        return [0, 18, 60], "Age"
    elif spec in ["play_group"]:
        return [3, 7, 12, 18], "Age"
    elif spec in ["learning_center"]:
        return ["students", "teachers"], "Discrete"
    elif spec in ["hospital"]:
        return ["workers", "patients", "icu_patients"], "Discrete"
    elif spec in ["shelter"]:
        return ["inter", "intra"], "Discrete"
    elif spec in ["informal_work"]:
        return [0, 100], "Age"

    else:
        return ["defualt"], "Discrete"


class SubgroupParams:
    """
    Class to read and collect Interaction matrix information. Allows for reading of subgroups from generic bins

    Parameters
    ----------
        bins_groups:
            list of bin edges or categories
        bins_type:
            str, "Age" for bin ages, or "Discrete" for categorical bins

    Returns
    -------
        SubgroupParams class
    """

    AgeYoungAdult = 18
    AgeAdult = 18
    AgeOldAdult = 65

    PossibleLocs = [
        "pub",
        "grocery",
        "cinema",
        "city_transport",
        "inter_city_transport",
        "gym",
        "care_home",
        "university",
        "school",
        "household",
        "company",
        "communal",
        "distribution_center",
        "e_voucher",
        "female_communal",
        "isolation_unit",
        "n_f_distribution_center",
        "pump_latrine",
        "religious",
        "play_group",
        "learning_center",
        "hospital",
        "shelter",
        "informal_work",
    ]

    def __init__(self, params=None) -> None:

        if params is None:
            self.params = params
            self.specs = None
        else:
            self.params = params
            self.specs = params.keys()

    def subgroup_bins(self, spec):
        return self.params[spec]["bins"]

    def subgroup_type(self, spec):
        return self.params[spec]["type"]

    def subgroup_labels(self, spec):
        if spec not in self.params.keys():

            if spec not in self.PossibleLocs:
                print(f"{spec} not defined in interaction yaml or defualt options")
                return list(["default"])
            else:
                Bins, Type = get_defaults(spec)
                logger.info(
                    f"{spec} interaction bins not specified. Using default values {Bins}"
                )
                self.params[spec] = {"bins": Bins, "type": Type}

        if (
            "bins" not in self.params[spec].keys()
            or "type" not in self.params[spec].keys()
        ):
            Bins, Type = get_defaults(spec)
            logger.info(
                f"{spec} interaction bins not specified. Using default values {Bins}"
            )
            self.params[spec]["bins"] = Bins
            self.params[spec]["type"] = Type
        elif spec in [
            "learning_center",
            "hospital",
            "shelter",
            "university",
            "school",
            "care_home",
            "household",
            "company",
        ]:
            Bins, Type = get_defaults(spec)
            if self.params[spec]["bins"] != Bins:
                logger.info(f"{spec} interaction bins need default values for methods.")
                self.params[spec]["bins"] = Bins
                self.params[spec]["type"] = Type

        if self.subgroup_type(spec) == "Age":  # Make dummy names for N age bins
            Nbins = len(self.params[spec]["bins"]) - 1
            return list(itertools.islice(self.excel_cols(), Nbins))
        elif self.subgroup_type(spec) == "Discrete":
            return list(self.params[spec]["bins"])  # Already have our names!

    # def kids_indexes(self, spec):
    #     if self.subgroup_type(spec) == "Age": #Make dummy names for N age bins
    #         index = sum(np.array(self.params[spec]["bins"]) < self.AgeAdult)
    #         return np.arange(0, index, 1)
    #     else:
    #         return np.array([]) #Empty list of bin indexes

    # def adults_indexes(self, spec):
    #     if self.subgroup_type(spec) == "Age": #Make dummy names for N age bins
    #         index = sum(np.array(self.params[spec]["bins"]) < self.AgeAdult)
    #         return np.arange(index, len(self.params[spec]["bins"])-1, 1)
    #     else:
    #         return np.array([]) #Empty list of bin indexes

    def excel_cols(self):
        """
        Generate generic string labels in form ["A", "B", "C", ... , "Z", "AA", "AB", .... ]

        Parameters
        ----------
            None

        Returns
        -------
            List of unique strings
        """
        n = 1
        while True:
            yield from (
                "".join(group)
                for group in itertools.product(string.ascii_uppercase, repeat=n)
            )
            n += 1

    @classmethod
    def from_file(cls, config_filename=default_config_filename) -> "SubgroupParams":
        """
        Read from interaction yaml and extract information on bins and bin types. Returning instance of SubgroupParams

        Parameters
        ----------
            config_filename:
                yaml location

        Returns
        -------
            SubgroupParams class instance
        """
        if config_filename is None:
            config_filename = default_config_filename
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        return SubgroupParams(params=config["contact_matrices"])


from june.demography.person import Person
from .abstract import AbstractGroup
from typing import List


class Subgroup(AbstractGroup):
    external = False
    __slots__ = ("group", "subgroup_type", "people")

    def __init__(self, group, subgroup_type: int):
        """
        A group within a group. For example, children in a household.
        """
        self.group = group
        self.subgroup_type = subgroup_type
        self.people = []

    def _collate(self, attribute: str) -> List[Person]:
        return [person for person in self.people if getattr(person, attribute)]

    @property
    def spec(self):
        return self.group.spec

    @property
    def infected(self):
        return self._collate("infected")

    @property
    def susceptible(self):
        return self._collate("susceptible")

    @property
    def recovered(self):
        return self._collate("recovered")

    @property
    def dead(self):
        return self._collate("dead")

    @property
    def in_hospital(self):
        return self._collate("in_hospital")

    def __contains__(self, item):
        return item in self.people

    def __iter__(self):
        return iter(self.people)

    def __len__(self):
        return len(self.people)

    def clear(self):
        self.people = []

    @property
    def contains_people(self) -> bool:
        """
        Whether or not the group contains people.
        """
        return len(self.people) > 0

    def append(self, person: Person):
        """
        Add a person to this group
        """
        self.people.append(person)
        person.busy = True

    def remove(self, person: Person):
        self.people.remove(person)
        person.busy = False

    def __getitem__(self, item):
        return list(self.people)[item]


import re
from collections import OrderedDict

from june.groups.group.make_subgroups import SubgroupParams
import numpy as np


class Supergroup:
    """
    A group containing a collection of groups of the same specification,
    like households, carehomes, etc.
    This class is meant to be used as template to inherit from, and it
    integrates basic functionality like iteration, etc.
    It also includes a method to delete information about people in the
    groups.
    """

    def __init__(self, members):
        self.group_type = self.__class__.__name__
        self.spec = self.get_spec()
        self.members = members
        self.members_by_id = self._make_member_ids_dict(members)

    def _make_member_ids_dict(self, members):
        """
        Makes a dictionary with the ids of the members.
        """
        ret = OrderedDict()
        for member in members:
            ret[member.id] = member
        return ret

    def __iter__(self):
        return iter(self.members)

    def __len__(self):
        return len(self.members)

    def __getitem__(self, item):
        return self.members[item]

    def get_from_id(self, id):
        return self.members_by_id[id]

    def __add__(self, supergroup: "Supergroup"):
        for group in supergroup:
            self.add(group)
        return self

    def clear(self):
        self.members_by_id.clear()

    def add(self, group):
        self.members_by_id[group.id] = group
        self.members.append(group)

    @property
    def member_ids(self):
        return list(self.members_by_id.keys())

    def get_spec(self) -> str:
        """
        Returns the speciailization of the super group.
        """
        return re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()

    @property
    def group_spec(self):
        return self.members[0].spec

    @property
    def group_subgroups_size(self):
        Nsubgroups = len(self.members[0].subgroups)
        if self.spec in ["schools"]:
            Nsubgroups = 2
        subgroup_sizes = np.zeros(Nsubgroups)
        for member in self.members:
            if self.spec not in ["schools"]:
                for sub_i in range(Nsubgroups):
                    subgroup_sizes[sub_i] += member.subgroups[sub_i].size
            elif self.spec in ["schools"]:
                subgroup_sizes[0] += member.n_teachers
                subgroup_sizes[1] += member.n_pupils
        return subgroup_sizes

    @classmethod
    def for_geography(cls):
        raise NotImplementedError(
            "Geography initialization not available for this supergroup."
        )

    @classmethod
    def from_file(cls):
        raise NotImplementedError(
            "From file initialization not available for this supergroup."
        )

    @classmethod
    def get_interaction(self, config_filename=None):
        self.venue_class.subgroup_params = SubgroupParams.from_file(config_filename)


from .abstract import AbstractGroup
from .subgroup import Subgroup
from .supergroup import Supergroup
from .external import ExternalSubgroup, ExternalGroup


from .social_venue import SocialVenue, SocialVenues
from .social_venue_distributor import SocialVenueDistributor
from june.paths import data_path, configs_path

default_cinemas_coordinates_filename = (
    data_path / "input/leisure/cinemas_per_super_area.csv"
)
default_config_filename = configs_path / "defaults/groups/leisure/cinemas.yaml"


class Cinema(SocialVenue):
    max_size = 1000


class Cinemas(SocialVenues):
    venue_class = Cinema
    default_coordinates_filename = default_cinemas_coordinates_filename


class CinemaDistributor(SocialVenueDistributor):
    default_config_filename = default_config_filename


from .social_venue import SocialVenue, SocialVenues
from .social_venue_distributor import SocialVenueDistributor
from june.paths import data_path, configs_path

default_config_filename = configs_path / "defaults/groups/leisure/groceries.yaml"
default_groceries_coordinates_filename = (
    data_path / "input/leisure/groceries_per_super_area.csv"
)


class Grocery(SocialVenue):
    max_size = 200


class Groceries(SocialVenues):
    venue_class = Grocery
    default_coordinates_filename = default_groceries_coordinates_filename


class GroceryDistributor(SocialVenueDistributor):
    default_config_filename = default_config_filename


from .social_venue import SocialVenue, SocialVenues
from .social_venue_distributor import SocialVenueDistributor
from june.paths import data_path, configs_path

default_gym_coordinates_filename = data_path / "input/leisure/gyms_per_super_area.csv"
default_config_filename = configs_path / "defaults/groups/leisure/gyms.yaml"


class Gym(SocialVenue):
    max_size = 300
    pass


class Gyms(SocialVenues):
    venue_class = Gym
    default_coordinates_filename = default_gym_coordinates_filename


class GymDistributor(SocialVenueDistributor):
    default_config_filename = default_config_filename


import numpy as np
import yaml
import logging
from random import random
from typing import Dict
from june.demography import Person
from june.geography import SuperAreas, Areas, Regions, Region
from june.groups.leisure import (
    SocialVenueDistributor,
    PubDistributor,
    GroceryDistributor,
    CinemaDistributor,
    ResidenceVisitsDistributor,
    GymDistributor,
)
from june.utils import random_choice_numba
from june import paths
from june.utils.parse_probabilities import parse_opens


default_config_filename = paths.configs_path / "config_example.yaml"

logger = logging.getLogger("leisure")


def generate_leisure_for_world(list_of_leisure_groups, world, daytypes):
    """
    Generates an instance of the leisure class for the specified geography and leisure groups.

    Parameters
    ----------
    list_of_leisure_groups
        list of names of the lesire groups desired. Ex: ["pubs", "cinemas"]
    """
    leisure_distributors = {}
    if "pubs" in list_of_leisure_groups:
        if not hasattr(world, "pubs") or world.pubs is None or len(world.pubs) == 0:
            logger.warning("No pubs in this world/domain")
        else:
            leisure_distributors["pub"] = PubDistributor.from_config(
                world.pubs, daytypes=daytypes
            )
    if "gyms" in list_of_leisure_groups:
        if not hasattr(world, "gyms") or world.gyms is None or len(world.gyms) == 0:
            logger.warning("No gyms in this world/domain")
        else:
            leisure_distributors["gym"] = GymDistributor.from_config(
                world.gyms, daytypes=daytypes
            )
    if "cinemas" in list_of_leisure_groups:
        if (
            not hasattr(world, "cinemas")
            or world.cinemas is None
            or len(world.cinemas) == 0
        ):
            logger.warning("No cinemas in this world/domain")
        else:
            leisure_distributors["cinema"] = CinemaDistributor.from_config(
                world.cinemas, daytypes=daytypes
            )
    if "groceries" in list_of_leisure_groups:
        if (
            not hasattr(world, "groceries")
            or world.groceries is None
            or len(world.groceries) == 0
        ):
            logger.warning("No groceries in this world/domain")
        else:
            leisure_distributors["grocery"] = GroceryDistributor.from_config(
                world.groceries, daytypes=daytypes
            )
    if (
        "household_visits" in list_of_leisure_groups
        or "care_home_visits" in list_of_leisure_groups
    ):
        if not hasattr(world, "care_homes") or not hasattr(world, "households"):
            raise ValueError(
                "Your world does not have care homes or households for visits."
            )
        leisure_distributors[
            "residence_visits"
        ] = ResidenceVisitsDistributor.from_config(daytypes=daytypes)
    leisure = Leisure(leisure_distributors=leisure_distributors, regions=world.regions)
    return leisure


def generate_leisure_for_config(world, config_filename=default_config_filename):
    """
    Generates an instance of the leisure class for the specified geography and leisure groups.
    Parameters
    ----------
    list_of_leisure_groups
        list of names of the lesire groups desired. Ex: ["pubs", "cinemas"]
    """
    with open(config_filename) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    try:
        list_of_leisure_groups = config["activity_to_super_groups"]["leisure"]
    except Exception:
        list_of_leisure_groups = config["activity_to_groups"]["leisure"]

    if "weekday" in config.keys() and "weekend" in config.keys():
        daytypes = {"weekday": config["weekday"], "weekend": config["weekend"]}
    else:
        daytypes = {
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        }
    leisure_instance = generate_leisure_for_world(
        list_of_leisure_groups, world, daytypes
    )
    return leisure_instance


class Leisure:
    """
    Class to manage all possible activites that happen during leisure time.
    """

    def __init__(
        self,
        leisure_distributors: Dict[str, SocialVenueDistributor],
        regions: Regions = None,
    ):
        """
        Parameters
        ----------
        leisure_distributors
            List of social venue distributors.
        """
        self.probabilities_by_region_sex_age = None
        self.leisure_distributors = leisure_distributors
        self.n_activities = len(self.leisure_distributors)
        self.policy_reductions = {}
        self.regions = regions  # needed for regional compliances

    def distribute_social_venues_to_areas(self, areas: Areas, super_areas: SuperAreas):
        logger.info("Linking households and care homes for visits")
        if "residence_visits" in self.leisure_distributors:
            self.leisure_distributors["residence_visits"].link_households_to_households(
                super_areas
            )
            self.leisure_distributors["residence_visits"].link_households_to_care_homes(
                super_areas
            )
        logger.info("Done")
        logger.info("Distributing social venues to areas")
        for i, area in enumerate(areas):
            if i % 2000 == 0:
                logger.info(f"Distributed in {i} of {len(areas)} areas.")
            for activity, distributor in self.leisure_distributors.items():
                if "visits" in activity:
                    continue
                social_venues = distributor.get_possible_venues_for_area(area)
                if social_venues is not None:
                    area.social_venues[activity] = social_venues
        logger.info(f"Distributed in {len(areas)} of {len(areas)} areas.")

    def generate_leisure_probabilities_for_timestep(
        self, delta_time: float, working_hours: bool, date: str
    ):
        self.probabilities_by_region_sex_age = {}
        if self.regions:
            for region in self.regions:
                self.probabilities_by_region_sex_age[
                    region.name
                ] = self._generate_leisure_probabilities_for_age_and_sex(
                    delta_time=delta_time,
                    working_hours=working_hours,
                    date=date,
                    region=region,
                )
        else:
            self.probabilities_by_region_sex_age = (
                self._generate_leisure_probabilities_for_age_and_sex(
                    delta_time=delta_time,
                    working_hours=working_hours,
                    date=date,
                    region=None,
                )
            )

    def get_subgroup_for_person_and_housemates(
        self, person: Person, to_send_abroad: dict = None
    ):
        """
        Main function of the Leisure class. For every possible activity a person can do,
        we chech the Poisson parameter lambda = probability / day * deltat of that activty
        taking place. We then sum up the Poisson parameters to decide whether a person
        does any activity at all. The relative weight of the Poisson parameters gives then
        the specific activity a person does.
        If a person ends up going to a social venue, we do a second check to see if his/her
        entire household accompanies him/her.
        The social venue subgroups are attached to the involved people, but they are not
        added to the subgroups, since it is possible they change their plans if a policy is in
        place or they have other responsibilities.
        The function returns None if no activity takes place.

        Parameters
        ----------
        person
            an instance of person
        """

        ###########################################
        age_before = person.age
        age = person.age
        # AorC_value = self.AorC(person.age)
        # if age < 18 and AorC_value == "Adult":
        #     age = 18

        # Does this change actual persons name above?
        person.age = age

        if person.residence.group.spec == "care_home":
            person.age = age_before
            return
        prob_age_sex = self._get_activity_probabilities_for_person(person=person)
        if random() < prob_age_sex["does_activity"]:
            activity_idx = random_choice_numba(
                arr=np.arange(0, len(prob_age_sex["activities"])),
                prob=np.array(list(prob_age_sex["activities"].values())),
            )
            activity = list(prob_age_sex["activities"].keys())[activity_idx]
            activity_distributor = self.leisure_distributors[activity]
            subgroup = activity_distributor.get_leisure_subgroup(
                person, to_send_abroad=to_send_abroad
            )
            person.subgroups.leisure = subgroup
            activity_distributor.send_household_with_person_if_necessary(
                person=person, to_send_abroad=to_send_abroad
            )
            person.age = age_before
            return subgroup
        person.age = age_before

    def _generate_leisure_probabilities_for_age_and_sex(
        self, delta_time: float, working_hours: bool, date: str, region: Region
    ):
        ret = {}
        for sex in ["m", "f"]:
            probs = [
                self._get_leisure_probability_for_age_and_sex(
                    age=age,
                    sex=sex,
                    delta_time=delta_time,
                    date=date,
                    working_hours=working_hours,
                    region=region,
                )
                for age in range(0, 100)
            ]
            ret[sex] = probs

        return ret

    def _get_leisure_probability_for_age_and_sex(
        self,
        age: int,
        sex: str,
        delta_time: float,
        date: str,
        working_hours: bool,
        region: Region,
    ):
        """
        Computes the probabilities of going to different leisure activities,
        and dragging the household with the person that does the activity.
        When policies are present, then the regional leisure poisson parameters are
        changed according to the present policy poisson parameter (lambda_2) and the local
        regional compliance like so:
        $ lambda = lambda_1 + regional_compliance * (lambda_2 - lambda_1) $
        where lambda_1 is the original poisson parameter.
        lockdown tier: 1,2,3 - has different implications for leisure:
            1: do nothing
            2: stop household-to-household probability with regional compliance and
               reduce pub probability by 20% - conservative to account for the serving of meals
            3: stop household-to-household probability with regional compliance and
               reduce pub and cinema probability to 0 to simulate closure
        """
        poisson_parameters = []
        drags_household_probabilities = []
        activities = []
        for activity, distributor in self.leisure_distributors.items():
            drags_household_probabilities.append(
                distributor.drags_household_probability
            )

            activity_poisson_parameter = self._get_activity_poisson_parameter(
                activity=activity,
                distributor=distributor,
                age=age,
                sex=sex,
                date=date,
                working_hours=working_hours,
                region=region,
            )
            poisson_parameters.append(activity_poisson_parameter)
            activities.append(activity)
        total_poisson_parameter = sum(poisson_parameters)
        does_activity_probability = 1.0 - np.exp(-delta_time * total_poisson_parameter)
        activities_probabilities = {}
        drags_household_probabilities_dict = {}
        for i in range(len(activities)):
            if poisson_parameters[i] == 0:
                activities_probabilities[activities[i]] = 0
            else:
                activities_probabilities[activities[i]] = (
                    poisson_parameters[i] / total_poisson_parameter
                )
            drags_household_probabilities_dict[
                activities[i]
            ] = drags_household_probabilities[i]
        return {
            "does_activity": does_activity_probability,
            "drags_household": drags_household_probabilities_dict,
            "activities": activities_probabilities,
        }

    def _get_activity_poisson_parameter(
        self,
        activity: str,
        distributor: SocialVenueDistributor,
        age: int,
        sex: str,
        date: str,
        working_hours: bool,
        region: Region,
    ):
        """
        Computes an activity poisson parameter taking into account active policies,
        regional compliances and lockdown tiers.
        """
        day = [
            "Monday",
            "Tuesday",
            "Wednesday",
            "Thursday",
            "Friday",
            "Saturday",
            "Sunday",
        ][date.weekday()]
        if day in distributor.daytypes["weekday"]:
            day_type = "weekday"
        elif day in distributor.daytypes["weekend"]:
            day_type = "weekend"

        # TODO check closures etc!
        open_times = parse_opens(distributor.open)[day_type]
        open = 1
        if open_times[1] - open_times[0] == 0:
            open = 0
        if date.hour < open_times[0] or date.hour >= open_times[1]:
            open = 0

        if activity in self.policy_reductions:
            policy_reduction = self.policy_reductions[activity][day_type][sex][age]
        else:
            policy_reduction = 1

        activity_poisson_parameter = distributor.get_poisson_parameter(
            sex=sex,
            age=age,
            day_type=day_type,
            working_hours=working_hours,
            policy_reduction=policy_reduction,
            region=region,
        )
        return activity_poisson_parameter * open

    def _drags_household_to_activity(self, person, activity):
        """
        Checks whether the person drags the household to the activity.
        """
        try:
            prob = self.probabilities_by_region_sex_age[person.region.name][person.sex][
                person.age
            ]["drags_household"][activity]
        except KeyError:
            prob = self.probabilities_by_region_sex_age[person.sex][person.age][
                "drags_household"
            ][activity]
        except AttributeError:
            if person.sex in self.probabilities_by_region_sex_age:
                prob = self.probabilities_by_region_sex_age[person.sex][person.age][
                    "drags_household"
                ][activity]
            else:
                prob = self.probabilities_by_region_sex_age[
                    list(self.probabilities_by_region_sex_age.keys())[0]
                ][person.sex][person.age]["drags_household"][activity]
        return random() < prob

    # TESTING TODO
    ######################################################################
    def P_IsAdult(self, age):
        tanh_halfpeak_age = 15  # 17.1
        tanh_width = 0.7  # 1

        minageadult = 13
        maxagechild = 17
        if age < minageadult:
            return 0
        elif age > maxagechild:
            return 1
        else:
            return (np.tanh(tanh_width * (age - tanh_halfpeak_age)) + 1) / 2

    def P_IsChild(self, age):
        return 1 - self.P_IsAdult(age)

    def AorC(self, age):
        r = np.random.rand(1)[0]
        if r < self.P_IsAdult(age):
            return "Adult"
        else:
            return "Child"

    ######################################################################

    def _get_activity_probabilities_for_person(self, person: Person):

        try:
            return self.probabilities_by_region_sex_age[person.region.name][person.sex][
                person.age
            ]
        except KeyError:
            return self.probabilities_by_region_sex_age[person.sex][person.age]
        except AttributeError:
            if person.sex in self.probabilities_by_region_sex_age:
                return self.probabilities_by_region_sex_age[person.sex][person.age]
            else:
                return self.probabilities_by_region_sex_age[
                    list(self.probabilities_by_region_sex_age.keys())[0]
                ][person.sex][person.age]


from .social_venue import SocialVenue, SocialVenues
from .social_venue_distributor import SocialVenueDistributor
from june.paths import data_path, configs_path

default_pub_coordinates_filename = data_path / "input/leisure/pubs_per_super_area.csv"
default_config_filename = configs_path / "defaults/groups/leisure/pubs.yaml"


class Pub(SocialVenue):
    max_size = 100
    pass


class Pubs(SocialVenues):
    venue_class = Pub
    default_coordinates_filename = default_pub_coordinates_filename


class PubDistributor(SocialVenueDistributor):
    default_config_filename = default_config_filename


import yaml
from random import shuffle, randint
import numpy as np

from june.groups.leisure import SocialVenueDistributor
from june.paths import configs_path
from june.utils import random_choice_numba

default_config_filename = configs_path / "defaults/groups/leisure/visits.yaml"

default_daytypes = {
    "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
    "weekend": ["Saturday", "Sunday"],
}


class ResidenceVisitsDistributor(SocialVenueDistributor):
    """
    This is a social distributor specific to model visits between residences,
    ie, visits between households or to care homes. The meaning of the parameters
    is the same as for the SVD. Residence visits are not decied on neighbours or distances
    so we ignore some parameters.
    """

    def __init__(
        self,
        residence_type_probabilities,
        times_per_week,
        hours_per_day,
        daytypes=default_daytypes,
        drags_household_probability=0,
    ):
        # it is necessary to make them arrays for performance
        self.residence_type_probabilities = residence_type_probabilities
        self.policy_reductions = {}
        super().__init__(
            social_venues=None,
            times_per_week=times_per_week,
            daytypes=daytypes,
            hours_per_day=hours_per_day,
            drags_household_probability=drags_household_probability,
            neighbours_to_consider=None,
            maximum_distance=None,
            leisure_subgroup_type=None,
        )

    @classmethod
    def from_config(cls, daytypes, config_filename: str = default_config_filename):
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        return cls(daytypes=daytypes, **config)

    def link_households_to_households(self, super_areas):
        """
        Links people between households. Strategy: We pair each household with 0, 1,
        or 2 other households (with equal prob.). The household of the former then
        has a probability of visiting the household of the later
        at every time step.

        Parameters
        ----------
        super_areas
            list of super areas
        """
        for super_area in super_areas:
            households_in_super_area = [
                household for area in super_area.areas for household in area.households
            ]
            for household in households_in_super_area:
                if household.n_residents == 0:
                    continue
                households_to_link_n = randint(2, 4)
                households_to_visit = []
                n_linked = 0
                while n_linked < households_to_link_n:
                    house_idx = randint(0, len(households_in_super_area) - 1)
                    house = households_in_super_area[house_idx]
                    if house.id == household.id or not house.residents:
                        continue
                    households_to_visit.append(house)
                    n_linked += 1
                if households_to_visit:
                    household.residences_to_visit["household"] = tuple(
                        households_to_visit
                    )

    def link_households_to_care_homes(self, super_areas):
        """
        Links households and care homes in the giving super areas. For each care home,
        we find a random house in the super area and link it to it.
        The house needs to be occupied by a family, or a couple.

        Parameters
        ----------
        super_areas
            list of super areas
        """
        for super_area in super_areas:
            households_super_area = []
            for area in super_area.areas:
                households_super_area += [
                    household
                    for household in area.households
                    if household.type in ["families", "ya_parents", "nokids"]
                ]
            shuffle(households_super_area)
            for area in super_area.areas:
                if area.care_home is not None:
                    people_in_care_home = [
                        person for person in area.care_home.residents
                    ]
                    for i, person in enumerate(people_in_care_home):
                        household = households_super_area[i]
                        household.residences_to_visit["care_home"] = (
                            *household.residences_to_visit["care_home"],
                            area.care_home,
                        )

    def get_leisure_group(self, person):
        residence_types = list(person.residence.group.residences_to_visit.keys())
        if not residence_types:
            return
        if len(residence_types) == 0:
            which_type = residence_types[0]
        else:
            if self.policy_reductions:
                probabilities = self.policy_reductions
            else:
                probabilities = self.residence_type_probabilities
            residence_type_probabilities = np.array(
                [probabilities[residence_type] for residence_type in residence_types]
            )
            residence_type_probabilities = (
                residence_type_probabilities / residence_type_probabilities.sum()
            )
            type_sample = random_choice_numba(
                tuple(range(len(residence_type_probabilities))),
                residence_type_probabilities,
            )
            which_type = residence_types[type_sample]
        candidates = person.residence.group.residences_to_visit[which_type]
        n_candidates = len(candidates)
        if n_candidates == 0:
            return
        elif n_candidates == 1:
            group = candidates[0]
        else:
            group = candidates[randint(0, n_candidates - 1)]
        return group

    def get_poisson_parameter(
        self, sex, age, day_type, working_hours, region=None, policy_reduction=None
    ):
        """
        This differs from the super() implementation in that we do not allow
        visits during working hours as most people are away.
        """
        if working_hours:
            return 0
        return super().get_poisson_parameter(
            sex=sex,
            age=age,
            day_type=day_type,
            working_hours=working_hours,
            region=region,
            policy_reduction=policy_reduction,
        )


import numpy as np
import pandas as pd
import logging
from typing import List, Optional
from enum import IntEnum
from sklearn.neighbors import BallTree

from june.groups import Supergroup, Group
from june.geography import Area, Areas, SuperArea, SuperAreas, Geography
from june.mpi_setup import mpi_rank

earth_radius = 6371  # km

logger = logging.getLogger("social_venue")
if mpi_rank > 0:
    logger.propagate = False


class SocialVenueError(BaseException):
    pass


class SocialVenue(Group):
    max_size = np.inf

    # class SubgroupType(IntEnum):
    #     leisure = 0

    def __init__(self, area=None):
        super().__init__()
        self.area = area

    def add(self, person, activity="leisure"):
        self.subgroups[0].append(person)
        setattr(person.subgroups, activity, self.subgroups[0])

    @property
    def super_area(self):
        return self.area.super_area

    @property
    def get_coordinates(self):
        if self.area is None:
            return
        else:
            return self.area.coordinates

    # def get_leisure_subgroup(self, person, subgroup_type, to_send_abroad):
    #     return self[self.SubgroupType.leisure]


class SocialVenues(Supergroup):
    venue_class = SocialVenue

    def __init__(self, social_venues: List[venue_class], make_tree=True):
        super().__init__(members=social_venues)
        logger.info(f"Domain {mpi_rank} has {len(self)} {self.spec}(s)")
        self.ball_tree = None
        if make_tree:
            if not social_venues:
                logger.warning(f"No social venues of type {self.spec} in this domain")
            else:
                self.make_tree()

    @classmethod
    def from_coordinates(
        cls,
        coordinates: List[np.array],
        super_areas: Optional[Areas],
        max_distance_to_area=10,
        **kwargs,
    ):
        if len(coordinates) == 0:
            return cls([], **kwargs)

        if super_areas:
            super_areas, distances = super_areas.get_closest_super_areas(
                coordinates, k=1, return_distance=True
            )
            distances_close = np.where(distances < max_distance_to_area)
            coordinates = coordinates[distances_close]
        social_venues = []

        for i, coord in enumerate(coordinates):
            sv = cls.venue_class()
            if super_areas:
                super_area = super_areas[i]
            else:
                super_area = None
            sv.coordinates = coord
            if super_areas:
                area, dist = Areas(super_area.areas).get_closest_area(
                    coordinates=coord, return_distance=True
                )

                if dist > max_distance_to_area:
                    continue

                sv.area = area
            social_venues.append(sv)
        return cls(social_venues, **kwargs)

    @classmethod
    def for_super_areas(
        cls, super_areas: List[SuperArea], coordinates_filename: str = None
    ):
        if coordinates_filename is None:
            coordinates_filename = cls.default_coordinates_filename
        sv_coordinates = pd.read_csv(coordinates_filename, index_col=0).values
        return cls.from_coordinates(sv_coordinates, super_areas=super_areas)

    @classmethod
    def for_areas(cls, areas: Areas, coordinates_filename: str = None):
        if coordinates_filename is None:
            coordinates_filename = cls.default_coordinates_filename
        super_areas = SuperAreas([area.super_area for area in areas])
        return cls.for_super_areas(super_areas, coordinates_filename)

    @classmethod
    def for_geography(cls, geography: Geography, coordinates_filename: str = None):
        if coordinates_filename is None:
            coordinates_filename = cls.default_coordinates_filename
        return cls.for_super_areas(geography.super_areas, coordinates_filename)

    @classmethod
    def distribute_for_areas(
        cls,
        areas: List[Area],
        venues_per_capita: float = None,
        venues_per_area: int = None,
    ):
        """
        Generates social venues in the given areas.

        Parameters
        ----------
        areas
            list of areas to generate the venues in
        venues_per_capita
            number of venues per person in each area.
        venues_per_area
            number of venues in each area.
        """
        if venues_per_area is not None and venues_per_capita is not None:
            raise SocialVenueError(
                "Please specify only one of venues_per_capita or venues_per_area."
            )
        social_venues = []
        if venues_per_area is not None:
            for area in areas:
                for _ in range(venues_per_area):
                    sv = cls.venue_class()
                    sv.area = area
                    social_venues.append(sv)
        elif venues_per_capita is not None:
            for area in areas:
                area_population = len(area.people)
                for _ in range(int(np.ceil(venues_per_capita * area_population))):
                    sv = cls.venue_class()
                    sv.area = area
                    sv.coordinates = area.coordinates
                    social_venues.append(sv)
        else:
            raise SocialVenueError(
                "Specify one of venues_per_capita or venues_per_area"
            )
        return cls(social_venues)

    @classmethod
    def distribute_for_super_areas(
        cls, super_areas: List[SuperArea], venues_per_super_area=1, venues_per_capita=1
    ):
        """
        Generates social venues in the given super areas.

        Parameters
        ----------
        super_areas
            list of areas to generate the venues in
        venues_per_super_area
            how many venus per super_area to generate
        """
        if venues_per_super_area is not None and venues_per_capita is not None:
            raise SocialVenueError(
                "Please specify only one of venues_per_capita or venues_per_area."
            )
        social_venues = []
        if venues_per_super_area is not None:
            for area in super_areas:
                for _ in range(venues_per_super_area):
                    sv = cls.venue_class()
                    sv.area = area
                    social_venues.append(sv)
        elif venues_per_capita is not None:
            for super_area in super_areas:
                super_area_population = len(super_area.people)
                for _ in range(int(np.ceil(venues_per_capita * super_area_population))):
                    sv = cls.venue_class()
                    area = Areas(super_area.areas).get_closest_area(
                        coordinates=super_area.coordinates
                    )
                    sv.area = area
                    sv.coordinates = area.coordinates
                    social_venues.append(sv)
        else:
            raise SocialVenueError(
                "Specify one of venues_per_capita or venues_per_area"
            )
        return cls(social_venues)

    def make_tree(self):
        self.ball_tree = BallTree(
            np.array([np.deg2rad(sv.coordinates) for sv in self]), metric="haversine"
        )

    def add_to_areas(self, areas: Areas):
        """
        Adds all venues to the closest super area
        """
        for venue in self:
            if not hasattr(venue, "coordinates"):
                raise SocialVenueError(
                    "Can't add to super area if venues don't have coordiantes."
                )
            venue.area = areas.get_closest_areas(venue.coordinates)[0]

    def get_closest_venues(self, coordinates, k=1):
        """
        Queries the ball tree for the closests venues.

        Parameters
        ----------
        coordinates
            coordinates in the format [Latitude, Longitude]
        k
            number of neighbours desired
        """
        if not self.members:
            return
        if self.ball_tree is None:
            raise SocialVenueError("Initialise ball tree first with self.make_tree()")
        venue_idxs = self.ball_tree.query(
            np.deg2rad(coordinates).reshape(1, -1), return_distance=False, k=k
        ).flatten()
        social_venues = self.members
        return [social_venues[idx] for idx in venue_idxs]

    def get_venues_in_radius(self, coordinates, radius=5):
        """
        Queries the ball tree for the closests venues.

        Parameters
        ----------
        coordinates
            coordinates in the format [Latitude, Longitude]
        radius
            radius in km to query
        """
        if not self.members:
            return
        if self.ball_tree is None:
            raise SocialVenueError("Initialise ball tree first with self.make_tree()")
        radius = radius / earth_radius
        venue_idxs, _ = self.ball_tree.query_radius(
            np.deg2rad(coordinates).reshape(1, -1),
            r=radius,
            sort_results=True,
            return_distance=True,
        )
        venue_idxs = venue_idxs[0]
        if not venue_idxs.size:
            return None
        social_venues = self.members
        return [social_venues[idx] for idx in venue_idxs]

    def get_leisure_subgroup(self, person, subgroup_type, to_send_abroad):
        return self[subgroup_type]


import numpy as np
from random import random, sample, randint
from numba import jit
from typing import Dict
import yaml
import re

from june.groups.leisure import SocialVenues
from june.utils.parse_probabilities import parse_age_probabilities
from june.geography import Area


@jit(nopython=True)
def random_choice_numba(arr, prob):
    """
    Fast implementation of np.random.choice
    """
    return arr[np.searchsorted(np.cumsum(prob), random(), side="right")]


default_daytypes = {
    "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
    "weekend": ["Saturday", "Sunday"],
}


class SocialVenueDistributor:
    """
    Tool to associate social venues to people.
    """

    def __init__(
        self,
        social_venues: SocialVenues,
        times_per_week: Dict[Dict, float],
        daytypes: Dict[str, str] = default_daytypes,
        hours_per_day: Dict[Dict, float] = None,
        drags_household_probability=0.0,
        neighbours_to_consider=5,
        maximum_distance=5,
        leisure_subgroup_type=0,
        nearest_venues_to_visit=0,
        open={"weekday": "0-24", "weekend": "0-24"},
    ):
        """
        A sex/age profile for the social venue attendees can be specified as
        male_age_probabilities = {"18-65" : 0.3}
        any non-specified ages in the range (0,99) will have 0 probabilty
        Parameters
        ----------
        social_venues
            A SocialVenues object
        times_per_week:
            How many times per day type, age, and sex, a person does this activity.
            Example:
            times_per_week = {"weekday" : {"male" : {"0-50":0.5, "50-100" : 0.2},
                                            "female" : {"0-100" : 0.5}},
                              "weekend" : {"male" : {"0-100" : 1.0},
                                            "female" : {"0-100" : 1.0}}}
        hours_per_day:
            How many leisure hours per day a person has. This is the time window in which
            a person can do leisure.
            Example:
            hours_per_day = {"weekday" : {"male" : {"0-65": 3, "65-100" : 11},
                                          "female" : {"0-65" : 3, "65-100" : 11}},
                              "weekend" : {"male" : {"0-100" : 12},
                                            "female" : {"0-100" : 12}}}
        drags_household_probabilitiy:
            Probability of doing a certain activity together with the housheold.
        maximum_distance:
            Maximum distance to travel until the social venue
        leisure_subgroup_type
            Subgroup of the venue that the person will be appended to
            (for instance, the visitors subgroup of the care home)
        nearest_venues_to_visit:
            restrict people only travelling to nearest venue(s). 0 means no restriction.
            if >0, "neighbours_to_consider" will be ignored.
        """
        self.spec = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[:-1]
        self.spec = "_".join(self.spec).lower()
        if hours_per_day is None:
            hours_per_day = {
                "weekday": {
                    "male": {"0-65": 3, "65-100": 11},
                    "female": {"0-65": 3, "65-100": 11},
                },
                "weekend": {"male": {"0-100": 12}, "female": {"0-100": 12}},
            }
        self.social_venues = social_venues
        self.open = open
        self.daytypes = daytypes

        self.poisson_parameters = self._parse_poisson_parameters(
            times_per_week=times_per_week, hours_per_day=hours_per_day
        )
        self.neighbours_to_consider = neighbours_to_consider
        self.maximum_distance = maximum_distance
        self.drags_household_probability = drags_household_probability
        self.leisure_subgroup_type = leisure_subgroup_type
        self.spec = re.findall("[A-Z][^A-Z]*", self.__class__.__name__)[:-1]
        self.spec = "_".join(self.spec).lower()
        self.nearest_venues_to_visit = nearest_venues_to_visit

    @classmethod
    def from_config(
        cls,
        social_venues: SocialVenues,
        daytypes: dict = default_daytypes,
        config_filename: str = None,
        config_override: Dict[str, int] = None,
    ):
        """
        Parameters
        ----------
        config_override
            a dict of parameters overrides their values in "config_filename"
        """
        if config_filename is None:
            config_filename = cls.default_config_filename
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        if config_override is not None:
            for key, value in config_override.items():
                if value is not None:
                    config[key] = value
        return cls(social_venues, daytypes=daytypes, **config)

    def _compute_poisson_parameter_from_times_per_week(
        self, times_per_week, hours_per_day, day_type
    ):
        if times_per_week == 0:
            return 0
        ndays = len(self.daytypes[day_type])
        return (times_per_week / ndays) * (24 / hours_per_day)

    def _parse_poisson_parameters(self, times_per_week, hours_per_day):
        ret = {}
        _sex_t = {"male": "m", "female": "f"}

        for day_type in ["weekday", "weekend"]:
            ret[day_type] = {}
            for sex in ["male", "female"]:
                parsed_times_per_week = parse_age_probabilities(
                    times_per_week[day_type][sex]
                )
                parsed_hours_per_day = parse_age_probabilities(
                    hours_per_day[day_type][sex]
                )

                ret[day_type][_sex_t[sex]] = [
                    self._compute_poisson_parameter_from_times_per_week(
                        times_per_week=parsed_times_per_week[i],
                        hours_per_day=parsed_hours_per_day[i],
                        day_type=day_type,
                    )
                    for i in range(len(parsed_times_per_week))
                ]
        return ret

    def get_poisson_parameter(
        self, sex, age, day_type, working_hours, region=None, policy_reduction=None
    ):
        """
        Poisson parameter (lambda) of a person going to one social venue according to their
        age and sex and the distribution of visitors in the venue.

        Parameters
        ----------
        person
            an instance of Person
        delta_t
            interval of time in units of days
        weekday or weekend

            whether it is a weekend or not
        """
        if region is None:
            regional_compliance = 1
        else:
            if self.spec in region.closed_venues:
                return 0
            regional_compliance = region.regional_compliance

        original_poisson_parameter = self.poisson_parameters[day_type][sex][age]
        if policy_reduction is None:
            return original_poisson_parameter
        poisson_parameter = original_poisson_parameter * (
            1 + regional_compliance * (policy_reduction - 1)
        )
        return poisson_parameter

    def probability_to_go_to_social_venue(
        self, person, delta_time, day_type, working_hours
    ):
        """
        Probabilty of a person going to one social venue according to their
        age and sex and the distribution of visitors in the venue.

        Parameters
        ----------
        person
            an instance of Person
        delta_t
            interval of time in units of days
        day_type
            weekday or weekend
        """
        poisson_parameter = self.get_poisson_parameter(
            sex=person.sex,
            age=person.age,
            day_type=day_type,
            working_hours=working_hours,
        )
        return 1 - np.exp(-poisson_parameter * delta_time)

    def get_possible_venues_for_area(self, area: Area):
        """
        Given an area, searches for the social venues inside
        ``self.maximum_distance``. It then returns ``self.neighbours_to_consider``
        of them randomly, or ``nearest_venues_to_visit`` of them sorting by distance ascending.
        If there are no social venues inside the maximum distance, it returns the closest one.
        """
        area_location = area.coordinates
        potential_venues = self.social_venues.get_venues_in_radius(
            area_location, self.maximum_distance
        )
        if potential_venues is None:
            closest_venue = self.social_venues.get_closest_venues(area_location, k=1)
            if closest_venue is None:
                return
            return (closest_venue[0],)
        if self.nearest_venues_to_visit > 0:
            indices_len = min(len(potential_venues), self.nearest_venues_to_visit)
            return tuple([potential_venues[idx] for idx in range(indices_len)])
        else:
            indices_len = min(len(potential_venues), self.neighbours_to_consider)
            random_idx_choice = sample(range(len(potential_venues)), indices_len)
            return tuple([potential_venues[idx] for idx in random_idx_choice])

    def get_leisure_group(self, person):
        candidates = person.area.social_venues[self.spec]
        n_candidates = len(candidates)
        if n_candidates == 0:
            return
        elif n_candidates == 1:
            group = candidates[0]
        else:
            group = candidates[randint(0, n_candidates - 1)]
        return group

    def get_leisure_subgroup(self, person, to_send_abroad=None):
        group = self.get_leisure_group(person)
        # this may not necessary be the same subgroup, allow for customization here.
        if group is None:
            return
        subgroup = group.get_leisure_subgroup(
            person=person,
            subgroup_type=self.leisure_subgroup_type,
            to_send_abroad=to_send_abroad,
        )
        return subgroup

    def person_drags_household(self):
        """
        Check whether person drags household or not.
        """
        return random() < self.drags_household_probability

    def send_household_with_person_if_necessary(self, person, to_send_abroad=None):
        """
        When we know that the person does an activity in the social venue X,
        then we ask X whether the person needs to drag the household with
        him or her.
        """
        if (
            person.residence.group.spec == "care_home"
            or person.residence.group.type in ["communal", "other", "student"]
        ):
            return
        subgroup = person.leisure
        if self.person_drags_household():
            for mate in person.residence.group.residents:
                if mate != person:
                    if mate.busy:
                        if (
                            mate.leisure is not None
                        ):  # this perosn has already been assigned somewhere
                            if not mate.leisure.external:
                                if mate not in mate.leisure.people:
                                    # person active somewhere else, let's not disturb them
                                    continue
                                mate.leisure.remove(mate)
                            else:
                                ret = to_send_abroad.delete_person(mate, mate.leisure)
                                if ret:
                                    # person active somewhere else, let's not disturb them
                                    continue
                            if not subgroup.external:
                                subgroup.append(mate)
                            else:
                                to_send_abroad.add_person(mate, subgroup)
                    mate.subgroups.leisure = (
                        subgroup  # person will be added later in the simulator.
                    )


from .social_venue import SocialVenue, SocialVenues, SocialVenueError
from .social_venue_distributor import SocialVenueDistributor
from .pub import Pub, Pubs, PubDistributor
from .cinema import Cinema, Cinemas, CinemaDistributor
from .grocery import Groceries, Grocery, GroceryDistributor
from .gym import Gym, Gyms, GymDistributor
from .residence_visits import ResidenceVisitsDistributor
from .leisure import Leisure, generate_leisure_for_world, generate_leisure_for_config


import csv
from typing import List, Tuple, Dict

import numpy as np
import yaml

from june import paths
from june.utils import random_choice_numba

default_config_filename = (
    paths.configs_path / "defaults/groups/travel/mode_of_transport.yaml"
)

default_commute_file = paths.data_path / "input/travel/mode_of_transport_ew.csv"


class ModeOfTransport:
    __all = {}
    __slots__ = "description", "is_public"

    def __new__(cls, description, is_public=False):
        if description not in ModeOfTransport.__all:
            ModeOfTransport.__all[description] = super().__new__(cls)
        return ModeOfTransport.__all[description]

    def __init__(self, description: str, is_public: bool = False):
        """
        Create a ModeOfTransport from its description.

        Only one instance of each mode of transport exists with instances being
        retrieved from the __all dictionary.

        Parameters
        ----------
        description
            e.g. "Bus, minibus or coach"
        is_public
            True if this is public transport, for example a bus.
        """
        self.description = description
        self.is_public = is_public

    @property
    def is_private(self) -> bool:
        """
        True if this is private transport, for example a car.
        """
        return not self.is_public

    @classmethod
    def with_description(cls, description: str) -> "ModeOfTransport":
        """
        Retrieve a mode of transport by its description.

        Parameters
        ----------
        description
            A description, e.g. 'Bus, minibus or coach'

        Returns
        -------
        The corresponding ModeOfTransport instance
        """
        return ModeOfTransport.__all[description]

    def index(self, headers: List[str]) -> int:
        """
        Determine the column index of this mode of transport.

        The first header that contains this mode of transport's description
        is counted.

        Parameters
        ----------
        headers
            A list of headers from a CSV file.

        Returns
        -------
        The column index corresponding to this mode of transport.

        Raises
        ------
        An assertion error if no such header is found.
        """
        for i, header in enumerate(headers):
            if self.description in header:
                return i
        raise AssertionError(f"{self} not found in headers {headers}")

    def __eq__(self, other):
        if isinstance(other, str):
            return self.description == other
        if isinstance(other, ModeOfTransport):
            return self.description == other.description
        return super().__eq__(other)

    def __hash__(self):
        return hash(self.description)

    def __str__(self):
        return self.description

    def __repr__(self):
        return f"<{self.__class__.__name__} {self}>"

    def __getnewargs__(self):
        return self.description, self.is_public

    @classmethod
    def load_from_file(
        cls, config_filename=default_config_filename
    ) -> List["ModeOfTransport"]:
        """
        Load all of the modes of transport from commute.yaml.

        Modes of transport are globally unique. That is, even if the function
        is called twice identical mode of transport objects are returned.

        Parameters
        ----------
        config_filename
            The path to the mode of transport yaml configuration

        Returns
        -------
        A list of modes of transport
        """
        with open(config_filename) as f:
            configs = yaml.load(f, Loader=yaml.FullLoader)
        return [ModeOfTransport(**config) for config in configs]


class RegionalGenerator:
    def __init__(self, area: str, weighted_modes: List[Tuple[int, "ModeOfTransport"]]):
        """
        Randomly generate modes of transport, weighted by usage, for
        one particular region.

        Parameters
        ----------
        area
            A unique identifier for a Output region
        weighted_modes
            A list of tuples comprising the number of people using a mode
            of a transport and a representation of that mode of transport
        """
        self.area = area
        self.weighted_modes = weighted_modes
        self.total = self._get_total()
        self.modes = self._get_modes()
        self.weights = self._get_weights()
        self.modes_idx = np.arange(0, len(self.modes))

    def _get_total(self) -> int:
        """
        The sum of the numbers of people using each mode of transport
        """
        return sum(mode[0] for mode in self.weighted_modes)

    def _get_modes(self) -> List["ModeOfTransport"]:
        """
        A list of modes of transport
        """
        return [mode[1] for mode in self.weighted_modes]

    def _get_weights(self) -> List[float]:
        """
        The normalised weights for each mode of transport.
        """
        return np.array([mode[0] / self.total for mode in self.weighted_modes])

    def weighted_random_choice(self) -> "ModeOfTransport":
        """
        Randomly choose a mode of transport, weighted by usage in this region.
        """
        idx = random_choice_numba(self.modes_idx, self.weights)
        return self.modes[idx]

    def __repr__(self):
        return f"<{self.__class__.__name__} {self}>"

    def __str__(self):
        return self.area


class ModeOfTransportGenerator:
    def __init__(self, regional_generators: Dict[str, RegionalGenerator]):
        """
        Generate a mode of transport that a person uses in their commute.

        Modes of transport are chosen randomly, weighted by the numbers taken
        from census data for each given Output area.

        Parameters
        ----------
        regional_generators
            A dictionary mapping Geography areas to objects that randomly
            generate modes of transport
        """
        self.regional_generators = regional_generators

    def regional_gen_from_area(self, area: str) -> RegionalGenerator:
        """
        Get a regional generator for an Area identified
        by its output output area, e.g. E00062207

        Parameters
        ----------
        super_area
            A code for an super_area

        Returns
        -------
        An object that weighted-randomly selects modes of transport for the region.
        """
        return self.regional_generators[area]

    @classmethod
    def from_file(
        cls,
        filename: str = default_commute_file,
        config_filename: str = default_config_filename,
    ) -> "ModeOfTransportGenerator":
        """
        Parse configuration describing each included mode of transport
        along with census data describing the weightings for modes of
        transport in each output area.

        Parameters
        ----------
        filename
            The path to the commute.csv file.
            This contains data on the number of people using each mode
            of transport.
        config_filename
            The path to the commute.yaml file

        Returns
        -------
        An object used to generate commutes
        """
        regional_generators = {}
        with open(filename) as f:
            reader = csv.reader(f)
            headers = next(reader)
            area_column = headers.index("geography code")
            modes_of_transport = ModeOfTransport.load_from_file(config_filename)
            for row in reader:
                weighted_modes = []
                for mode in modes_of_transport:
                    weighted_modes.append((int(row[mode.index(headers)]), mode))
                area = row[area_column]
                regional_generators[area] = RegionalGenerator(
                    area=area, weighted_modes=weighted_modes
                )

        return ModeOfTransportGenerator(regional_generators)


from enum import IntEnum
from typing import List

from june.groups import Group, Supergroup


class Transport(Group):
    """
    A class representing a transport unit.
    """

    # class SubgroupType(IntEnum):
    #     passengers = 0

    def __init__(self, station):
        super().__init__()
        self.station = station

    @property
    def area(self):
        return self.station.super_area.areas[0]

    @property
    def super_area(self):
        return self.station.super_area

    @property
    def coordinates(self):
        return self.area.coordinates


class Transports(Supergroup):
    """
    A collection of transport units.
    """

    def __init__(self, transports: List[Transport]):
        super().__init__(transports)


class CityTransport(Transport):
    """
    Inner city transport
    """


class CityTransports(Transports):

    """
    Inner city transports
    """

    venue_class = CityTransport


class InterCityTransport(Transport):
    """
    Transport between cities.
    """


class InterCityTransports(Transports):
    """
    Inter city transports
    """

    venue_class = InterCityTransport


import logging
import yaml
import numpy as np

from june.paths import configs_path, data_path
from june.geography import Cities, Stations
from june.world import World
from .mode_of_transport import ModeOfTransport, ModeOfTransportGenerator
from .transport import CityTransports, InterCityTransports


logger = logging.getLogger("travel")
default_cities_filename = data_path / "input/geography/cities_per_super_area_ew.csv"

default_city_stations_config_filename = (
    configs_path / "defaults/travel/city_stations.yaml"
)

default_commute_config_filename = configs_path / "defaults/groups/travel/commute.yaml"


class Travel:
    """
    This class handles all functionality related to travel, from local commute,
    to inter-city and inter-regional travel.
    """

    def __init__(
        self,
        city_super_areas_filename=default_cities_filename,
        city_stations_filename=default_city_stations_config_filename,
        commute_config_filename=default_commute_config_filename,
    ):
        self.city_super_areas_filename = city_super_areas_filename
        self.city_stations_filename = city_stations_filename
        with open(commute_config_filename) as f:
            self.commute_config = yaml.load(f, Loader=yaml.FullLoader)

    def initialise_commute(
        self, world: World, maximum_number_commuters_per_city_station=200000
    ):
        logger.info("Initialising commute...")
        self._generate_cities(
            world=world, city_super_areas_filename=self.city_super_areas_filename
        )
        self._assign_mode_of_transport_to_people(world=world)
        commuters_dict = self._get_city_commuters(
            world=world, city_stations_filename=self.city_stations_filename
        )
        self._create_stations(
            world=world,
            commuters_dict=commuters_dict,
            maximum_number_commuters_per_city_station=maximum_number_commuters_per_city_station,
            city_stations_filename=self.city_stations_filename,
        )
        self._distribute_commuters_to_stations(
            world=world, commuters_dict=commuters_dict
        )
        self._create_transports_in_cities(world)

    def get_commute_subgroup(self, person):
        work_city = person.work_city
        if work_city is None or not person.mode_of_transport.is_public:
            return
        subgroup = work_city.get_commute_subgroup(person)
        person.subgroups.commute = subgroup
        return subgroup

    def _generate_cities(self, world, city_super_areas_filename: str):
        """
        Generates cities in the current world.
        """
        # initialise cities
        logger.info("Creating cities...")
        world.cities = Cities.for_super_areas(
            world.super_areas, city_super_areas_filename=city_super_areas_filename
        )
        city_names = [city.name for city in world.cities]
        if len(city_names) > 0:
            logger.info(
                f"This world has {len(city_names)} cities, with names\n" f"{city_names}"
            )
        else:
            logger.info("This world has no important cities in it")

    def _assign_mode_of_transport_to_people(self, world: World):
        """
        Assigns a mode of transport (public or not) to the world's population.
        """
        logger.info("Determining people mode of transport")
        mode_of_transport_generator = ModeOfTransportGenerator.from_file()
        for i, area in enumerate(world.areas):
            if i % 4000 == 0:
                logger.info(
                    f"Mode of transport allocated in {i} of {len(world.areas)} areas."
                )
            mode_of_transport_generator_area = (
                mode_of_transport_generator.regional_gen_from_area(area.name)
            )
            for person in area.people:
                if person.age < 18 or person.age >= 65:
                    person.mode_of_transport = ModeOfTransport(
                        description="Not in employment", is_public=False
                    )
                else:
                    person.mode_of_transport = (
                        mode_of_transport_generator_area.weighted_random_choice()
                    )
        logger.info("Mode of transport determined for everyone.")

    def _get_city_commuters(self, world: World, city_stations_filename: str):
        """
        Gets internal and external commuters per city.
        - If the person lives and works in the same city, then the person is assigned
          to be an internal commuter (think as the person takes the subway).
        - If the person lives outside their working city, then that person has to commute
          through a station, and is assigned to the city external commuters.
        - Likewise for the people living in the city but working outside.

        """
        with open(city_stations_filename) as f:
            cities_with_stations = yaml.load(f, Loader=yaml.FullLoader)[
                "number_of_inter_city_stations"
            ]
        ret = {}
        for city in world.cities:
            if city.name in cities_with_stations:
                ret[city.name] = {"internal": [], "external": []}
        logger.info("Assigning commuters to stations...")
        for i, person in enumerate(world.people):
            if person.mode_of_transport.is_public:
                if (
                    person.work_city is not None
                    and person.work_city.name in cities_with_stations
                ):
                    if person.home_city == person.work_city:
                        # this person commutes internally
                        ret[person.work_city.name]["internal"].append(person.id)
                    else:
                        # commutes away to an external station
                        ret[person.work_city.name]["external"].append(person.id)
            if i % 500_000 == 0:
                logger.info(
                    f"Assigned {i} of {len(world.people)} potential commuters..."
                )
        logger.info("Commuters assigned")
        for key, value in ret.items():
            internal = value["internal"]
            external = value["external"]
            if len(internal) + len(external) > 0:
                logger.info(
                    f"City {key} has {len(internal)} internal and {len(external)} external commuters."
                )
        return ret

    def _create_stations(
        self,
        world: World,
        city_stations_filename: str,
        commuters_dict: dict,
        maximum_number_commuters_per_city_station: int,
    ):
        """
        Generates cities, super stations, and stations on the given world.
        """
        with open(city_stations_filename) as f:
            inter_city_stations_per_city = yaml.load(f, Loader=yaml.FullLoader)[
                "number_of_inter_city_stations"
            ]
        logger.info("Creating stations...")
        world.stations = Stations([])
        for city in world.cities:
            if city.name not in inter_city_stations_per_city:
                continue
            else:
                n_inter_city_stations = inter_city_stations_per_city[city.name]
                city.inter_city_stations = Stations.from_city_center(
                    city=city,
                    super_areas=world.super_areas,
                    number_of_stations=n_inter_city_stations,
                    type="inter_city_station",
                    distance_to_city_center=10,
                )
                city.inter_city_stations._construct_ball_tree()
                world.stations += city.inter_city_stations
                n_internal_commuters = len(commuters_dict[city.name]["internal"])
                n_city_stations = int(
                    np.ceil(
                        n_internal_commuters / maximum_number_commuters_per_city_station
                    )
                )
                city.city_stations = Stations.from_city_center(
                    city=city,
                    super_areas=world.super_areas,
                    number_of_stations=n_city_stations,
                    type="city_station",
                    distance_to_city_center=5,
                )
                city.city_stations._construct_ball_tree()
                world.stations += city.city_stations
                # initialise ball tree for stations in the city.
                logger.info(
                    f"City {city.name} has {n_city_stations} city "
                    f"and {n_inter_city_stations} inter city stations."
                )
        for super_area in world.super_areas:
            for city in world.cities:
                if city.has_stations:
                    super_area.closest_inter_city_station_for_city[
                        city.name
                    ] = city.get_closest_inter_city_station(super_area.coordinates)

    def _distribute_commuters_to_stations(self, world: World, commuters_dict: dict):
        for city, commuters in commuters_dict.items():
            city = world.cities.get_by_name(city)
            city.internal_commuter_ids = set(commuters["internal"])
            for external_commuter_id in commuters["external"]:
                external_commuter = world.people.get_from_id(external_commuter_id)
                work_city = external_commuter.work_city.name
                station = (
                    external_commuter.super_area.closest_inter_city_station_for_city[
                        work_city
                    ]
                )
                station.commuter_ids.add(external_commuter_id)

    def _create_transports_in_cities(
        self, world, seats_per_city_transport=50, seats_per_inter_city_transport=50
    ):
        """
        Creates city transports and inter city transports in CityStations and
        InterCityStations respectively.
        """
        logger.info("Creating transport units for the population")
        if not hasattr(world, "city_transports"):
            world.city_transports = CityTransports([])
        if not hasattr(world, "inter_city_transports"):
            world.inter_city_transports = InterCityTransports([])

        for city in world.cities:
            if city.has_stations:
                seats_per_passenger = self.commute_config["seats_per_passenger"].get(
                    city.name, 1
                )
                n_commute_internal = len(city.internal_commuter_ids)
                number_city_transports = int(
                    np.ceil(
                        (
                            seats_per_passenger
                            * n_commute_internal
                            / seats_per_city_transport
                        )
                    )
                )
                logger.info(
                    f"City {city.name} has {number_city_transports} city train carriages."
                )
                n_city_stations = len(city.city_stations)
                transports_per_station = int(
                    np.ceil(number_city_transports / n_city_stations)
                )
                for station in city.city_stations:
                    for _ in range(transports_per_station):
                        city_transport = world.city_transports.venue_class(
                            station=station
                        )
                        station.city_transports.append(city_transport)
                        world.city_transports.add(city_transport)
                        number_city_transports -= 1
                        if number_city_transports <= 0:
                            break
                number_inter_city_transports_total = 0
                for station in city.inter_city_stations:
                    if len(station.commuter_ids) == 0:
                        continue
                    number_inter_city_transports = int(
                        np.ceil(
                            (
                                seats_per_passenger
                                * len(station.commuter_ids)
                                / seats_per_inter_city_transport
                            )
                        )
                    )
                    number_inter_city_transports_total += number_inter_city_transports
                    for _ in range(number_inter_city_transports):
                        inter_city_transport = world.inter_city_transports.venue_class(
                            station=station
                        )
                        station.inter_city_transports.append(inter_city_transport)
                        world.inter_city_transports.add(inter_city_transport)
                logger.info(
                    f"City {city.name} has {number_inter_city_transports_total} inter-city train carriages."
                )
        logger.info("Cities' transport initialised")


from .mode_of_transport import ModeOfTransport, ModeOfTransportGenerator
from .travel import Travel
from .transport import (
    Transport,
    Transports,
    CityTransport,
    CityTransports,
    InterCityTransport,
    InterCityTransports,
)


import h5py
import numpy as np

from june.groups import CareHome, CareHomes
from june.world import World
from june.groups.group.make_subgroups import SubgroupParams
from .utils import read_dataset

nan_integer = -999


def save_care_homes_to_hdf5(
    care_homes: CareHomes, file_path: str, chunk_size: int = 50000
):
    """
    Saves the care_homes object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, n_beds, n_icu_beds, super_area, coordinates

    Parameters
    ----------
    companies
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_care_homes = len(care_homes)
    n_chunks = int(np.ceil(n_care_homes / chunk_size))
    with h5py.File(file_path, "a") as f:
        care_homes_dset = f.create_group("care_homes")
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_care_homes)
            ids = []
            areas = []
            super_areas = []
            n_residents = []
            n_workers = []
            for carehome in care_homes[idx1:idx2]:
                ids.append(carehome.id)
                if carehome.area is None:
                    areas.append(nan_integer)
                    super_areas.append(nan_integer)
                else:
                    areas.append(carehome.area.id)
                    super_areas.append(carehome.super_area.id)
                n_residents.append(carehome.n_residents)
                n_workers.append(carehome.n_workers)

            ids = np.array(ids, dtype=np.int64)
            areas = np.array(areas, dtype=np.int64)
            n_residents = np.array(n_residents, dtype=np.float64)
            n_workers = np.array(n_workers, dtype=np.float64)
            if chunk == 0:
                care_homes_dset.attrs["n_care_homes"] = n_care_homes
                care_homes_dset.create_dataset("id", data=ids, maxshape=(None,))
                care_homes_dset.create_dataset("area", data=areas, maxshape=(None,))
                care_homes_dset.create_dataset(
                    "super_area", data=super_areas, maxshape=(None,)
                )
                care_homes_dset.create_dataset(
                    "n_residents", data=n_residents, maxshape=(None,)
                )
                care_homes_dset.create_dataset(
                    "n_workers", data=n_workers, maxshape=(None,)
                )
            else:
                newshape = (care_homes_dset["id"].shape[0] + ids.shape[0],)
                care_homes_dset["id"].resize(newshape)
                care_homes_dset["id"][idx1:idx2] = ids
                care_homes_dset["area"].resize(newshape)
                care_homes_dset["area"][idx1:idx2] = areas
                care_homes_dset["super_area"].resize(newshape)
                care_homes_dset["super_area"][idx1:idx2] = super_areas
                care_homes_dset["n_residents"].resize(newshape)
                care_homes_dset["n_residents"][idx1:idx2] = n_residents
                care_homes_dset["n_workers"].resize(newshape)
                care_homes_dset["n_workers"][idx1:idx2] = n_workers


def load_care_homes_from_hdf5(
    file_path: str, chunk_size=50000, domain_super_areas=None, config_filename=None
):
    """
    Loads carehomes from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    CareHome_Class = CareHome
    CareHome_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        care_homes = f["care_homes"]
        care_homes_list = []
        n_carehomes = care_homes.attrs["n_care_homes"]
        n_chunks = int(np.ceil(n_carehomes / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_carehomes)
            ids = read_dataset(care_homes["id"], idx1, idx2)
            n_residents = read_dataset(care_homes["n_residents"], idx1, idx2)
            n_workers = read_dataset(care_homes["n_workers"], idx1, idx2)
            super_areas = read_dataset(care_homes["super_area"], idx1, idx2)
            for k in range(idx2 - idx1):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                care_home = CareHome_Class(
                    area=None, n_residents=n_residents[k], n_workers=n_workers[k]
                )
                care_home.id = ids[k]
                care_homes_list.append(care_home)
    return CareHomes(care_homes_list)


def restore_care_homes_properties_from_hdf5(
    world: World, file_path: str, chunk_size=50000, domain_super_areas=None
):
    """
    Loads carehomes from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        carehomes = f["care_homes"]
        n_carehomes = carehomes.attrs["n_care_homes"]
        n_chunks = int(np.ceil(n_carehomes / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_carehomes)
            ids = carehomes["id"][idx1:idx2]
            areas = carehomes["area"][idx1:idx2]
            super_areas = carehomes["super_area"][idx1:idx2]
            for k in range(idx2 - idx1):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                care_home = world.care_homes.get_from_id(ids[k])
                if areas[k] == nan_integer:
                    area = None
                else:
                    area = world.areas.get_from_id(areas[k])
                care_home.area = area
                area.care_home = care_home


from typing import Optional
import numpy as np
from datetime import datetime, timedelta
import h5py
from glob import glob
import logging

from june.world import World
from june.interaction import Interaction
from june.groups.leisure import Leisure
from june.policy import Policies
from june.event import Events
from june.simulator import Simulator
from june.epidemiology.epidemiology import Epidemiology
from june.hdf5_savers.utils import write_dataset
from june.demography import Population
from june.demography.person import Activities
from june.hdf5_savers import (
    save_infections_to_hdf5,
    load_infections_from_hdf5,
    save_immunities_to_hdf5,
    load_immunities_from_hdf5,
)
from june.groups.travel import Travel
import june.simulator as june_simulator_module

from june.tracker import Tracker

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.records.records_writer import Record

default_config_filename = june_simulator_module.default_config_filename

logger = logging.getLogger("checkpoint_saver")


def save_checkpoint_to_hdf5(
    population: Population, date: str, hdf5_file_path: str, chunk_size: int = 50000
):
    """
    Saves a checkpoint at the given date by saving the infection information of the world.

    Parameters
    ----------
    population:
        world's population
    date:
        date of the checkpoint
    hdf5_file_path
        path where to save the hdf5 checkpoint
    chunk_size
        hdf5 chunk_size to write data
    """
    dead_people_ids = [person.id for person in population if person.dead]
    people_ids = []
    infected_people_ids = []
    infection_list = []
    for person in population:
        people_ids.append(person.id)
        if person.infected:
            infected_people_ids.append(person.id)
            infection_list.append(person.infection)
    with h5py.File(hdf5_file_path, "w") as f:
        f.create_group("time")
        f["time"].attrs["date"] = date
        f.create_group("people_data")
        for name, data in zip(
            ["people_id", "infected_id", "dead_id"],
            [people_ids, infected_people_ids, dead_people_ids],
        ):
            write_dataset(
                group=f["people_data"],
                dataset_name=name,
                data=np.array(data, dtype=np.int64),
            )
    save_infections_to_hdf5(
        hdf5_file_path=hdf5_file_path, infections=infection_list, chunk_size=chunk_size
    )
    immunities = [person.immunity for person in population]
    save_immunities_to_hdf5(hdf5_file_path=hdf5_file_path, immunities=immunities)


def load_checkpoint_from_hdf5(hdf5_file_path: str, chunk_size=50000, load_date=True):
    """
    Loads checkpoint data from hdf5.

    Parameters
    ----------
    hdf5_file_path
        hdf5 path to load from
    chunk_size
        number of hdf5 chunks to use while loading
    """
    ret = {}
    ret["infection_list"] = load_infections_from_hdf5(
        hdf5_file_path, chunk_size=chunk_size
    )
    ret["immunity_list"] = load_immunities_from_hdf5(
        hdf5_file_path, chunk_size=chunk_size
    )
    with h5py.File(hdf5_file_path, "r") as f:
        people_group = f["people_data"]
        ret["infected_id"] = people_group["infected_id"][:]
        ret["dead_id"] = people_group["dead_id"][:]
        ret["people_id"] = people_group["people_id"][:]
        if load_date:
            ret["date"] = f["time"].attrs["date"]
    return ret


def combine_checkpoints_for_ranks(hdf5_file_root: str):
    """
    After running a parallel simulation with checkpoints, the
    checkpoint data will be scattered accross, with each process
    saving a checkpoint_date.0.hdf5 file. This function can be used
    to unify all data in one single checkpoint, so that we can load it
    later with any arbitray number of cores.

    Parameters
    ----------
    hdf5_file_root
        the str root of the pasts like "checkpoint_2020-01-01". The checkpoint files
        will be expected to have names like "checkpoint_2020-01-01.{rank}.hdf5 where
        rank = 0, 1, 2, etc.
    """
    checkpoint_files = glob(hdf5_file_root + ".[0-9]*.hdf5")
    try:
        cp_date = hdf5_file_root.split("_")[-1]
    except Exception:
        cp_date = hdf5_file_root
    logger.info(f"found {len(checkpoint_files)} {cp_date} checkpoint files")
    ret = load_checkpoint_from_hdf5(checkpoint_files[0])
    for i in range(1, len(checkpoint_files)):
        file = checkpoint_files[i]
        ret2 = load_checkpoint_from_hdf5(file, load_date=False)
        for key, value in ret2.items():
            ret[key] = np.concatenate((ret[key], value))

    unified_checkpoint_path = hdf5_file_root + ".hdf5"
    with h5py.File(unified_checkpoint_path, "w") as f:
        f.create_group("time")
        f["time"].attrs["date"] = ret["date"]
        f.create_group("people_data")
        for name in ["people_id", "infected_id", "dead_id"]:
            write_dataset(
                group=f["people_data"],
                dataset_name=name,
                data=np.array(ret[name], dtype=np.int64),
            )
    save_infections_to_hdf5(
        hdf5_file_path=unified_checkpoint_path,
        infections=ret["infection_list"],
        chunk_size=1000000,
    )
    save_immunities_to_hdf5(
        hdf5_file_path=unified_checkpoint_path, immunities=ret["immunity_list"]
    )


def restore_simulator_to_checkpoint(
    simulator,
    world: World,
    checkpoint_path: str,
    chunk_size: Optional[int] = 50000,
    reset_infections=False,
):
    """
    Initializes the simulator from a saved checkpoint. The arguments are the same as the standard .from_file()
    initialisation but with the additional path to where the checkpoint pickle file is located.
    The checkpoint saves information about the infection status of all the people in the world as well as the timings.
    Note, nonetheless, that all the past infections / deaths will have the checkpoint date as date.

    Parameters
    ----------
    simulator:
        An instance of the Simulator class
    checkpoint_path:
        path to the hdf5 file containing the checkpoint data
    chunk_size
        chunk load size of the hdf5
    reset_infected
        whether to reset the current infected to 0. Useful for reseeding.
    """
    people_ids = set(world.people.people_ids)
    checkpoint_data = load_checkpoint_from_hdf5(checkpoint_path, chunk_size=chunk_size)
    for dead_id in checkpoint_data["dead_id"]:
        if dead_id not in people_ids:
            continue
        person = simulator.world.people.get_from_id(dead_id)
        person.dead = True
        cemetery = world.cemeteries.get_nearest(person)
        cemetery.add(person)
        person.subgroups = Activities(None, None, None, None, None, None)
    if not reset_infections:
        for infected_id, infection in zip(
            checkpoint_data["infected_id"], checkpoint_data["infection_list"]
        ):
            if infected_id not in people_ids:
                continue
            person = simulator.world.people.get_from_id(infected_id)
            person.infection = infection
    # restore immunities
    for person_id, immunity in zip(
        checkpoint_data["people_id"], checkpoint_data["immunity_list"]
    ):
        if person_id not in people_ids:
            continue
        person = world.people.get_from_id(person_id)
        person.immunity = immunity
    # restore timer
    checkpoint_date = datetime.strptime(checkpoint_data["date"], "%Y-%m-%d")
    # we need to start the next day
    checkpoint_date += timedelta(days=1)
    simulator.timer.reset_to_new_date(checkpoint_date)
    logger.info(f"Restored checkpoint at date {checkpoint_date.date()}")
    return simulator


def generate_simulator_from_checkpoint(
    world: World,
    checkpoint_path: str,
    interaction: Interaction,
    chunk_size: Optional[int] = 50000,
    epidemiology: Optional[Epidemiology] = None,
    tracker: Optional[Tracker] = None,
    policies: Optional[Policies] = None,
    leisure: Optional[Leisure] = None,
    travel: Optional[Travel] = None,
    events: Optional[Events] = None,
    config_filename: str = default_config_filename,
    record: "Record" = None,
    reset_infections=False,
):
    simulator = Simulator.from_file(
        world=world,
        interaction=interaction,
        epidemiology=epidemiology,
        tracker=tracker,
        policies=policies,
        leisure=leisure,
        travel=travel,
        events=events,
        config_filename=config_filename,
        record=record,
    )
    return restore_simulator_to_checkpoint(
        world=world,
        checkpoint_path=checkpoint_path,
        chunk_size=chunk_size,
        simulator=simulator,
        reset_infections=reset_infections,
    )


import h5py
import numpy as np
from typing import List

from june.world import World
from june.geography import (
    City,
    Cities,
    CityStation,
    InterCityStation,
    Stations,
    ExternalCityStation,
    ExternalInterCityStation,
    ExternalCity,
)
from .utils import read_dataset
from june.groups import ExternalGroup
from june.groups.travel import (
    CityTransport,
    CityTransports,
    InterCityTransport,
    InterCityTransports,
)
from june.groups.group.make_subgroups import SubgroupParams

nan_integer = -999
int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))
string_15_vlen_type = h5py.vlen_dtype(np.dtype("S15"))
string_30_vlen_type = h5py.vlen_dtype(np.dtype("S30"))


def save_cities_to_hdf5(cities: Cities, file_path: str):
    n_cities = len(cities)
    with h5py.File(file_path, "a") as f:
        cities_dset = f.create_group("cities")
        ids = []
        city_super_area_list = []
        super_areas_list = []
        super_areas_list_lengths = []
        names = []
        internal_commuters_list = []
        internal_commuters_list_lengths = []
        city_stations_id_list = []
        city_station_ids_lengths = []
        inter_city_stations_id_list = []
        inter_city_station_ids_lengths = []
        coordinates = []
        for city in cities:
            ids.append(city.id)
            names.append(city.name.encode("ascii", "ignore"))
            internal_commuters = [
                person_id for person_id in list(city.internal_commuter_ids)
            ]
            internal_commuters_list.append(np.array(internal_commuters, dtype=np.int64))
            internal_commuters_list_lengths.append(len(internal_commuters))
            super_areas = np.array(
                [
                    super_area.encode("ascii", "ignore")
                    for super_area in city.super_areas
                ],
                dtype="S20",
            )
            super_areas_list.append(super_areas)
            super_areas_list_lengths.append(len(super_areas))
            coordinates.append(np.array(city.coordinates, dtype=np.float64))
            if city.super_area is None:
                city_super_area_list.append(nan_integer)
            else:
                city_super_area_list.append(city.super_area.id)
            # stations
            city_stations_ids = np.array(
                [station.id for station in city.city_stations], dtype=np.int64
            )
            inter_city_stations_ids = np.array(
                [station.id for station in city.inter_city_stations], dtype=np.int64
            )
            city_station_ids_lengths.append(len(city_stations_ids))
            inter_city_station_ids_lengths.append(len(inter_city_stations_ids))
            city_stations_id_list.append(city_stations_ids)
            inter_city_stations_id_list.append(inter_city_stations_ids)

        ids = np.array(ids, dtype=np.int64)
        names = np.array(names, dtype="S30")
        if len(np.unique(super_areas_list_lengths)) == 1:
            super_areas_list = np.array(super_areas_list, dtype="S15")
        else:
            super_areas_list = np.array(super_areas_list, dtype=string_15_vlen_type)
        if len(np.unique(city_station_ids_lengths)) == 1:
            city_stations_id_list = np.array(city_stations_id_list, dtype=np.int64)
        else:
            city_stations_id_list = np.array(city_stations_id_list, dtype=int_vlen_type)
        if len(np.unique(city_station_ids_lengths)) == 1:
            inter_city_stations_id_list = np.array(
                inter_city_stations_id_list, dtype=np.int64
            )
        else:
            inter_city_stations_id_list = np.array(
                inter_city_stations_id_list, dtype=int_vlen_type
            )
        if len(np.unique(internal_commuters_list_lengths)) == 1:
            internal_commuters_list = np.array(internal_commuters_list, dtype=np.int64)
        else:
            internal_commuters_list = np.array(
                internal_commuters_list, dtype=int_vlen_type
            )
        city_super_area_list = np.array(city_super_area_list, dtype=np.int64)

        cities_dset.attrs["n_cities"] = n_cities
        cities_dset.create_dataset("id", data=ids)
        cities_dset.create_dataset("name", data=names)
        cities_dset.create_dataset("coordinates", data=coordinates)
        cities_dset.create_dataset("super_areas", data=super_areas_list)
        cities_dset.create_dataset("city_super_area", data=city_super_area_list)
        cities_dset.create_dataset("internal_commuters", data=internal_commuters_list)
        # stations
        cities_dset.create_dataset("city_station_id", data=city_stations_id_list)
        cities_dset.create_dataset(
            "inter_city_station_id", data=inter_city_stations_id_list
        )


def load_cities_from_hdf5(
    file_path: str,
    domain_super_areas: List[int] = None,
    super_areas_to_domain_dict: dict = None,
):
    """
    Loads cities from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        cities = f["cities"]
        n_cities = cities.attrs["n_cities"]
        ids = read_dataset(cities["id"])
        names = read_dataset(cities["name"])
        coordinates = read_dataset(cities["coordinates"])
        super_areas_list = read_dataset(cities["super_areas"])
        city_super_areas = read_dataset(cities["city_super_area"])
        cities = []
        for k in range(n_cities):
            super_areas = [super_area.decode() for super_area in super_areas_list[k]]
            city_super_area = city_super_areas[k]
            if domain_super_areas is None or city_super_area in domain_super_areas:
                city = City(
                    name=names[k].decode(),
                    super_areas=super_areas,
                    coordinates=coordinates[k],
                )
                city.id = ids[k]
            else:
                # this city is external to the domain
                city = ExternalCity(
                    id=ids[k],
                    domain_id=super_areas_to_domain_dict[city_super_area],
                    commuter_ids=None,
                    name=names[k].decode(),
                )
            cities.append(city)
    return Cities(cities, ball_tree=False)


def save_stations_to_hdf5(stations: Stations, file_path: str):
    n_stations = len(stations)
    with h5py.File(file_path, "a") as f:
        stations_dset = f.create_group("stations")
        stations_dset.attrs["n_stations"] = n_stations
        station_ids = []
        station_cities = []
        station_types = []
        station_super_areas = []
        station_commuters = []
        station_transport_ids_list = []
        station_transport_ids_list_lengths = []
        for station in stations:
            if isinstance(station, CityStation):
                station_types.append("city".encode("ascii", "ignore"))
            else:
                station_types.append("inter".encode("ascii", "ignore"))
            station_ids.append(station.id)
            station_super_areas.append(station.super_area.id)
            station_commuters.append(
                np.array(
                    [person_id for person_id in list(station.commuter_ids)],
                    dtype=np.int64,
                )
            )
            if isinstance(station, CityStation):
                station_transport_ids = [
                    transport.id for transport in station.city_transports
                ]
            else:
                station_transport_ids = [
                    transport.id for transport in station.inter_city_transports
                ]
            station_transport_ids_list.append(
                np.array(station_transport_ids, dtype=np.int64)
            )
            station_transport_ids_list_lengths.append(len(station_transport_ids))
            station_cities.append(station.city.encode("ascii", "ignore"))
        station_ids = np.array(station_ids, dtype=np.int64)
        station_super_areas = np.array(station_super_areas, dtype=np.int64)
        station_commuters = np.array(station_commuters, dtype=int_vlen_type)
        station_cities = np.array(station_cities, dtype="S30")
        station_types = np.array(station_types, dtype="S10")
        if len(np.unique(station_transport_ids_list_lengths)) == 1:
            station_transport_ids_list = np.array(
                station_transport_ids_list, dtype=np.int64
            )
        else:
            station_transport_ids_list = np.array(
                station_transport_ids_list, dtype=int_vlen_type
            )
        stations_dset.create_dataset("id", data=station_ids)
        stations_dset.create_dataset("super_area", data=station_super_areas)
        stations_dset.create_dataset("commuters", data=station_commuters)
        stations_dset.create_dataset("transport_ids", data=station_transport_ids_list)
        stations_dset.create_dataset("station_cities", data=station_cities)
        stations_dset.create_dataset("type", data=station_types)


def load_stations_from_hdf5(
    file_path: str,
    domain_super_areas: List[int] = None,
    super_areas_to_domain_dict: dict = None,
    config_filename=None,
):
    """
    Loads cities from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """

    InterCityTransport_Class = InterCityTransport
    InterCityTransport_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    CityTransport_Class = CityTransport
    CityTransport_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        stations = f["stations"]
        n_stations = stations.attrs["n_stations"]
        ids = read_dataset(stations["id"])
        if len(stations["transport_ids"].shape) == 1:
            transport_ids = read_dataset(stations["transport_ids"])
        else:
            transport_ids = [[] for _ in range(stations["transport_ids"].len())]
        cities = read_dataset(stations["station_cities"])
        super_areas = read_dataset(stations["super_area"])
        types = read_dataset(stations["type"])
        stations = []
        inter_city_transports = []
        city_transports = []
        for k in range(n_stations):
            super_area = super_areas[k]
            transports_station = []
            station_type = types[k].decode()
            city = cities[k].decode()
            if domain_super_areas is None or super_area in domain_super_areas:
                if station_type == "inter":
                    station = InterCityStation(city=city)
                else:
                    station = CityStation(city=city)
                station.id = ids[k]
                for transport_id in transport_ids[k]:
                    if station_type == "inter":
                        transport = InterCityTransport_Class(station=station)
                    else:
                        transport = CityTransport_Class(station=station)
                    transport.id = transport_id
                    transports_station.append(transport)
            else:
                if station_type == "inter":
                    station = ExternalInterCityStation(
                        id=ids[k],
                        domain_id=super_areas_to_domain_dict[super_area],
                        city=city,
                    )
                else:
                    station = ExternalCityStation(
                        id=ids[k],
                        domain_id=super_areas_to_domain_dict[super_area],
                        city=city,
                    )
                for transport_id in transport_ids[k]:
                    if station_type == "inter":
                        transport = ExternalGroup(
                            domain_id=super_areas_to_domain_dict[super_area],
                            spec="inter_city_transport",
                            id=transport_id,
                        )
                    else:
                        transport = ExternalGroup(
                            domain_id=super_areas_to_domain_dict[super_area],
                            spec="city_transport",
                            id=transport_id,
                        )
                    transports_station.append(transport)
            if station_type == "inter":
                station.inter_city_transports = transports_station
                inter_city_transports += transports_station
            else:
                station.city_transports = transports_station
                city_transports += transports_station
            stations.append(station)
    return (
        Stations(stations),
        InterCityTransports(inter_city_transports),
        CityTransports(city_transports),
    )


def restore_cities_and_stations_properties_from_hdf5(
    world: World,
    file_path: str,
    chunk_size: int,
    domain_super_areas: List[int] = None,
    super_areas_to_domain_dict: dict = None,
):
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        # load cities data
        cities = f["cities"]
        n_cities = cities.attrs["n_cities"]
        city_ids = read_dataset(cities["id"])
        city_city_station_ids = read_dataset(cities["city_station_id"])
        city_inter_city_station_ids = read_dataset(cities["inter_city_station_id"])
        city_internal_commuters_list = read_dataset(cities["internal_commuters"])
        city_super_areas = read_dataset(cities["city_super_area"])
        # load stations data
        stations = f["stations"]
        n_stations = stations.attrs["n_stations"]
        station_ids = read_dataset(stations["id"])
        station_super_areas = read_dataset(stations["super_area"])
        if len(stations["commuters"].shape) == 1:
            station_commuters_list = read_dataset(stations["commuters"])
        else:
            station_commuters_list = [[] for _ in range(stations["commuters"].len())]
        for k in range(n_stations):
            station_id = station_ids[k]
            station = world.stations.get_from_id(station_id)
            station.commuter_ids = set([c_id for c_id in station_commuters_list[k]])
            station_super_area = station_super_areas[k]
            if domain_super_areas is None or station_super_area in domain_super_areas:
                station.super_area = world.super_areas.get_from_id(
                    station_super_areas[k]
                )

        for k in range(n_cities):
            city_id = city_ids[k]
            city_super_area = city_super_areas[k]
            city = world.cities.get_from_id(city_id)
            commuters = set(
                [commuter_id for commuter_id in city_internal_commuters_list[k]]
            )
            city.internal_commuter_ids = commuters
            city.city_stations = []
            city.inter_city_stations = []
            for station_id in city_city_station_ids[k]:
                station = world.stations.get_from_id(station_id)
                city.city_stations.append(station)
            for station_id in city_inter_city_station_ids[k]:
                station = world.stations.get_from_id(station_id)
                city.inter_city_stations.append(station)
            if domain_super_areas is None or city_super_area in domain_super_areas:
                city_super_area_instance = world.super_areas.get_from_id(
                    city_super_area
                )
                city.super_area = city_super_area_instance
                city_super_area_instance.city = city
        # super areas info
        geography = f["geography"]
        n_super_areas = geography.attrs["n_super_areas"]
        n_chunks = int(np.ceil(n_super_areas / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_super_areas)
            length = idx2 - idx1
            super_area_ids = read_dataset(geography["super_area_id"], idx1, idx2)
            super_area_city = read_dataset(geography["super_area_city"], idx1, idx2)
            super_area_closest_stations_cities = read_dataset(
                geography["super_area_closest_stations_cities"], idx1, idx2
            )
            super_area_closest_stations_stations = read_dataset(
                geography["super_area_closest_stations_stations"], idx1, idx2
            )
            # load closest station
            for k in range(length):
                super_area_id = super_area_ids[k]
                if domain_super_areas is not None:
                    if super_area_id == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area_id not in domain_super_areas:
                        continue
                super_area = world.super_areas.get_from_id(super_area_id)
                if super_area_city[k] == nan_integer:
                    super_area.city = None
                else:
                    super_area.city = world.cities.get_from_id(super_area_city[k])
                for city, station in zip(
                    super_area_closest_stations_cities[k],
                    super_area_closest_stations_stations[k],
                ):
                    super_area.closest_inter_city_station_for_city[
                        city.decode()
                    ] = world.stations.get_from_id(station)


import h5py
import numpy as np
import logging

from june.groups import Company, Companies
from june.world import World
from june.groups.group.make_subgroups import SubgroupParams
from june.mpi_setup import mpi_rank
from .utils import read_dataset

nan_integer = -999

logger = logging.getLogger("company_saver")
if mpi_rank > 0:
    logger.propagate = False


def save_companies_to_hdf5(
    companies: Companies, file_path: str, chunk_size: int = 500000
):
    """
    Saves the Population object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, super_area, sector, n_workers_max,

    Parameters
    ----------
    companies
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_companies = len(companies)
    n_chunks = int(np.ceil(n_companies / chunk_size))
    with h5py.File(file_path, "a") as f:
        companies_dset = f.create_group("companies")
        first_company_idx = companies[0].id
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_companies)
            ids = []
            super_areas = []
            sectors = []
            n_workers_max = []
            company_idx = [company.id for company in companies[idx1:idx2]]
            # sort companies by id
            companies_sorted = [
                companies[i - first_company_idx] for i in np.sort(company_idx)
            ]
            for company in companies_sorted:
                ids.append(company.id)
                if company.super_area is None:
                    super_areas.append(nan_integer)
                else:
                    super_areas.append(company.super_area.id)
                sectors.append(company.sector.encode("ascii", "ignore"))
                n_workers_max.append(company.n_workers_max)

            ids = np.array(ids, dtype=np.int64)
            super_areas = np.array(super_areas, dtype=np.int64)
            sectors = np.array(sectors, dtype="S10")
            n_workers_max = np.array(n_workers_max, dtype=np.float64)
            if chunk == 0:
                companies_dset.attrs["n_companies"] = n_companies
                companies_dset.create_dataset("id", data=ids, maxshape=(None,))
                companies_dset.create_dataset(
                    "super_area", data=super_areas, maxshape=(None,)
                )
                companies_dset.create_dataset("sector", data=sectors, maxshape=(None,))
                companies_dset.create_dataset(
                    "n_workers_max", data=n_workers_max, maxshape=(None,)
                )
            else:
                newshape = (companies_dset["id"].shape[0] + ids.shape[0],)
                companies_dset["id"].resize(newshape)
                companies_dset["id"][idx1:idx2] = ids
                companies_dset["super_area"].resize(newshape)
                companies_dset["super_area"][idx1:idx2] = super_areas
                companies_dset["sector"].resize(newshape)
                companies_dset["sector"][idx1:idx2] = sectors
                companies_dset["n_workers_max"].resize(newshape)
                companies_dset["n_workers_max"][idx1:idx2] = n_workers_max


def load_companies_from_hdf5(
    file_path: str, chunk_size=50000, domain_super_areas=None, config_filename=None
):
    """
    Loads companies from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """

    Company_Class = Company
    Company_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    logger.info("loading companies...")
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        companies = f["companies"]
        companies_list = []
        n_companies = companies.attrs["n_companies"]
        n_chunks = int(np.ceil(n_companies / chunk_size))
        for chunk in range(n_chunks):
            logger.info(f"Companies chunk {chunk} of {n_chunks}")
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_companies)
            length = idx2 - idx1
            ids = read_dataset(companies["id"], idx1, idx2)
            sectors = read_dataset(companies["sector"], idx1, idx2)
            n_workers_maxs = read_dataset(companies["n_workers_max"], idx1, idx2)
            super_areas = read_dataset(companies["super_area"], idx1, idx2)
            for k in range(length):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                company = Company(
                    super_area=None,
                    n_workers_max=n_workers_maxs[k],
                    sector=sectors[k].decode(),
                )
                company.id = ids[k]
                companies_list.append(company)
    return Companies(companies_list)


def restore_companies_properties_from_hdf5(
    world: World, file_path: str, chunk_size, domain_super_areas=None
):
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        companies = f["companies"]
        n_companies = companies.attrs["n_companies"]
        n_chunks = int(np.ceil(n_companies / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_companies)
            length = idx2 - idx1
            ids = read_dataset(companies["id"], idx1, idx2)
            super_areas = read_dataset(companies["super_area"], idx1, idx2)
            for k in range(length):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                company = world.companies.get_from_id(ids[k])
                if super_areas[k] == nan_integer:
                    company.super_area = None
                else:
                    company.super_area = world.super_areas.get_from_id(super_areas[k])


import h5py
import numpy as np
from collections import defaultdict
from june.world import World

from .utils import read_dataset


def get_commuters_per_super_area(world: World):
    ret = defaultdict(int)
    if world.stations:
        for station in world.stations:
            ret[station.super_area.name] += len(station.commuter_ids)
    if world.cities:
        for city in world.cities:
            n_internal_commuters = len(city.internal_commuter_ids)
            if n_internal_commuters == 0:
                continue
            city_stations = city.city_stations
            n_stations = len(city_stations)
            n_commuters_per_station = n_internal_commuters / n_stations
            for station in city_stations:
                ret[station.super_area.name] += n_commuters_per_station
    return ret


def save_data_for_domain_decomposition(world: World, file_path: str):
    """
    Saves data required to generate a domain decomposition. For each super area,
    we save:
        - Population
        - Number of workers
        - Number of pupils
        - Number of commuters
    """
    super_area_names = []
    super_area_population = []
    super_area_pupils = []
    super_area_workers = []
    super_area_commuters = []
    commuters_per_super_area = get_commuters_per_super_area(world)
    for super_area in world.super_areas:
        super_area_names.append(super_area.name.encode("ascii", "ignore"))
        super_area_population.append(len(super_area.people))
        super_area_workers.append(len(super_area.workers))
        super_area_pupils.append(
            sum(
                [
                    len(school.people)
                    for area in super_area.areas
                    for school in area.schools
                ]
            )
        )
        super_area_commuters.append(commuters_per_super_area[super_area.name])
    super_area_names = np.array(super_area_names, dtype="S20")
    super_area_population = np.array(super_area_population, dtype=np.int64)
    super_area_workers = np.array(super_area_workers, dtype=np.int64)
    super_area_commuters = np.array(super_area_commuters, dtype=np.int64)
    super_area_pupils = np.array(super_area_pupils, dtype=np.int64)
    with h5py.File(file_path, "a") as f:
        group = f.create_group("domain_decomposition_data")
        group.create_dataset("super_area_names", data=super_area_names)
        group.create_dataset("super_area_population", data=super_area_population)
        group.create_dataset("super_area_workers", data=super_area_workers)
        group.create_dataset("super_area_pupils", data=super_area_pupils)
        group.create_dataset("super_area_commuters", data=super_area_commuters)


def load_data_for_domain_decomposition(file_path: str):
    """
    Reads the saved data for the domain decomposition.
    See the docs of read_data_for_domain_decomposition for more information.
    """
    ret = {}
    with h5py.File(file_path, "r") as f:
        data = f["domain_decomposition_data"]
        super_area_names = read_dataset(data["super_area_names"])
        super_area_population = read_dataset(data["super_area_population"])
        super_area_workers = read_dataset(data["super_area_workers"])
        super_area_pupils = read_dataset(data["super_area_pupils"])
        super_area_commuters = read_dataset(data["super_area_commuters"])
        for i in range(len(super_area_names)):
            super_area_name = super_area_names[i].decode()
            sa_dict = {}
            sa_dict["n_people"] = super_area_population[i]
            sa_dict["n_workers"] = super_area_workers[i]
            sa_dict["n_pupils"] = super_area_pupils[i]
            sa_dict["n_commuters"] = super_area_commuters[i]
            ret[super_area_name] = sa_dict
    return ret


import h5py
import numpy as np
from collections import defaultdict

from june.groups import ExternalGroup
from june.geography import (
    Geography,
    Area,
    SuperArea,
    Areas,
    SuperAreas,
    Region,
    Regions,
)
from .utils import read_dataset
from june.world import World

nan_integer = -999
int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))
str_vlen_type = h5py.vlen_dtype(np.dtype("S40"))

spec_to_supergroup_mapper = {
    "pub": "pubs",
    "cinema": "cinemas",
    "grocery": "groceries",
    "gym": "gyms",
}


def save_geography_to_hdf5(geography: Geography, file_path: str):
    """
    Saves the households object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, n_beds, n_icu_beds, super_area, coordinates

    Parameters
    ----------
    companies
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_areas = len(geography.areas)
    area_ids = []
    area_names = []
    area_super_areas = []
    area_coordinates = []
    area_socioeconomic_indices = []
    n_super_areas = len(geography.super_areas)
    super_area_ids = []
    super_area_names = []
    super_area_coordinates = []
    super_area_regions = []
    closest_hospitals_ids = []
    closest_hospitals_super_areas = []
    hospital_lengths = []
    social_venues_specs_list = []
    social_venues_ids_list = []
    social_venues_super_areas = []
    social_venues_lengths = []
    super_area_city = []
    super_area_closest_stations_cities = []
    super_area_closest_stations_stations = []
    super_area_closest_stations_lengths = []
    super_area_n_people = []
    super_area_n_workers = []
    super_area_n_pupils = []
    n_regions = len(geography.regions)
    region_ids = []
    region_names = []

    for area in geography.areas:
        area_ids.append(area.id)
        area_super_areas.append(area.super_area.id)
        area_names.append(area.name.encode("ascii", "ignore"))
        area_coordinates.append(np.array(area.coordinates, dtype=np.float64))
        area_socioeconomic_indices.append(area.socioeconomic_index)
        social_venues_ids = []
        social_venues_specs = []
        social_venues_sas = []
        for spec in area.social_venues.keys():
            for social_venue in area.social_venues[spec]:
                social_venues_specs.append(spec.encode("ascii", "ignore"))
                social_venues_ids.append(social_venue.id)
                social_venues_sas.append(social_venue.super_area.id)
        social_venues_specs_list.append(np.array(social_venues_specs, dtype="S20"))
        social_venues_ids_list.append(np.array(social_venues_ids, dtype=np.int64))
        social_venues_super_areas.append(np.array(social_venues_sas, dtype=np.int64))
        social_venues_lengths.append(len(social_venues_specs))
    if len(np.unique(social_venues_lengths)) == 1:
        social_venues_specs_list = np.array(social_venues_specs_list, dtype="S20")
        social_venues_ids_list = np.array(social_venues_ids_list, dtype=np.int64)
        social_venues_super_areas = np.array(social_venues_super_areas, dtype=np.int64)
    else:
        social_venues_specs_list = np.array(
            social_venues_specs_list, dtype=str_vlen_type
        )
        social_venues_ids_list = np.array(social_venues_ids_list, dtype=int_vlen_type)
        social_venues_super_areas = np.array(
            social_venues_super_areas, dtype=int_vlen_type
        )

    for super_area in geography.super_areas:
        super_area_ids.append(super_area.id)
        super_area_names.append(super_area.name.encode("ascii", "ignore"))
        super_area_regions.append(super_area.region.id)
        super_area_coordinates.append(np.array(super_area.coordinates))
        super_area_n_people.append(len(super_area.people))
        super_area_n_workers.append(len(super_area.workers))
        super_area_n_pupils.append(
            sum(
                [
                    len(school.people)
                    for area in super_area.areas
                    for school in area.schools
                ]
            )
        )
        if super_area.closest_hospitals is None:
            closest_hospitals_ids.append(np.array([nan_integer], dtype=np.int64))
            closest_hospitals_super_areas.append(
                np.array([nan_integer], dtype=np.int64)
            )
            hospital_lengths.append(1)
        else:
            hospital_ids = np.array(
                [hospital.id for hospital in super_area.closest_hospitals],
                dtype=np.int64,
            )
            hospital_sas = np.array(
                [hospital.super_area.id for hospital in super_area.closest_hospitals],
                dtype=np.int64,
            )
            closest_hospitals_ids.append(hospital_ids)
            closest_hospitals_super_areas.append(hospital_sas)
            hospital_lengths.append(len(hospital_ids))
        if super_area.city is None:
            super_area_city.append(nan_integer)
        else:
            super_area_city.append(super_area.city.id)

        for region in geography.regions:
            region_ids.append(region.id)
            region_names.append(region.name)
        cities = []
        stations = []
        for key, value in super_area.closest_inter_city_station_for_city.items():
            cities.append(key.encode("ascii", "ignore"))
            stations.append(value.id)
        super_area_closest_stations_cities.append(cities)
        super_area_closest_stations_stations.append(stations)
        super_area_closest_stations_lengths.append(
            len(super_area.closest_inter_city_station_for_city)
        )

    area_ids = np.array(area_ids, dtype=np.int64)
    area_names = np.array(area_names, dtype="S20")
    area_super_areas = np.array(area_super_areas, dtype=np.int64)
    area_coordinates = np.array(area_coordinates, dtype=np.float64)
    area_socioeconomic_indices = np.array(area_socioeconomic_indices, dtype=np.float64)
    super_area_ids = np.array(super_area_ids, dtype=np.int64)
    super_area_names = np.array(super_area_names, dtype="S20")
    super_area_coordinates = np.array(super_area_coordinates, dtype=np.float64)
    super_area_regions = np.array(super_area_regions, dtype=np.int64)
    super_area_n_people = np.array(super_area_n_people, dtype=np.int64)
    super_area_n_workers = np.array(super_area_n_workers, dtype=np.int64)
    super_area_n_pupils = np.array(super_area_n_pupils, dtype=np.int64)
    region_ids = np.array(region_ids, dtype=np.int64)
    region_names = np.array(region_names, dtype="S50")
    if len(np.unique(hospital_lengths)) == 1:
        closest_hospitals_ids = np.array(closest_hospitals_ids, dtype=np.int64)
        closest_hospitals_super_areas = np.array(
            closest_hospitals_super_areas, dtype=np.int64
        )
    else:
        closest_hospitals_ids = np.array(closest_hospitals_ids, dtype=int_vlen_type)
        closest_hospitals_super_areas = np.array(
            closest_hospitals_super_areas, dtype=int_vlen_type
        )
    super_area_city = np.array(super_area_city, dtype=np.int64)
    if len(np.unique(super_area_closest_stations_lengths)) == 1:
        super_area_closest_stations_cities = np.array(
            super_area_closest_stations_cities, dtype="S40"
        )
        super_area_closest_stations_stations = np.array(
            super_area_closest_stations_stations, dtype=np.int64
        )
    else:
        super_area_closest_stations_cities = np.array(
            super_area_closest_stations_cities, dtype=str_vlen_type
        )
        super_area_closest_stations_stations = np.array(
            super_area_closest_stations_stations, dtype=int_vlen_type
        )

    with h5py.File(file_path, "a") as f:
        geography_dset = f.create_group("geography")
        geography_dset.attrs["n_areas"] = n_areas
        geography_dset.attrs["n_super_areas"] = n_super_areas
        geography_dset.attrs["n_regions"] = n_regions
        geography_dset.create_dataset("area_id", data=area_ids)
        geography_dset.create_dataset("area_name", data=area_names)
        geography_dset.create_dataset("area_super_area", data=area_super_areas)
        geography_dset.create_dataset("area_coordinates", data=area_coordinates)
        geography_dset.create_dataset(
            "area_socioeconomic_indices", data=area_socioeconomic_indices
        )
        geography_dset.create_dataset("super_area_id", data=super_area_ids)
        geography_dset.create_dataset("super_area_name", data=super_area_names)
        geography_dset.create_dataset("super_area_region", data=super_area_regions)
        geography_dset.create_dataset("super_area_city", data=super_area_city)
        geography_dset.create_dataset("super_area_n_people", data=super_area_n_people)
        geography_dset.create_dataset("super_area_n_workers", data=super_area_n_workers)
        geography_dset.create_dataset("super_area_n_pupils", data=super_area_n_pupils)
        geography_dset.create_dataset(
            "super_area_closest_stations_cities",
            data=super_area_closest_stations_cities,
        )
        geography_dset.create_dataset(
            "super_area_closest_stations_stations",
            data=super_area_closest_stations_stations,
        )
        geography_dset.create_dataset(
            "super_area_coordinates", data=super_area_coordinates
        )
        geography_dset.create_dataset(
            "closest_hospitals_ids", data=closest_hospitals_ids
        )
        geography_dset.create_dataset(
            "closest_hospitals_super_areas", data=closest_hospitals_super_areas
        )
        geography_dset.create_dataset("region_id", data=region_ids)
        geography_dset.create_dataset("region_name", data=region_names)
        if social_venues_specs and social_venues_ids:
            geography_dset.create_dataset(
                "social_venues_specs", data=social_venues_specs_list
            )
            geography_dset.create_dataset(
                "social_venues_ids", data=social_venues_ids_list
            )
            geography_dset.create_dataset(
                "social_venues_super_areas", data=social_venues_super_areas
            )


def load_geography_from_hdf5(file_path: str, chunk_size=50000, domain_super_areas=None):
    """
    Loads geography from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        geography = f["geography"]
        n_areas = geography.attrs["n_areas"]
        area_list = []
        n_super_areas = geography.attrs["n_super_areas"]
        n_regions = geography.attrs["n_regions"]
        # areas
        n_chunks = int(np.ceil(n_areas / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_areas)
            length = idx2 - idx1
            area_ids = read_dataset(geography["area_id"], index1=idx1, index2=idx2)
            area_names = read_dataset(geography["area_name"], index1=idx1, index2=idx2)
            area_coordinates = read_dataset(geography["area_coordinates"], idx1, idx2)
            area_socioeconomic_indices = read_dataset(
                geography["area_socioeconomic_indices"], idx1, idx2
            )
            area_super_areas = read_dataset(geography["area_super_area"], idx1, idx2)
            for k in range(length):
                if domain_super_areas is not None:
                    super_area = area_super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                area = Area(
                    name=area_names[k].decode(),
                    super_area=None,
                    coordinates=area_coordinates[k],
                    socioeconomic_index=area_socioeconomic_indices[k],
                )
                area.id = area_ids[k]
                area_list.append(area)
        # super areas
        super_area_list = []
        domain_regions = set()
        n_chunks = int(np.ceil(n_super_areas / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_super_areas)
            length = idx2 - idx1
            super_area_ids = read_dataset(geography["super_area_id"], idx1, idx2)
            super_area_names = read_dataset(geography["super_area_name"], idx1, idx2)
            super_area_regions = read_dataset(
                geography["super_area_region"], idx1, idx2
            )
            super_area_coordinates = read_dataset(
                geography["super_area_coordinates"], idx1, idx2
            )
            for k in range(idx2 - idx1):
                if domain_super_areas is not None:
                    super_area_id = super_area_ids[k]
                    if super_area_id == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area_id not in domain_super_areas:
                        continue
                super_area = SuperArea(
                    name=super_area_names[k].decode(),
                    areas=None,
                    coordinates=super_area_coordinates[k],
                )
                super_area.id = super_area_ids[k]
                super_area_list.append(super_area)
                domain_regions.add(super_area_regions[k])
        # regions
        region_list = []
        n_chunks = int(np.ceil(n_regions / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_regions)
            length = idx2 - idx1
            region_ids = read_dataset(geography["region_id"], idx1, idx2)
            region_names = read_dataset(geography["region_name"], idx1, idx2)
            for k in range(idx2 - idx1):
                if domain_super_areas is not None:
                    region_id = region_ids[k]
                    if region_id == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones regions."
                        )
                    if region_id not in domain_regions:
                        continue
                region = Region(name=region_names[k].decode(), super_areas=None)
                region.id = region_ids[k]
                region_list.append(region)

    areas = Areas(area_list)
    super_areas = SuperAreas(super_area_list)
    regions = Regions(region_list)
    return Geography(areas, super_areas, regions)


def restore_geography_properties_from_hdf5(
    world: World,
    file_path: str,
    chunk_size,
    domain_super_areas=None,
    super_areas_to_domain_dict: dict = None,
):
    """
    Long function to restore geographic attributes to the world's geography.
    The closest hospitals, commuting cities, stations, and social venues are restored
    to areas and super areas. For the cases that the super areas would be outside the
    simulated domain, the instances of cities,stations, etc. are substituted by
    external groups, which point to the domain where they are at.
    """
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        geography = f["geography"]
        n_areas = geography.attrs["n_areas"]
        n_chunks = int(np.ceil(n_areas / chunk_size))
        # areas
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_areas)
            length = idx2 - idx1
            areas_ids = read_dataset(geography["area_id"], idx1, idx2)
            super_areas = read_dataset(geography["area_super_area"], idx1, idx2)
            if "social_venues_specs" in geography and "social_venues_ids" in geography:
                social_venues_specs = read_dataset(
                    geography["social_venues_specs"], idx1, idx2
                )
                social_venues_ids = read_dataset(
                    geography["social_venues_ids"], idx1, idx2
                )
                # TODO:
                social_venues_super_areas = read_dataset(
                    geography["social_venues_super_areas"], idx1, idx2
                )
            for k in range(length):
                if domain_super_areas is not None:
                    super_area_id = super_areas[k]
                    if super_area_id == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area_id not in domain_super_areas:
                        continue
                super_area = world.super_areas.get_from_id(super_areas[k])
                area = world.areas.get_from_id(areas_ids[k])
                area.super_area = super_area
                area.super_area.areas.append(area)
                # social venues
                area.social_venues = defaultdict(tuple)
                if (
                    "social_venues_specs" in geography
                    and "social_venues_ids" in geography
                ):
                    for group_spec, group_id, group_super_area in zip(
                        social_venues_specs[k],
                        social_venues_ids[k],
                        social_venues_super_areas[k],
                    ):
                        spec = group_spec.decode()
                        spec_mapped = spec_to_supergroup_mapper[spec]
                        supergroup = getattr(world, spec_mapped)
                        if (
                            domain_super_areas is not None
                            and group_super_area not in domain_super_areas
                        ):

                            domain_of_group = super_areas_to_domain_dict[
                                group_super_area
                            ]
                            group = ExternalGroup(
                                id=group_id, domain_id=domain_of_group, spec=spec
                            )
                        else:
                            group = supergroup.get_from_id(group_id)
                        area.social_venues[spec] = (*area.social_venues[spec], group)
        n_super_areas = geography.attrs["n_super_areas"]
        n_chunks = int(np.ceil(n_super_areas / chunk_size))
        # areas
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_super_areas)
            length = idx2 - idx1
            super_area_ids = read_dataset(
                geography["super_area_id"], index1=idx1, index2=idx2
            )
            regions = read_dataset(
                geography["super_area_region"], index1=idx1, index2=idx2
            )
            for k in range(length):
                if domain_super_areas is not None:
                    super_area_id = super_area_ids[k]
                    if super_area_id == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area_id not in domain_super_areas:
                        continue
                super_area = world.super_areas.get_from_id(super_area_ids[k])
                region = world.regions.get_from_id(regions[k])
                super_area.region = region
                region.super_areas.append(super_area)


import h5py
import numpy as np

from june.world import World
from june.groups import Hospital, Hospitals, ExternalHospital
from june.groups.group.make_subgroups import SubgroupParams
from .utils import read_dataset

nan_integer = -999


def save_hospitals_to_hdf5(
    hospitals: Hospitals, file_path: str, chunk_size: int = 50000
):
    """
    Saves the Hospitals object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, n_beds, n_icu_beds, area, coordinates

    Parameters
    ----------
    companies
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_hospitals = len(hospitals)
    n_chunks = int(np.ceil(n_hospitals / chunk_size))
    with h5py.File(file_path, "a") as f:
        hospitals_dset = f.create_group("hospitals")
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_hospitals)
            ids = []
            n_beds = []
            n_icu_beds = []
            areas = []
            super_areas = []
            region_names = []
            coordinates = []
            trust_code = []
            for hospital in hospitals[idx1:idx2]:
                ids.append(hospital.id)
                if hospital.area is None:
                    areas.append(nan_integer)
                    super_areas.append(nan_integer)
                    region_names.append(nan_integer)
                else:
                    areas.append(hospital.area.id)
                    super_areas.append(hospital.super_area.id)
                    region_names.append(hospital.region_name)
                n_beds.append(hospital.n_beds)
                n_icu_beds.append(hospital.n_icu_beds)
                coordinates.append(np.array(hospital.coordinates))
                trust_code.append(hospital.trust_code)

            ids = np.array(ids, dtype=np.int64)
            areas = np.array(areas, dtype=np.int64)
            super_areas = np.array(super_areas, dtype=np.int64)
            region_names = np.array(region_names, dtype="S50")
            trust_code = np.array(trust_code, dtype="S10")
            n_beds = np.array(n_beds, dtype=np.int64)
            n_icu_beds = np.array(n_icu_beds, dtype=np.int64)
            coordinates = np.array(coordinates, dtype=np.float64)
            if chunk == 0:
                hospitals_dset.attrs["n_hospitals"] = n_hospitals
                hospitals_dset.create_dataset("id", data=ids, maxshape=(None,))
                hospitals_dset.create_dataset("area", data=areas, maxshape=(None,))
                hospitals_dset.create_dataset(
                    "super_area", data=super_areas, maxshape=(None,)
                )
                hospitals_dset.create_dataset(
                    "region_name", data=region_names, maxshape=(None,)
                )
                hospitals_dset.create_dataset(
                    "trust_code", data=trust_code, maxshape=(None,)
                )
                hospitals_dset.create_dataset("n_beds", data=n_beds, maxshape=(None,))
                hospitals_dset.create_dataset(
                    "n_icu_beds", data=n_icu_beds, maxshape=(None,)
                )
                hospitals_dset.create_dataset(
                    "coordinates",
                    data=coordinates,
                    maxshape=(None, coordinates.shape[1]),
                )
            else:
                newshape = (hospitals_dset["id"].shape[0] + ids.shape[0],)
                hospitals_dset["id"].resize(newshape)
                hospitals_dset["id"][idx1:idx2] = ids
                hospitals_dset["area"].resize(newshape)
                hospitals_dset["area"][idx1:idx2] = areas
                hospitals_dset["super_area"].resize(newshape)
                hospitals_dset["super_area"][idx1:idx2] = super_areas
                hospitals_dset["region_name"].resize(newshape)
                hospitals_dset["region_name"][idx1:idx2] = region_names
                hospitals_dset["trust_code"].resize(newshape)
                hospitals_dset["trust_code"][idx1:idx2] = trust_code
                hospitals_dset["n_beds"].resize(newshape)
                hospitals_dset["n_beds"][idx1:idx2] = n_beds
                hospitals_dset["n_icu_beds"].resize(newshape)
                hospitals_dset["n_icu_beds"][idx1:idx2] = n_icu_beds
                hospitals_dset["coordinates"].resize(newshape[0], axis=0)
                hospitals_dset["coordinates"][idx1:idx2] = coordinates


def load_hospitals_from_hdf5(
    file_path: str,
    chunk_size=50000,
    domain_super_areas=None,
    super_areas_to_domain_dict: dict = None,
    config_filename=None,
):
    """
    Loads companies from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    Hospital_Class = Hospital
    Hospital_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )
    ExternalHospital_Class = ExternalHospital

    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        hospitals = f["hospitals"]
        hospitals_list = []
        chunk_size = 50000
        n_hospitals = hospitals.attrs["n_hospitals"]
        n_chunks = int(np.ceil(n_hospitals / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_hospitals)
            ids = read_dataset(hospitals["id"], idx1, idx2)
            n_beds_list = read_dataset(hospitals["n_beds"], idx1, idx2)
            n_icu_beds_list = read_dataset(hospitals["n_icu_beds"], idx1, idx2)
            trust_codes = read_dataset(hospitals["trust_code"], idx1, idx2)
            coordinates = read_dataset(hospitals["coordinates"], idx1, idx2)
            super_areas = read_dataset(hospitals["super_area"], idx1, idx2)
            region_name = read_dataset(hospitals["region_name"], idx1, idx2)
            for k in range(idx2 - idx1):
                super_area = super_areas[k]
                if super_area == nan_integer:
                    raise ValueError(
                        "if ``domain_super_areas`` is True, I expect not Nones super areas."
                    )
                trust_code = trust_codes[k]
                if trust_code.decode() == " ":
                    trust_code = None
                else:
                    trust_code = trust_code.decode()

                if (
                    domain_super_areas is not None
                    and super_area not in domain_super_areas
                ):
                    hospital = ExternalHospital_Class(
                        id=ids[k],
                        spec="hospital",
                        domain_id=super_areas_to_domain_dict[super_area],
                        region_name=region_name[k].decode(),
                    )
                else:
                    hospital = Hospital_Class(
                        n_beds=n_beds_list[k],
                        n_icu_beds=n_icu_beds_list[k],
                        coordinates=coordinates[k],
                        trust_code=trust_code,
                    )
                    hospital.id = ids[k]
                hospitals_list.append(hospital)
    return Hospitals(hospitals_list, ball_tree=False)


def restore_hospital_properties_from_hdf5(
    world: World,
    file_path: str,
    chunk_size=50000,
    domain_super_areas=None,
    domain_areas=None,
    super_areas_to_domain_dict: dict = None,
):
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        hospitals = f["hospitals"]
        n_hospitals = hospitals.attrs["n_hospitals"]
        n_chunks = int(np.ceil(n_hospitals / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_hospitals)
            length = idx2 - idx1
            ids = np.empty(length, dtype=int)
            hospitals["id"].read_direct(ids, np.s_[idx1:idx2], np.s_[0:length])
            areas = np.empty(length, dtype=int)
            hospitals["area"].read_direct(areas, np.s_[idx1:idx2], np.s_[0:length])
            for k in range(length):
                if domain_areas is not None:
                    area = areas[k]
                    if area == nan_integer:
                        raise ValueError(
                            "if ``domain_areas`` is True, I expect not Nones areas."
                        )
                    if area not in domain_areas:
                        continue
                hospital = world.hospitals.get_from_id(ids[k])
                area = areas[k]
                if area == nan_integer:
                    area = None
                else:
                    area = world.areas.get_from_id(area)
                hospital.area = area

        # super areas
        geography = f["geography"]
        n_super_areas = geography.attrs["n_super_areas"]
        n_chunks = int(np.ceil(n_super_areas / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_super_areas)
            length = idx2 - idx1
            super_areas_ids = read_dataset(geography["super_area_id"], idx1, idx2)
            closest_hospitals_ids = read_dataset(
                geography["closest_hospitals_ids"], idx1, idx2
            )
            closest_hospitals_super_areas = read_dataset(
                geography["closest_hospitals_super_areas"], idx1, idx2
            )
            for k in range(length):
                if domain_super_areas is not None:
                    super_area_id = super_areas_ids[k]
                    if super_area_id == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super_areas."
                        )
                    if super_area_id not in domain_super_areas:
                        continue
                super_area = world.super_areas.get_from_id(super_areas_ids[k])
                # load closest hospitals
                hospitals = []
                for hospital_id, hospital_super_area_id in zip(
                    closest_hospitals_ids[k], closest_hospitals_super_areas[k]
                ):
                    hospital = world.hospitals.get_from_id(hospital_id)
                    hospitals.append(hospital)
                super_area.closest_hospitals = hospitals


import h5py
import numpy as np
import logging
from itertools import chain

from june.world import World
from june.groups import Household, Households, ExternalGroup
from june.groups.group.make_subgroups import SubgroupParams
from june.mpi_setup import mpi_rank
from .utils import read_dataset

nan_integer = -999

int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))
str_vlen_type = h5py.vlen_dtype(np.dtype("S20"))
logger = logging.getLogger("household_saver")
if mpi_rank > 0:
    logger.propagate = False


def save_households_to_hdf5(
    households: Households, file_path: str, chunk_size: int = 50000
):
    """
    Saves the households object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, n_beds, n_icu_beds, super_area, coordinates

    Parameters
    ----------
    companies
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_households = len(households)
    n_chunks = int(np.ceil(n_households / chunk_size))
    with h5py.File(file_path, "a") as f:
        households_dset = f.create_group("households")
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_households)
            ids = []
            areas = []
            super_areas = []
            types = []
            composition_types = []
            max_sizes = []
            for household in households[idx1:idx2]:
                ids.append(household.id)
                if household.area is None:
                    areas.append(nan_integer)
                    super_areas.append(nan_integer)
                else:
                    areas.append(household.area.id)
                    super_areas.append(household.super_area.id)
                if household.type is None:
                    types.append(" ".encode("ascii", "ignore"))
                else:
                    types.append(household.type.encode("ascii", "ignore"))
                if household.composition_type is None:
                    composition_types.append(" ".encode("ascii", "ignore"))
                else:
                    composition_types.append(
                        household.composition_type.encode("ascii", "ignore")
                    )
                max_sizes.append(household.max_size)

            ids = np.array(ids, dtype=np.int64)
            areas = np.array(areas, dtype=np.int64)
            super_areas = np.array(super_areas, dtype=np.int64)
            types = np.array(types, dtype="S20")
            composition_types = np.array(composition_types, dtype="S20")
            max_sizes = np.array(max_sizes, dtype=np.float64)
            if chunk == 0:
                households_dset.attrs["n_households"] = n_households
                households_dset.create_dataset("id", data=ids, maxshape=(None,))
                households_dset.create_dataset("area", data=areas, maxshape=(None,))
                households_dset.create_dataset(
                    "super_area", data=super_areas, maxshape=(None,)
                )
                households_dset.create_dataset("type", data=types, maxshape=(None,))
                households_dset.create_dataset(
                    "composition_type", data=composition_types, maxshape=(None,)
                )
                households_dset.create_dataset(
                    "max_size", data=max_sizes, maxshape=(None,)
                )

            else:
                newshape = (households_dset["id"].shape[0] + ids.shape[0],)
                households_dset["id"].resize(newshape)
                households_dset["id"][idx1:idx2] = ids
                households_dset["area"].resize(newshape)
                households_dset["area"][idx1:idx2] = areas
                households_dset["super_area"].resize(newshape)
                households_dset["super_area"][idx1:idx2] = super_areas
                households_dset["type"].resize(newshape)
                households_dset["type"][idx1:idx2] = types
                households_dset["composition_type"].resize(newshape)
                households_dset["composition_type"][idx1:idx2] = composition_types
                households_dset["max_size"].resize(newshape)
                households_dset["max_size"][idx1:idx2] = max_sizes

        residences_to_visit_specs = []
        residences_to_visit_ids = []
        residences_to_visit_super_areas = []
        for household in households:
            if not household.residences_to_visit:
                residences_to_visit_specs.append(np.array(["none"], dtype="S20"))
                residences_to_visit_ids.append(np.array([nan_integer], dtype=np.int64))
                residences_to_visit_super_areas.append(
                    np.array([nan_integer], dtype=np.int64)
                )
            else:
                to_visit_ids = []
                to_visit_specs = []
                to_visit_super_areas = []
                for residence_type in household.residences_to_visit:
                    for residence_to_visit in household.residences_to_visit[
                        residence_type
                    ]:
                        to_visit_specs.append(residence_type)
                        to_visit_ids.append(residence_to_visit.id)
                        to_visit_super_areas.append(residence_to_visit.super_area.id)
                residences_to_visit_specs.append(np.array(to_visit_specs, dtype="S20"))
                residences_to_visit_ids.append(np.array(to_visit_ids, dtype=np.int64))
                residences_to_visit_super_areas.append(
                    np.array(to_visit_super_areas, dtype=np.int64)
                )

        if len(np.unique(list(chain(*residences_to_visit_ids)))) > 1:
            residences_to_visit_ids = np.array(
                residences_to_visit_ids, dtype=int_vlen_type
            )
            residences_to_visit_specs = np.array(
                residences_to_visit_specs, dtype=str_vlen_type
            )
            residences_to_visit_super_areas = np.array(
                residences_to_visit_super_areas, dtype=int_vlen_type
            )
        else:
            residences_to_visit_ids = np.array(residences_to_visit_ids, dtype=np.int64)
            residences_to_visit_specs = np.array(residences_to_visit_specs, dtype="S20")
            residences_to_visit_super_areas = np.array(
                residences_to_visit_super_areas, dtype=np.int64
            )
        households_dset.create_dataset(
            "residences_to_visit_ids", data=residences_to_visit_ids
        )
        households_dset.create_dataset(
            "residences_to_visit_specs", data=residences_to_visit_specs
        )
        households_dset.create_dataset(
            "residences_to_visit_super_areas", data=residences_to_visit_super_areas
        )


def load_households_from_hdf5(
    file_path: str, chunk_size=50000, domain_super_areas=None, config_filename=None
):
    """
    Loads households from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """

    Household_Class = Household
    Household_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    logger.info("loading households...")
    households_list = []
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        households = f["households"]
        n_households = households.attrs["n_households"]
        n_chunks = int(np.ceil(n_households / chunk_size))
        for chunk in range(n_chunks):
            logger.info(f"Loaded chunk {chunk} of {n_chunks}")
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_households)
            length = idx2 - idx1
            ids = read_dataset(households["id"], idx1, idx2)
            types = read_dataset(households["type"], idx1, idx2)
            composition_types = read_dataset(households["composition_type"], idx1, idx2)
            max_sizes = read_dataset(households["max_size"], idx1, idx2)
            super_areas = read_dataset(households["super_area"], idx1, idx2)
            for k in range(length):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                household = Household_Class(
                    area=None,
                    type=types[k].decode(),
                    max_size=max_sizes[k],
                    composition_type=composition_types[k].decode(),
                )
                households_list.append(household)
                household.id = ids[k]
    return Households(households_list)


def restore_households_properties_from_hdf5(
    world: World,
    file_path: str,
    chunk_size=50000,
    domain_super_areas=None,
    super_areas_to_domain_dict: dict = None,
):
    """
    Loads households from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    logger.info("restoring households...")
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        households = f["households"]
        n_households = households.attrs["n_households"]
        n_chunks = int(np.ceil(n_households / chunk_size))
        for chunk in range(n_chunks):
            logger.info(f"Restored chunk {chunk} of {n_chunks}")
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_households)
            length = idx2 - idx1
            ids = read_dataset(households["id"], idx1, idx2)
            super_areas = read_dataset(households["super_area"], idx1, idx2)
            areas = read_dataset(households["area"], idx1, idx2)
            residences_to_visit_ids = read_dataset(
                households["residences_to_visit_ids"], idx1, idx2
            )
            residences_to_visit_specs = read_dataset(
                households["residences_to_visit_specs"], idx1, idx2
            )
            residences_to_visit_super_areas = read_dataset(
                households["residences_to_visit_super_areas"], idx1, idx2
            )
            for k in range(length):
                if domain_super_areas is not None:
                    """
                    Note: if the relatives live outside the super area this will fail.
                    """
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                household = world.households.get_from_id(ids[k])
                area = world.areas.get_from_id(areas[k])
                household.area = area
                area.households.append(household)
                household.residents = tuple(household.people)
                # visits
                visit_ids = residences_to_visit_ids[k]
                if visit_ids[0] == nan_integer:
                    continue
                visit_specs = residences_to_visit_specs[k]
                visit_super_areas = residences_to_visit_super_areas[k]
                for visit_id, visit_spec, visit_super_area in zip(
                    visit_ids, visit_specs, visit_super_areas
                ):
                    if (
                        domain_super_areas is not None
                        and visit_super_area not in domain_super_areas
                    ):
                        residence = ExternalGroup(
                            id=visit_id,
                            domain_id=super_areas_to_domain_dict[visit_super_area],
                            spec=visit_spec.decode(),
                        )
                    else:
                        visit_spec = visit_spec.decode()
                        if visit_spec == "household":
                            residence = world.households.get_from_id(visit_id)
                        elif visit_spec == "care_home":
                            residence = world.care_homes.get_from_id(visit_id)
                    household.residences_to_visit[visit_spec] = (
                        *household.residences_to_visit[visit_spec],
                        residence,
                    )


import h5py
import numpy as np
from typing import List
from june.groups.group.make_subgroups import SubgroupParams

from .utils import read_dataset
from june.groups.leisure import (
    Pub,
    Pubs,
    Grocery,
    Groceries,
    Cinema,
    Cinemas,
    Gym,
    Gyms,
    SocialVenues,
)
from june.world import World

nan_integer = -999


def save_social_venues_to_hdf5(social_venues_list: List[SocialVenues], file_path: str):
    with h5py.File(file_path, "a") as f:
        f.create_group("social_venues")
        for social_venues in social_venues_list:
            n_svs = len(social_venues)
            social_venues_dset = f["social_venues"].create_group(social_venues.spec)
            ids = []
            coordinates = []
            areas = []
            for sv in social_venues:
                ids.append(sv.id)
                coordinates.append(np.array(sv.coordinates, dtype=np.float64))
                if sv.super_area is None:
                    areas.append(nan_integer)
                else:
                    areas.append(sv.area.id)
            ids = np.array(ids, dtype=np.int64)
            coordinates = np.array(coordinates, dtype=np.float64)
            areas = np.array(areas, dtype=np.int64)
            social_venues_dset.attrs["n"] = n_svs
            social_venues_dset.create_dataset("id", data=ids)
            social_venues_dset.create_dataset("coordinates", data=coordinates)
            social_venues_dset.create_dataset("area", data=areas)


def load_social_venues_from_hdf5(
    file_path: str, domain_areas=None, config_filename=None
):
    social_venues_dict = {}

    Pub_Class = Pub
    Pub_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    Cinema_Class = Cinema
    Cinema_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    Grocery_Class = Grocery
    Grocery_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    Gym_Class = Gym
    Gym_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    spec_to_group_dict = {
        "pubs": Pub_Class,
        "cinemas": Cinema_Class,
        "groceries": Grocery_Class,
        "gyms": Gym_Class,
    }
    spec_to_supergroup_dict = {
        "pubs": Pubs,
        "cinemas": Cinemas,
        "groceries": Groceries,
        "gyms": Gyms,
    }

    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        for spec in f["social_venues"]:
            data = f["social_venues"][spec]
            social_venues = []
            n = data.attrs["n"]
            if n == 0:
                social_venues_dict[spec] = None
                continue
            ids = read_dataset(data["id"])
            coordinates = read_dataset(data["coordinates"])
            areas = read_dataset(data["area"])
            for k in range(n):
                if domain_areas is not None:
                    area = areas[k]
                    if area == nan_integer:
                        raise ValueError(
                            "if ``domain_areas`` is True, I expect not Nones super areas."
                        )
                    if area not in domain_areas:
                        continue
                social_venue = spec_to_group_dict[spec]()
                social_venue.id = ids[k]
                social_venue.coordinates = coordinates[k]
                social_venues.append(social_venue)
            social_venues_dict[spec] = spec_to_supergroup_dict[spec](social_venues)
        return social_venues_dict


def restore_social_venues_properties_from_hdf5(
    world: World, file_path: str, domain_areas=None
):
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        for spec in f["social_venues"]:
            data = f["social_venues"][spec]
            n = data.attrs["n"]
            if n == 0:
                continue
            social_venues = getattr(world, spec)
            ids = read_dataset(data["id"])
            areas = read_dataset(data["area"])
            for k in range(n):
                if domain_areas is not None:
                    area = areas[k]
                    if area == nan_integer:
                        raise ValueError(
                            "if ``domain_areas`` is True, I expect not Nones super areas."
                        )
                    if area not in domain_areas:
                        continue
                social_venue = social_venues.get_from_id(ids[k])
                area = areas[k]
                if area == nan_integer:
                    area = None
                else:
                    area = world.areas.get_from_id(area)
                social_venue.area = area


import h5py
import numpy as np
import logging


from .utils import read_dataset
from june.groups import ExternalSubgroup, ExternalGroup
from june.groups.travel import ModeOfTransport
from june.demography import Population, Person
from june.demography.person import Activities
from june.geography import ExternalSuperArea
from june.world import World
from june.mpi_setup import mpi_rank

logger = logging.getLogger("population saver")
if mpi_rank > 0:
    logger.propagate = False

nan_integer = -999  # only used to store/load hdf5 integer arrays with inf/nan values
spec_mapper = {
    "hospital": "hospitals",
    "company": "companies",
    "school": "schools",
    "household": "households",
    "care_home": "care_homes",
    "university": "universities",
    "pub": "pubs",
    "grocery": "groceries",
    "cinema": "cinemas",
}


def save_population_to_hdf5(
    population: Population, file_path: str, chunk_size: int = 100000
):
    """
    Saves the Population object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, age, sex, ethnicity, area, subgroup memberships ids, housemate ids, mode_of_transport,

    Parameters
    ----------
    population
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_people = len(population.people)
    n_chunks = int(np.ceil(n_people / chunk_size))
    with h5py.File(file_path, "a") as f:
        people_dset = f.create_group("population")
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_people)
            ids = []
            ages = []
            sexes = []
            ethns = []
            areas = []
            super_areas = []
            work_super_areas = []
            work_super_areas_cities = []
            work_super_area_coords = []
            sectors = []
            sub_sectors = []
            group_ids = []
            group_specs = []
            group_super_areas = []
            subgroup_types = []
            mode_of_transport_description = []
            mode_of_transport_is_public = []
            lockdown_status = []

            for person in population.people[idx1:idx2]:
                ids.append(person.id)
                ages.append(person.age)
                sexes.append(person.sex.encode("ascii", "ignore"))
                if person.ethnicity is None:
                    ethns.append(" ".encode("ascii", "ignore"))
                else:
                    ethns.append(person.ethnicity.encode("ascii", "ignore"))
                if person.area is not None:
                    areas.append(person.area.id)
                    super_areas.append(person.area.super_area.id)
                else:
                    areas.append(nan_integer)
                    super_areas.append(nan_integer)
                if person.work_super_area is not None:
                    work_super_areas.append(person.work_super_area.id)
                    work_super_area_coords.append(
                        np.array(person.work_super_area.coordinates, dtype=np.float64)
                    )
                    if person.work_super_area.city is not None:
                        work_super_areas_cities.append(person.work_super_area.city.id)
                    else:
                        work_super_areas_cities.append(nan_integer)
                else:
                    work_super_areas.append(nan_integer)
                    work_super_areas_cities.append(nan_integer)
                    work_super_area_coords.append(
                        np.array([nan_integer, nan_integer], dtype=np.float64)
                    )
                if person.sector is None:
                    sectors.append(" ".encode("ascii", "ignore"))
                else:
                    sectors.append(person.sector.encode("ascii", "ignore"))
                if person.sub_sector is None:
                    sub_sectors.append(" ".encode("ascii", "ignore"))
                else:
                    sub_sectors.append(person.sub_sector.encode("ascii", "ignore"))
                if person.lockdown_status is None:
                    lockdown_status.append(" ".encode("ascii", "ignore"))
                else:
                    lockdown_status.append(
                        person.lockdown_status.encode("ascii", "ignore")
                    )
                gids = []
                stypes = []
                specs = []
                group_super_areas_temp = []
                for subgroup in person.subgroups.iter():
                    if subgroup is None:
                        gids.append(nan_integer)
                        stypes.append(nan_integer)
                        specs.append(" ".encode("ascii", "ignore"))
                        group_super_areas_temp.append(nan_integer)
                    else:
                        gids.append(subgroup.group.id)
                        stypes.append(subgroup.subgroup_type)
                        specs.append(subgroup.group.spec.encode("ascii", "ignore"))
                        if subgroup.group.super_area is None:
                            group_super_areas_temp.append(nan_integer)
                        else:
                            group_super_areas_temp.append(subgroup.group.super_area.id)
                group_specs.append(np.array(specs, dtype="S20"))
                group_ids.append(np.array(gids, dtype=np.int64))
                subgroup_types.append(np.array(stypes, dtype=np.int64))
                group_super_areas.append(
                    np.array(group_super_areas_temp, dtype=np.int64)
                )
                if person.mode_of_transport is None:
                    mode_of_transport_description.append(" ".encode("ascii", "ignore"))
                    mode_of_transport_is_public.append(False)
                else:
                    mode_of_transport_description.append(
                        person.mode_of_transport.description.encode("ascii", "ignore")
                    )
                    mode_of_transport_is_public.append(
                        person.mode_of_transport.is_public
                    )

            ids = np.array(ids, dtype=np.int64)
            ages = np.array(ages, dtype=np.int64)
            sexes = np.array(sexes, dtype="S10")
            ethns = np.array(ethns, dtype="S10")
            areas = np.array(areas, dtype=np.int64)
            super_areas = np.array(super_areas, dtype=np.int64)
            work_super_areas = np.array(work_super_areas, dtype=np.int64)
            work_super_areas_cities = np.array(work_super_areas_cities, dtype=np.int64)
            work_super_area_coords = np.array(work_super_area_coords, dtype=np.float64)
            group_ids = np.array(group_ids, dtype=np.int64)
            subgroup_types = np.array(subgroup_types, dtype=np.int64)
            group_specs = np.array(group_specs, dtype="S20")
            group_super_areas = np.array(group_super_areas, dtype=np.int64)
            sectors = np.array(sectors, dtype="S30")
            sub_sectors = np.array(sub_sectors, dtype="S30")
            mode_of_transport_description = np.array(
                mode_of_transport_description, dtype="S100"
            )
            mode_of_transport_is_public = np.array(
                mode_of_transport_is_public, dtype=bool
            )
            lockdown_status = np.array(lockdown_status, dtype="S20")

            if chunk == 0:
                people_dset.attrs["n_people"] = n_people
                people_dset.create_dataset("id", data=ids, maxshape=(None,))
                people_dset.create_dataset("age", data=ages, maxshape=(None,))
                people_dset.create_dataset("sex", data=sexes, maxshape=(None,))
                people_dset.create_dataset("sector", data=sectors, maxshape=(None,))
                people_dset.create_dataset(
                    "sub_sector", data=sub_sectors, maxshape=(None,)
                )
                people_dset.create_dataset("ethnicity", data=ethns, maxshape=(None,))
                people_dset.create_dataset(
                    "group_ids", data=group_ids, maxshape=(None, group_ids.shape[1])
                )
                people_dset.create_dataset(
                    "group_specs",
                    data=group_specs,
                    maxshape=(None, group_specs.shape[1]),
                )
                people_dset.create_dataset(
                    "subgroup_types",
                    data=subgroup_types,
                    maxshape=(None, subgroup_types.shape[1]),
                )
                people_dset.create_dataset(
                    "group_super_areas",
                    data=group_super_areas,
                    maxshape=(None, subgroup_types.shape[1]),
                )
                people_dset.create_dataset("area", data=areas, maxshape=(None,))
                people_dset.create_dataset(
                    "super_area", data=super_areas, maxshape=(None,)
                )
                people_dset.create_dataset(
                    "work_super_area", data=work_super_areas, maxshape=(None,)
                )
                people_dset.create_dataset(
                    "work_super_area_coords",
                    data=work_super_area_coords,
                    maxshape=(None, work_super_area_coords.shape[1]),
                )
                people_dset.create_dataset(
                    "work_super_area_city",
                    data=work_super_areas_cities,
                    maxshape=(None,),
                )
                people_dset.create_dataset(
                    "mode_of_transport_description",
                    data=mode_of_transport_description,
                    maxshape=(None,),
                )
                people_dset.create_dataset(
                    "mode_of_transport_is_public",
                    data=mode_of_transport_is_public,
                    maxshape=(None,),
                )
                people_dset.create_dataset(
                    "lockdown_status", data=lockdown_status, maxshape=(None,)
                )
            else:
                newshape = (people_dset["id"].shape[0] + ids.shape[0],)
                people_dset["id"].resize(newshape)
                people_dset["id"][idx1:idx2] = ids
                people_dset["age"].resize(newshape)
                people_dset["age"][idx1:idx2] = ages
                people_dset["sex"].resize(newshape)
                people_dset["sex"][idx1:idx2] = sexes
                people_dset["ethnicity"].resize(newshape)
                people_dset["ethnicity"][idx1:idx2] = ethns
                people_dset["sector"].resize(newshape)
                people_dset["sector"][idx1:idx2] = sectors
                people_dset["sub_sector"].resize(newshape)
                people_dset["sub_sector"][idx1:idx2] = sub_sectors
                people_dset["area"].resize(newshape)
                people_dset["area"][idx1:idx2] = areas
                people_dset["super_area"].resize(newshape)
                people_dset["super_area"][idx1:idx2] = super_areas
                people_dset["work_super_area"].resize(newshape)
                people_dset["work_super_area"][idx1:idx2] = work_super_areas
                people_dset["work_super_area_coords"].resize(newshape[0], axis=0)
                people_dset["work_super_area_coords"][
                    idx1:idx2
                ] = work_super_area_coords
                people_dset["work_super_area_city"].resize(newshape)
                people_dset["work_super_area_city"][idx1:idx2] = work_super_areas_cities
                people_dset["group_ids"].resize(newshape[0], axis=0)
                people_dset["group_ids"][idx1:idx2] = group_ids
                people_dset["group_specs"].resize(newshape[0], axis=0)
                people_dset["group_specs"][idx1:idx2] = group_specs
                people_dset["subgroup_types"].resize(newshape[0], axis=0)
                people_dset["subgroup_types"][idx1:idx2] = subgroup_types
                people_dset["group_super_areas"].resize(newshape[0], axis=0)
                people_dset["group_super_areas"][idx1:idx2] = group_super_areas
                people_dset["mode_of_transport_description"].resize(newshape)
                people_dset["mode_of_transport_description"][
                    idx1:idx2
                ] = mode_of_transport_description
                people_dset["mode_of_transport_is_public"].resize(newshape)
                people_dset["mode_of_transport_is_public"][
                    idx1:idx2
                ] = mode_of_transport_is_public
                people_dset["lockdown_status"].resize(newshape)
                people_dset["lockdown_status"][idx1:idx2] = lockdown_status


def load_population_from_hdf5(
    file_path: str, chunk_size=100000, domain_super_areas=None
):
    """
    Loads the population from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """
    people = []
    logger.info("loading population...")
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        # people = []
        population = f["population"]
        # read in chunks of 100k people
        n_people = population.attrs["n_people"]
        n_chunks = int(np.ceil(n_people / chunk_size))
        for chunk in range(n_chunks):
            logger.info(f"Loaded chunk {chunk} of {n_chunks}")
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_people)
            ids = read_dataset(population["id"], idx1, idx2)
            ages = read_dataset(population["age"], idx1, idx2)
            sexes = read_dataset(population["sex"], idx1, idx2)
            ethns = read_dataset(population["ethnicity"], idx1, idx2)
            super_areas = read_dataset(population["super_area"], idx1, idx2)
            sectors = read_dataset(population["sector"], idx1, idx2)
            sub_sectors = read_dataset(population["sub_sector"], idx1, idx2)
            lockdown_status = read_dataset(population["lockdown_status"], idx1, idx2)
            mode_of_transport_is_public_list = read_dataset(
                population["mode_of_transport_is_public"], idx1, idx2
            )
            mode_of_transport_description_list = read_dataset(
                population["mode_of_transport_description"], idx1, idx2
            )
            for k in range(idx2 - idx1):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                if ethns[k].decode() == " ":
                    ethn = None
                else:
                    ethn = ethns[k].decode()
                person = Person.from_attributes(
                    id=ids[k], age=ages[k], sex=sexes[k].decode(), ethnicity=ethn
                )
                people.append(person)
                mode_of_transport_description = mode_of_transport_description_list[k]
                mode_of_transport_is_public = mode_of_transport_is_public_list[k]
                # mode of transport
                if mode_of_transport_description.decode() == " ":
                    person.mode_of_transport = None
                else:
                    person.mode_of_transport = ModeOfTransport(
                        description=mode_of_transport_description.decode(),
                        is_public=mode_of_transport_is_public,
                    )
                if sectors[k].decode() == " ":
                    person.sector = None
                else:
                    person.sector = sectors[k].decode()
                if sub_sectors[k].decode() == " ":
                    person.sub_sector = None
                else:
                    person.sub_sector = sub_sectors[k].decode()
                if lockdown_status[k].decode() == " ":
                    person.lockdown_status = None
                else:
                    person.lockdown_status = lockdown_status[k].decode()
    return Population(people)


def restore_population_properties_from_hdf5(
    world: World,
    file_path: str,
    chunk_size=50000,
    domain_super_areas=None,
    super_areas_to_domain_dict: dict = None,
):
    logger.info("restoring population...")
    activities_fields = Activities.__fields__
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        # people = []
        population = f["population"]
        # read in chunks of 100k people
        n_people = population.attrs["n_people"]
        n_chunks = int(np.ceil(n_people / chunk_size))
        for chunk in range(n_chunks):
            logger.info(f"Restored chunk {chunk} of {n_chunks}")
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_people)
            length = idx2 - idx1
            ids = read_dataset(population["id"], idx1, idx2)
            group_ids = read_dataset(population["group_ids"], idx1, idx2)
            group_specs = read_dataset(population["group_specs"], idx1, idx2)
            subgroup_types = read_dataset(population["subgroup_types"], idx1, idx2)
            group_super_areas = read_dataset(
                population["group_super_areas"], idx1, idx2
            )
            areas = read_dataset(population["area"], idx1, idx2)
            super_areas = read_dataset(population["super_area"], idx1, idx2)
            work_super_areas = read_dataset(population["work_super_area"], idx1, idx2)
            work_super_areas_coords = read_dataset(
                population["work_super_area_coords"], idx1, idx2
            )
            work_super_areas_cities = read_dataset(
                population["work_super_area_city"], idx1, idx2
            )
            for k in range(length):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                person = world.people.get_from_id(ids[k])
                # restore area
                person.area = world.areas.get_from_id(areas[k])
                person.area.people.append(person)
                work_super_area_id = work_super_areas[k]
                if work_super_area_id == nan_integer:
                    person.work_super_area = None
                else:
                    if (
                        domain_super_areas is None
                        or work_super_area_id in domain_super_areas
                    ):
                        person.work_super_area = world.super_areas.get_from_id(
                            work_super_area_id
                        )
                        if person.work_super_area.city is not None:
                            assert (
                                person.work_super_area.city.id
                                == work_super_areas_cities[k]
                            )
                        person.work_super_area.workers.append(person)
                    else:
                        person.work_super_area = ExternalSuperArea(
                            domain_id=super_areas_to_domain_dict[work_super_area_id],
                            id=work_super_area_id,
                            coordinates=work_super_areas_coords[k],
                        )
                        if work_super_areas_cities[k] == nan_integer:
                            person.work_super_area.city = None
                        else:
                            person.work_super_area.city = world.cities.get_from_id(
                                work_super_areas_cities[k]
                            )
                # restore groups and subgroups
                subgroups_instances = Activities(None, None, None, None, None, None)
                for (
                    i,
                    (group_id, subgroup_type, group_spec, group_super_area),
                ) in enumerate(
                    zip(
                        group_ids[k],
                        subgroup_types[k],
                        group_specs[k],
                        group_super_areas[k],
                    )
                ):
                    if group_id == nan_integer:
                        continue
                    group_spec = group_spec.decode()
                    supergroup = getattr(world, spec_mapper[group_spec])
                    if (
                        domain_super_areas is None
                        or group_super_area in domain_super_areas
                    ):
                        group = supergroup.get_from_id(group_id)
                        assert group_id == group.id
                        subgroup = group[subgroup_type]
                        subgroup.append(person)
                        setattr(subgroups_instances, activities_fields[i], subgroup)
                    else:
                        domain_of_subgroup = super_areas_to_domain_dict[
                            group_super_area
                        ]
                        group = ExternalGroup(
                            domain_id=domain_of_subgroup, id=group_id, spec=group_spec
                        )
                        subgroup_external = ExternalSubgroup(
                            group=group, subgroup_type=subgroup_type
                        )
                        setattr(
                            subgroups_instances, activities_fields[i], subgroup_external
                        )
                person.subgroups = subgroups_instances


import h5py
import numpy as np

from june.groups import Schools, School
from june.world import World
from june.groups.group.make_subgroups import SubgroupParams
from .utils import read_dataset

nan_integer = -999

int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))


def save_schools_to_hdf5(schools: Schools, file_path: str, chunk_size: int = 50000):
    """
    Saves the schools object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, n_pupils_max,  age_min, age_max, sector, coordiantes

    Parameters
    ----------
    schools
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_schools = len(schools)
    n_chunks = int(np.ceil(n_schools / chunk_size))
    with h5py.File(file_path, "a") as f:
        schools_dset = f.create_group("schools")
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_schools)
            ids = []
            n_pupils_max = []
            age_min = []
            age_max = []
            sectors = []
            coordinates = []
            n_classrooms = []
            years = []
            areas = []
            super_areas = []
            for school in schools[idx1:idx2]:
                ids.append(school.id)
                n_pupils_max.append(school.n_pupils_max)
                age_min.append(school.age_min)
                age_max.append(school.age_max)
                if type(school.sector) is float or school.sector is None:
                    sectors.append(" ".encode("ascii", "ignore"))
                else:
                    sectors.append(school.sector.encode("ascii", "ignore"))
                if school.area is None:
                    areas.append(nan_integer)
                    super_areas.append(nan_integer)
                else:
                    areas.append(school.area.id)
                    super_areas.append(school.super_area.id)
                coordinates.append(np.array(school.coordinates))
                n_classrooms.append(school.n_classrooms)
                years.append(np.array(school.years))

            ids = np.array(ids, dtype=np.int64)
            n_pupils_max = np.array(n_pupils_max, dtype=np.int64)
            age_min = np.array(age_min, dtype=np.int64)
            age_max = np.array(age_max, dtype=np.int64)
            sectors = np.array(sectors, dtype="S20")
            areas = np.array(areas, dtype=np.int64)
            super_areas = np.array(super_areas, dtype=np.int64)
            coordinates = np.array(coordinates, dtype=np.float64)
            n_classrooms = np.array(n_classrooms, dtype=np.int64)
            if len(years) < 2:
                years = np.array(years, dtype=np.int64)
            else:
                years = np.array(years, dtype=int_vlen_type)
            if chunk == 0:
                schools_dset.attrs["n_schools"] = n_schools
                schools_dset.create_dataset("id", data=ids, maxshape=(None,))
                schools_dset.create_dataset(
                    "n_pupils_max", data=n_pupils_max, maxshape=(None,)
                )
                schools_dset.create_dataset("age_min", data=age_min, maxshape=(None,))
                schools_dset.create_dataset("age_max", data=age_max, maxshape=(None,))
                schools_dset.create_dataset("sector", data=sectors, maxshape=(None,))
                schools_dset.create_dataset(
                    "coordinates",
                    data=coordinates,
                    maxshape=(None, coordinates.shape[1]),
                )
                schools_dset.create_dataset("area", data=areas, maxshape=(None,))
                schools_dset.create_dataset(
                    "super_area", data=super_areas, maxshape=(None,)
                )
                schools_dset.create_dataset(
                    "n_classrooms", data=n_classrooms, maxshape=(None,)
                )
                schools_dset.create_dataset("years", data=years)
            else:
                newshape = (schools_dset["id"].shape[0] + ids.shape[0],)
                schools_dset["id"].resize(newshape)
                schools_dset["id"][idx1:idx2] = ids
                schools_dset["n_pupils_max"].resize(newshape)
                schools_dset["n_pupils_max"][idx1:idx2] = n_pupils_max
                schools_dset["age_min"].resize(newshape)
                schools_dset["age_min"][idx1:idx2] = age_min
                schools_dset["age_max"].resize(newshape)
                schools_dset["age_max"][idx1:idx2] = age_max
                schools_dset["sector"].resize(newshape)
                schools_dset["sector"][idx1:idx2] = sectors
                schools_dset["coordinates"].resize(newshape[0], axis=0)
                schools_dset["coordinates"][idx1:idx2] = coordinates
                schools_dset["area"].resize(newshape[0], axis=0)
                schools_dset["area"][idx1:idx2] = areas
                schools_dset["super_area"].resize(newshape[0], axis=0)
                schools_dset["super_area"][idx1:idx2] = super_areas
                schools_dset["n_classrooms"].resize(newshape[0], axis=0)
                schools_dset["n_classrooms"][idx1:idx2] = n_classrooms
                schools_dset["years"].resize(newshape[0], axis=0)
                schools_dset["years"][idx1:idx2] = years


def load_schools_from_hdf5(
    file_path: str,
    chunk_size: int = 50000,
    domain_super_areas=None,
    config_filename=None,
):
    """
    Loads schools from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """

    School_Class = School
    School_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        schools = f["schools"]
        schools_list = []
        n_schools = schools.attrs["n_schools"]
        n_chunks = int(np.ceil(n_schools / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_schools)
            ids = read_dataset(schools["id"], idx1, idx2)
            n_pupils_max = read_dataset(schools["n_pupils_max"], idx1, idx2)
            age_min = read_dataset(schools["age_min"], idx1, idx2)
            age_max = read_dataset(schools["age_max"], idx1, idx2)
            coordinates = read_dataset(schools["coordinates"], idx1, idx2)
            n_classrooms = read_dataset(schools["n_classrooms"], idx1, idx2)
            years = read_dataset(schools["years"], idx1, idx2)
            super_areas = read_dataset(schools["super_area"], idx1, idx2)
            sectors = read_dataset(schools["sector"], idx1, idx2)
            for k in range(idx2 - idx1):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                sector = sectors[k]
                if sector.decode() == " ":
                    sector = None
                else:
                    sector = sector.decode()
                school = School_Class(
                    coordinates=coordinates[k],
                    n_pupils_max=n_pupils_max[k],
                    age_min=age_min[k],
                    age_max=age_max[k],
                    sector=sector,
                    n_classrooms=n_classrooms[k],
                    years=years[k],
                )
                school.id = ids[k]
                schools_list.append(school)
    return Schools(schools_list)


def restore_school_properties_from_hdf5(
    world: World, file_path: str, chunk_size, domain_super_areas=None
):
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        schools = f["schools"]
        n_schools = schools.attrs["n_schools"]
        n_chunks = int(np.ceil(n_schools / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_schools)
            length = idx2 - idx1
            ids = read_dataset(schools["id"], idx1, idx2)
            areas = read_dataset(schools["area"], idx1, idx2)
            super_areas = read_dataset(schools["super_area"], idx1, idx2)
            for k in range(length):
                if domain_super_areas is not None:
                    super_area = super_areas[k]
                    if super_area == nan_integer:
                        raise ValueError(
                            "if ``domain_super_areas`` is True, I expect not Nones super areas."
                        )
                    if super_area not in domain_super_areas:
                        continue
                school = world.schools.get_from_id(ids[k])
                area = areas[k]
                if area == nan_integer:
                    school.area = None
                else:
                    school.area = world.areas.get_from_id(area)


import h5py
import numpy as np

from june.groups import University, Universities
from june.groups.group.make_subgroups import SubgroupParams
from .utils import read_dataset

nan_integer = -999


def save_universities_to_hdf5(universities: Universities, file_path: str):
    """
    Saves the universities object to hdf5 format file ``file_path``. Currently for each person,
    the following values are stored:
    - id, n_pupils_max,  age_min, age_max, sector

    Parameters
    ----------
    universities
        population object
    file_path
        path of the saved hdf5 file
    chunk_size
        number of people to save at a time. Note that they have to be copied to be saved,
        so keep the number below 1e6.
    """
    n_universities = len(universities)
    with h5py.File(file_path, "a") as f:
        universities_dset = f.create_group("universities")
        ids = []
        n_students_max = []
        n_years = []
        ukprns = []
        areas = []
        coordinates = []
        for university in universities:
            ids.append(university.id)
            n_students_max.append(university.n_students_max)
            n_years.append(university.n_years)
            coordinates.append(np.array(university.coordinates, dtype=np.float64))
            ukprns.append(university.ukprn)
            if university.area is None:
                areas.append(nan_integer)
            else:
                areas.append(university.area.id)

        ids = np.array(ids, dtype=np.int64)
        n_students_max = np.array(n_students_max, dtype=np.int64)
        n_years = np.array(n_years, dtype=np.int64)
        ukprns = np.array(ukprns, dtype=np.int64)
        areas = np.array(areas, dtype=np.int64)
        coordinates = np.array(coordinates, dtype=np.float64)
        universities_dset.attrs["n_universities"] = n_universities
        universities_dset.create_dataset("id", data=ids)
        universities_dset.create_dataset("n_students_max", data=n_students_max)
        universities_dset.create_dataset("n_years", data=n_years)
        universities_dset.create_dataset("area", data=areas)
        universities_dset.create_dataset("coordinates", data=coordinates)
        universities_dset.create_dataset("ukprns", data=ukprns)


def load_universities_from_hdf5(
    file_path: str, chunk_size: int = 50000, domain_areas=None, config_filename=None
):
    """
    Loads universities from an hdf5 file located at ``file_path``.
    Note that this object will not be ready to use, as the links to
    object instances of other classes need to be restored first.
    This function should be rarely be called oustide world.py
    """

    University_Class = University
    University_Class.subgroup_params = SubgroupParams.from_file(
        config_filename=config_filename
    )

    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        universities = f["universities"]
        universities_list = []
        n_universities = universities.attrs["n_universities"]
        ids = read_dataset(universities["id"])
        n_students_max = read_dataset(universities["n_students_max"])
        n_years = read_dataset(universities["n_years"])
        ukprns = read_dataset(universities["ukprns"])
        areas = read_dataset(universities["area"])
        coordinates = read_dataset(universities["coordinates"])
        for k in range(n_universities):
            if domain_areas is not None:
                area = areas[k]
                if area == nan_integer:
                    raise ValueError(
                        "if ``domain_areas`` is True, I expect not Nones areas."
                    )
                if area not in domain_areas:
                    continue
            university = University_Class(
                n_students_max=n_students_max[k],
                n_years=n_years[k],
                ukprn=ukprns[k],
                coordinates=coordinates[k],
            )
            university.id = ids[k]
            universities_list.append(university)
    return Universities(universities_list)


def restore_universities_properties_from_hdf5(
    world, file_path: str, chunk_size: int = 50000, domain_areas=None
):
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        universities = f["universities"]
        n_universities = universities.attrs["n_universities"]
        ids = np.empty(n_universities, dtype=int)
        universities["id"].read_direct(
            ids, np.s_[0:n_universities], np.s_[0:n_universities]
        )
        areas = np.empty(n_universities, dtype=int)
        universities["area"].read_direct(
            areas, np.s_[0:n_universities], np.s_[0:n_universities]
        )
        for k in range(n_universities):
            if domain_areas is not None:
                area = areas[k]
                if area == nan_integer:
                    raise ValueError(
                        "if ``domain_areas`` is True, I expect not Nones super areas."
                    )
                if area not in domain_areas:
                    continue
            university = world.universities.get_from_id(ids[k])
            area = areas[k]
            if area == nan_integer:
                area = None
            else:
                area = world.areas.get_from_id(area)
            university.area = area


import numpy as np


def read_dataset(dataset, index1=None, index2=None):
    if index1 is None:
        index1 = 0
    if index2 is None:
        index2 = dataset.len()
    dataset_shape = dataset.shape
    if len(dataset_shape) > 1:
        load_shape = [index2 - index1] + list(dataset_shape[1:])
    else:
        load_shape = index2 - index1
    ret = np.empty(load_shape, dtype=dataset.dtype)
    dataset.read_direct(ret, np.s_[index1:index2], np.s_[0 : index2 - index1])
    return ret


def write_dataset(group, dataset_name, data, index1=None, index2=None):
    if dataset_name not in group:
        if len(data.shape) > 1:
            maxshape = (None, *data.shape[1:])
        else:
            maxshape = (None,)
        group.create_dataset(dataset_name, data=data, maxshape=maxshape)
    else:
        if len(data.shape) > 1:
            newshape = (group[dataset_name].shape[0] + data.shape[0], *data.shape[1:])
        else:
            newshape = (group[dataset_name].shape[0] + data.shape[0],)
        group[dataset_name].resize(newshape)
        group[dataset_name][index1:index2] = data


import h5py
import logging

from june.geography import Geography
from june.world import World
from june.groups import Cemeteries
from . import (
    load_geography_from_hdf5,
    load_hospitals_from_hdf5,
    load_schools_from_hdf5,
    load_companies_from_hdf5,
    load_population_from_hdf5,
    load_care_homes_from_hdf5,
    load_households_from_hdf5,
    load_universities_from_hdf5,
    load_stations_from_hdf5,
    load_cities_from_hdf5,
    load_social_venues_from_hdf5,
    save_geography_to_hdf5,
    save_population_to_hdf5,
    save_schools_to_hdf5,
    save_hospitals_to_hdf5,
    save_companies_to_hdf5,
    save_universities_to_hdf5,
    save_cities_to_hdf5,
    save_stations_to_hdf5,
    save_care_homes_to_hdf5,
    save_social_venues_to_hdf5,
    save_households_to_hdf5,
    save_data_for_domain_decomposition,
    restore_population_properties_from_hdf5,
    restore_households_properties_from_hdf5,
    restore_care_homes_properties_from_hdf5,
    restore_cities_and_stations_properties_from_hdf5,
    restore_geography_properties_from_hdf5,
    restore_companies_properties_from_hdf5,
    restore_school_properties_from_hdf5,
    restore_social_venues_properties_from_hdf5,
    restore_universities_properties_from_hdf5,
    restore_hospital_properties_from_hdf5,
)
from june.mpi_setup import mpi_rank

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.domains import Domain

logger = logging.getLogger("world_saver")
if mpi_rank > 0:
    logger.propagate = False


def save_world_to_hdf5(world: World, file_path: str, chunk_size=100000):
    """
    Saves the world to an hdf5 file. All supergroups and geography
    are stored as groups. Class instances are substituted by ids of the
    instances. To load the world back, one needs to call the
    generate_world_from_hdf5 function.

    Parameters
    ----------
    file_path
        path of the hdf5 file
    chunk_size
        how many units of supergroups to process at a time.
        It is advise to keep it around 1e5
    """
    logger.info("saving world to HDF5")
    # empty file
    with h5py.File(file_path, "w"):
        pass
    geo = Geography(world.areas, world.super_areas, world.regions)
    save_geography_to_hdf5(geo, file_path)
    logger.info("saving population...")
    needs_to_be_saved = lambda x: (x is not None) and (len(x) > 0)
    save_population_to_hdf5(world.people, file_path, chunk_size)
    if needs_to_be_saved(world.hospitals):
        logger.info("saving hospitals...")
        save_hospitals_to_hdf5(world.hospitals, file_path, chunk_size)
    if needs_to_be_saved(world.schools):
        logger.info("saving schools...")
        save_schools_to_hdf5(world.schools, file_path, chunk_size)
    if needs_to_be_saved(world.companies):
        logger.info("saving companies...")
        save_companies_to_hdf5(world.companies, file_path, chunk_size)
    if needs_to_be_saved(world.households):
        logger.info("saving households...")
        save_households_to_hdf5(world.households, file_path, chunk_size)
    if needs_to_be_saved(world.care_homes):
        logger.info("saving care homes...")
        save_care_homes_to_hdf5(world.care_homes, file_path, chunk_size)
    if needs_to_be_saved(world.cities):
        logger.info("saving cities...")
        save_cities_to_hdf5(world.cities, file_path)
    if needs_to_be_saved(world.stations):
        logger.info("saving stations...")
        save_stations_to_hdf5(world.stations, file_path)
    if needs_to_be_saved(world.universities):
        logger.info("saving universities...")
        save_universities_to_hdf5(world.universities, file_path)
    social_venue_possible_specs = [
        "pubs",
        "groceries",
        "cinemas",
        "gyms",
    ]  # TODO: generalise
    social_venues_list = []
    for spec in social_venue_possible_specs:
        if hasattr(world, spec) and getattr(world, spec) is not None:
            social_venues_list.append(getattr(world, spec))
    if social_venues_list:
        logger.info("saving social venues...")
        save_social_venues_to_hdf5(social_venues_list, file_path)
    logger.info("Saving domain decomposition data...")
    save_data_for_domain_decomposition(world, file_path)


def generate_world_from_hdf5(
    file_path: str, chunk_size=500000, interaction_config=None
) -> World:
    """
    Loads the world from an hdf5 file. All id references are substituted
    by actual references to the relevant instances.
    Parameters
    ----------
    file_path
        path of the hdf5 file
    chunk_size
        how many units of supergroups to process at a time.
        It is advise to keep it around 1e6
    """
    logger.info("loading world from HDF5")
    world = World()
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        f_keys = list(f.keys()).copy()
    geography = load_geography_from_hdf5(file_path=file_path, chunk_size=chunk_size)
    world.areas = geography.areas
    world.super_areas = geography.super_areas
    world.regions = geography.regions
    if "hospitals" in f_keys:
        logger.info("loading hospitals...")
        world.hospitals = load_hospitals_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            config_filename=interaction_config,
        )
    if "schools" in f_keys:
        logger.info("loading schools...")
        world.schools = load_schools_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            config_filename=interaction_config,
        )
    if "companies" in f_keys:
        world.companies = load_companies_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            config_filename=interaction_config,
        )
    if "care_homes" in f_keys:
        logger.info("loading care homes...")
        world.care_homes = load_care_homes_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            config_filename=interaction_config,
        )
    if "universities" in f_keys:
        logger.info("loading universities...")
        world.universities = load_universities_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            config_filename=interaction_config,
        )
    if "cities" in f_keys:
        logger.info("loading cities...")
        world.cities = load_cities_from_hdf5(file_path)
    if "stations" in f_keys:
        logger.info("loading stations...")
        (
            world.stations,
            world.inter_city_transports,
            world.city_transports,
        ) = load_stations_from_hdf5(file_path, config_filename=interaction_config)
    if "households" in f_keys:
        world.households = load_households_from_hdf5(
            file_path, chunk_size=chunk_size, config_filename=interaction_config
        )
    if "population" in f_keys:
        world.people = load_population_from_hdf5(file_path, chunk_size=chunk_size)
    if "social_venues" in f_keys:
        logger.info("loading social venues...")
        social_venues_dict = load_social_venues_from_hdf5(
            file_path, config_filename=interaction_config
        )
        for social_venues_spec, social_venues in social_venues_dict.items():
            setattr(world, social_venues_spec, social_venues)

    # restore world
    logger.info("restoring world...")
    restore_geography_properties_from_hdf5(
        world=world, file_path=file_path, chunk_size=chunk_size
    )
    if "population" in f_keys:
        restore_population_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "households" in f_keys:
        restore_households_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "care_homes" in f_keys:
        logger.info("restoring care homes...")
        restore_care_homes_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "hospitals" in f_keys:
        logger.info("restoring hospitals...")
        restore_hospital_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "cities" and "stations" in f_keys:
        logger.info("restoring commute...")
        restore_cities_and_stations_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "companies" in f_keys:
        logger.info("restoring companies...")
        restore_companies_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "schools" in f_keys:
        logger.info("restoring schools...")
        restore_school_properties_from_hdf5(
            world=world, file_path=file_path, chunk_size=chunk_size
        )
    if "universities" in f_keys:
        logger.info("restoring unis...")
        restore_universities_properties_from_hdf5(world=world, file_path=file_path)

    if "social_venues" in f_keys:
        logger.info("restoring social venues...")
        restore_social_venues_properties_from_hdf5(world=world, file_path=file_path)
    world.cemeteries = Cemeteries()
    return world


def generate_domain_from_hdf5(
    domain_id,
    super_areas_to_domain_dict: dict,
    file_path: str,
    chunk_size=500000,
    interaction_config=None,
) -> "Domain":
    """
    Loads the world from an hdf5 file. All id references are substituted
    by actual references to the relevant instances.
    Parameters
    ----------
    file_path
        path of the hdf5 file
    chunk_size
        how many units of supergroups to process at a time.
        It is advise to keep it around 1e6
    """
    logger.info(f"loading domain {domain_id} from HDF5")
    # import here to avoid recurisve imports
    from june.domains import Domain

    # get the super area ids of this domain
    super_area_ids = set()
    for super_area, did in super_areas_to_domain_dict.items():
        if did == domain_id:
            super_area_ids.add(super_area)
    domain = Domain()
    # get keys in hdf5 file
    with h5py.File(file_path, "r", libver="latest", swmr=True) as f:
        f_keys = list(f.keys()).copy()
    geography = load_geography_from_hdf5(
        file_path=file_path, chunk_size=chunk_size, domain_super_areas=super_area_ids
    )
    domain.areas = geography.areas
    area_ids = set([area.id for area in domain.areas])
    domain.super_areas = geography.super_areas
    domain.regions = geography.regions

    # load world data
    if "hospitals" in f_keys:
        logger.info("loading hospitals...")
        domain.hospitals = load_hospitals_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
            config_filename=interaction_config,
        )
    if "schools" in f_keys:
        logger.info("loading schools...")
        domain.schools = load_schools_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            config_filename=interaction_config,
        )
    if "companies" in f_keys:
        domain.companies = load_companies_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            config_filename=interaction_config,
        )
    if "care_homes" in f_keys:
        logger.info("loading care homes...")
        domain.care_homes = load_care_homes_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            config_filename=interaction_config,
        )
    if "universities" in f_keys:
        logger.info("loading universities...")
        domain.universities = load_universities_from_hdf5(
            file_path=file_path,
            chunk_size=chunk_size,
            domain_areas=area_ids,
            config_filename=interaction_config,
        )
    if "cities" in f_keys:
        logger.info("loading cities...")
        domain.cities = load_cities_from_hdf5(
            file_path=file_path,
            domain_super_areas=super_area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
        )
    if "stations" in f_keys:
        logger.info("loading stations...")
        (
            domain.stations,
            domain.inter_city_transports,
            domain.city_transports,
        ) = load_stations_from_hdf5(
            file_path,
            domain_super_areas=super_area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
            config_filename=interaction_config,
        )
    if "households" in f_keys:
        domain.households = load_households_from_hdf5(
            file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            config_filename=interaction_config,
        )
    if "population" in f_keys:
        domain.people = load_population_from_hdf5(
            file_path, chunk_size=chunk_size, domain_super_areas=super_area_ids
        )
    if "social_venues" in f_keys:
        logger.info("loading social venues...")
        social_venues_dict = load_social_venues_from_hdf5(
            file_path, domain_areas=area_ids, config_filename=interaction_config
        )
        for social_venues_spec, social_venues in social_venues_dict.items():
            setattr(domain, social_venues_spec, social_venues)

    # restore world
    logger.info("restoring world...")
    restore_geography_properties_from_hdf5(
        world=domain,
        file_path=file_path,
        chunk_size=chunk_size,
        domain_super_areas=super_area_ids,
        super_areas_to_domain_dict=super_areas_to_domain_dict,
    )
    if "population" in f_keys:
        restore_population_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
        )
    if "households" in f_keys:
        restore_households_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
        )
    if "care_homes" in f_keys:
        logger.info("restoring care homes...")
        restore_care_homes_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
        )
    if "hospitals" in f_keys:
        logger.info("restoring hospitals...")
        restore_hospital_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            domain_areas=area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
        )
    if "companies" in f_keys:
        logger.info("restoring companies...")
        restore_companies_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
        )
    if "schools" in f_keys:
        logger.info("restoring schools...")
        restore_school_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
        )
    if "universities" in f_keys:
        logger.info("restoring unis...")
        restore_universities_properties_from_hdf5(
            world=domain, file_path=file_path, domain_areas=area_ids
        )

    if "cities" and "stations" in f_keys:
        logger.info("restoring commute...")
        restore_cities_and_stations_properties_from_hdf5(
            world=domain,
            file_path=file_path,
            chunk_size=chunk_size,
            domain_super_areas=super_area_ids,
            super_areas_to_domain_dict=super_areas_to_domain_dict,
        )

    if "social_venues" in f_keys:
        logger.info("restoring social venues...")
        restore_social_venues_properties_from_hdf5(
            world=domain, file_path=file_path, domain_areas=area_ids
        )
    domain.cemeteries = Cemeteries()
    return domain


from .population_saver import (
    save_population_to_hdf5,
    load_population_from_hdf5,
    restore_population_properties_from_hdf5,
)
from .household_saver import (
    save_households_to_hdf5,
    load_households_from_hdf5,
    restore_households_properties_from_hdf5,
)
from .carehome_saver import (
    save_care_homes_to_hdf5,
    load_care_homes_from_hdf5,
    restore_care_homes_properties_from_hdf5,
)
from .school_saver import (
    save_schools_to_hdf5,
    load_schools_from_hdf5,
    restore_school_properties_from_hdf5,
)
from .company_saver import (
    save_companies_to_hdf5,
    load_companies_from_hdf5,
    restore_companies_properties_from_hdf5,
)
from .geography_saver import (
    save_geography_to_hdf5,
    load_geography_from_hdf5,
    restore_geography_properties_from_hdf5,
)
from .hospital_saver import (
    save_hospitals_to_hdf5,
    load_hospitals_from_hdf5,
    restore_hospital_properties_from_hdf5,
)
from .commute_saver import (
    save_cities_to_hdf5,
    save_stations_to_hdf5,
    load_cities_from_hdf5,
    load_stations_from_hdf5,
    restore_cities_and_stations_properties_from_hdf5,
)
from .university_saver import (
    save_universities_to_hdf5,
    load_universities_from_hdf5,
    restore_universities_properties_from_hdf5,
)
from .leisure_saver import (
    save_social_venues_to_hdf5,
    load_social_venues_from_hdf5,
    restore_social_venues_properties_from_hdf5,
)
from .domain_data_saver import (
    save_data_for_domain_decomposition,
    load_data_for_domain_decomposition,
)

from .infection_savers import *  # noqa

# important this needs to be last:
from .world_saver import (
    generate_world_from_hdf5,
    save_world_to_hdf5,
    generate_domain_from_hdf5,
)


import numpy as np
import h5py
from typing import List

from june.hdf5_savers.utils import read_dataset
from june.epidemiology.infection import Immunity

int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))
float_vlen_type = h5py.vlen_dtype(np.dtype("float64"))

nan_integer = -999
nan_float = -999.0


def save_immunities_to_hdf5(hdf5_file_path: str, immunities: List[Immunity]):
    """
    Saves infections data to hdf5.

    Parameters
    ----------
    hdf5_file_path
        hdf5 path to save symptoms
    immunities
        list of Immunity objects
    chunk_size
        number of hdf5 chunks to use while saving
    """
    with h5py.File(hdf5_file_path, "a") as f:
        g = f.create_group("immunities")
        n_immunities = len(immunities)
        g.attrs["n_immunities"] = n_immunities
        if n_immunities == 0:
            return
        susc_infection_ids = []
        susc_susceptibilities = []
        lengths = []
        for imm in immunities:
            inf_ids = []
            suscs = []
            for key, value in imm.susceptibility_dict.items():
                inf_ids.append(key)
                suscs.append(value)
            if len(inf_ids) == 0:
                inf_ids = [nan_integer]
                suscs = [nan_float]
            susc_infection_ids.append(np.array(inf_ids, dtype=np.int64))
            susc_susceptibilities.append(np.array(suscs, dtype=np.float64))
            lengths.append(len(suscs))
        if len(np.unique(lengths)) > 1:
            susc_infection_ids = np.array(susc_infection_ids, dtype=int_vlen_type)
            susc_susceptibilities = np.array(
                susc_susceptibilities, dtype=float_vlen_type
            )
        else:
            susc_infection_ids = np.array(susc_infection_ids, dtype=np.int64)
            susc_susceptibilities = np.array(susc_susceptibilities, dtype=np.float64)
        g.create_dataset("susc_infection_ids", data=susc_infection_ids)
        g.create_dataset("susc_susceptibilities", data=susc_susceptibilities)


def load_immunities_from_hdf5(hdf5_file_path: str, chunk_size=50000):
    """
    Loads immunities data from hdf5.

    Parameters
    ----------
    hdf5_file_path
        hdf5 path to load from
    chunk_size
        number of hdf5 chunks to use while loading
    """
    immunities = []
    with h5py.File(hdf5_file_path, "r") as f:
        g = f["immunities"]
        n_immunities = g.attrs["n_immunities"]
        if n_immunities == 0:
            return []
        n_chunks = int(np.ceil(n_immunities / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_immunities)
            susc_infection_ids = read_dataset(g["susc_infection_ids"], idx1, idx2)
            susc_susceptibilities = read_dataset(g["susc_susceptibilities"], idx1, idx2)
            length = idx2 - idx1
            for k in range(length):
                if susc_infection_ids[k][0] == nan_integer:
                    immunity = Immunity()
                else:
                    susceptibilities_dict = {
                        key: value
                        for key, value in zip(
                            susc_infection_ids[k], susc_susceptibilities[k]
                        )
                    }
                    immunity = Immunity(susceptibilities_dict)
                immunities.append(immunity)
    return immunities


import numpy as np
import h5py
from collections import defaultdict
from typing import List

from june.hdf5_savers.utils import read_dataset, write_dataset
from june.epidemiology.infection import infection as infection_module
from june.epidemiology.infection import Infection
from .symptoms_saver import save_symptoms_to_hdf5, load_symptoms_from_hdf5
from .transmission_saver import save_transmissions_to_hdf5, load_transmissions_from_hdf5

int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))
float_vlen_type = h5py.vlen_dtype(np.dtype("float64"))


def save_infection_classes_to_hdf5(
    hdf5_file_path: str, infections: List[Infection], chunk_size: int = 50000
):
    n_infections = len(infections)
    n_chunks = int(np.ceil(n_infections / chunk_size))
    with h5py.File(hdf5_file_path, "a") as f:
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_infections)
            tosave = []
            for index in range(idx1, idx2):
                infection = infections[index]
                tosave.append(infection.__class__.__name__.encode("ascii", "ignore"))
            data = np.array(tosave, dtype="S20")
            write_dataset(
                group=f["infections"],
                dataset_name="infection_class",
                data=data,
                index1=idx1,
                index2=idx2,
            )


def save_infections_to_hdf5(
    hdf5_file_path: str, infections: List[Infection], chunk_size: int = 50000
):
    """
    Saves infections data to hdf5.

    Parameters
    ----------
    attributes_to_save
        attributes to save from each symptom
    hdf5_file_path
        hdf5 path to save symptoms
    symptoms
        list of symptom objects
    chunk_size
        number of hdf5 chunks to use while saving
    """
    with h5py.File(hdf5_file_path, "a") as f:
        f.create_group("infections")
        n_infections = len(infections)
        f["infections"].attrs["n_infections"] = n_infections
        if n_infections == 0:
            return
        symptoms_list = [infection.symptoms for infection in infections]
        transmission_list = [infection.transmission for infection in infections]
        save_symptoms_to_hdf5(
            symptoms_list=symptoms_list,
            hdf5_file_path=hdf5_file_path,
            chunk_size=chunk_size,
        )
        save_transmissions_to_hdf5(
            transmissions=transmission_list,
            hdf5_file_path=hdf5_file_path,
            chunk_size=chunk_size,
        )
        attributes_to_save = ["start_time"]
        n_chunks = int(np.ceil(n_infections / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_infections)
            attribute_dict = defaultdict(list)
            for index in range(idx1, idx2):
                infection = infections[index]
                for attribute_name in attributes_to_save:
                    attribute = getattr(infection, attribute_name)
                    if attribute is None:
                        attribute_dict[attribute_name].append(np.nan)
                    else:
                        attribute_dict[attribute_name].append(attribute)
            for attribute_name in attributes_to_save:
                data = np.array(attribute_dict[attribute_name], dtype=np.float64)
                write_dataset(
                    group=f["infections"],
                    dataset_name=attribute_name,
                    data=data,
                    index1=idx1,
                    index2=idx2,
                )
    save_infection_classes_to_hdf5(
        hdf5_file_path=hdf5_file_path, infections=infections, chunk_size=chunk_size
    )


def load_infections_from_hdf5(hdf5_file_path: str, chunk_size=50000):
    """
    Loads infections data from hdf5.

    Parameters
    ----------
    hdf5_file_path
        hdf5 path to load from
    chunk_size
        number of hdf5 chunks to use while loading
    """
    infections = []
    with h5py.File(hdf5_file_path, "r") as f:
        infections_group = f["infections"]
        n_infections = infections_group.attrs["n_infections"]
        if n_infections == 0:
            return []
        symptoms_list = load_symptoms_from_hdf5(
            hdf5_file_path=hdf5_file_path, chunk_size=chunk_size
        )
        transmissions = load_transmissions_from_hdf5(
            hdf5_file_path=hdf5_file_path, chunk_size=chunk_size
        )
        trans_symp_index = 0
        n_infections = infections_group.attrs["n_infections"]
        n_chunks = int(np.ceil(n_infections / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_infections)
            attribute_dict = {}
            for attribute_name in infections_group.keys():
                if attribute_name in ["symptoms", "transmissions", "infection_class"]:
                    continue
                attribute_dict[attribute_name] = read_dataset(
                    infections_group[attribute_name], idx1, idx2
                )
            for index in range(idx2 - idx1):
                infection_class_str = infections_group["infection_class"][
                    trans_symp_index
                ].decode()
                infection_class = getattr(infection_module, infection_class_str)
                infection = infection_class(
                    transmission=transmissions[trans_symp_index],
                    symptoms=symptoms_list[trans_symp_index],
                )
                trans_symp_index += 1
                for attribute_name in attribute_dict:
                    attribute_value = attribute_dict[attribute_name][index]
                    if attribute_value == np.nan:
                        attribute_value = None
                    setattr(infection, attribute_name, attribute_value)
                infections.append(infection)
    return infections


import numpy as np
import h5py
from typing import List

from june.hdf5_savers.utils import read_dataset, write_dataset
from june.epidemiology.infection import Symptoms, SymptomTag

int_vlen_type = h5py.vlen_dtype(np.dtype("int64"))
float_vlen_type = h5py.vlen_dtype(np.dtype("float64"))


def save_symptoms_to_hdf5(
    hdf5_file_path: str, symptoms_list: List[Symptoms], chunk_size: int = 50000
):
    """
    Saves symptoms data to hdf5.

    Parameters
    ----------
    attributes_to_save
        attributes to save from each symptom
    hdf5_file_path
        hdf5 path to save symptoms
    symptoms
        list of symptom objects
    chunk_size
        number of hdf5 chunks to use while saving
    """
    with h5py.File(hdf5_file_path, "a") as f:
        if "infections" not in f:
            f.create_group("infections")
        f["infections"].create_group("symptoms")
        symptoms_group = f["infections"]["symptoms"]
        n_symptoms = len(symptoms_list)
        symptoms_group.attrs["n_symptoms"] = n_symptoms
        n_chunks = int(np.ceil(n_symptoms / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_symptoms)
            attribute_dict = {}
            max_tag_list = []
            tag_list = []
            max_severity_list = []
            stage_list = []
            time_of_symptoms_onset_list = []
            for index in range(idx1, idx2):
                symptoms = symptoms_list[index]
                max_tag_list.append(symptoms.max_tag.value)
                tag_list.append(symptoms.tag.value)
                max_severity_list.append(symptoms.max_severity)
                stage_list.append(symptoms.stage)
                time_of_symptoms_onset_list.append(symptoms.time_of_symptoms_onset)
            attribute_dict["max_tag"] = np.array(max_tag_list, dtype=np.int64)
            attribute_dict["tag"] = np.array(tag_list, dtype=np.int64)
            attribute_dict["max_severity"] = np.array(
                max_severity_list, dtype=np.float64
            )
            attribute_dict["stage"] = np.array(stage_list, dtype=np.int64)
            attribute_dict["time_of_symptoms_onset"] = np.array(
                time_of_symptoms_onset_list, dtype=np.float64
            )
            for attribute_name, attribute_value in attribute_dict.items():
                write_dataset(
                    group=symptoms_group,
                    dataset_name=attribute_name,
                    data=attribute_value,
                    index1=idx1,
                    index2=idx2,
                )
        trajectory_times_list = []
        trajectory_symptom_list = []
        trajectory_lengths = []
        for symptoms in symptoms_list:
            times = []
            symps = []
            for time, symp in symptoms.trajectory:
                times.append(time)
                symps.append(symp.value)
            trajectory_times_list.append(np.array(times, dtype=np.float64))
            trajectory_symptom_list.append(np.array(symps, dtype=np.int64))
            trajectory_lengths.append(len(times))
        if len(np.unique(trajectory_lengths)) == 1:
            write_dataset(
                group=symptoms_group,
                dataset_name="trajectory_times",
                data=np.array(trajectory_times_list, dtype=float),
            )
            write_dataset(
                group=symptoms_group,
                dataset_name="trajectory_symptoms",
                data=np.array(trajectory_symptom_list, dtype=int),
            )
        else:
            write_dataset(
                group=symptoms_group,
                dataset_name="trajectory_times",
                data=np.array(trajectory_times_list, dtype=float_vlen_type),
            )
            write_dataset(
                group=symptoms_group,
                dataset_name="trajectory_symptoms",
                data=np.array(trajectory_symptom_list, dtype=int_vlen_type),
            )


def load_symptoms_from_hdf5(hdf5_file_path: str, chunk_size=50000):
    """
    Loads symptoms data from hdf5.

    Parameters
    ----------
    hdf5_file_path
        hdf5 path to load from
    chunk_size
        number of hdf5 chunks to use while loading
    """
    symptoms = []
    with h5py.File(hdf5_file_path, "r") as f:
        symptoms_group = f["infections"]["symptoms"]
        n_symptoms = symptoms_group.attrs["n_symptoms"]
        n_chunks = int(np.ceil(n_symptoms / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_symptoms)
            max_tag_list = read_dataset(symptoms_group["max_tag"], idx1, idx2)
            tag_list = read_dataset(symptoms_group["tag"], idx1, idx2)
            max_severity_list = read_dataset(symptoms_group["max_severity"], idx1, idx2)
            stage_list = read_dataset(symptoms_group["stage"], idx1, idx2)
            time_of_symptoms_onset_list = read_dataset(
                symptoms_group["time_of_symptoms_onset"], idx1, idx2
            )
            trajectory_times_list = read_dataset(
                symptoms_group["trajectory_times"], idx1, idx2
            )
            trajectory_symptom_list = read_dataset(
                symptoms_group["trajectory_symptoms"], idx1, idx2
            )
            for index in range(idx2 - idx1):
                symptom = Symptoms()
                symptom.tag = SymptomTag(tag_list[index])
                symptom.max_tag = SymptomTag(max_tag_list[index])
                symptom.stage = stage_list[index]
                symptom.max_severity = max_severity_list[index]
                symptom.time_of_symptoms_onset = time_of_symptoms_onset_list[index]
                symptom.trajectory = tuple(
                    [
                        (time, SymptomTag(symp))
                        for time, symp in zip(
                            trajectory_times_list[index], trajectory_symptom_list[index]
                        )
                    ]
                )
                symptoms.append(symptom)
    return symptoms


import numpy as np
import h5py
from collections import defaultdict
from typing import List

from june.epidemiology.infection import (
    TransmissionGamma,
    Transmission,
    TransmissionConstant,
    TransmissionXNExp,
)
from june.hdf5_savers.utils import read_dataset, write_dataset

str_to_class = {
    "TransmissionXNExp": TransmissionXNExp,
    "TransmissionGamma": TransmissionGamma,
    "TransmissionConstant": TransmissionConstant,
}
attributes_to_save_dict = {
    "TransmissionXNExp": ["time_first_infectious", "norm_time", "n", "norm", "alpha"],
    "TransmissionGamma": ["shape", "shift", "scale", "norm"],
    "TransmissionConstant": ["probability"],
}


def save_transmissions_to_hdf5(
    hdf5_file_path: str, transmissions: List[Transmission], chunk_size: int = 50000
):
    """
    Saves transmissions data to hdf5. The transmission type is inferred from the first
    element of the list.

    Parameters
    ----------
    attributes_to_save
        attributes to save from each transmission
    hdf5_file_path
        hdf5 path to save transmissions
    transmissions
        list of transmission objects
    chunk_size
        number of hdf5 chunks to use while saving
    """
    with h5py.File(hdf5_file_path, "a") as f:
        if "infections" not in f:
            f.create_group("infections")
        f["infections"].create_group("transmissions")
        transmissions_group = f["infections"]["transmissions"]
        n_transsmissions = len(transmissions)
        transmissions_group.attrs["n_transsmissions"] = n_transsmissions
        transmission_type = transmissions[0].__class__.__name__
        transmissions_group.attrs["transmission_type"] = transmission_type
        n_chunks = int(np.ceil(n_transsmissions / chunk_size))
        attributes_to_save = attributes_to_save_dict[transmission_type]
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_transsmissions)
            attribute_dict = defaultdict(list)
            for index in range(idx1, idx2):
                transmission = transmissions[index]
                for attribute_name in attributes_to_save:
                    attribute = getattr(transmission, attribute_name)
                    if attribute is None:
                        attribute_dict[attribute_name].append(np.nan)
                    else:
                        attribute_dict[attribute_name].append(attribute)
            for attribute_name in attributes_to_save:
                attribute_dict[attribute_name] = np.array(
                    attribute_dict[attribute_name], dtype=np.float64
                )
            for attribute_name in attributes_to_save:
                write_dataset(
                    group=transmissions_group,
                    dataset_name=attribute_name,
                    data=attribute_dict[attribute_name],
                    index1=idx1,
                    index2=idx2,
                )


def load_transmissions_from_hdf5(hdf5_file_path: str, chunk_size=50000):
    """
    Loads transmissions data from hdf5.

    Parameters
    ----------
    hdf5_file_path
        hdf5 path to load from
    chunk_size
        number of hdf5 chunks to use while loading
    """
    transmissions = []
    with h5py.File(hdf5_file_path, "r") as f:
        transmissions_group = f["infections"]["transmissions"]
        n_transsmissions = transmissions_group.attrs["n_transsmissions"]
        transmission_type = transmissions_group.attrs["transmission_type"]
        transmission_class = str_to_class[transmission_type]
        n_chunks = int(np.ceil(n_transsmissions / chunk_size))
        for chunk in range(n_chunks):
            idx1 = chunk * chunk_size
            idx2 = min((chunk + 1) * chunk_size, n_transsmissions)
            attribute_dict = {}
            for attribute_name in transmissions_group.keys():
                attribute_dict[attribute_name] = read_dataset(
                    transmissions_group[attribute_name], idx1, idx2
                )
            for index in range(idx2 - idx1):
                transmission = transmission_class()
                for attribute_name in attribute_dict:
                    attribute_value = attribute_dict[attribute_name][index]
                    if attribute_value == np.nan:
                        attribute_value = None
                    setattr(transmission, attribute_name, attribute_value)
                transmissions.append(transmission)
    return transmissions


from .transmission_saver import save_transmissions_to_hdf5, load_transmissions_from_hdf5
from .symptoms_saver import save_symptoms_to_hdf5, load_symptoms_from_hdf5
from .infection_saver import save_infections_to_hdf5, load_infections_from_hdf5
from .immunity_saver import save_immunities_to_hdf5, load_immunities_from_hdf5


import numpy as np
import yaml
from random import random
from typing import List, Dict

from june.groups.group.interactive import InteractiveGroup
from june.groups import InteractiveSchool
from june.records import Record
from june import paths

default_config_filename = paths.configs_path / "defaults/interaction/interaction.yaml"

default_sector_beta_filename = (
    paths.configs_path / "defaults/interaction/sector_beta.yaml"
)


class Interaction:
    """
    Class to handle interaction in groups.

    Parameters
    ----------
    alpha_physical
        Scaling factor for physical contacts, an alpha_physical factor of 1, means that physical
        contacts count as much as non-physical contacts.
    beta
        dictionary mapping the group specs with their contact intensities
    contact_matrices
        dictionary mapping the group specs with their contact matrices
    """

    def __init__(
        self, alpha_physical: float, betas: Dict[str, float], contact_matrices: dict
    ):
        self.alpha_physical = alpha_physical
        self.betas = betas or {}
        contact_matrices = contact_matrices or {}
        self.contact_matrices = self.get_raw_contact_matrices(
            input_contact_matrices=contact_matrices,
            groups=self.betas.keys(),
            alpha_physical=alpha_physical,
        )
        self.beta_reductions = {}

    @classmethod
    def from_file(cls, config_filename: str = default_config_filename) -> "Interaction":
        with open(config_filename) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        contact_matrices = config["contact_matrices"]
        return Interaction(
            alpha_physical=config["alpha_physical"],
            betas=config["betas"],
            contact_matrices=contact_matrices,
        )

    def get_raw_contact_matrices(
        self, groups: List[str], input_contact_matrices: dict, alpha_physical: float
    ):
        """
        Processes the input data regarding to contacts to construct the contact matrix used in the interaction.
        In particular, given a contact matrix, a matrix of physical contact ratios, and the physical contact weighting
        (alpha_physical) constructs the contact matrix via:
        $ contact_matrix = contact_matrix * (1 + (alpha_physical - 1) * physical_ratios) $

        Parameters
        ----------
        groups
            a list of group names that will be handled by the interaction
        input_contact_data
            configuration regarding contact matrices and physical contacts
        alpha_physical
            The relative weight of physical conctacts respect o non-physical ones.
        """
        contact_matrices = {}
        for group in groups:
            # school is a special case.
            contact_data = input_contact_matrices.get(group, {})
            contact_matrix = np.array(contact_data.get("contacts", [[1]]))
            proportion_physical = np.array(
                contact_data.get("proportion_physical", [[0]])
            )
            characteristic_time = contact_data.get("characteristic_time", 8)
            if group == "school":
                contact_matrix = InteractiveSchool.get_raw_contact_matrix(
                    contact_matrix=contact_matrix,
                    proportion_physical=proportion_physical,
                    alpha_physical=alpha_physical,
                    characteristic_time=characteristic_time,
                )
            else:
                contact_matrix = InteractiveGroup.get_raw_contact_matrix(
                    contact_matrix=contact_matrix,
                    proportion_physical=proportion_physical,
                    alpha_physical=alpha_physical,
                    characteristic_time=characteristic_time,
                )
            contact_matrices[group] = contact_matrix
        return contact_matrices

    def _get_interactive_group_beta(self, interactive_group):
        return interactive_group.get_processed_beta(
            betas=self.betas, beta_reductions=self.beta_reductions
        )

    def create_infector_tensor(
        self,
        infectors_per_infection_per_subgroup,
        subgroup_sizes,
        contact_matrix,
        beta,
        delta_time,
    ):
        ret = {}
        for inf_id in infectors_per_infection_per_subgroup:
            infector_matrix = np.zeros_like(contact_matrix, dtype=np.float64)
            for subgroup_id in infectors_per_infection_per_subgroup[inf_id]:
                subgroup_trans_prob = sum(
                    infectors_per_infection_per_subgroup[inf_id][subgroup_id][
                        "trans_probs"
                    ]
                )
                for i in range(len(contact_matrix)):
                    subgroup_size = subgroup_sizes[subgroup_id]
                    if i == subgroup_id:
                        subgroup_size = max(1, subgroup_size - 1)
                    infector_matrix[i, subgroup_id] = (
                        contact_matrix[i, subgroup_id]
                        * subgroup_trans_prob
                        / subgroup_size
                    )
            ret[inf_id] = infector_matrix * beta * delta_time
        return ret

    def time_step_for_group(
        self,
        group: InteractiveGroup,
        delta_time: float,
        people_from_abroad: dict = None,
        record: Record = None,
    ):
        """
        Runs an interaction time step for the given interactive group. First, we
        give the beta and contact matrix to the group to process it. There may be groups
        that change the betas depending on the situation, ie, a school interactive group,
        has to treat the contact matrix on a special way, or the company beta may change
        due to the company's sector. Second, we iterate over all subgroups that contain
        susceptible people, and compute the interaction between them and the subgroups that
        contain infected people.

        Parameters
        ----------
        group:
            An instance of InteractiveGroup
        delta_time:
            Time interval of the interaction
        """
        interactive_group = group.get_interactive_group(
            people_from_abroad=people_from_abroad
        )
        if not interactive_group.must_timestep:
            return [], [], interactive_group.size
        infected_ids = []
        infection_ids = []
        to_blame_subgroups = []
        beta = self._get_interactive_group_beta(interactive_group)
        contact_matrix_raw = self.contact_matrices[group.spec]
        contact_matrix = interactive_group.get_processed_contact_matrix(
            contact_matrix_raw
        )
        infector_tensor = self.create_infector_tensor(
            interactive_group.infectors_per_infection_per_subgroup,
            interactive_group.subgroup_sizes,
            contact_matrix,
            beta,
            delta_time,
        )

        for (
            susceptible_subgroup_id,
            subgroup_susceptibles,
        ) in interactive_group.susceptibles_per_subgroup.items():
            (
                new_infected_ids,
                new_infection_ids,
                new_to_blame_subgroups,
            ) = self._time_step_for_subgroup(
                infector_tensor=infector_tensor,
                susceptible_subgroup_id=susceptible_subgroup_id,
                subgroup_susceptibles=subgroup_susceptibles,
            )
            infected_ids += new_infected_ids
            infection_ids += new_infection_ids
            to_blame_subgroups += new_to_blame_subgroups
        to_blame_ids = self._blame_individuals(
            to_blame_subgroups,
            infection_ids,
            interactive_group.infectors_per_infection_per_subgroup,
        )
        if record:
            self._log_infections_to_record(
                infected_ids=infected_ids,
                infection_ids=infection_ids,
                to_blame_ids=to_blame_ids,
                record=record,
                group=group,
            )
        return infected_ids, infection_ids, interactive_group.size

    def _time_step_for_subgroup(
        self, infector_tensor, susceptible_subgroup_id, subgroup_susceptibles
    ):
        """
        Time step for one susceptible subgroup. We first compute the combined
        effective transmission probability of all the subgroups that contain infected
        people, and then run this effective transmission over the susceptible subgroup,
        to check who got infected.

        Parameters
        ----------
        """
        new_infected_ids = []
        new_infection_ids = []
        new_to_blame_subgroups = []
        infection_ids = list(infector_tensor.keys())
        for susceptible_id, susceptibility_dict in subgroup_susceptibles.items():
            infection_transmission_parameters = []
            for infection_id in infector_tensor:
                susceptibility = susceptibility_dict.get(infection_id, 1.0)
                infector_transmission = infector_tensor[infection_id][
                    susceptible_subgroup_id
                ].sum()
                infection_transmission_parameters.append(
                    infector_transmission * susceptibility
                )
            infection_id = self._gets_infected(
                np.array(infection_transmission_parameters), infection_ids
            )
            if infection_id is not None:
                new_infected_ids.append(susceptible_id)
                new_infection_ids.append(infection_id)
                new_to_blame_subgroups.append(
                    self._blame_subgroup(
                        infector_tensor[infection_id][susceptible_subgroup_id]
                    )
                )
        return new_infected_ids, new_infection_ids, new_to_blame_subgroups

    def _gets_infected(self, infection_transmission_parameters, infection_ids):
        total_exp = infection_transmission_parameters.sum()
        if random() < 1 - np.exp(-total_exp):
            if len(infection_ids) == 1:
                return infection_ids[0]
            return np.random.choice(
                infection_ids, p=infection_transmission_parameters / total_exp
            )

    def _blame_subgroup(self, vector):
        probs = vector / vector.sum()
        return np.random.choice(len(vector), p=probs)

    def _blame_individuals(
        self, to_blame_subgroups, infection_ids, infectors_per_infection_per_subgroup
    ):
        ret = []
        for infection_id, subgroup in zip(infection_ids, to_blame_subgroups):
            candidates_ids = infectors_per_infection_per_subgroup[infection_id][
                subgroup
            ]["ids"]
            candidates_probs = np.array(
                infectors_per_infection_per_subgroup[infection_id][subgroup][
                    "trans_probs"
                ]
            )
            candidates_probs /= candidates_probs.sum()
            ret.append(np.random.choice(candidates_ids, p=candidates_probs))
        return ret

    def _log_infections_to_record(
        self,
        infected_ids: list,
        infection_ids: list,
        to_blame_ids: list,
        group: InteractiveGroup,
        record: Record,
    ):
        """
        Logs new infected people to record, and their infectors.
        """
        record.accumulate(
            table_name="infections",
            location_spec=group.spec,
            location_id=group.id,
            region_name=group.super_area.region.name,
            infected_ids=infected_ids,
            infection_ids=infection_ids,
            infector_ids=to_blame_ids,
        )


from .interaction import Interaction

# from .interactive_group import InteractiveGroup


import numpy as np
from typing import List, Optional, Union
import datetime
from random import random

from june.epidemiology.infection import SymptomTag
from june.demography.person import Person
from june.policy import Policy, PolicyCollection
from june.mpi_setup import mpi_size
from june.utils.distances import haversine_distance


class IndividualPolicy(Policy):
    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
    ):
        super().__init__(start_time=start_time, end_time=end_time)
        self.policy_type = "individual"
        self.policy_subtype = None


class IndividualPolicies(PolicyCollection):
    policy_type = "individual"
    min_age_home_alone = 15

    def get_active(self, date: datetime.date):
        return IndividualPolicies(
            [policy for policy in self.policies if policy.is_active(date)]
        )

    def apply(
        self,
        active_policies,
        person: Person,
        days_from_start: float,
        activities: List[str],
    ):
        """
        Applies all active individual policies to the person. Stay home policies are applied first,
        since if the person stays home we don't need to check for the others.
        IF a person is below 15 years old, then we look for a guardian to stay with that person at home.
        """
        for policy in active_policies:
            if policy.policy_subtype == "stay_home":
                if policy.check_stay_home_condition(person, days_from_start):
                    activities = policy.apply(
                        person=person,
                        days_from_start=days_from_start,
                        activities=activities,
                    )
                    # TODO: make it work with parallelisation
                    if mpi_size == 1:
                        if (
                            person.age < self.min_age_home_alone
                        ):  # can't stay home alone
                            possible_guardians = [
                                housemate
                                for housemate in person.residence.group.people
                                if housemate.age >= 18
                            ]
                            if not possible_guardians:
                                guardian = person.find_guardian()
                                if guardian is not None:
                                    if guardian.busy:
                                        for subgroup in guardian.subgroups.iter():
                                            if (
                                                subgroup is not None
                                                and guardian in subgroup
                                            ):
                                                subgroup.remove(guardian)
                                                break
                                    guardian.residence.append(guardian)
                    return activities  # if it stays at home we don't need to check the rest
            elif policy.policy_subtype == "skip_activity":
                if policy.check_skips_activity(person):
                    activities = policy.apply(activities=activities)
            else:
                raise ValueError("policy type not expected")
        return activities


class StayHome(IndividualPolicy):
    """
    Template for policies that will force someone to stay at home
    """

    def __init__(self, start_time="1900-01-01", end_time="2100-01-01"):
        super().__init__(start_time=start_time, end_time=end_time)
        self.policy_subtype = "stay_home"

    def apply(self, person: Person, days_from_start: float, activities: List[str]):
        """
        Removes all activities but residence if the person has to stay at home.
        """
        if "medical_facility" in activities:
            return ("medical_facility", "residence")
        else:
            return ("residence",)

    def check_stay_home_condition(self, person: Person, days_from_start: float):
        """
        Returns true if a person must stay at home.
        Parameters
        ----------
        person:
            person to whom the policy is being applied

        days_from_start:
            time past from beginning of simulation, in units of days
        """
        raise NotImplementedError(
            f"Need to implement check_stay_home_condition for policy {self.__class__.__name__}"
        )


class SevereSymptomsStayHome(StayHome):
    def check_stay_home_condition(self, person: Person, days_from_start: float) -> bool:
        return (
            person.infection is not None and person.infection.tag is SymptomTag.severe
        )


class Quarantine(StayHome):
    def __init__(
        self,
        start_time: Union[str, datetime.datetime] = "1900-01-01",
        end_time: Union[str, datetime.datetime] = "2100-01-01",
        n_days: int = 7,
        n_days_household: int = 14,
        compliance: float = 1.0,
        household_compliance: float = 1.0,
        vaccinated_household_compliance: float = 1.0,
    ):
        """
        This policy forces people to stay at home for ```n_days``` days after they show symtpoms, and for ```n_days_household``` if someone else in their household shows symptoms

        Parameters
        ----------
        start_time:
            date at which to start applying the policy
        end_time:
            date from which the policy won't apply
        n_days:
            days for which the person has to stay at home if they show symtpoms
        n_days_household:
            days for which the person has to stay at home if someone in their household shows symptoms
        compliance:
            percentage of symptomatic people that will adhere to the quarantine policy
        household_compliance:
            percentage of people that will adhere to the hoseuhold quarantine policy
        vaccinated_household_compliance:
            over 18s don't quarantine up to household compliance
            those fully vaccinated don't quarantine up to household compliance
        """
        super().__init__(start_time, end_time)
        self.n_days = n_days
        self.n_days_household = n_days_household
        self.compliance = compliance
        self.household_compliance = household_compliance
        self.vaccinated_household_compliance = vaccinated_household_compliance

    def check_stay_home_condition(self, person: Person, days_from_start):
        try:
            regional_compliance = person.region.regional_compliance
        except Exception:
            regional_compliance = 1
        if person.infected:
            time_of_symptoms_onset = person.infection.time_of_symptoms_onset
            if time_of_symptoms_onset is not None:
                # record to the household that this person is infected:
                person.residence.group.quarantine_starting_date = time_of_symptoms_onset
                if person.symptoms.tag in (SymptomTag.mild, SymptomTag.severe):
                    release_day = time_of_symptoms_onset + self.n_days
                    if 0 < release_day - days_from_start < self.n_days:
                        if random() < self.compliance * regional_compliance:
                            return True

        if (person.vaccinated and person.vaccine_trajectory is None) or person.age < 18:
            housemates_quarantine = person.residence.group.quarantine(
                time=days_from_start,
                quarantine_days=self.n_days_household,
                household_compliance=self.vaccinated_household_compliance
                * self.household_compliance
                * regional_compliance,
            )

        else:
            housemates_quarantine = person.residence.group.quarantine(
                time=days_from_start,
                quarantine_days=self.n_days_household,
                household_compliance=self.household_compliance * regional_compliance,
            )

        return housemates_quarantine


class SchoolQuarantine(StayHome):
    def __init__(
        self,
        start_time: Union[str, datetime.datetime] = "1900-01-01",
        end_time: Union[str, datetime.datetime] = "2100-01-01",
        compliance: float = 1.0,
        n_days: int = 7,
        isolate_on: str = "symptoms",
    ):
        """
        This policy forces kids to stay at home if there is a symptomatic case of covid in their classroom.

        Parameters
        ----------
        start_time:
            date at which to start applying the policy
        end_time:
            date from which the policy won't apply
        n_days:
            days for which the person has to stay at home if they show symtpoms
        n_days_household:
            days for which the person has to stay at home if someone in their household
            shows symptoms
        compliance:
            percentage of symptomatic people that will adhere to the quarantine policy
        household_compliance:
            percentage of people that will adhere to the hoseuhold quarantine policy
        """
        super().__init__(start_time, end_time)
        self.compliance = compliance
        self.n_days = n_days
        self.isolate_on = isolate_on

    def check_stay_home_condition(self, person: Person, days_from_start):
        try:
            if (
                not person.primary_activity.group.spec == "school"
                or person.primary_activity.group.external
            ):
                return False
        except Exception:
            return False
        try:
            regional_compliance = person.region.regional_compliance
        except Exception:
            regional_compliance = 1
        compliance = self.compliance * regional_compliance
        if person.infected:
            # infected people set quarantine date to the school.
            # there is no problem in order as this will activate
            # days before it is actually applied (during incubation time).
            if self.isolate_on == "infection":
                time_start_quarantine = person.infection.start_time
            else:
                if person.infection.time_of_symptoms_onset:
                    time_start_quarantine = (
                        person.infection.start_time
                        + person.infection.time_of_symptoms_onset
                    )
                else:
                    time_start_quarantine = None
            if time_start_quarantine is not None:
                if (
                    time_start_quarantine
                    < person.primary_activity.quarantine_starting_date
                ):
                    # If the agent will show symptoms earlier than the quarantine time, update it.
                    person.primary_activity.quarantine_starting_date = (
                        time_start_quarantine
                    )
                if (
                    days_from_start - person.primary_activity.quarantine_starting_date
                ) > self.n_days:
                    # If it's been more than n_days since last quarantine
                    person.primary_activity.quarantine_starting_date = (
                        time_start_quarantine
                    )
        if (
            0
            < (days_from_start - person.primary_activity.quarantine_starting_date)
            < self.n_days
        ):
            return random() < compliance
        return False


class Shielding(StayHome):
    def __init__(
        self,
        start_time: str,
        end_time: str,
        min_age: int,
        compliance: Optional[float] = None,
    ):
        super().__init__(start_time, end_time)
        self.min_age = min_age
        self.compliance = compliance

    def check_stay_home_condition(self, person: Person, days_from_start: float):
        try:
            regional_compliance = person.region.regional_compliance
        except Exception:
            regional_compliance = 1
        if person.age >= self.min_age:
            if (
                self.compliance is None
                or random() < self.compliance * regional_compliance
            ):
                return True
        return False


class SkipActivity(IndividualPolicy):
    """
    Template for policies that will ban an activity for a person
    """

    def __init__(
        self,
        start_time: Union[str, datetime.datetime] = "1900-01-01",
        end_time: Union[str, datetime.datetime] = "2100-01-01",
        activities_to_remove=None,
    ):
        super().__init__(start_time=start_time, end_time=end_time)
        self.activities_to_remove = activities_to_remove
        self.policy_subtype = "skip_activity"

    def check_skips_activity(self, person: "Person") -> bool:
        """
        Returns True if the activity is to be skipped, otherwise False
        """

    def apply(self, activities: List[str]) -> List[str]:
        """
        Remove an activity from a list of activities

        Parameters
        ----------
        activities:
            list of activities
        activity_to_remove:
            activity that will be removed from the list
        """
        return [
            activity
            for activity in activities
            if activity not in self.activities_to_remove
        ]


class CloseSchools(SkipActivity):
    def __init__(
        self,
        start_time: str,
        end_time: str,
        years_to_close=None,
        attending_compliance=1.0,
        full_closure=None,
    ):
        super().__init__(
            start_time, end_time, activities_to_remove=("primary_activity")
        )
        self.full_closure = full_closure
        self.years_to_close = years_to_close
        self.attending_compliance = attending_compliance  # compliance with opening
        if self.years_to_close == "all":
            self.years_to_close = list(np.arange(20))

    def _check_kid_goes_to_school(self, person: "Person"):
        """
        Checks if a kid should go to school when there is a lockdown.
        The rule is that a kid goes to school if the age is below 14 (not included)
        and there are at least two key workers at home.
        """

        if person.age < 14:
            keyworkers_parents = 0
            for person in person.residence.group.residents:
                if person.lockdown_status == "key_worker":
                    keyworkers_parents += 1
                    if keyworkers_parents > 1:
                        return True
        return False

    def check_skips_activity(self, person: "Person") -> bool:
        """
        Returns True if the activity is to be skipped, otherwise False
        """
        try:
            if person.primary_activity.group.spec == "school":
                if self.full_closure:
                    return True
                elif not self._check_kid_goes_to_school(person):
                    if self.years_to_close and person.age in self.years_to_close:
                        return True
                    else:
                        if random() > self.attending_compliance:
                            return True
        except AttributeError:
            return False
        return False


class CloseUniversities(SkipActivity):
    def __init__(self, start_time: str, end_time: str):
        super().__init__(
            start_time, end_time, activities_to_remove=("primary_activity")
        )

    def check_skips_activity(self, person: "Person") -> bool:
        """
        Returns True if the activity is to be skipped, otherwise False
        """
        if (
            person.primary_activity is not None
            and person.primary_activity.group.spec == "university"
        ):
            return True
        return False


class CloseCompaniesLockdownTiers(SkipActivity):
    TIERS = set([3, 4])

    def __init__(self, start_time: str, end_time: str):
        super().__init__(start_time, end_time, ("primary_activity", "commute"))

    def check_skips_activity(self, person: "Person") -> bool:
        """
        Returns True if the activity is to be skipped, otherwise False
        """
        if (
            person.primary_activity is not None
            and person.primary_activity.group.spec == "company"
        ):
            # import pdb; pdb.set_trace()
            if person.lockdown_status == "random":
                # stop people going to work in Tier 3 or 4 regions
                # if they don't work in the same region
                # and if their region is not in Tier 3 or 4
                # subject to regional compliance
                try:
                    if (
                        person.work_super_area != person.area.super_area
                        and person.work_super_area.region.policy["lockdown_tier"]
                        in CloseCompaniesLockdownTiers.TIERS
                        and person.region.policy["lockdown_tier"]
                        not in CloseCompaniesLockdownTiers.TIERS
                    ):
                        try:
                            return random() < person.region.regional_compliance
                        except Exception:
                            return True
                except AttributeError:
                    pass

                # stop people going to work who are living in a Tier 3 or 4 region unless they work
                # in that same region
                # subject to regional compliance
                try:
                    if (
                        person.work_super_area != person.area.super_area
                        and person.region.policy["lockdown_tier"]
                        in CloseCompaniesLockdownTiers.TIERS
                    ):
                        try:
                            return random() < person.region.regional_compliance
                        except Exception:
                            return True
                except AttributeError:
                    pass
        return False


class CloseCompanies(SkipActivity):
    furlough_ratio = None
    key_ratio = None
    random_ratio = None

    def __init__(
        self,
        start_time: str,
        end_time: str,
        full_closure=False,
        avoid_work_probability=None,
        furlough_probability=None,
        key_probability=None,
    ):
        """
        Prevents workers with the tag ``person.lockdown_status=furlough" to go to work.
        If full_closure is True, then no one will go to work.
        """
        super().__init__(start_time, end_time, ("primary_activity", "commute"))
        self.full_closure = full_closure
        self.avoid_work_probability = avoid_work_probability
        self.furlough_probability = furlough_probability
        self.key_probability = key_probability

    @classmethod
    def initialize(cls, world, date, record):
        furlough_ratio = 0
        key_ratio = 0
        random_ratio = 0
        for person in world.people:
            if person.lockdown_status == "furlough":
                furlough_ratio += 1
            elif person.lockdown_status == "key_worker":
                key_ratio += 1
            elif person.lockdown_status == "random":
                random_ratio += 1
        if furlough_ratio != 0 and key_ratio != 0 and random_ratio != 0:
            furlough_ratio /= furlough_ratio + key_ratio + random_ratio
            key_ratio /= furlough_ratio + key_ratio + random_ratio
            random_ratio /= furlough_ratio + key_ratio + random_ratio
        else:
            furlough_ratio = None
            key_ratio = None
            random_ratio = None
        cls.furlough_ratio = furlough_ratio
        cls.key_ratio = key_ratio
        cls.random_ratio = random_ratio

    def check_skips_activity(self, person: "Person") -> bool:
        """
        Returns True if the activity is to be skipped, otherwise False
        """
        if (
            person.primary_activity is not None
            and person.primary_activity.group.spec == "company"
        ):
            # if companies closed skip
            if self.full_closure:
                return True

            elif person.lockdown_status == "furlough":
                if (
                    self.furlough_ratio is not None
                    and self.furlough_probability is not None
                ):
                    # if there are too few furloughed people then always furlough all
                    if self.furlough_ratio < self.furlough_probability:
                        return True
                    # if there are too many or correct number of furloughed people then furlough with a probability
                    elif self.furlough_ratio >= self.furlough_probability:
                        if random() < self.furlough_probability / self.furlough_ratio:
                            return True
                        # otherwise treat them as random
                        elif self.avoid_work_probability is not None:
                            if random() < self.avoid_work_probability:
                                return True
                else:
                    return True

            elif (
                person.lockdown_status == "key_worker"
                and self.key_ratio is not None
                and self.key_probability is not None
            ):
                # if there are too many key workers, scale them down - otherwise send all to work
                if self.key_ratio > self.key_probability:
                    if random() > self.key_probability / self.key_ratio:
                        return True

            elif (
                person.lockdown_status == "random"
                and self.avoid_work_probability is not None
            ):

                if (
                    self.furlough_ratio is not None
                    and self.furlough_probability is not None
                    and self.key_ratio is not None
                    and self.key_probability is not None
                    and self.random_ratio is not None
                ):
                    # if there are too few furloughed people and too few key workers
                    if (
                        self.furlough_ratio < self.furlough_probability
                        and self.key_ratio < self.key_probability
                    ):
                        if (
                            random()
                            < (self.furlough_probability - self.furlough_ratio)
                            / self.random_ratio
                        ):
                            return True
                        # correct for some random workers now being treated as furloughed
                        elif random() < (self.key_probability - self.key_ratio) / (
                            self.random_ratio
                            - (self.furlough_probability - self.furlough_ratio)
                        ):
                            return False
                    # if there are too few furloughed people
                    elif self.furlough_ratio < self.furlough_probability:
                        if (
                            random()
                            < (self.furlough_probability - self.furlough_ratio)
                            / self.random_ratio
                        ):
                            return True
                    # if there are too few kew workers
                    elif self.key_ratio < self.key_probability:
                        if (
                            random()
                            < (self.key_probability - self.key_ratio)
                            / self.random_ratio
                        ):
                            return False

                elif (
                    self.furlough_ratio is not None
                    and self.furlough_probability is not None
                    and self.random_ratio is not None
                ):
                    # if there are too few furloughed people then randomly stop extra people from going to work
                    if self.furlough_ratio < self.furlough_probability:
                        if (
                            random()
                            < (self.furlough_probability - self.furlough_ratio)
                            / self.random_ratio
                        ):
                            return True

                elif (
                    self.key_ratio is not None
                    and self.key_probability is not None
                    and self.random_ratio is not None
                ):
                    # if there are too few key workers then randomly boost more people going to work and do not subject them to the random choice
                    if self.key_ratio < self.key_probability:
                        if (
                            random()
                            < (self.key_probability - self.key_ratio)
                            / self.random_ratio
                        ):
                            return False

                if random() < self.avoid_work_probability:
                    return True

        return False


class LimitLongCommute(SkipActivity):
    """
    Limits long distance commuting from a certain distance.
    If the person has its workplace further than a certain threshold,
    then their probability of going to work every day decreases.
    """

    long_distance_commuter_ids = set()
    apply_from_distance = 150

    def __init__(
        self,
        start_time: str = "1000-01-01",
        end_time: str = "9999-12-31",
        apply_from_distance: float = 150,
        going_to_work_probability: float = 0.2,
    ):
        super().__init__(
            start_time, end_time, activities_to_remove=("primary_activity", "commute")
        )
        self.going_to_work_probability = going_to_work_probability
        self.__class__.apply_from_distance = apply_from_distance
        self.__class__.long_distance_commuter_ids = set()

    def initialize(self, world, date, record):
        return self.get_long_commuters(world.people)

    @classmethod
    def get_long_commuters(cls, people):
        for person in people:
            if cls._does_long_commute(person):
                cls.long_distance_commuter_ids.add(person.id)

    @classmethod
    def _does_long_commute(cls, person: Person):
        if person.work_super_area is None:
            return False
        distance_to_work = haversine_distance(
            person.area.coordinates, person.work_super_area.coordinates
        )
        if distance_to_work > cls.apply_from_distance:
            return True
        return False

    def check_skips_activity(self, person: Person):
        if person.id not in self.long_distance_commuter_ids:
            return False
        else:
            if random() < self.going_to_work_probability:
                return True
            else:
                return False


import datetime

from .policy import Policy, PolicyCollection
from june.interaction import Interaction
from collections import defaultdict


class InteractionPolicy(Policy):
    policy_type = "interaction"


class InteractionPolicies(PolicyCollection):
    policy_type = "interaction"

    def apply(self, date: datetime, interaction: Interaction):
        active_policies = self.get_active(date)
        beta_reductions = defaultdict(lambda: 1.0)
        for policy in active_policies:
            beta_reductions_dict = policy.apply()
            for group in beta_reductions_dict:
                beta_reductions[group] *= beta_reductions_dict[group]
        interaction.beta_reductions = beta_reductions


class SocialDistancing(InteractionPolicy):
    policy_subtype = "beta_factor"

    def __init__(self, start_time: str, end_time: str, beta_factors: dict = None):
        super().__init__(start_time, end_time)
        self.beta_factors = beta_factors

    def apply(self):
        """
        Implement social distancing policy

        -----------
        Parameters:
        betas: e.g. (dict) from DefaultInteraction, e.g. DefaultInteraction.from_file(selector=selector).beta

        Assumptions:
        - Currently we assume that social distancing is implemented first and this affects all
          interactions and intensities globally
        - Currently we assume that the changes are not group dependent
        TODO:
        - Implement structure for people to adhere to social distancing with a certain compliance
        - Check per group in config file
        """
        return self.beta_factors


class MaskWearing(InteractionPolicy):
    policy_subtype = "beta_factor"

    def __init__(
        self,
        start_time: str,
        end_time: str,
        compliance: float,
        beta_factor: float,
        mask_probabilities: dict = None,
    ):
        super().__init__(start_time, end_time)
        self.compliance = compliance
        self.beta_factor = beta_factor
        self.mask_probabilities = mask_probabilities

    def apply(self):
        """
        Implement mask wearing policy

        -----------
        Parameters:
        betas: e.g. (dict) from DefaultInteraction, e.g. DefaultInteraction.from_file(selector=selector).beta

        Assumptions:
        - Currently we assume that mask wearing is implemented in a similar way to social distanding
          but with a mean field effect in beta reduction
        - Currently we assume that the changes are group dependent
        """
        ret = {}
        for key, value in self.mask_probabilities.items():
            ret[key] = 1 - (value * self.compliance * (1 - self.beta_factor))
        return ret


import datetime
from typing import Dict, Union

from .policy import Policy, PolicyCollection
from june.utils.parse_probabilities import parse_age_probabilities
from june.groups.leisure import Leisure


class LeisurePolicy(Policy):
    policy_type = "leisure"

    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
    ):
        super().__init__(start_time, end_time)
        self.policy_type = "leisure"


class LeisurePolicies(PolicyCollection):
    policy_type = "leisure"

    def apply(self, date: datetime, leisure: Leisure):
        """
        Applies all the leisure policies. Each Leisure policy will change the probability of
        doing a certain leisure activity. For instance, closing Pubs sets the probability of
        going to the Pub to zero. We store a dictionary with the relative reductions in leisure
        probabilities per activity, and this dictionary is then looked at by the leisure module.

        This is very similar to how we deal with social distancing / mask wearing policies.
        """
        for region in leisure.regions:
            region.policy["global_closed_venues"] = set()
        leisure.policy_reductions = {}
        if "residence_visits" in leisure.leisure_distributors:
            leisure.leisure_distributors["residence_visits"].policy_reductions = {}
        change_leisure_probability_policies_counter = 0
        for policy in self.get_active(date):
            if policy.policy_subtype == "change_leisure_probability":
                change_leisure_probability_policies_counter += 1
                if change_leisure_probability_policies_counter > 1:
                    raise ValueError(
                        "Having more than one change leisure probability policy"
                        "active is not supported."
                    )
                leisure.policy_reductions = policy.apply(leisure=leisure)
            else:
                policy.apply(leisure=leisure)


class CloseLeisureVenue(LeisurePolicy):
    policy_subtype = "close_venues"

    def __init__(
        self,
        start_time: Union[str, datetime.datetime],
        end_time: Union[str, datetime.datetime],
        venues_to_close=("cinemas", "groceries"),
    ):
        """
        Template for policies that will close types of leisure venues

        Parameters
        ----------
        start_time:
            date at which to start applying the policy
        end_time:
            date from which the policy won't apply
        venues_to_close:
            list of leisure venues that will close
        """

        super().__init__(start_time, end_time)
        self.venues_to_close = venues_to_close

    def apply(self, leisure: Leisure):
        for region in leisure.regions:
            for venue in self.venues_to_close:
                region.policy["global_closed_venues"].add(venue)


class ChangeLeisureProbability(LeisurePolicy):
    policy_subtype = "change_leisure_probability"

    def __init__(
        self,
        start_time: str,
        end_time: str,
        activity_reductions: Dict[str, Dict[str, Dict[str, float]]],
    ):
        """
        Changes the probability of the specified leisure activities.

        Parameters
        ----------
        - start_time : starting time of the policy.
        - end_time : end time of the policy.
        - leisure_activities_probabilities : dictionary where the first key is an age range, and the second  a
            number with the new probability for the activity in that age. Example:
            * leisure_activities_probabilities = {"pubs" : {"men" :{"0-50" : 0.5, "50-99" : 0.2}, "women" : {"0-70" : 0.2, "71-99" : 0.8}}}
        """
        super().__init__(start_time, end_time)
        self.activity_reductions = self._read_activity_reductions(activity_reductions)

    def _read_activity_reductions(self, activity_reductions):
        ret = {}
        day_types = ["weekday", "weekend"]
        sexes = ["male", "female"]
        _sex_t = {"male": "m", "female": "f"}
        for activity, pp in activity_reductions.items():
            ret[activity] = {}
            ret[activity]["weekday"] = {}
            ret[activity]["weekend"] = {}
            for first_entry in pp:
                if first_entry in ["weekday", "weekend"]:
                    day_type = first_entry
                    if "both_sexes" in pp[day_type]:
                        for sex in sexes:
                            june_sex = _sex_t[sex]
                            probs = parse_age_probabilities(
                                activity_reductions[activity][day_type]["both_sexes"]
                            )
                            ret[activity][day_type][june_sex] = probs
                    else:
                        for sex in sexes:
                            june_sex = _sex_t[sex]
                            probs = parse_age_probabilities(
                                activity_reductions[activity][day_type][sex]
                            )
                            ret[activity][day_type][june_sex] = probs
                elif first_entry == "any" or first_entry in ["male", "female"]:
                    for sex in sexes:
                        june_sex = _sex_t[sex]
                        probs = parse_age_probabilities(
                            activity_reductions[activity][sex]
                        )
                        for day_type in day_types:
                            ret[activity][day_type][june_sex] = probs
                elif first_entry == "both_sexes":
                    for sex in sexes:
                        june_sex = _sex_t[sex]
                        probs = parse_age_probabilities(
                            activity_reductions[activity]["both_sexes"]
                        )
                        for day_type in day_types:
                            ret[activity][day_type][june_sex] = probs
                else:
                    for day_type in day_types:
                        for sex in sexes:
                            june_sex = _sex_t[sex]
                            ret[activity][day_type][june_sex] = parse_age_probabilities(
                                activity_reductions[activity][day_type][sex]
                            )
        return ret

    def apply(self, leisure: Leisure):
        return self.activity_reductions


class ChangeVisitsProbability(LeisurePolicy):
    policy_subtype = "change_visits_probability"

    def __init__(
        self,
        start_time: str,
        end_time: str,
        new_residence_type_probabilities: Dict[str, float],
    ):
        """
        Changes the probability of the specified leisure activities.

        Parameters
        ----------
        - start_time : starting time of the policy.
        - end_time : end time of the policy.
        - new_residence_type_probabilities
            new probabilities for residence visits splits, eg, {"household" : 0.8, "care_home" : 0.2}
        """
        super().__init__(start_time, end_time)
        self.policy_reductions = new_residence_type_probabilities

    def apply(self, leisure: Leisure):
        leisure.leisure_distributors[
            "residence_visits"
        ].policy_reductions = self.policy_reductions


import datetime
from typing import List, Optional

from .policy import Policy, PolicyCollection
from june.demography import Person
from june.epidemiology.infection import SymptomTag
from june.records import Record

hospitalised_tags = (SymptomTag.hospitalised, SymptomTag.intensive_care)
dead_hospital_tags = (SymptomTag.dead_hospital, SymptomTag.dead_icu)


class MedicalCarePolicy(Policy):
    def __init__(self, start_time="1900-01-01", end_time="2500-01-01"):
        super().__init__(start_time, end_time)
        self.policy_type = "medical_care"

    def is_active(self, date: datetime.datetime) -> bool:
        return True


class MedicalCarePolicies(PolicyCollection):
    policy_type = "medical_care"

    def __init__(self, policies: List[Policy]):
        """
        A collection of like policies active on the same date
        """
        self.policies = policies
        self.policies_by_name = {
            self._get_policy_name(policy): policy for policy in policies
        }
        self.hospitalisation_policies = [
            policy for policy in self.policies if isinstance(policy, Hospitalisation)
        ]
        self.non_hospitalisation_policies = [
            policy
            for policy in self.policies
            if policy not in self.hospitalisation_policies
        ]

    def apply(
        self,
        person: Person,
        medical_facilities,
        days_from_start: float,
        record: Optional[Record],
    ):
        """
        Applies medical care policies. Hospitalisation takes preference over all.
        """
        for policy in self.hospitalisation_policies:
            activates = policy.apply(person=person, record=record)
            if activates:
                return
        for policy in self.non_hospitalisation_policies:
            activates = policy.apply(person, medical_facilities, days_from_start)
            if activates:
                return


class Hospitalisation(MedicalCarePolicy):
    """
    Hospitalisation policy. When applied to a sick person, allocates that person to a hospital, if the symptoms are severe
    enough. When the person recovers, releases the person from the hospital.
    """

    def __init__(self, start_time="1900-01-01", end_time="2500-01-01"):
        super().__init__(start_time, end_time)

    def apply(self, person: Person, record: Optional[Record] = None):
        symptoms_tag = person.infection.tag
        if symptoms_tag in hospitalised_tags:
            if (
                person.medical_facility is not None
                and person.medical_facility.group.spec == "hospital"
            ):
                patient_hospital = person.medical_facility.group
            else:
                patient_hospital = person.super_area.closest_hospitals[0]
            # note, we dont model hospital capacity here.
            status = patient_hospital.allocate_patient(person)
            if record is not None:
                if status in ["ward_admitted"]:
                    record.accumulate(
                        table_name="hospital_admissions",
                        hospital_id=patient_hospital.id,
                        patient_id=person.id,
                    )
                elif status in ["icu_transferred"]:
                    record.accumulate(
                        table_name="icu_admissions",
                        hospital_id=patient_hospital.id,
                        patient_id=person.id,
                    )
            return True
        else:
            if (
                person.medical_facility is not None
                and person.medical_facility.group.spec == "hospital"
                and symptoms_tag not in dead_hospital_tags
            ):
                if record is not None:
                    record.accumulate(
                        table_name="discharges",
                        hospital_id=person.medical_facility.group.id,
                        patient_id=person.id,
                    )
                person.medical_facility.group.release_patient(person)
        return False


import datetime
import re
from abc import ABC
from typing import List, Union

import yaml

from june import paths
from june.utils import read_date, str_to_class

default_config_filename = paths.configs_path / "defaults/policy/policy.yaml"


class Policy(ABC):
    def __init__(
        self,
        start_time: Union[str, datetime.datetime] = "1900-01-01",
        end_time: Union[str, datetime.datetime] = "2100-01-01",
    ):
        """
        Template for a general policy.

        Parameters
        ----------
        start_time:
            date at which to start applying the policy
        end_time:
            date from which the policy won't apply
        """
        self.spec = self.get_spec()
        self.start_time = read_date(start_time)
        self.end_time = read_date(end_time)

    def get_spec(self) -> str:
        """
        Returns the speciailization of the policy.
        """
        return re.sub(r"(?<!^)(?=[A-Z])", "_", self.__class__.__name__).lower()

    def is_active(self, date: datetime.datetime) -> bool:
        """
        Returns true if the policy is active, false otherwise

        Parameters
        ----------
        date:
            date to check
        """
        return self.start_time <= date < self.end_time

    def initialize(self, world, date, record=None):
        pass


class Policies:
    def __init__(self, policies=None):
        self.policies = policies
        # Note (Arnau): This import here is ugly, but I couldn't
        # find a way to get around a redundant import loop.
        from june.policy import (
            IndividualPolicies,
            InteractionPolicies,
            MedicalCarePolicies,
            LeisurePolicies,
            RegionalCompliances,
            TieredLockdowns,
        )

        self.individual_policies = IndividualPolicies.from_policies(self)
        self.interaction_policies = InteractionPolicies.from_policies(self)
        self.medical_care_policies = MedicalCarePolicies.from_policies(self)
        self.leisure_policies = LeisurePolicies.from_policies(self)
        self.regional_compliance = RegionalCompliances.from_policies(self)
        self.tiered_lockdown = TieredLockdowns.from_policies(self)

    @classmethod
    def from_file(
        cls, config_file=default_config_filename, base_policy_modules=("june.policy",)
    ):
        with open(config_file) as f:
            config = yaml.load(f, Loader=yaml.FullLoader)
        policies = []
        for policy, policy_data in config.items():
            camel_case_key = "".join(x.capitalize() or "_" for x in policy.split("_"))
            if "start_time" not in policy_data:
                for policy_i, policy_data_i in policy_data.items():
                    if (
                        "start_time" not in policy_data_i.keys()
                        or "end_time" not in policy_data_i.keys()
                    ):
                        raise ValueError("policy config file not valid.")
                    policies.append(
                        str_to_class(camel_case_key, base_policy_modules)(
                            **policy_data_i
                        )
                    )
            else:
                policies.append(
                    str_to_class(camel_case_key, base_policy_modules)(**policy_data)
                )

        return Policies(policies=policies)

    def get_policies_for_type(self, policy_type):
        return [policy for policy in self if policy.policy_type == policy_type]

    def __iter__(self):
        if self.policies is None:
            return iter([])
        return iter(self.policies)

    def init_policies(self, world, date, record=None):
        """
        This function is meant to be used for those policies that need world information to initialise,
        like policies depending on workers' behaviours during lockdown.
        """
        for policy in self:
            policy.initialize(world=world, date=date, record=record)


class PolicyCollection:
    def __init__(self, policies: List[Policy]):
        """
        A collection of like policies active on the same date
        """
        self.policies = policies
        self.policies_by_name = {
            self._get_policy_name(policy): policy for policy in policies
        }

    def _get_policy_name(self, policy):
        return re.sub(r"(?<!^)(?=[A-Z])", "_", policy.__class__.__name__).lower()

    @classmethod
    def from_policies(cls, policies: Policies):
        return cls(policies.get_policies_for_type(policy_type=cls.policy_type))

    def get_active(self, date: datetime):
        return [policy for policy in self.policies if policy.is_active(date)]

    def apply(self, active_policies):
        raise NotImplementedError()

    def __iter__(self):
        return iter(self.policies)

    def __getitem__(self, index):
        return self.policies[index]

    def get_from_name(self, name):
        return self.policies_by_name[name]

    def __contains__(self, policy_name):
        return policy_name in self.policies_by_name


from datetime import datetime

from .policy import Policy, PolicyCollection, read_date
from june.geography import Regions


class RegionalCompliance(Policy):
    policy_type = "regional_compliance"

    def __init__(self, start_time: str, end_time: str, compliances_per_region: dict):
        super().__init__(start_time=start_time, end_time=end_time)
        self.compliances_per_region = compliances_per_region

    def apply(self, date: datetime, regions: Regions):
        date = read_date(date)
        if self.is_active(date):
            for region in regions:
                region.regional_compliance = self.compliances_per_region[region.name]


class RegionalCompliances(PolicyCollection):
    policy_type = "regional_compliance"

    def apply(self, date: datetime, regions: Regions):
        # before applying compliances, reset all of them to 1.0
        if self.policies:
            for region in regions:
                region.regional_compliance = 1.0
        for policy in self.policies:
            policy.apply(date=date, regions=regions)


class TieredLockdown(Policy):
    policy_type = "tiered_lockdown"

    def __init__(self, start_time: str, end_time: str, tiers_per_region: dict):
        super().__init__(start_time=start_time, end_time=end_time)
        self.tiers_per_region = tiers_per_region

    def apply(self, date: datetime, regions: Regions):
        date = read_date(date)
        if self.is_active(date):
            for region in regions:
                lockdown_tier = int(self.tiers_per_region[region.name])
                region.policy["lockdown_tier"] = lockdown_tier
                if lockdown_tier == 2:
                    region.policy["local_closed_venues"].update("residence_visits")
                elif lockdown_tier == 3:
                    region.policy["local_closed_venues"].update(
                        set(("cinema", "residence_visits"))
                    )
                elif lockdown_tier == 4:
                    region.policy["local_closed_venues"].update(
                        set(("pub", "cinema", "gym", "residence_visits"))
                    )


class TieredLockdowns(PolicyCollection):
    policy_type = "tiered_lockdown"

    def apply(self, date: datetime, regions: Regions):
        # before applying compliances, reset all of them to None and empty sets
        if self.policies:
            for region in regions:
                region.policy["lockdown_tier"] = None
                region.policy["local_closed_venues"] = set()
        for policy in self.policies:
            policy.apply(date=date, regions=regions)


from .policy import (
    Policy,
    Policies,
    PolicyCollection,
)  # , regional_compliance_is_active
from .interaction_policies import (
    InteractionPolicy,
    InteractionPolicies,
    SocialDistancing,
    MaskWearing,
)
from .leisure_policies import (
    LeisurePolicy,
    LeisurePolicies,
    CloseLeisureVenue,
    ChangeLeisureProbability,
    ChangeVisitsProbability,
)
from .individual_policies import (
    IndividualPolicy,
    IndividualPolicies,
    StayHome,
    SevereSymptomsStayHome,
    Quarantine,
    SchoolQuarantine,
    Shielding,
    CloseCompanies,
    CloseSchools,
    CloseUniversities,
    LimitLongCommute,
)

from .medical_care_policies import (
    MedicalCarePolicy,
    MedicalCarePolicies,
    Hospitalisation,
)

from .regional_compliance import (
    RegionalCompliance,
    RegionalCompliances,
    TieredLockdown,
    TieredLockdowns,
)


import tables
import numpy as np
from june.records.helper_records_writer import _get_description_for_event


class EventRecord:
    def __init__(self, hdf5_filename, table_name, int_names, float_names, str_names):
        self.filename = hdf5_filename
        self.table_name = table_name
        self.int_names = int_names
        self.float_names = float_names
        self.str_names = str_names
        self.attributes = int_names + float_names + str_names
        for attribute in self.attributes:
            setattr(self, attribute, [])
        self._create_table(int_names, float_names, str_names)

    def _create_table(self, int_names, float_names, str_names):
        with tables.open_file(self.filename, mode="a") as file:
            table_description = _get_description_for_event(
                int_names=int_names,
                float_names=float_names,
                str_names=str_names,
                timestamp=True,
            )
            self.table = file.create_table(
                file.root, self.table_name, table_description
            )

    @property
    def number_of_events(self):
        return len(getattr(self, self.attributes[0]))

    def accumulate(self):
        pass

    def record(self, hdf5_file, timestamp: str):
        data = np.rec.fromarrays(
            [
                np.array(
                    [timestamp.strftime("%Y-%m-%d")] * self.number_of_events,
                    dtype="S10",
                )
            ]
            + [np.array(getattr(self, name), dtype=np.int32) for name in self.int_names]
            + [
                np.array(getattr(self, name), dtype=np.float6432)
                for name in self.float_names
            ]
            + [np.array(getattr(self, name), dtype="S20") for name in self.str_names]
        )

        table = getattr(hdf5_file.root, self.table_name)
        table.append(data)
        table.flush()
        for attribute in self.attributes:
            setattr(self, attribute, [])


class InfectionRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="infections",
            int_names=["location_ids", "infector_ids", "infected_ids", "infection_ids"],
            float_names=[],
            str_names=["location_specs", "region_names"],
        )

    def accumulate(
        self,
        location_spec,
        location_id,
        region_name,
        infector_ids,
        infected_ids,
        infection_ids,
    ):
        self.location_specs.extend([location_spec] * len(infected_ids))
        self.location_ids.extend([location_id] * len(infected_ids))
        self.region_names.extend([region_name] * len(infected_ids))
        self.infector_ids.extend(infector_ids)
        self.infected_ids.extend(infected_ids)
        self.infection_ids.extend(infection_ids)


class HospitalAdmissionsRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="hospital_admissions",
            int_names=["hospital_ids", "patient_ids"],
            float_names=[],
            str_names=[],
        )

    def accumulate(self, hospital_id, patient_id):
        self.hospital_ids.append(hospital_id)
        self.patient_ids.append(patient_id)


class ICUAdmissionsRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="icu_admissions",
            int_names=["hospital_ids", "patient_ids"],
            float_names=[],
            str_names=[],
        )

    def accumulate(self, hospital_id, patient_id):
        self.hospital_ids.append(hospital_id)
        self.patient_ids.append(patient_id)


class DischargesRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="discharges",
            int_names=["hospital_ids", "patient_ids"],
            float_names=[],
            str_names=[],
        )

    def accumulate(self, hospital_id, patient_id):
        self.hospital_ids.append(hospital_id)
        self.patient_ids.append(patient_id)


class DeathsRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="deaths",
            int_names=["location_ids", "dead_person_ids"],
            float_names=[],
            str_names=["location_specs"],
        )

    def accumulate(self, location_spec, location_id, dead_person_id):
        self.location_specs.append(location_spec)
        self.location_ids.append(location_id)
        self.dead_person_ids.append(dead_person_id)


class RecoveriesRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="recoveries",
            int_names=["recovered_person_ids", "infection_ids"],
            float_names=[],
            str_names=[],
        )

    def accumulate(self, recovered_person_id, infection_id):
        self.recovered_person_ids.append(recovered_person_id)
        self.infection_ids.append(infection_id)


class SymptomsRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="symptoms",
            int_names=["infected_ids", "new_symptoms", "infection_ids"],
            float_names=[],
            str_names=[],
        )

    def accumulate(self, infected_id, symptoms, infection_id):
        self.infected_ids.append(infected_id)
        self.new_symptoms.append(symptoms)
        self.infection_ids.append(infection_id)


class VaccinesRecord(EventRecord):
    def __init__(self, hdf5_filename):
        super().__init__(
            hdf5_filename=hdf5_filename,
            table_name="vaccines",
            int_names=["vaccinated_ids", "dose_numbers"],
            float_names=[],
            str_names=["vaccine_names"],
        )

    def accumulate(self, vaccinated_id, vaccine_name, dose_number):
        self.vaccinated_ids.append(vaccinated_id)
        self.vaccine_names.append(vaccine_name)
        self.dose_numbers.append(dose_number)


import tables


def _get_description_for_event(
    int_names,
    float_names,
    str_names,
    int_size=32,
    float_size=32,
    str_size=20,
    timestamp=True,
):
    int_constructor = tables.Int64Col
    if int_size == 32:
        int_constructor = tables.Int32Col
    elif int_size not in (32, 64):
        raise "int_size must be left unspecified, or should equal 32 or 64"
    float_constructor = tables.Float32Col
    if float_size == 64:
        float_constructor = tables.Float64Col
    elif float_size not in (32, 64):
        raise "float_size must be left unspecified, or should equal 32 or 64"
    str_constructor = tables.StringCol
    description = {}
    pos = 0
    if timestamp:
        description["timestamp"] = tables.StringCol(itemsize=10, pos=pos)
        pos += 1
    for n in int_names:
        description[n] = int_constructor(pos=pos)
        pos += 1
    for n in float_names:
        description[n] = float_constructor(pos=pos)
        pos += 1
    for n in str_names:
        description[n] = str_constructor(itemsize=str_size, pos=pos)
        pos += 1
    return description


from pathlib import Path
from typing import Optional, Tuple
import numpy as np
import pandas as pd
import tables
import logging


logger = logging.getLogger(__name__)


class RecordReader:
    def __init__(self, results_path=Path("results"), record_name: str = None):
        self.results_path = Path(results_path)
        try:
            self.regional_summary = self.get_regional_summary(
                self.results_path / "summary.csv"
            )
        except Exception:
            self.regional_summary = None
            logger.warning("No summary available to read...")
        if self.regional_summary is not None:
            self.world_summary = self.get_world_summary()
        if record_name is None:
            self.record_name = "june_record.h5"
        else:
            self.record_name = record_name

    def decode_bytes_columns(self, df):
        str_df = df.select_dtypes([object])
        for col in str_df:
            df[col] = str_df[col].str.decode("utf-8")
        return df

    def get_regional_summary(self, summary_path):
        df = pd.read_csv(summary_path)
        cols = [col for col in df.columns if col not in ["time_stamp", "region"]]
        self.aggregator = {col: np.mean if "current" in col else sum for col in cols}
        df = df.groupby(["region", "time_stamp"], as_index=False).agg(self.aggregator)
        df.set_index("time_stamp", inplace=True)
        df.index = pd.to_datetime(df.index)
        return df

    def get_world_summary(self):
        return (
            self.regional_summary.drop(columns="region")
            .groupby("time_stamp")
            .agg(self.aggregator)
        )

    def table_to_df(
        self, table_name: str, index: str = "id", fields: Optional[Tuple] = None
    ) -> pd.DataFrame:
        # TODO: include fields to read only certain columns
        with tables.open_file(self.results_path / self.record_name, mode="r") as f:
            table = getattr(f.root, table_name)
            df = pd.DataFrame.from_records(table.read(), index=index)
        df = self.decode_bytes_columns(df)
        return df

    def get_geography_df(
        self,
    ):
        areas_df = self.table_to_df("areas")
        super_areas_df = self.table_to_df("super_areas")
        regions_df = self.table_to_df("regions")

        geography_df = areas_df[["super_area_id", "name"]].merge(
            super_areas_df[["region_id", "name"]],
            how="inner",
            left_on="super_area_id",
            right_index=True,
            suffixes=("_area", "_super_area"),
        )
        geography_df = geography_df.merge(
            regions_df, how="inner", left_on="region_id", right_index=True
        )
        return geography_df.rename(
            columns={geography_df.index.name: "area_id", "name": "name_region"}
        )

    def get_table_with_extras(
        self,
        table_name,
        index,
        with_people=True,
        with_geography=True,
        people_df=None,
        geography_df=None,
    ):
        logger.info(f"Loading {table_name} table")
        df = self.table_to_df(table_name, index=index)
        if with_people:
            logger.info("Loading population table")
            if people_df is None:
                people_df = self.table_to_df("population", index="id")
            logger.info("Merging infection and population tables")
            df = df.merge(people_df, how="inner", left_index=True, right_index=True)
            if with_geography:
                logger.info("Loading geography table")
                if geography_df is None:
                    geography_df = self.get_geography_df()
                logger.info("Mergeing infection and geography tables")
                df = df.merge(
                    geography_df.drop_duplicates(),
                    left_on="area_id",
                    right_index=True,
                    how="inner",
                )
        if "timestamp" in df.columns:
            df["timestamp"] = pd.to_datetime(df["timestamp"])
        return df


import os
import tables
import pandas as pd
import yaml
import numpy as np
import csv
import json
import subprocess
from datetime import datetime
from pathlib import Path
from typing import Optional
from collections import defaultdict
import logging

import june
from june.records.event_records_writer import (
    InfectionRecord,
    HospitalAdmissionsRecord,
    ICUAdmissionsRecord,
    DischargesRecord,
    DeathsRecord,
    RecoveriesRecord,
    SymptomsRecord,
    VaccinesRecord,
)
from june.records.static_records_writer import (
    PeopleRecord,
    LocationRecord,
    AreaRecord,
    SuperAreaRecord,
    RegionRecord,
)

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from june.world import World
    from june.interaction.interaction import Interaction
    from june.epidemiology.infection_seed.infection_seed import InfectionSeeds
    from june.epidemiology.infection import InfectionSelectors
    from june.epidemiology.epidemiology import Epidemiology
    from june.activity.activity_manager import ActivityManager

logger = logging.getLogger("records_writer")


class Record:
    def __init__(
        self, record_path: str, record_static_data=False, mpi_rank: Optional[int] = None
    ):
        self.record_path = Path(record_path)
        self.record_path.mkdir(parents=True, exist_ok=True)
        self.mpi_rank = mpi_rank
        if mpi_rank is not None:
            self.filename = f"june_record.{mpi_rank}.h5"
            self.summary_filename = f"summary.{mpi_rank}.csv"
        else:
            self.filename = "june_record.h5"
            self.summary_filename = "summary.csv"
        self.configs_filename = "config.yaml"
        self.record_static_data = record_static_data
        try:
            os.remove(self.record_path / self.filename)
        except OSError:
            pass
        filename = self.record_path / self.filename
        self.events = {
            "infections": InfectionRecord(hdf5_filename=filename),
            "hospital_admissions": HospitalAdmissionsRecord(hdf5_filename=filename),
            "icu_admissions": ICUAdmissionsRecord(hdf5_filename=filename),
            "discharges": DischargesRecord(hdf5_filename=filename),
            "deaths": DeathsRecord(hdf5_filename=filename),
            "recoveries": RecoveriesRecord(hdf5_filename=filename),
            "symptoms": SymptomsRecord(hdf5_filename=filename),
            "vaccines": VaccinesRecord(hdf5_filename=filename),
        }
        if self.record_static_data:
            self.statics = {
                "people": PeopleRecord(),
                "locations": LocationRecord(),
                "areas": AreaRecord(),
                "super_areas": SuperAreaRecord(),
                "regions": RegionRecord(),
            }
        with open(
            self.record_path / self.summary_filename, "w", newline=""
        ) as summary_file:
            writer = csv.writer(summary_file)
            # fields = ["infected", "recovered", "hospitalised", "intensive_care"]
            fields = ["infected", "hospitalised", "intensive_care"]
            header = ["time_stamp", "region"]
            for field in fields:
                header.append("current_" + field)
                header.append("daily_" + field)
            header.extend(
                # ["current_susceptible", "daily_hospital_deaths", "daily_deaths"]
                ["daily_hospital_deaths", "daily_deaths"]
            )
            writer.writerow(header)
        description = {
            "description": f"Started runnning at {datetime.now()}. Good luck!"
        }
        with open(self.record_path / self.configs_filename, "w") as f:
            yaml.dump(description, f)

    def static_data(self, world: "World"):
        with tables.open_file(self.record_path / self.filename, mode="a") as file:
            for static_name in self.statics.keys():
                self.statics[static_name].record(hdf5_file=file, world=world)

    def accumulate(self, table_name: str, **kwargs):
        self.events[table_name].accumulate(**kwargs)

    def time_step(self, timestamp: str):
        with tables.open_file(self.record_path / self.filename, mode="a") as file:
            for event_name in self.events.keys():
                self.events[event_name].record(hdf5_file=file, timestamp=timestamp)

    def summarise_hospitalisations(self, world: "World"):
        hospital_admissions, icu_admissions = defaultdict(int), defaultdict(int)
        for hospital_id in self.events["hospital_admissions"].hospital_ids:
            hospital = world.hospitals.get_from_id(hospital_id)
            hospital_admissions[hospital.region_name] += 1
        for hospital_id in self.events["icu_admissions"].hospital_ids:
            hospital = world.hospitals.get_from_id(hospital_id)
            icu_admissions[hospital.region_name] += 1
        current_hospitalised, current_intensive_care = (
            defaultdict(int),
            defaultdict(int),
        )
        for hospital in world.hospitals:
            if not hospital.external:
                current_hospitalised[hospital.region_name] += len(hospital.ward)
                current_intensive_care[hospital.region_name] += len(hospital.icu)
        return (
            hospital_admissions,
            icu_admissions,
            current_hospitalised,
            current_intensive_care,
        )

    def summarise_infections(self, world="World"):
        daily_infections, current_infected = defaultdict(int), defaultdict(int)
        for region in self.events["infections"].region_names:
            daily_infections[region] += 1
        for region in world.regions:
            current_infected[region.name] = len(
                [person for person in region.people if person.infected]
            )
        return daily_infections, current_infected

    def summarise_deaths(self, world="World"):
        daily_deaths, daily_deaths_in_hospital = defaultdict(int), defaultdict(int)
        for i, person_id in enumerate(self.events["deaths"].dead_person_ids):
            region = world.people.get_from_id(person_id).super_area.region.name
            daily_deaths[region] += 1
            if self.events["deaths"].location_specs[i] == "hospital":
                hospital_id = self.events["deaths"].location_ids[i]
                region = world.hospitals.get_from_id(hospital_id).region_name
                daily_deaths_in_hospital[region] += 1
        return daily_deaths, daily_deaths_in_hospital

    def summarise_time_step(self, timestamp: str, world: "World"):
        daily_infected, current_infected = self.summarise_infections(world=world)
        (
            daily_hospitalised,
            daily_intensive_care,
            current_hospitalised,
            current_intensive_care,
        ) = self.summarise_hospitalisations(world=world)

        daily_deaths, daily_deaths_in_hospital = self.summarise_deaths(world=world)
        all_hospital_regions = [hospital.region_name for hospital in world.hospitals]
        all_world_regions = [region.name for region in world.regions]
        all_regions = set(all_hospital_regions + all_world_regions)
        with open(
            self.record_path / self.summary_filename, "a", newline=""
        ) as summary_file:
            summary_writer = csv.writer(summary_file)
            for region in all_regions:
                data = [
                    current_infected.get(region, 0),
                    daily_infected.get(region, 0),
                    current_hospitalised.get(region, 0),
                    daily_hospitalised.get(region, 0),
                    current_intensive_care.get(region, 0),
                    daily_intensive_care.get(region, 0),
                    daily_deaths_in_hospital.get(region, 0),
                    daily_deaths.get(region, 0),
                ]
                if sum(data) > 0:
                    summary_writer.writerow(
                        [timestamp.strftime("%Y-%m-%d"), region] + data
                    )

    def combine_outputs(self, remove_left_overs=True):
        combine_records(self.record_path, remove_left_overs=remove_left_overs)

    def append_dict_to_configs(self, config_dict):
        with open(self.record_path / self.configs_filename, "r") as f:
            configs = yaml.safe_load(f)
            configs.update(config_dict)
        with open(self.record_path / self.configs_filename, "w") as f:
            yaml.safe_dump(
                configs,
                f,
                allow_unicode=True,
                default_flow_style=False,
                sort_keys=False,
            )

    def parameters_interaction(self, interaction: "Interaction" = None):
        if interaction is not None:
            interaction_dict = {}
            interaction_dict["betas"] = interaction.betas
            interaction_dict["alpha_physical"] = interaction.alpha_physical
            interaction_dict["contact_matrices"] = {}
            for key, values in interaction.contact_matrices.items():
                interaction_dict["contact_matrices"][key] = values.tolist()
            self.append_dict_to_configs(config_dict={"interaction": interaction_dict})

    def parameters_seed(self, infection_seeds: "InfectionSeeds" = None):
        if infection_seeds is not None:
            infection_seeds_dict = {}
            for infection_seed in infection_seeds:
                inf_seed_dict = {}
                inf_seed_dict["seed_strength"] = infection_seed.seed_strength
                inf_seed_dict["min_date"] = infection_seed.min_date.strftime("%Y-%m-%d")
                inf_seed_dict["max_date"] = infection_seed.max_date.strftime("%Y-%m-%d")
                infection_seeds_dict[
                    infection_seed.infection_selector.infection_class.__name__
                ] = inf_seed_dict
            self.append_dict_to_configs(
                config_dict={"infection_seeds": infection_seeds_dict}
            )

    def parameters_infection(self, infection_selectors: "InfectionSelectors" = None):
        if infection_selectors is not None:
            save_dict = {}
            for selector in infection_selectors._infection_selectors:
                class_name = selector.infection_class.__name__
                save_dict[class_name] = {}
                save_dict[class_name]["transmission_type"] = selector.transmission_type
            self.append_dict_to_configs(config_dict={"infections": save_dict})

    def parameters_policies(self, activity_manager: "ActivityManager" = None):
        if activity_manager is not None:
            policy_dicts = []
            for policy in activity_manager.policies.policies:
                policy_dicts.append(policy.__dict__.copy())
            with open(self.record_path / "policies.json", "w") as f:
                json.dump(policy_dicts, f, indent=4, default=str)

    @staticmethod
    def get_username():
        try:
            username = os.getlogin()
        except Exception:
            username = "no_user"
        return username

    def parameters(
        self,
        interaction: "Interaction" = None,
        epidemiology: "Epidemiology" = None,
        activity_manager: "ActivityManager" = None,
    ):
        if epidemiology:
            infection_seeds = epidemiology.infection_seeds
            infection_selectors = epidemiology.infection_selectors
        if self.mpi_rank is None or self.mpi_rank == 0:
            self.parameters_interaction(interaction=interaction)
            self.parameters_seed(infection_seeds=infection_seeds)
            self.parameters_infection(infection_selectors=infection_selectors)
            self.parameters_policies(activity_manager=activity_manager)

    def meta_information(
        self,
        comment: Optional[str] = None,
        random_state: Optional[int] = None,
        number_of_cores: Optional[int] = None,
    ):
        if self.mpi_rank is None or self.mpi_rank == 0:
            june_git = Path(june.__path__[0]).parent / ".git"
            meta_dict = {}
            branch_cmd = f"git --git-dir {june_git} rev-parse --abbrev-ref HEAD".split()
            try:
                meta_dict["branch"] = (
                    subprocess.run(branch_cmd, stdout=subprocess.PIPE)
                    .stdout.decode("utf-8")
                    .strip()
                )
            except Exception as e:
                print(e)
                print("Could not record git branch")
                meta_dict["branch"] = "unavailable"
            local_SHA_cmd = f'git --git-dir {june_git} log -n 1 --format="%h"'.split()
            try:
                meta_dict["local_SHA"] = (
                    subprocess.run(local_SHA_cmd, stdout=subprocess.PIPE)
                    .stdout.decode("utf-8")
                    .strip()
                )
            except Exception:
                print("Could not record local git SHA")
                meta_dict["local_SHA"] = "unavailable"
            user = self.get_username()
            meta_dict["user"] = user
            if comment is None:
                comment = "No comment provided."
            meta_dict["user_comment"] = f"{comment}"
            meta_dict["june_path"] = str(june.__path__[0])
            meta_dict["number_of_cores"] = number_of_cores
            meta_dict["random_state"] = random_state
            with open(self.record_path / self.configs_filename, "r") as f:
                configs = yaml.safe_load(f)
                configs.update({"meta_information": meta_dict})
            with open(self.record_path / self.configs_filename, "w") as f:
                yaml.safe_dump(configs, f)


def combine_summaries(record_path, remove_left_overs=False, save_dir=None):
    record_path = Path(record_path)
    summary_files = record_path.glob("summary.*.csv")
    dfs = []
    for summary_file in summary_files:
        df = pd.read_csv(summary_file)
        if len(df) == 0:
            continue
        aggregator = {
            col: "mean" if "current" in col else "sum" for col in df.columns[2:]
        }
        df = df.groupby(["region", "time_stamp"], as_index=False).agg(aggregator)
        dfs.append(df)
        if remove_left_overs:
            summary_file.unlink()
    summary = pd.concat(dfs)
    summary = summary.groupby(["region", "time_stamp"]).sum()
    if save_dir is None:
        save_path = record_path
    else:
        save_path = Path(save_dir)
    full_summary_save_path = save_path / "summary.csv"
    summary.to_csv(full_summary_save_path)


def combine_hdf5s(
    record_path,
    table_names=("infections", "population"),
    remove_left_overs=False,
    save_dir=None,
):
    record_files = record_path.glob("june_record.*.h5")
    if save_dir is None:
        save_path = Path(record_path)
    else:
        save_path = Path(save_dir)
    full_record_save_path = save_path / "june_record.h5"
    with tables.open_file(full_record_save_path, "w") as merged_record:
        for i, record_file in enumerate(record_files):
            with tables.open_file(str(record_file), "r") as record:
                datasets = record.root._f_list_nodes()
                for dataset in datasets:
                    arr_data = dataset[:]
                    if i == 0:
                        description = getattr(record.root, dataset.name).description
                        merged_record.create_table(
                            merged_record.root, dataset.name, description=description
                        )
                    if len(arr_data) > 0:
                        table = getattr(merged_record.root, dataset.name)
                        table.append(arr_data)
                        table.flush()
            if remove_left_overs:
                record_file.unlink()


def combine_records(record_path, remove_left_overs=False, save_dir=None):
    record_path = Path(record_path)
    combine_summaries(
        record_path, remove_left_overs=remove_left_overs, save_dir=save_dir
    )
    combine_hdf5s(record_path, remove_left_overs=remove_left_overs, save_dir=save_dir)


def prepend_checkpoint_hdf5(
    pre_checkpoint_record_path,
    post_checkpoint_record_path,
    tables_to_merge=(
        "deaths",
        "discharges",
        "hospital_admissions",
        "icu_admissions",
        "infections",
        "recoveries",
        "symptoms",
    ),
    merged_record_path=None,
    checkpoint_date: str = None,
):
    pre_checkpoint_record_path = Path(pre_checkpoint_record_path)
    post_checkpoint_record_path = Path(post_checkpoint_record_path)
    if merged_record_path is None:
        merged_record_path = (
            post_checkpoint_record_path.parent / "merged_checkpoint_june_record.h5"
        )

    with tables.open_file(merged_record_path, "w") as merged_record:
        with tables.open_file(pre_checkpoint_record_path, "r") as pre_record:
            with tables.open_file(post_checkpoint_record_path, "r") as post_record:
                post_infection_dates = np.array(
                    [
                        datetime.strptime(x.decode("utf-8"), "%Y-%m-%d")
                        for x in post_record.root["infections"][:]["timestamp"]
                    ]
                )
                min_date = min(post_infection_dates)
                if checkpoint_date is None:
                    print("provide the date you expect the checkpoint to start at!")
                else:
                    if checkpoint_date != checkpoint_date:
                        print(
                            f"provided date {checkpoint_date} does not match min date {min_date}"
                        )

                for dataset in post_record.root._f_list_nodes():
                    description = getattr(post_record.root, dataset.name).description
                    if dataset.name not in tables_to_merge:
                        arr_data = dataset[:]
                        merged_record.create_table(
                            merged_record.root, dataset.name, description=description
                        )
                        if len(arr_data) > 0:
                            table = getattr(merged_record.root, dataset.name)
                            table.append(arr_data)
                            table.flush()
                    else:
                        pre_arr_data = pre_record.root[dataset.name][:]
                        pre_dates = np.array(
                            [
                                datetime.strptime(x.decode("utf-8"), "%Y-%m-%d")
                                for x in pre_arr_data["timestamp"]
                            ]
                        )
                        pre_arr_data = pre_arr_data[pre_dates < min_date]
                        post_arr_data = dataset[:]

                        merged_record.create_table(
                            merged_record.root, dataset.name, description=description
                        )
                        table = getattr(merged_record.root, dataset.name)
                        if len(pre_arr_data) > 0:
                            table.append(pre_arr_data)
                        if len(post_arr_data) > 0:
                            table.append(post_arr_data)
                        table.flush()
    logger.info(f"written prepended record to {merged_record_path}")


def prepend_checkpoint_summary(
    pre_checkpoint_summary_path,
    post_checkpoint_summary_path,
    merged_summary_path=None,
    checkpoint_date=None,
):
    pre_checkpoint_summary_path = Path(pre_checkpoint_summary_path)
    post_checkpoint_summary_path = Path(post_checkpoint_summary_path)

    if merged_summary_path is None:
        merged_summary_path = (
            post_checkpoint_summary_path.parent / "merged_checkpoint_summary.csv"
        )

    pre_summary = pd.read_csv(pre_checkpoint_summary_path)
    post_summary = pd.read_csv(post_checkpoint_summary_path)
    pre_summary["time_stamp"] = pd.to_datetime(pre_summary["time_stamp"])
    post_summary["time_stamp"] = pd.to_datetime(post_summary["time_stamp"])
    min_date = min(post_summary["time_stamp"])
    if checkpoint_date is None:
        print("Provide the checkpoint date you expect the post-summary to start at!")
    else:
        if min_date != checkpoint_date:
            print(
                f"Provided date {checkpoint_date} does not match the earliest date in the summary!"
            )
    pre_summary = pre_summary[pre_summary["time_stamp"] < min_date]
    merged_summary = pd.concat([pre_summary, post_summary], ignore_index=True)
    merged_summary.set_index(["region", "time_stamp"])
    merged_summary.sort_index(inplace=True)
    merged_summary.to_csv(merged_summary_path, index=True)
    logger.info(f"Written merged summary to {merged_summary_path}")


import numpy as np

from june.records.helper_records_writer import _get_description_for_event
from june.groups import Supergroup


class StaticRecord:
    def __init__(self, table_name, int_names, float_names, str_names, expectedrows):
        self.table_name = table_name
        self.int_names = int_names
        self.float_names = float_names
        self.str_names = str_names
        self.expectedrows = expectedrows
        self.extra_int_data = {}
        self.extra_float_data = {}
        self.extra_str_data = {}

    def _create_table(self, file, int_names, float_names, str_names, expectedrows):
        table_description = _get_description_for_event(
            int_names=int_names,
            float_names=float_names,
            str_names=str_names,
            timestamp=False,
        )
        self.table = file.create_table(
            file.root, self.table_name, table_description, expectedrows=expectedrows
        )

    def _record(self, hdf5_file, int_data, float_data, str_data):
        data = np.rec.fromarrays(
            [np.array(data, dtype=np.uint32) for data in int_data]
            + [np.array(data, dtype=np.float32) for data in float_data]
            + [np.array(data, dtype="S20") for data in str_data]
        )
        table = getattr(hdf5_file.root, self.table_name)
        table.append(data)
        table.flush()

    def get_data(self, world):
        pass

    def record(self, hdf5_file, world):
        int_data, float_data, str_data = self.get_data(world=world)
        if self.extra_int_data is not None:
            self.int_names += list(self.extra_int_data.keys())
            for value in self.extra_int_data.values():
                int_data += [value]
        if self.extra_float_data is not None:
            self.float_names += list(self.extra_float_data.keys())
            for value in self.extra_float_data.values():
                float_data += [value]
        if self.extra_str_data is not None:
            self.str_names += list(self.extra_str_data.keys())
            for value in self.extra_str_data.values():
                str_data += [value]
        self._create_table(
            hdf5_file,
            self.int_names,
            self.float_names,
            self.str_names,
            self.expectedrows,
        )
        self._record(
            hdf5_file=hdf5_file,
            int_data=int_data,
            float_data=float_data,
            str_data=str_data,
        )


class PeopleRecord(StaticRecord):
    def __init__(self):
        super().__init__(
            table_name="population",
            int_names=["id", "age", "primary_activity_id", "residence_id", "area_id"],
            float_names=[],
            str_names=["sex", "ethnicity", "primary_activity_type", "residence_type"],
            expectedrows=1_000_000,
        )

        self.extra_float_data = {}
        self.extra_int_data = {}
        self.extra_str_data = {}

    def get_data(self, world):
        (
            ids,
            age,
            primary_activity_type,
            primary_activity_id,
            residence_type,
            residence_id,
            area_id,
            sex,
            ethnicity,
        ) = ([], [], [], [], [], [], [], [], [])
        for person in world.people:
            ids.append(person.id)
            age.append(person.age)
            primary_activity_type.append(
                person.primary_activity.group.spec
                if person.primary_activity is not None
                else "None"
            )
            primary_activity_id.append(
                person.primary_activity.group.id
                if person.primary_activity is not None
                else 0
            )
            residence_type.append(
                person.residence.group.spec if person.residence is not None else "None"
            )
            residence_id.append(
                person.residence.group.id if person.residence is not None else 0
            )
            area_id.append(person.area.id if person.area is not None else 0)
            sex.append(person.sex)
            ethnicity.append(
                person.ethnicity if person.ethnicity is not None else "None"
            )
        int_data = [ids, age, primary_activity_id, residence_id, area_id]
        float_data = []
        str_data = []
        str_data = [sex, ethnicity, primary_activity_type, residence_type]
        return int_data, float_data, str_data


class LocationRecord(StaticRecord):
    def __init__(self):
        super().__init__(
            table_name="locations",
            int_names=["id", "group_id", "area_id"],
            float_names=["latitude", "longitude"],
            str_names=["spec"],
            expectedrows=1_000_000,
        )

    def get_data(self, world):
        (ids, latitude, longitude, group_spec, group_id, area_id) = (
            [],
            [],
            [],
            [],
            [],
            [],
        )
        counter = 0
        for attribute, value in world.__dict__.items():
            if isinstance(value, Supergroup) and attribute not in (
                "cities",
                "cemeteries",
                "stations",
            ):
                for group in getattr(world, attribute):
                    if group.external:
                        continue
                    ids.append(counter)
                    group_spec.append(group.spec)
                    group_id.append(group.id)
                    area_id.append(group.area.id)
                    latitude.append(group.coordinates[0])
                    longitude.append(group.coordinates[1])
                    counter += 1
        int_data = [ids, group_id, area_id]
        float_data = [latitude, longitude]
        str_data = [group_spec]
        return int_data, float_data, str_data


class AreaRecord(StaticRecord):
    def __init__(self):
        super().__init__(
            table_name="areas",
            int_names=["id", "super_area_id"],
            float_names=["latitude", "longitude", "socioeconomic_index"],
            str_names=["name"],
            expectedrows=10_000,
        )

    def get_data(self, world):
        (
            area_id,
            super_area_id,
            latitude,
            longitude,
            socioeconomic_index,
            area_name,
        ) = ([], [], [], [], [], [])
        if world.areas is not None:
            for area in world.areas:
                area_id.append(area.id)
                super_area_id.append(area.super_area.id)
                latitude.append(area.coordinates[0])
                longitude.append(area.coordinates[1])
                socioeconomic_index.append(area.socioeconomic_index)
                area_name.append(area.name)
        int_data = [area_id, super_area_id]
        float_data = [latitude, longitude, socioeconomic_index]
        str_data = [area_name]
        return int_data, float_data, str_data


class SuperAreaRecord(StaticRecord):
    def __init__(self):
        super().__init__(
            table_name="super_areas",
            int_names=["id", "region_id"],
            float_names=["latitude", "longitude"],
            str_names=["name"],
            expectedrows=5_000,
        )

    def get_data(self, world):
        super_area_id, region_id, latitude, longitude, super_area_name = (
            [],
            [],
            [],
            [],
            [],
        )
        if world.super_areas is not None:
            for super_area in world.super_areas:
                super_area_id.append(super_area.id)
                region_id.append(super_area.region.id)
                latitude.append(super_area.coordinates[0])
                longitude.append(super_area.coordinates[1])
                super_area_name.append(super_area.name)
        int_data = [super_area_id, region_id]
        float_data = [latitude, longitude]
        str_data = [super_area_name]
        return int_data, float_data, str_data


class RegionRecord(StaticRecord):
    def __init__(self):
        super().__init__(
            table_name="regions",
            int_names=["id"],
            float_names=[],
            str_names=["name"],
            expectedrows=50,
        )

    def get_data(self, world):
        region_id, region_name = [], []
        if world.regions is not None:
            for region in world.regions:
                region_id.append(region.id)
                region_name.append(region.name)
        int_data = [region_id]
        float_data = []
        str_data = [region_name]
        return int_data, float_data, str_data


from .records_writer import Record
from .records_writer import combine_records
from .records_reader import RecordReader


import numpy as np
from june.mpi_setup import mpi_comm, mpi_size, mpi_rank
import logging

import yaml
import pandas as pd
import warnings

from pathlib import Path
from june import paths

from june.world import World
import geopy.distance

from june.groups.group import make_subgroups

warnings.simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

AgeAdult = make_subgroups.SubgroupParams.AgeYoungAdult
ACArray = np.array([0, AgeAdult, 100])
DaysOfWeek_Names = [
    "Sunday",
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
]

default_interaction_path = paths.configs_path / "defaults/interaction/interaction.yaml"

logger = logging.getLogger("tracker")
mpi_logger = logging.getLogger("mpi")

if mpi_rank > 0:
    logger.propagate = False


class Tracker:
    """
    Class to handle the contact tracker.

    Parameters
    ----------
    world:
        instance of World class
    age_bins:
        dictionary mapping of bin structure and array of bin edges
    contact_sexes:
        list of sexes for which to create contact matrix. "male", "female" and or "unisex" (for both together)
    group_types:
        list of world.locations for tracker to loop over
    record_path:
        path for results directory
    load_interactions_path:
        path for interactions yaml directory
    Tracker_Contact_Type:
        NONE, Not used
    MaxVenueTrackingSize:
        int, Maximum number for venue type to track. Default is all venues in world.VENUE are tracked

    Returns
    -------
        A Tracker

    """

    def __init__(
        self,
        world: World,
        age_bins={"syoa": np.arange(0, 101, 1)},
        contact_sexes=["unisex"],
        group_types=None,
        record_path=Path(""),
        load_interactions_path=default_interaction_path,
        Tracker_Contact_Type=None,
        MaxVenueTrackingSize=np.inf,
    ):

        if Tracker_Contact_Type is None:
            pass
        else:
            print("Tracker_Contact_Type argument no longer required")
        self.world = world
        self.age_bins = age_bins
        self.contact_sexes = contact_sexes
        self.group_types = group_types
        self.timer = None
        self.record_path = record_path
        self.load_interactions_path = load_interactions_path

        self.MaxVenueTrackingSize = MaxVenueTrackingSize

        # If we want to track total persons at each location
        self.initialise_group_names()

        # Maximum number of locations...
        locations = []
        for locs in self.group_type_names:
            if locs in ["global", "shelter_inter", "shelter_intra"]:
                continue

            locations.append(self.pluralize(locs))

        self.venues_which = {}
        for spec in locations:
            if len(getattr(self.world, spec).members) > MaxVenueTrackingSize:
                self.venues_which[spec] = np.random.choice(
                    np.arange(0, len(getattr(self.world, spec).members), 1),
                    size=self.MaxVenueTrackingSize,
                    replace=False,
                )
            else:
                self.venues_which[spec] = np.arange(
                    0, len(getattr(self.world, spec).members), 1
                )

        self.initialise_location_counters()

        self.load_interactions(
            self.load_interactions_path
        )  # Load in pre-made contact matrices
        self.initialise_contact_matrices()

        # store all ages/ index to age bins in python dict for quick lookup.
        self.hash_ages()

        # Initialize time, pop and contact counters
        self.initialise_location_cum_time()
        self.initialise_location_cum_pop()
        self.initialise_contact_counters()

        self.travel_distance = {}

    #####################################################
    # Useful functions ##################################
    #####################################################

    @staticmethod
    def _random_round(x):
        """
        Round integer randomly up or down

        Parameters
        ----------
            x:
                A float

        Returns
        -------
            int

        """
        f = x % 1
        if np.random.uniform(0, 1, 1) < f:
            return int(x) + 1
        else:
            return int(x)

    def intersection(self, list_A, list_B, permute=True):
        """
        Get shared elements in two lists

        Parameters
        ----------
            list_A:
                list of objects
            list_B:
                second list of objects
            permute: default = True
                bool, shuffle the returned list

        Returns
        -------
            list of shared elements

        """
        Intersection = np.array(list(set(list_A) & set(list_B)))
        if permute:
            return list(Intersection[np.random.permutation(len(Intersection))])
        else:
            return list(Intersection)

    def union(self, list_A, list_B):
        """
        Get all unique elements in two lists

        Parameters
        ----------
            list_A:
                list of objects
            list_B:
                second list of objects


        Returns
        -------
            list of all unique elements

        """
        Union = sorted(list(set(list_A + list_B)))
        return Union

    def pluralize_r(self, loc):
        """
        Some naming conventions of the venues are plurals or not.
        Here is a function for consistent conversion to de-pluralize

        Parameters
        ----------
            loc:
                string
        Returns
        -------
            string, singular

        """
        # Global is exception
        if loc == "global":
            return loc
        if loc[-3:] == "ies":
            loc = loc[:-3] + "y"
        elif loc[-1] == "s":
            loc = loc[:-1]
        return loc

    def pluralize(self, loc):
        """
        Some naming conventions of the venues are plurals or not.
        Here is a function for consistent conversion to pluralize

        Parameters
        ----------
            loc:
                string
        Returns
        -------
            string, pluralized

        """
        # Global is exception
        if loc == "global":
            return loc
        if loc[-1] == "y":
            loc = loc[:-1] + "ies"
        else:
            loc = loc + "s"
        return loc

    ########################################################
    # CM Normalization functions ###########################
    ########################################################

    def cm_shelter_renorm(self, cm, shelter_shared=0.75):
        """
        Special Normalization for shelters. Re-weight based on households sharing shelters
        TODO Feed this in so not to be hard coded

        Parameters
        ----------
            cm:
                np.array: The contact matrix between households in a shelter
            shelter_shared:
                np.float: The proportion of shelters with multiple households

        Returns
        -------
            cm:
        """

        FIntraExtra = shelter_shared / (2 * (1 - shelter_shared))
        FIntraIntra = 1 / ((1 - shelter_shared) / (2 * (1 - shelter_shared)))
        cm[0, 0] /= FIntraIntra
        cm[1, 1] /= FIntraIntra
        cm[0, 1] /= FIntraExtra
        cm[1, 0] /= FIntraExtra
        return cm

    #############################################
    # Grab CM  ##################################
    #############################################

    def CMPlots_GetCM(self, bin_type, contact_type, sex="unisex", which="UNCM"):
        """
        Get cm out of dictionary.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            sex:
                Sex contact matrix
            which:
                str, which matrix type to collect "CM", "UNCM", "UNCM_R", "CMV", "UNCM_V"

        Returns
        -------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
        """
        if bin_type != "Interaction":
            if which == "CM":
                cm = self.CM[bin_type][contact_type]["sex"][sex]
                cm_err = self.CM_err[bin_type][contact_type]["sex"][sex]
            elif which == "UNCM":
                cm = self.UNCM[bin_type][contact_type]["sex"][sex]
                cm_err = self.UNCM_err[bin_type][contact_type]["sex"][sex]
            elif which == "UNCM_R":
                cm = self.UNCM_R[bin_type][contact_type]["sex"][sex]
                cm_err = self.UNCM_R_err[bin_type][contact_type]["sex"][sex]

            elif which == "CMV":
                cm = self.CMV[bin_type][contact_type]["sex"][sex]
                cm_err = self.CMV_err[bin_type][contact_type]["sex"][sex]
            elif which == "UNCM_V":
                cm = self.UNCM_V[bin_type][contact_type]["sex"][sex]
                cm_err = self.UNCM_V_err[bin_type][contact_type]["sex"][sex]

        else:
            if which == "CM":
                cm = self.CM[bin_type][contact_type]
                cm_err = self.CM_err[bin_type][contact_type]
            elif which == "UNCM":
                cm = self.UNCM[bin_type][contact_type]
                cm_err = self.UNCM_err[bin_type][contact_type]
            elif which == "UNCM_R":
                cm = self.UNCM_R[bin_type][contact_type]
                cm_err = self.UNCM_R_err[bin_type][contact_type]

            elif which == "CMV":
                cm = self.CMV[bin_type][contact_type]
                cm_err = self.CMV_err[bin_type][contact_type]
            elif which == "UNCM_V":
                cm = self.UNCM_V[bin_type][contact_type]
                cm_err = self.UNCM_V_err[bin_type][contact_type]
        return np.array(cm), np.array(cm_err)

    def IMPlots_GetIM(self, contact_type):
        """
        Get IM out of dictionary.

        Parameters
        ----------
            contact_type:
                Location of contacts

        Returns
        -------
            cm:
                np.array interaction matrix
            cm_err:
                np.array interaction matrix errors (could be none)
        """
        im = np.array(self.IM[contact_type]["contacts"], dtype=float)
        if "contacts_err" not in self.IM[contact_type].keys():
            im_err = None
        else:
            im_err = np.array(self.IM[contact_type]["contacts_err"], dtype=float)
        return im, im_err

    ########################################################
    # CM Metric functions ##################################
    ########################################################

    def Canberra_distance(self, x, y):
        """
        calculate the Canberra distance metric between two matrices, x and y

        Parameters
        ----------
            x:
                np.array, a matrix
            y:
                np.array, a matrix
        Returns
        -------
            CD:
                float, CD the Canberra distance
        """
        n = np.prod(x.shape)
        Z = np.nansum((x - y) == 0)
        Norm = n - Z

        if Norm == 0:
            Norm = 1

        DM = abs(x - y) / (abs(x) + abs(y))
        return np.nansum(DM) / Norm, DM

    def Calc_QIndex(self, cm):
        """
        calculate the normalized population contact density matrix, NPCDM

        Parameters
        ----------
            cm:
                np.array, the contact matrix. Should be normalized per capita eg. UNCM or UNCM_R types.
        Returns
        -------
            Q:
                float, Q index of assortativeness
        """
        P = np.zeros_like(cm, dtype=float)
        P = np.nan_to_num(cm / np.nansum(cm, axis=1), nan=0.0)
        return (np.trace(P) - 1) / (P.shape[0] - 1)

    def Calc_NPCDM(self, cm, pop_by_bin, pop_width):
        """
        calculate the normalized population contact density matrix, NPCDM

        Parameters
        ----------
            cm:
                np.array, the contact matrix. Should be normalized per capita eg. UNCM or UNCM_R types.
            pop_by_bin:
                np.array, un-normalized population counts per age bin
            pop_width:
                np.array, age bin widths

        Returns
        -------
            NPCDM:
                np.array, The normalized population contact density matrix
        """
        NPCDM = np.zeros_like(cm)
        NPCDM = cm * np.multiply.outer(pop_by_bin, pop_by_bin)

        V = np.nansum(np.multiply.outer(pop_width, pop_width) * NPCDM)
        return NPCDM / V

    def Expectation_Assortativeness(self, NPCDM, pop_bins):
        """
        Expectation of assortativeness E(age_i - age_j)^2 over the normalized population contact density matrix, NPCDM

        Parameters
        ----------
            NPCDM:
                np.array, The normalized population contact density matrix
            pop_bins:
                np.array, The age binning bin edges for the population bin type

        Returns
        -------
            I_sq:
                float, The expectation value for assortativeness I^2
        """
        pop_width = np.diff(pop_bins)
        ages = (pop_bins[1:] + pop_bins[:-1]) / 2

        I_sq = 0
        for i in range(NPCDM.shape[0]):
            for j in range(NPCDM.shape[1]):
                w = pop_width[i] * pop_width[j]
                I_sq += w * NPCDM[i, j] * ((ages[i] - ages[j]) / np.sqrt(2)) ** 2
        return I_sq / 2.0

    def Population_Metrics(self, pop_by_bin, pop_bins):
        """
        Get the mean the variance of the population using binned population data

        Parameters
        ----------
            pop_by_bin:
                np.array, un-normalized population counts per age bin
            pop_bins:
                np.array, The age binning bin edges for the population bin type

        Returns
        -------
            mean:
                float, mean age of population
            variance:
                float, variance of population
        """
        ages = (pop_bins[1:] + pop_bins[:-1]) / 2
        Npeople = np.sum(pop_by_bin)
        mean = np.sum(pop_by_bin * ages) / Npeople
        variance = np.sqrt(np.nansum(pop_by_bin * (ages - mean) ** 2) / (Npeople - 1))
        return mean, variance

    def Calculate_CM_Metrics(
        self, bin_type, contact_type, CM, CM_err, ratio, sex="unisex"
    ):
        """
        Calculate key metrics for CM, {Q, I^2, I^2_s} and return as formatted string dict for saving

        Parameters
        ----------
            binType:
                string, Name of bin type syoa, AC etc
            contact_type:
                string, location to be considered
            CM:
                dict, dictionary of all matrices of type. eg self.CM
            CM_err:
                dict, dictionary of all matrices of type. eg self.CM_err
            ratio:
                float, attendance fraction of population
            sex:
                string, sex matrix to use

        Returns
        -------
            jsonfile:
                json of interaction matrices metrics

        """
        if bin_type == "Interaction":
            return None

        cm = CM[bin_type][contact_type][sex]
        cm_err = CM_err[bin_type][contact_type][sex]

        cm = self.UNtoPNConversion(cm, ratio)
        cm_err = self.UNtoPNConversion(cm_err, ratio)

        cm = np.nan_to_num(cm, nan=0.0)
        cm_err = np.nan_to_num(cm_err, nan=0.0)

        pop_by_bin = np.array(self.age_profiles[bin_type][contact_type][sex])
        pop_bins = np.array(self.age_bins[bin_type])
        pop_width = np.diff(pop_bins)

        pop_density = pop_by_bin / (np.nansum(pop_by_bin) * pop_width)

        pop_by_bin_true = np.array(self.age_profiles["syoa"][contact_type][sex])
        pop_bins_true = np.array(self.age_bins["syoa"])
        mean, var = self.Population_Metrics(pop_by_bin_true, pop_bins_true)

        Q = self.Calc_QIndex(cm)
        NPCDM = self.Calc_NPCDM(cm, pop_density, pop_width)
        I_sq = self.Expectation_Assortativeness(NPCDM, pop_bins)
        I_sq_s = I_sq / var**2
        return {"Q": f"{Q}", "I_sq": f"{I_sq}", "I_sq_s": f"{I_sq_s}"}

    ########################################################
    # Useful CM functions ##################################
    ########################################################

    def Probabilistic_Contacts(self, mean, mean_err, Probabilistic=True):
        """
        Poisson variable. How many contacts statistically.

        Parameters
        ----------
            mean:
                float, the mean expected counts
            mean_err:
                float, the 1 sigma error on the mean
            Probabilistic:
                bool, True to allow the err to value the poisson mean. False otherwise

        Returns
        -------
            C_i:
                The randomly distributed number of errors.
        """
        if Probabilistic:
            if mean_err != 0:  # Errored input
                C_i = max(0, np.random.normal(mean, mean_err))
                C_i = self._random_round(np.random.poisson(C_i))
            else:  # Error on counts treated as zero
                C_i = self._random_round(np.random.poisson(mean))
            return C_i
        else:
            return self._random_round(mean)

    def contract_matrix(self, CM, bins, method=np.sum):
        """
        Re-bin the matrix from "syoa" bin type to general given by bins with method.

        Parameters
        ----------
            CM:
                np.array The contact matrix (un-normalized)
            bins:
                np.array, bin edges used for re-binning
            method:
                np.method, The method of contraction. np.sum, np.mean etc

        Returns
        -------
            CM:
                np.array The contracted matrix
        """
        cm = np.zeros((len(bins) - 1, len(bins) - 1), dtype=float)
        for bin_xi in range(len(bins) - 1):
            for bin_yi in range(len(bins) - 1):
                Win_Xi = (bins[bin_xi], bins[bin_xi + 1])
                Win_Yi = (bins[bin_yi], bins[bin_yi + 1])

                cm[bin_xi, bin_yi] = method(
                    CM[Win_Xi[0] : Win_Xi[1], Win_Yi[0] : Win_Yi[1]]
                )
        return cm

    def contract_matrices(self, Name, bins=np.arange(0, 100 + 5, 5)):
        """
        Re-bin the integer year binning to custom bins specified by list using produced contact matrix
        Appends new re-binning to self.CM or self.CMV for "1D" and "All" contact tracing types.

        Parameters
        ----------
            Name:
                string, Name of matrix re-binning

            bins:
                array, bin edges used for re-binning

        Returns
        -------
            None

        """

        cm = self.CM["syoa"]
        self.CM[Name] = {}

        for group in cm.keys():
            # Recreate new hash ages for the new bins and add bins to bin list.
            Test = [list(item) for item in self.age_bins.values()]
            if list(bins) not in Test:
                self.age_bins = {Name: bins, **self.age_bins}
            append = {}
            for sex in self.contact_sexes:
                append[sex] = np.zeros((len(bins) - 1, len(bins) - 1), dtype=float)
            self.CM[Name][group] = append
            for sex in self.contact_sexes:

                self.CM[Name][group][sex] = self.contract_matrix(
                    cm[group][sex], bins, np.sum
                )

        cm = self.CMV["syoa"]
        self.CMV[Name] = {}

        for group in cm.keys():
            # Recreate new hash ages for the new bins and add bins to bin list.
            Test = [list(item) for item in self.age_bins.values()]
            if list(bins) not in Test:
                self.age_bins = {Name: bins, **self.age_bins}
            append = {}
            for sex in self.contact_sexes:
                append[sex] = np.zeros((len(bins) - 1, len(bins) - 1), dtype=float)
            self.CMV[Name][group] = append
            for sex in self.contact_sexes:

                self.CMV[Name][group][sex] = self.contract_matrix(
                    cm[group][sex], bins, np.sum
                )

        # Rehash the ages
        self.hash_ages()
        return 1

    def get_characteristic_time(self, location):
        """
        Get the characteristic time and proportion_physical time for location. (In hours)

        Parameters
        ----------
            location:
                string, location

        Returns
        -------
            None

        """
        if location not in ["global", "shelter_intra", "shelter_inter"]:
            characteristic_time = self.IM[location]["characteristic_time"] / 24
            proportion_physical = self.IM[location]["proportion_physical"]
        elif location in ["shelter_intra", "shelter_inter"]:
            characteristic_time = self.IM["shelter"]["characteristic_time"] / 24
            proportion_physical = self.IM["shelter"]["proportion_physical"]
        else:
            characteristic_time = 1
            proportion_physical = 0.12
        return characteristic_time, proportion_physical

    ##############################################
    # Initialize ##################################
    ##############################################

    def initialise_group_names(self):
        """
        Get list of names of the location sites and set as class variable
        initialise;
            self.group_type_names

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        group_type_names = []
        for groups in self.group_types:
            if groups is not None and len(groups) != 0:
                spec = groups[0].spec
            else:
                continue

            group_type_names.append(spec)
            if spec == "shelter":
                group_type_names.append(spec + "_intra")
                group_type_names.append(spec + "_inter")
        self.group_type_names = group_type_names
        return 1

    def initialise_contact_matrices(self):
        """
        Create set of empty contact matrices and set as class variable
        initialise;
            self.CM
            self.CMV

            depending on contact tracking type

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        # 1D tracker
        self.CM = {}
        # For each type of contact matrix binning, eg BBC, polymod, SYOA...
        for bin_type, bins in self.age_bins.items():
            CM = np.zeros((len(bins) - 1, len(bins) - 1), dtype=float)
            append = {}
            for sex in self.contact_sexes:  # For each sex
                append[sex] = np.zeros_like(CM, dtype=float)

            self.CM[bin_type] = {"global": append}  # Add in a global matrix tracker
            for spec in self.group_type_names:  # Over location
                append = {}
                for sex in self.contact_sexes:
                    append[sex] = np.zeros_like(CM, dtype=float)
                self.CM[bin_type][spec] = append

        # Initialize for the input contact matrices.
        self.CM["Interaction"] = {}
        for spec in self.IM.keys():  # Over location
            if spec not in self.CM["syoa"].keys():
                continue

            IM = self.IM[spec]["contacts"]
            append = np.zeros_like(IM, dtype=float)
            self.CM["Interaction"][spec] = append

        # All tracker
        self.CMV = {}
        # For each type of contact matrix binning, eg BBC, polymod, SYOA...
        for bin_type, bins in self.age_bins.items():
            CM = np.zeros((len(bins) - 1, len(bins) - 1), dtype=float)
            append = {}
            for sex in self.contact_sexes:  # For each sex
                append[sex] = np.zeros_like(CM, dtype=float)

            self.CMV[bin_type] = {"global": append}  # Add in a global matrix tracker
            for spec in self.group_type_names:  # Over location
                append = {}
                for sex in self.contact_sexes:
                    append[sex] = np.zeros_like(CM, dtype=float)
                self.CMV[bin_type][spec] = append

        # Initialize for the input contact matrices.
        self.CMV["Interaction"] = {}
        for spec in self.IM.keys():  # Over location
            if spec not in self.CMV["syoa"].keys():
                continue

            IM = self.IM[spec]["contacts"]
            append = np.zeros_like(IM, dtype=float)
            self.CMV["Interaction"][spec] = append
        return 1

    def initialise_contact_counters(self):
        """
        Create set of empty interactions for each person in each location and set as class variable
        initialise;
            self.contact_counts

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        self.contact_counts = {
            person.id: {
                spec: 0
                for spec in self.group_type_names
                + ["care_home_visits", "household_visits", "global"]
            }
            for person in self.world.people
        }

        return 1

    def initialise_location_counters(self):
        """
        Create set of empty person counts for each location and set as class variable for all time steps, days and current day.
        initialise;
            self.location_counters
            self.location_counters_day
            self.location_counters_day_i

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        locations = []
        for locs in self.group_type_names:
            if locs in ["global", "shelter_inter", "shelter_intra"]:
                continue

            locations.append(self.pluralize(locs))
        self.location_counters = {
            "Timestamp": [],
            "delta_t": [],
            "loc": {
                spec: {
                    N: {sex: [] for sex in self.contact_sexes}
                    for N in range(
                        min(
                            len(getattr(self.world, spec).members),
                            self.MaxVenueTrackingSize,
                        )
                    )
                }
                for spec in locations
            },
        }

        self.location_counters_day = {
            "Timestamp": [],
            "loc": {
                spec: {
                    N: {sex: [] for sex in self.contact_sexes}
                    for N in range(
                        min(
                            len(getattr(self.world, spec).members),
                            self.MaxVenueTrackingSize,
                        )
                    )
                }
                for spec in locations
            },
        }

        self.location_counters_day_i = {
            "loc": {
                spec: {
                    N: {sex: [] for sex in self.contact_sexes}
                    for N in range(
                        min(
                            len(getattr(self.world, spec).members),
                            self.MaxVenueTrackingSize,
                        )
                    )
                }
                for spec in locations
            }
        }
        return 1

    def initialise_location_cum_pop(self):
        """
        Intitialize the cumalitive population at venues to be tracked
        initialise;
            self.location_cum_pop

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        self.location_cum_pop = {}
        for bin_type, bins in self.age_bins.items():
            # For each type of contact matrix binning, eg BBC, polymod, SYOA...
            self.location_cum_pop[bin_type] = {}
            CM = np.zeros(len(bins) - 1, dtype=float)
            append = {}
            for sex in self.contact_sexes:  # For each sex
                append[sex] = np.zeros_like(CM, dtype=float)

            self.location_cum_pop[bin_type][
                "global"
            ] = append  # Add in a global matrix tracker

            for spec in self.group_type_names:  # Over location
                append = {}
                for sex in self.contact_sexes:
                    append[sex] = np.zeros_like(CM, dtype=float)
                self.location_cum_pop[bin_type][spec] = append

        self.location_cum_pop["Interaction"] = {}
        for spec in self.IM.keys():  # Over location
            self.location_cum_pop["Interaction"][spec] = np.zeros(
                len(self.IM[spec]["contacts"]), dtype=float
            )
        return 1

    def initialise_location_cum_time(self):
        """
        Initialize the cumulative population time at venues to be tracked
        initialise;
            self.location_cum_time

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        self.location_cum_time = {spec: 0 for spec in self.group_type_names}
        self.location_cum_time["global"] = 0
        return 1

    def hash_ages(self):
        """
        store all ages and age_bin indexes in python dict for quick lookup as class variable
        Sets;
            self.age_idxs
            self.ages

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        self.age_idxs = {}
        for bins_name, bins in self.age_bins.items():
            self.age_idxs[bins_name] = {
                person.id: np.digitize(person.age, bins) - 1
                for person in self.world.people
            }
        self.ages = {person.id: person.age for person in self.world.people}
        self.sexes = {person.id: person.sex for person in self.world.people}
        return 1

    def load_interactions(self, interaction_path):
        """
        Load in the initial interaction matrices and set as class variable
        Loads;
            self.IM

        Parameters
        ----------
            interaction_path:
                string, location of the yaml file for interactions

        Returns
        -------
            None

        """
        with open(interaction_path) as f:
            interaction_config = yaml.load(f, Loader=yaml.FullLoader)
            self.IM = interaction_config["contact_matrices"]

        for loc in self.IM.keys():
            if "type" not in self.IM[loc].keys():
                Bins, Type = make_subgroups.get_defaults(loc)
                self.IM[loc]["type"] = Type
            if "bins" not in self.IM[loc].keys():
                Bins, Type = make_subgroups.get_defaults(loc)
                self.IM[loc]["bins"] = Bins
        return 1

    #################################################
    # Post Process ##################################
    #################################################

    def convert_dict_to_df(self):
        """
        Transform contact_counts into pandas data frame for easy sorting
        Sets;
            self.contacts_df

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        self.contacts_df = pd.DataFrame.from_dict(self.contact_counts, orient="index")
        self.contacts_df["age"] = pd.Series(self.ages)
        self.contacts_df["sex"] = pd.Series(self.sexes)

        for bins_type, age_idxes in self.age_idxs.items():
            col_name = f"{bins_type}_idx"
            self.contacts_df[col_name] = pd.Series(age_idxes)

        return 1

    def calc_age_profiles(self):
        """
        Group persons by their ages for contacts in each location
        Sets;
            self.age_profiles

        Parameters
        ----------
            None

        Returns
        -------
            None
        """

        def BinCounts(bins_idx, contact_type, ExpN):
            contacts_loc = self.contacts_df[self.contacts_df[contact_type] != 0]
            AgesCount = contacts_loc.groupby([bins_idx], dropna=False).size()
            AgesCount = AgesCount.reindex(range(ExpN - 1), fill_value=0)

            MaleCount = (
                contacts_loc[contacts_loc["sex"] == "m"]
                .groupby([bins_idx], dropna=False)
                .size()
            )
            MaleCount = MaleCount.reindex(range(ExpN - 1), fill_value=0)

            FemaleCount = (
                contacts_loc[contacts_loc["sex"] == "f"]
                .groupby([bins_idx], dropna=False)
                .size()
            )
            FemaleCount = FemaleCount.reindex(range(ExpN - 1), fill_value=0)
            return {
                "unisex": AgesCount.values,
                "male": MaleCount.values,
                "female": FemaleCount.values,
            }

        self.age_profiles = {}
        for bin_type in self.age_bins.keys():
            self.age_profiles[bin_type] = {}
            bins_idx = f"{bin_type}_idx"
            self.age_profiles[bin_type]["global"] = BinCounts(
                bins_idx, "global", len(self.age_bins[bin_type])
            )
            for contact_type in self.location_cum_pop["syoa"].keys():
                self.age_profiles[bin_type][contact_type] = BinCounts(
                    bins_idx, contact_type, len(self.age_bins[bin_type])
                )

        def Contract(bins_idx, locs):
            """
            Take full syoa year by year binning of full un-normalized contact matrix and reduce to matrix with age bins bins_idx.

            Parameters
            ----------
                bins_udx:
                    array, bin edges indices from syoa binning
                locs:
                    string, location considered

            Returns
            -------
                dict, new matrices for location by sex
            """
            CM = np.zeros(len(bins_idx) - 1, dtype=float)
            APPEND = {}
            for spec in locs:  # Over location
                append = {}
                for sex in self.contact_sexes:
                    append[sex] = np.zeros_like(CM, dtype=float)
                APPEND[spec] = append

            for spec in locs:  # Over location
                for sex in self.contact_sexes:  # Over location
                    for bin_x in range(len(bins_idx) - 1):
                        Win = [bins_idx[bin_x], bins_idx[bin_x + 1]]
                        APPEND[spec][sex][bin_x] = np.sum(
                            self.location_cum_pop["syoa"][spec][sex][Win[0] : Win[1]]
                        )
            return APPEND

        for bin_type, bins in self.age_bins.items():
            if bin_type == "syoa":
                continue
            self.location_cum_pop[bin_type] = Contract(
                bins, self.location_cum_pop["syoa"].keys()
            )
        return 1

    def calc_average_contacts(self):
        """
        Get average number of contacts per location per day per age bin
        Sets and rescales;
            self.average_contacts

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        self.average_contacts = {}
        colsWhich = [
            col
            for col in self.contacts_df.columns
            if col not in [key + "_idx" for key in self.age_bins.keys()]
            and col not in ["age", "sex"]
        ]
        self.contacts_df[colsWhich] /= self.timer.total_days
        for bin_type in self.age_bins.keys():
            bins_idx = f"{bin_type}_idx"
            ExpN = len(self.age_bins[bin_type])
            AgesCount = self.contacts_df.groupby(
                self.contacts_df[bins_idx], dropna=False
            ).mean()[colsWhich]
            AgesCount = AgesCount.reindex(range(ExpN - 1), fill_value=0)

            self.average_contacts[bin_type] = AgesCount
        return 1

    def normalize_1D_CM(self):
        """
        For 1D tracking
        normalize the contact matrices based on likelihood to interact with each demographic.
        Sets and rescales;
            self.CM
            self.CM_err

            self.UNCM
            self.UNCM_err

            self.UNCM_R
            self.UNCM_R_err


        Parameters
        ----------
            None

        Returns
        -------
            None
        """
        # Preform Normalization
        bin_Keys = self.CM.keys()
        for bin_type in bin_Keys:

            matrices = self.CM[bin_type]
            for contact_type, cm_spec in matrices.items():
                for sex in self.contact_sexes:

                    if bin_type == "Interaction":
                        if sex == "unisex":
                            cm = cm_spec
                            age_profile = np.array(
                                self.location_cum_pop["Interaction"][contact_type]
                            )
                            if contact_type == "shelter":
                                cm = self.cm_shelter_renorm(cm)
                            cm_err = np.sqrt(cm)
                        else:
                            continue
                    else:
                        cm = cm_spec[sex]
                        cm_err = np.sqrt(cm)

                        age_profile = np.array(
                            self.location_cum_pop[bin_type][contact_type][sex]
                        )

                    UNCM, UNCM_err = self.CM_Norm(
                        cm=cm,
                        cm_err=cm_err,
                        pop_tots=age_profile,
                        contact_type=contact_type,
                        Which="UNCM",
                    )

                    UNCM_R, UNCM_R_err = self.CM_Norm(
                        cm=cm,
                        cm_err=cm_err,
                        pop_tots=age_profile,
                        contact_type=contact_type,
                        Which="UNCM_R",
                    )

                    UNCM = np.nan_to_num(UNCM, nan=0)
                    UNCM_err = np.nan_to_num(UNCM_err, nan=0)

                    UNCM_R = np.nan_to_num(UNCM_R, nan=0)
                    UNCM_R_err = np.nan_to_num(UNCM_R_err, nan=0)

                    if bin_type == "Interaction":
                        if sex == "unisex":
                            self.UNCM["Interaction"][contact_type] = UNCM
                            self.UNCM_err["Interaction"][contact_type] = UNCM_err

                            self.UNCM_R["Interaction"][contact_type] = UNCM_R
                            self.UNCM_R_err["Interaction"][contact_type] = UNCM_R_err

                            # Basically just counts of interactions so assume a poisson error
                            self.CM["Interaction"][contact_type] = (
                                cm / self.timer.total_days
                            )
                            self.CM_err["Interaction"][contact_type] = (
                                cm_err / self.timer.total_days
                            )
                        else:
                            continue
                    else:
                        self.UNCM[bin_type][contact_type][sex] = UNCM
                        self.UNCM_err[bin_type][contact_type][sex] = UNCM_err

                        self.UNCM_R[bin_type][contact_type][sex] = UNCM_R
                        self.UNCM_R_err[bin_type][contact_type][sex] = UNCM_R_err

                        # Basically just counts of interactions so assume a poisson error
                        self.CM[bin_type][contact_type][sex] = (
                            cm / self.timer.total_days
                        )
                        self.CM_err[bin_type][contact_type][sex] = (
                            cm_err / self.timer.total_days
                        )
        return 1

    def normalize_All_CM(self):
        """
        For All contacts All tracking
        normalize the contact matrices based on likelihood to interact with each demographic.
        Sets and rescales;
            self.CMV
            self.CMV_err

            self.UNCM_V
            self.UNCM_V_err

        Parameters
        ----------
            None

        Returns
        -------
            None
        """
        # Preform Normalization
        bin_Keys = self.CMV.keys()

        for bin_type in bin_Keys:

            matrices = self.CMV[bin_type]
            for contact_type, cm_spec in matrices.items():
                for sex in self.contact_sexes:

                    if bin_type == "Interaction":
                        if sex == "unisex":
                            cm = np.array(cm_spec)
                            age_profile = np.array(
                                self.location_cum_pop["Interaction"][contact_type]
                            )
                            if contact_type == "shelter":
                                cm = self.cm_shelter_renorm(cm)
                            cm_err = np.sqrt(cm)
                        else:
                            continue
                    else:
                        cm = np.array(cm_spec[sex])
                        cm_err = np.sqrt(cm)
                        age_profile = np.array(
                            self.location_cum_pop[bin_type][contact_type][sex]
                        )

                    UNCMV, UNCMV_err = self.CM_Norm(
                        cm=cm,
                        cm_err=cm_err,
                        pop_tots=age_profile,
                        contact_type=contact_type,
                        Which="UNCM_V",
                    )

                    if bin_type == "Interaction":
                        if sex == "unisex":
                            self.UNCM_V["Interaction"][contact_type] = UNCMV
                            self.UNCM_V_err["Interaction"][contact_type] = UNCMV_err

                            # Basically just counts of interactions so assume a poisson error
                            self.CMV["Interaction"][contact_type] = (
                                cm / self.timer.total_days
                            )
                            self.CMV_err["Interaction"][contact_type] = (
                                cm_err / self.timer.total_days
                            )

                        else:
                            continue
                    else:
                        self.UNCM_V[bin_type][contact_type][sex] = UNCMV
                        self.UNCM_V_err[bin_type][contact_type][sex] = UNCMV_err

                        # Basically just counts of interactions so assume a poisson error
                        self.CMV[bin_type][contact_type][sex] = (
                            cm / self.timer.total_days
                        )
                        self.CMV_err[bin_type][contact_type][sex] = (
                            cm_err / self.timer.total_days
                        )

        return 1

    def AttendanceRatio(self, bin_type, contact_type, sex):
        """
        Get the attendance fraction of subgroup i with respect to the total population in subgroup i

        Parameters
        ----------
        bin_type:
            string, contact matrix binning type

        contact_type:
            List of the contact_type locations (or none to grab all of them)

        sex:
            string, the sex of the matrix "male", "female", "unisex"

        Returns
        -------
            ratio:
                float, attendance fraction

        """
        if bin_type != "Interaction":
            global_pop = self.location_cum_pop[bin_type]["global"][sex]
            local_pop = self.location_cum_pop[bin_type][contact_type][sex]
        else:
            return 1
        return np.array(local_pop / global_pop)

    def UNtoPNConversion(self, cm, ratio):
        """
        Function to rescale the contact matrices from venue to population normalized

        Parameters
        ----------
            cm:
                np.array, The contact matrix

            ratio:
                    float, Attendance fraction

        Returns
        -------

            cm:
                np.array, The contact matrix

        """
        return (cm.T.copy() * ratio).T

    def CM_Norm(self, cm, cm_err, pop_tots, contact_type="global", Which="UNCM"):
        """
        normalize the contact matrices using population at location data and time of simulation run time.

        Parameters
        ----------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            pop_tots:
                np.array total counts of visits of each age bin for entire simulation time. (1 person can go to same location more than once)
            contact_type:
                List of the contact_type locations (or none to grab all of them)
            which:
                string, contact matrix type "CM", "NCM", "NCM_R", "CMV", "NCM_V"
        Returns
        -------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors

        """
        # normalize based on characteristic time.

        # Normalization over characteristic time and population
        factor = (
            self.get_characteristic_time(location=contact_type)[0] * np.sum(pop_tots)
        ) / self.location_cum_time[contact_type]
        if np.isnan(factor):
            factor = 0

        # Create blanks to fill
        norm_cm = np.zeros_like(cm, dtype=float)
        norm_cm_err = np.zeros_like(cm, dtype=float)

        # Loop over elements
        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                # Population rescaling
                w = pop_tots[j] / pop_tots[i]
                if pop_tots[i] < 1 or pop_tots[j] < 1:
                    continue

                if Which in ["UNCM", "UNCM_V"]:  # Only count contacts i to j
                    norm_cm[i, j] = (cm[i, j] / pop_tots[i]) * factor

                    norm_cm_err[i, j] = (cm_err[i, j] / pop_tots[i]) * factor
                elif Which == "UNCM_R":  # Only count contacts i to j
                    norm_cm[i, j] = (
                        0.5
                        * (cm[i, j] / pop_tots[i] + (cm[j, i] / pop_tots[j]) * w)
                        * factor
                    )
                    norm_cm_err[i, j] = (
                        0.5
                        * np.sqrt(
                            (cm_err[i, j] / pop_tots[i]) ** 2
                            + ((cm_err[j, i] / pop_tots[j]) * w) ** 2
                        )
                        * factor
                    )

        # if Which == "UNCM_V":  # Only count contacts i to j
        #     old_frac_err = norm_cm_err / norm_cm

        #     sum_i = np.tile(np.nansum(norm_cm, axis=1), (norm_cm.shape[0], 1)).T
        #     sum_i_err = np.tile(
        #         np.sqrt(np.nansum(norm_cm**2, axis=1)), (norm_cm.shape[0], 1)
        #     ).T
        #     sum_frac_err = sum_i_err / sum_i

        #     norm_cm /= sum_i
        #     norm_cm_err = norm_cm * np.sqrt(old_frac_err**2 + sum_frac_err**2)

        return norm_cm, norm_cm_err

    def initialize_CM_Normalizations(self):
        """
        Create the CM Normalization arrays from the CM_T template

        Initialise
        ----------
            self.CM_err

            self.UNCM
            self.UNCM_err

            self.UNCM_R
            self.UNCM_R_err

        Parameters
        ----------
            None

        Returns
        -------
            None

        """

        # Create copies of the contact_matrices to be filled in.
        # Error Matrix
        self.CM_err = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CM[bin_type][loc][sex], dtype=float)
                    for sex in self.CM[bin_type][loc].keys()
                }
                for loc in self.CM[bin_type].keys()
            }
            for bin_type in self.CM.keys()
            if bin_type != "Interaction"
        }
        self.CM_err["Interaction"] = {
            loc: np.zeros_like(self.CM["Interaction"][loc], dtype=float)
            for loc in self.CM["Interaction"].keys()
        }

        # normalized Matrices
        self.UNCM = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CM[bin_type][loc][sex], dtype=float)
                    for sex in self.CM[bin_type][loc].keys()
                }
                for loc in self.CM[bin_type].keys()
            }
            for bin_type in self.CM.keys()
            if bin_type != "Interaction"
        }
        self.UNCM["Interaction"] = {
            loc: np.zeros_like(self.CM["Interaction"][loc], dtype=float)
            for loc in self.CM["Interaction"].keys()
        }

        self.UNCM_err = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CM[bin_type][loc][sex], dtype=float)
                    for sex in self.CM[bin_type][loc].keys()
                }
                for loc in self.CM[bin_type].keys()
            }
            for bin_type in self.CM.keys()
            if bin_type != "Interaction"
        }
        self.UNCM_err["Interaction"] = {
            loc: np.zeros_like(self.CM["Interaction"][loc], dtype=float)
            for loc in self.CM["Interaction"].keys()
        }

        # normalized Matrices with reciprocal contacts
        self.UNCM_R = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CM[bin_type][loc][sex], dtype=float)
                    for sex in self.CM[bin_type][loc].keys()
                }
                for loc in self.CM[bin_type].keys()
            }
            for bin_type in self.CM.keys()
            if bin_type != "Interaction"
        }
        self.UNCM_R["Interaction"] = {
            loc: np.zeros_like(self.CM["Interaction"][loc], dtype=float)
            for loc in self.CM["Interaction"].keys()
        }

        self.UNCM_R_err = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CM[bin_type][loc][sex], dtype=float)
                    for sex in self.CM[bin_type][loc].keys()
                }
                for loc in self.CM[bin_type].keys()
            }
            for bin_type in self.CM.keys()
            if bin_type != "Interaction"
        }
        self.UNCM_R_err["Interaction"] = {
            loc: np.zeros_like(self.CM["Interaction"][loc], dtype=float)
            for loc in self.CM["Interaction"].keys()
        }
        return 1

    def initialize_CM_All_Normalizations(self):
        """
        Create the CM Normalization arrays from the CM_AC template

        Initialise
        ----------
            self.CMV_err

            self.UNCM_V
            self.UNCM_V_err

        Parameters
        ----------
            None

        Returns
        -------
            None

        """
        # Error Matrix
        self.CMV_err = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CMV[bin_type][loc][sex], dtype=float)
                    for sex in self.CMV[bin_type][loc].keys()
                }
                for loc in self.CMV[bin_type].keys()
            }
            for bin_type in self.CMV.keys()
            if bin_type != "Interaction"
        }
        self.CMV_err["Interaction"] = {
            loc: np.zeros_like(self.CMV["Interaction"][loc], dtype=float)
            for loc in self.CMV["Interaction"].keys()
        }

        # normalized Matrices
        self.UNCM_V = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CMV[bin_type][loc][sex], dtype=float)
                    for sex in self.CMV[bin_type][loc].keys()
                }
                for loc in self.CMV[bin_type].keys()
            }
            for bin_type in self.CMV.keys()
            if bin_type != "Interaction"
        }
        self.UNCM_V["Interaction"] = {
            loc: np.zeros_like(self.CMV["Interaction"][loc], dtype=float)
            for loc in self.CMV["Interaction"].keys()
        }

        self.UNCM_V_err = {
            bin_type: {
                loc: {
                    sex: np.zeros_like(self.CMV[bin_type][loc][sex], dtype=float)
                    for sex in self.CMV[bin_type][loc].keys()
                }
                for loc in self.CMV[bin_type].keys()
            }
            for bin_type in self.CMV.keys()
            if bin_type != "Interaction"
        }
        self.UNCM_V_err["Interaction"] = {
            loc: np.zeros_like(self.CMV["Interaction"][loc], dtype=float)
            for loc in self.CMV["Interaction"].keys()
        }
        return 1

    def post_process_simulation(self, save=True):
        """
        Perform some post simulation checks and calculations.
            Create contact data frames
            Get age profiles over the age bins and locations
            Get average contacts by location
            normalize contact matrices by population demographics

            Print out results to Yaml in Results_Path directory

        Parameters
        ----------
            save:
                bool, Save out contact matrices

        Returns
        -------
            None

        """
        if self.group_type_names == []:
            return 1

        self.convert_dict_to_df()
        self.calc_age_profiles()
        self.calc_average_contacts()

        self.initialize_CM_Normalizations()
        self.normalize_1D_CM()

        self.initialize_CM_All_Normalizations()
        self.normalize_All_CM()

        if mpi_rank == 0:
            self.PrintOutResults()

        if save:
            if mpi_size == 1:
                folder_name = "merged_data_output"
            else:
                folder_name = "raw_data_output"

            merged = self.record_path / "Tracker" / "merged_data_output"
            merged.mkdir(exist_ok=True, parents=True)
            raw = self.record_path / "Tracker" / "raw_data_output"
            raw.mkdir(exist_ok=True, parents=True)

            self.tracker_tofile(self.record_path / "Tracker" / folder_name)
        return 1

    #################################################
    # Run tracker ##################################
    #################################################

    def get_active_subgroup(self, person):
        """
        Get subgroup index for interaction metric
        eg. household has subgroups[subgroup_type]: kids[0], young_adults[1], adults[2], old[3]
        subgroup_type is the integer representing the type of person you're wanting to look at.

        Parameters
        ----------
            Person:
                The JUNE person

        Returns
        -------
            active_subgroups:
                list of subgroup indexes

        """
        active_subgroups = []
        subgroup_ids = []
        for subgroup in person.subgroups.iter():
            if subgroup is None or subgroup.group.spec == "commute_hub":
                continue
            if person in subgroup.people:
                subgroup_id = f"{subgroup.group.spec}_{subgroup.group.id}"
                if subgroup_id in subgroup_ids:
                    # gotcha: if you're receiving household visits, then you're active in residence
                    # and leisure -- but they are actually the same location...
                    continue
                active_subgroups.append(subgroup)
                subgroup_ids.append(subgroup_id)
        return active_subgroups

    def get_contacts_per_subgroup(self, subgroup_type, group):
        """
        Get contacts that a person of subgroup type `subgroup_type` will have with each of the other subgroups,
        in a given group.
        eg. household has subgroups[subgroup_type]: kids[0], young_adults[1], adults[2], old[3]
        subgroup_type is the integer representing the type of person you're wanting to look at.

        Parameters
        ----------
            subgroup_type:
                index of subgroup for the interaction matrix
            group:
                group. Location and group of people at that location

        Returns
        -------
            contacts_per_subgroup:
                Mean number contacts in the time period

            contacts_per_subgroup_error:
                Error on mean number contacts in the time period
        """

        spec = group.spec
        cms = self.IM[spec]["contacts"]
        if "contacts_err" in self.IM[spec].keys():
            cms_err = self.IM[spec]["contacts_err"]
        else:
            cms_err = np.zeros_like(cms, dtype=float)

        NSubgroups = len(group.subgroups)
        if group.spec == "school":
            NSubgroups = 2
            # School has many subgroups 0th being for teachers. Rest for year groups
            if subgroup_type == 0:
                pass
            else:
                subgroup_type = 1

        delta_t = self.timer.delta_time.seconds / (3600 * 24)  # In Days
        characteristic_time = self.get_characteristic_time(location=spec)[0]  # In Days

        factor = delta_t / characteristic_time
        contacts_per_subgroup = [
            cms[subgroup_type][ii] * factor for ii in range(NSubgroups)
        ]
        contacts_per_subgroup_error = [
            cms_err[subgroup_type][ii] * factor for ii in range(NSubgroups)
        ]
        return contacts_per_subgroup, contacts_per_subgroup_error

    def simulate_1d_contacts(self, group):
        """
        Construct contact matrices.
        For group at a location we loop over all people and sample from the selection of available contacts to build more granular contact matrices.
        Sets;
            self.CM
            self.contact_counts

        Parameters
        ----------
            group:
                The group of interest to build contacts

        Returns
        -------
            None

        """
        # Loop over people
        if len(group.people) < 2:
            return 1

        for person in group.people:
            # Shelter we want family groups
            if group.spec == "shelter":
                groups_inter = [list(sub.people) for sub in group.families]
            else:  # Want subgroups as defined in groups
                groups_inter = [list(sub.people) for sub in group.subgroups]

            # Work out which subgroup they are in...
            person_subgroup_idx = -1
            for sub_i in range(len(groups_inter)):
                if person in groups_inter[sub_i]:
                    person_subgroup_idx = sub_i
                    break
            if person_subgroup_idx == -1:
                continue

            if group.spec == "school":
                # Allow teachers to mix with ALL students
                if person_subgroup_idx == 0:
                    groups_inter = [list(group.teachers.people), list(group.students)]
                    person_subgroup_idx = 0
                # Allow students to only mix in their classes.
                else:
                    groups_inter = [
                        list(group.teachers.people),
                        list(group.subgroups[person_subgroup_idx].people),
                    ]
                    person_subgroup_idx = 1

            # Get contacts person expects
            (
                contacts_per_subgroup,
                contacts_per_subgroup_error,
            ) = self.get_contacts_per_subgroup(person_subgroup_idx, group)

            total_contacts = 0

            contact_subgroups = np.arange(0, len(groups_inter), 1)
            for subgroup_contacts, subgroup_contacts_error, contact_subgroup_idx in zip(
                contacts_per_subgroup, contacts_per_subgroup_error, contact_subgroups
            ):
                # potential contacts is one less if you're in that subgroup - can't contact yourself!
                subgroup_people = groups_inter[contact_subgroup_idx]
                subgroup_people_without = subgroup_people.copy()

                # Person in this subgroup
                if person in subgroup_people:
                    inside = True
                    subgroup_people_without.remove(person)
                else:
                    inside = False

                # is_same_subgroup = subgroup.subgroup_type == subgroup_idx
                if len(subgroup_people) - inside <= 0:
                    continue
                int_contacts = self.Probabilistic_Contacts(
                    subgroup_contacts, subgroup_contacts_error, Probabilistic=True
                )

                contact_ids_inter = []
                contact_ids_intra = []
                contact_ids = []
                contact_ages = []

                if inside:
                    contacts_index = np.random.choice(
                        len(subgroup_people_without), int_contacts, replace=True
                    )
                else:
                    contacts_index = np.random.choice(
                        len(subgroup_people), int_contacts, replace=True
                    )

                # Shelters a special case...
                # Interaction Matrix
                if group.spec == "shelter":
                    if inside:
                        self.CM["Interaction"][group.spec][0, 0] += int_contacts
                        self.CM["Interaction"][group.spec][1, 1] += int_contacts
                    else:
                        self.CM["Interaction"][group.spec][
                            person_subgroup_idx, contact_subgroup_idx
                        ] += int_contacts
                        self.CM["Interaction"][group.spec][
                            contact_subgroup_idx, person_subgroup_idx
                        ] += int_contacts

                else:
                    self.CM["Interaction"][group.spec][
                        person_subgroup_idx, contact_subgroup_idx
                    ] += int_contacts

                # Get the ids
                for contacts_index_i in contacts_index:
                    if inside:
                        contact = subgroup_people_without[contacts_index_i]
                    else:
                        contact = subgroup_people[contacts_index_i]

                    if group.spec == "shelter":
                        if inside:
                            contact_ids_intra.append(contact.id)
                        else:
                            contact_ids_inter.append(contact.id)
                    contact_ids.append(contact.id)
                    contact_ages.append(contact.age)

                age_idx = self.age_idxs["syoa"][person.id]

                contact_age_idxs = [
                    self.age_idxs["syoa"][contact_id] for contact_id in contact_ids
                ]

                for cidx in contact_age_idxs:
                    self.CM["syoa"]["global"]["unisex"][age_idx, cidx] += 1
                    self.CM["syoa"][group.spec]["unisex"][age_idx, cidx] += 1
                    if person.sex == "m" and "male" in self.contact_sexes:
                        self.CM["syoa"]["global"]["male"][age_idx, cidx] += 1
                        self.CM["syoa"][group.spec]["male"][age_idx, cidx] += 1
                    if person.sex == "f" and "female" in self.contact_sexes:
                        self.CM["syoa"]["global"]["female"][age_idx, cidx] += 1
                        self.CM["syoa"][group.spec]["female"][age_idx, cidx] += 1
                    total_contacts += 1

                # For shelter only. We check over inter and intra groups
                if group.spec == "shelter":
                    # Inter
                    contact_age_idxs = [
                        self.age_idxs["syoa"][contact_id]
                        for contact_id in contact_ids_inter
                    ]
                    for cidx in contact_age_idxs:

                        self.CM["syoa"][group.spec + "_inter"]["unisex"][
                            age_idx, cidx
                        ] += 1
                        if person.sex == "m" and "male" in self.contact_sexes:
                            self.CM["syoa"][group.spec + "_inter"]["male"][
                                age_idx, cidx
                            ] += 1
                        if person.sex == "f" and "female" in self.contact_sexes:
                            self.CM["syoa"][group.spec + "_inter"]["female"][
                                age_idx, cidx
                            ] += 1

                    # Intra
                    contact_age_idxs = [
                        self.age_idxs["syoa"][contact_id]
                        for contact_id in contact_ids_intra
                    ]
                    for cidx in contact_age_idxs:
                        self.CM["syoa"][group.spec + "_intra"]["unisex"][
                            age_idx, cidx
                        ] += 1
                        if person.sex == "m" and "male" in self.contact_sexes:
                            self.CM["syoa"][group.spec + "_intra"]["male"][
                                age_idx, cidx
                            ] += 1
                        if person.sex == "f" and "female" in self.contact_sexes:
                            self.CM["syoa"][group.spec + "_intra"]["female"][
                                age_idx, cidx
                            ] += 1

            self.contact_counts[person.id]["global"] += total_contacts
            self.contact_counts[person.id][group.spec] += total_contacts
            if group.spec == "shelter":
                self.contact_counts[person.id][group.spec + "_inter"] += total_contacts
                self.contact_counts[person.id][group.spec + "_intra"] += total_contacts

        return 1

    def simulate_All_contacts(self, group):
        """
        Construct contact matrices for all contacts all
        For group at a location we loop over all people and sample from the selection of available contacts to build more granular contact matrices.
        Sets;
            self.CMV

        Parameters
        ----------
            group:
                The group of interest to build contacts

        Returns
        -------
            None

        """
        # Loop over people
        NPeople = len(group.people)
        if NPeople < 2:
            return 1

        # Shelter we want family groups
        if group.spec == "shelter":
            groups_inter = [list(sub.people) for sub in group.families]
        elif group.spec == "school":
            groups_inter = [list(group.teachers.people), list(group.students)]
        else:  # Want subgroups as defined in groups
            groups_inter = [list(sub.people) for sub in group.subgroups]

        # By Interaction groups
        subgroupNPeople = np.array([len(g) for g in groups_inter])
        if group.spec == "shelter":
            if len(groups_inter) == 1:
                NContacts_Interaction = (
                    np.eye(self.CMV["Interaction"][group.spec].shape[0])
                    * subgroupNPeople
                    * (subgroupNPeople - 1)
                )
            if len(groups_inter) > 1:
                NContacts_Interaction = np.outer(subgroupNPeople, subgroupNPeople)
                NContacts_Interaction = 0.5 * (
                    NContacts_Interaction + NContacts_Interaction.T
                )
                np.fill_diagonal(
                    NContacts_Interaction, subgroupNPeople * (subgroupNPeople - 1)
                )
            self.CMV["Interaction"][group.spec] += NContacts_Interaction
        else:
            NContacts_Interaction = np.outer(subgroupNPeople, subgroupNPeople)
            np.fill_diagonal(
                NContacts_Interaction, subgroupNPeople * (subgroupNPeople - 1)
            )
            self.CMV["Interaction"][group.spec] += NContacts_Interaction

        # By Age
        NAges_unisex = np.array([p.age for p in group.people])
        NAges_male = np.array([p.age for p in group.people if p.sex == "m"])
        NAges_female = np.array([p.age for p in group.people if p.sex == "f"])
        bins = np.arange(0, 101, 1)
        Counts_unisex, bins = np.histogram(NAges_unisex, bins=bins)
        Counts_male, _ = np.histogram(NAges_male, bins=bins)
        Counts_female, _ = np.histogram(NAges_female, bins=bins)

        NContacts_unisex = np.outer(Counts_unisex, Counts_unisex)
        np.fill_diagonal(NContacts_unisex, Counts_unisex * (Counts_unisex - 1))

        NContacts_female = np.outer(Counts_female, Counts_unisex)
        np.fill_diagonal(NContacts_female, Counts_female * (Counts_unisex - 1))

        NContacts_male = np.outer(Counts_male, Counts_unisex)
        np.fill_diagonal(NContacts_male, Counts_male * (Counts_unisex - 1))

        self.CMV["syoa"]["global"]["unisex"] += NContacts_unisex
        self.CMV["syoa"][group.spec]["unisex"] += NContacts_unisex

        self.CMV["syoa"]["global"]["female"] += NContacts_female
        self.CMV["syoa"][group.spec]["female"] += NContacts_female

        self.CMV["syoa"]["global"]["male"] += NContacts_male
        self.CMV["syoa"][group.spec]["male"] += NContacts_male

        # This is identical to shelters...
        if group.spec == "shelter":
            # Inter
            self.CMV["syoa"][group.spec + "_inter"]["unisex"] += NContacts_unisex
            self.CMV["syoa"][group.spec + "_inter"]["female"] += NContacts_female
            self.CMV["syoa"][group.spec + "_inter"]["male"] += NContacts_male
            # Intra
            self.CMV["syoa"][group.spec + "_intra"]["unisex"] += NContacts_unisex
            self.CMV["syoa"][group.spec + "_intra"]["female"] += NContacts_female
            self.CMV["syoa"][group.spec + "_intra"]["male"] += NContacts_male
        return 1

    def simulate_pop_time_venues(self, group):
        """
        Get the population and cumulative time at all venues over all time steps.
        Sets;
            self.location_cum_pop
            self.location_cum_time

        Parameters
        ----------
            group:
                The group of interest to build contacts

        Returns
        -------
            None

        """
        # Loop over people
        if len(group.people) < 2:
            return 1

        for subgroup, sub_i in zip(group.subgroups, range(len(group.subgroups))):
            if group.spec == "school":  # change subgroups to Teachers, Students
                if sub_i > 0:
                    sub_i = 1
            if group.spec == "shelter":
                self.location_cum_pop["Interaction"][group.spec][sub_i] += len(
                    group.people
                )
            else:
                self.location_cum_pop["Interaction"][group.spec][sub_i] += len(
                    subgroup.people
                )

        for person in group.people:
            # Only sum those which had any contacts

            age_idx = self.age_idxs["syoa"][person.id]
            self.location_cum_pop["syoa"]["global"]["unisex"][age_idx] += 1
            self.location_cum_pop["syoa"][group.spec]["unisex"][age_idx] += 1
            if group.spec == "shelter":
                self.location_cum_pop["syoa"][group.spec + "_inter"]["unisex"][
                    age_idx
                ] += 1
                self.location_cum_pop["syoa"][group.spec + "_intra"]["unisex"][
                    age_idx
                ] += 1
            if person.sex == "m" and "male" in self.contact_sexes:
                self.location_cum_pop["syoa"]["global"]["male"][age_idx] += 1
                self.location_cum_pop["syoa"][group.spec]["male"][age_idx] += 1
                if group.spec == "shelter":
                    self.location_cum_pop["syoa"][group.spec + "_inter"]["male"][
                        age_idx
                    ] += 1
                    self.location_cum_pop["syoa"][group.spec + "_intra"]["male"][
                        age_idx
                    ] += 1
            if person.sex == "f" and "female" in self.contact_sexes:
                self.location_cum_pop["syoa"]["global"]["female"][age_idx] += 1
                self.location_cum_pop["syoa"][group.spec]["female"][age_idx] += 1
                if group.spec == "shelter":
                    self.location_cum_pop["syoa"][group.spec + "_inter"]["female"][
                        age_idx
                    ] += 1
                    self.location_cum_pop["syoa"][group.spec + "_intra"]["female"][
                        age_idx
                    ] += 1

        self.location_cum_time["global"] += (
            len(group.people) * self.timer.delta_time.seconds
        ) / (
            3600 * 24
        )  # In Days
        self.location_cum_time[group.spec] += (
            len(group.people) * self.timer.delta_time.seconds
        ) / (
            3600 * 24
        )  # In Days
        if group.spec == "shelter":
            self.location_cum_time[group.spec + "_inter"] += (
                len(group.people) * self.timer.delta_time.seconds
            ) / (
                3600 * 24
            )  # In Days
            self.location_cum_time[group.spec + "_intra"] += (
                len(group.people) * self.timer.delta_time.seconds
            ) / (
                3600 * 24
            )  # In Days
        return 1

    def simulate_attendance(self, group, super_group_name, timer, counter):
        """
        Update person counts at location

        Sets;
            self.location_counters

        Parameters
        ----------
            group:
                The group of interest to build contacts
            super_groups_name:
                location name
            timer:
                timestamp of the time step
            counter:
                venue number in locations list

        Returns
        -------
            None

        """
        people = [p.id for p in group.people]
        men = [p.id for p in group.people if p.sex == "m"]
        women = [p.id for p in group.people if p.sex == "f"]
        if super_group_name in self.location_counters["loc"].keys():
            # By dt
            self.location_counters["loc"][super_group_name][counter]["unisex"].append(
                len(people)
            )
            NewPeople = self.union(
                self.location_counters_day_i["loc"][super_group_name][counter][
                    "unisex"
                ],
                people,
            )
            if "male" in self.contact_sexes:
                self.location_counters["loc"][super_group_name][counter]["male"].append(
                    len(men)
                )
                NewMen = self.union(
                    self.location_counters_day_i["loc"][super_group_name][counter][
                        "male"
                    ],
                    men,
                )
            if "female" in self.contact_sexes:
                self.location_counters["loc"][super_group_name][counter][
                    "female"
                ].append(len(women))
                NewWomen = self.union(
                    self.location_counters_day_i["loc"][super_group_name][counter][
                        "female"
                    ],
                    women,
                )

            # By Date
            if (
                timer.date.hour == timer.initial_date.hour
                and timer.date.minute == 0
                and timer.date.second == 0
            ):
                self.location_counters_day_i["loc"][super_group_name][counter][
                    "unisex"
                ] = people
                self.location_counters_day["loc"][super_group_name][counter][
                    "unisex"
                ].append(len(people))
                if "male" in self.contact_sexes:
                    self.location_counters_day_i["loc"][super_group_name][counter][
                        "male"
                    ] = men
                    self.location_counters_day["loc"][super_group_name][counter][
                        "male"
                    ].append(len(men))
                if "female" in self.contact_sexes:
                    self.location_counters_day_i["loc"][super_group_name][counter][
                        "female"
                    ] = women
                    self.location_counters_day["loc"][super_group_name][counter][
                        "female"
                    ].append(len(women))
            else:
                self.location_counters_day_i["loc"][super_group_name][counter][
                    "unisex"
                ] = NewPeople
                self.location_counters_day["loc"][super_group_name][counter]["unisex"][
                    -1
                ] = len(NewPeople)

                if "male" in self.contact_sexes:
                    self.location_counters_day_i["loc"][super_group_name][counter][
                        "male"
                    ] = NewMen
                    self.location_counters_day["loc"][super_group_name][counter][
                        "male"
                    ][-1] = len(NewMen)
                if "female" in self.contact_sexes:
                    self.location_counters_day_i["loc"][super_group_name][counter][
                        "female"
                    ] = NewWomen
                    self.location_counters_day["loc"][super_group_name][counter][
                        "female"
                    ][-1] = len(NewWomen)

    def simulate_traveldistance(self, day):
        """
        Simulate travels distances from distance to residence from venue

        Sets;
            self.travel_distance

        Parameters
        ----------
            day:
                str, day of the week for time step

        Returns
        -------
            None

        """
        self.travel_distance[day] = {}
        for loc in self.location_counters_day_i["loc"].keys():
            self.travel_distance[day][loc] = []
            grouptype = getattr(self.world, loc)
            if grouptype is not None:
                counter = 0
                groups_which = np.array(grouptype.members)[
                    np.array(self.venues_which[loc])
                ]
                for group in groups_which:  # Loop over all locations.
                    if group.external:
                        counter += 1
                        continue

                    venue_coords = group.coordinates

                    for ID in self.location_counters_day_i["loc"][loc][counter][
                        "unisex"
                    ]:
                        person = self.world.people.get_from_id(ID)
                        if person.residence is None:
                            continue
                        household_coords = person.residence.group.area.coordinates
                        self.travel_distance[day][loc].append(
                            geopy.distance.geodesic(household_coords, venue_coords).km
                        )
                    counter += 1
        return 1

    ####################################################
    # Tracker running ##################################
    ####################################################

    def trackertimestep(self, all_super_groups, timer):
        """
        Loop over all locations at each timestamp to get contact matrices and location population counts.

        Parameters
        ----------
            all_super_groups:
                List of all groups to track contacts over
            timer:
                timer object from simulator class

        Returns
        -------
            None

        """
        self.timer = timer
        self.location_counters["Timestamp"].append(self.timer.date)
        self.location_counters["delta_t"].append(self.timer.delta_time.seconds / 3600)

        if (
            self.timer.date.hour == self.timer.initial_date.hour
            and self.timer.date.minute == 0
            and self.timer.date.second == 0
        ):
            self.location_counters_day["Timestamp"].append(self.timer.date)

        DaysElapsed = len(self.location_counters_day["Timestamp"]) - 1
        day = self.timer.day_of_week

        if DaysElapsed > 0 and DaysElapsed <= 8:
            # Only run after first day completed first day
            self.simulate_traveldistance(day)

        for super_group_name in all_super_groups:
            if "visits" in super_group_name:
                continue
            grouptype = getattr(self.world, super_group_name)
            if grouptype is not None:

                # Venue type not in domain
                if super_group_name not in self.venues_which.keys():
                    continue

                counter = 0
                Skipped_E = 0
                groups_which = np.array(grouptype.members)[
                    np.array(self.venues_which[super_group_name])
                ]
                for group in groups_which:  # Loop over all locations.
                    if group.spec in self.group_type_names:
                        if counter == 0:
                            logger.info(
                                f"Rank {mpi_rank} -- tracking contacts -- {len(self.venues_which[super_group_name])} of {len(grouptype.members)} of type {group.spec}"
                            )
                        if group.external:
                            Skipped_E += 1
                            counter += 1
                            continue  # Skip external venues to the domain.

                        self.simulate_pop_time_venues(group)
                        self.simulate_attendance(
                            group, super_group_name, self.timer, counter
                        )
                        self.simulate_1d_contacts(group)
                        self.simulate_All_contacts(group)
                        counter += 1
        return 1

    ###########################################################
    # Saving tracker results ##################################
    ###########################################################

    def tracker_tofile(self, tracker_path):
        """
        Save tracker log. Including;
            Input interaction matrices
            Outputs over each contact matrix type syoa, AC, etc etc

        Parameters
        ----------
            tracker_path:
                str, path to save tracker results

        Returns
        -------
            None

        """
        # ratio = self.AttendanceRatio(binType, loc, "unisex")
        # cm = self.UNtoPNConversion(cm, ratio)
        # cm_err = self.UNtoPNConversion(cm_err, ratio)

        def SaveMatrix(CM, CM_err, Mtype, NormType="U"):
            jsonfile = {}
            for binType in list(CM.keys()):

                if NormType == "U":
                    pass
                elif NormType == "P":
                    Mtype = "P" + Mtype[1:]

                jsonfile[binType] = self.tracker_CMJSON(
                    binType=binType, CM=CM, CM_err=CM_err, NormType=NormType
                )
            # Save out the normalized UNCM
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / folder_name / "CM_yamls",
                folder=folder_name,
                filename=f"tracker_{Mtype}{mpi_rankname}.yaml",
                jsonfile=jsonfile,
            )

        def SaveMatrixMetrics(CM, CM_err, Mtype, NormType="U"):
            # Save out metric calculations
            jsonfile = {}
            for binType in list(CM.keys()):
                jsonfile[binType] = {}
                for loc in list(CM[binType].keys()):

                    if NormType == "U":
                        ratio = 1
                    elif NormType == "P":
                        ratio = self.AttendanceRatio(binType, loc, "unisex")
                        Mtype = "P" + Mtype[1:]

                    jsonfile[binType][loc] = self.Calculate_CM_Metrics(
                        bin_type=binType,
                        contact_type=loc,
                        CM=CM,
                        CM_err=CM_err,
                        ratio=ratio,
                        sex="unisex",
                    )
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / folder_name / "CM_Metrics",
                folder=folder_name,
                filename=f"tracker_Metrics_{Mtype}{mpi_rankname}.yaml",
                jsonfile=jsonfile,
            )

        def SaveMatrixCanberra(CM, CM_err, Mtype, NormType="U"):
            jsonfile = {}
            for loc in list(CM["Interaction"].keys()):

                if NormType == "U":
                    ratio = 1
                elif NormType == "P":
                    ratio = self.AttendanceRatio("Interaction", loc, "unisex")
                    Mtype = "P" + Mtype[1:]

                cm = CM["Interaction"][loc]
                cm = self.UNtoPNConversion(cm, ratio)

                A = np.array(cm, dtype=float)
                B = np.array(self.IM[loc]["contacts"], dtype=float)
                Dc = self.Canberra_distance(A, B)[0]
                jsonfile[loc] = {"Dc": f"{Dc}"}
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / folder_name / "CM_Metrics",
                folder=folder_name,
                filename=f"tracker_CanberraDist_{Mtype}{mpi_rankname}.yaml",
                jsonfile=jsonfile,
            )

        if mpi_size == 1:
            mpi_rankname = ""
            folder_name = "merged_data_output"
            MPI = False
        else:
            mpi_rankname = f"_r{mpi_rank}_"
            folder_name = "raw_data_output"
            MPI = True

        self.Save_CM_JSON(
            dir=tracker_path,
            folder=folder_name,
            filename=f"tracker_Simulation_Params{mpi_rankname}.yaml",
            jsonfile=self.tracker_Simulation_Params(),
        )

        # All Identical so don't need to do anything here
        if mpi_rank == 0:
            # Save out the IM
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / "merged_data_output" / "CM_yamls",
                folder=folder_name,
                filename=f"tracker_IM.yaml",
                jsonfile=self.tracker_IMJSON(),
            )

        # Saving Contacts tracker results ##################################
        SaveMatrix(CM=self.CM, CM_err=self.CM, Mtype="CM")
        SaveMatrix(CM=self.CMV, CM_err=self.CMV_err, Mtype="CMV")

        if not MPI:
            SaveMatrix(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM")
            SaveMatrix(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R")
            SaveMatrix(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V")

            SaveMatrix(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM", NormType="P")
            SaveMatrix(
                CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R", NormType="P"
            )
            SaveMatrix(
                CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V", NormType="P"
            )

            SaveMatrixMetrics(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM")
            SaveMatrixMetrics(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R")
            SaveMatrixMetrics(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V")

            SaveMatrixMetrics(
                CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM", NormType="P"
            )
            SaveMatrixMetrics(
                CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R", NormType="P"
            )
            SaveMatrixMetrics(
                CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V", NormType="P"
            )

            SaveMatrixCanberra(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM")
            SaveMatrixCanberra(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R")
            SaveMatrixCanberra(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V")

            SaveMatrixCanberra(
                CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM", NormType="P"
            )
            SaveMatrixCanberra(
                CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R", NormType="P"
            )
            SaveMatrixCanberra(
                CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V", NormType="P"
            )

        # Saving Venue tracker results ##################################
        VD_dir = self.record_path / "Tracker" / folder_name / "Venue_Demographics"
        VD_dir.mkdir(exist_ok=True, parents=True)
        for bin_types in self.age_profiles.keys():
            dat = self.age_profiles[bin_types]
            bins = self.age_bins[bin_types]
            with pd.ExcelWriter(
                VD_dir / f"PersonCounts_{bin_types}{mpi_rankname}.xlsx", mode="w"
            ) as writer:
                for local in dat.keys():

                    df = pd.DataFrame(dat[local])
                    if bin_types == "syoa":
                        df["Ages"] = [
                            f"{low}" for low, high in zip(bins[:-1], bins[1:])
                        ]
                    else:
                        df["Ages"] = [
                            f"{low}-{high-1}" for low, high in zip(bins[:-1], bins[1:])
                        ]
                    df = df.set_index("Ages")
                    df.loc["Total"] = df.sum()
                    df.to_excel(writer, sheet_name=f"{local}")

        VTD_dir = self.record_path / "Tracker" / folder_name / "Venue_TotalDemographics"
        VTD_dir.mkdir(exist_ok=True, parents=True)
        for bin_types in self.location_cum_pop.keys():
            dat = self.location_cum_pop[bin_types]
            with pd.ExcelWriter(
                VTD_dir / f"CumPersonCounts_{bin_types}{mpi_rankname}.xlsx", mode="w"
            ) as writer:
                for local in dat.keys():

                    df = pd.DataFrame(dat[local])
                    df.to_excel(writer, sheet_name=f"{local}")

        Dist_dir = self.record_path / "Tracker" / folder_name / "Venue_TravelDist"
        Dist_dir.mkdir(exist_ok=True, parents=True)
        days = list(self.travel_distance.keys())
        if len(days) != 0:
            with pd.ExcelWriter(
                Dist_dir / f"Distance_traveled{mpi_rankname}.xlsx", mode="w"
            ) as writer:
                for local in self.travel_distance[days[0]].keys():
                    df = pd.DataFrame()
                    bins = np.arange(0, 50, 0.05)
                    df["bins"] = (bins[:-1] + bins[1:]) / 2
                    for day in days:
                        df[day] = np.histogram(
                            self.travel_distance[day][local], bins=bins, density=False
                        )[0]
                    df.to_excel(writer, sheet_name=f"{local}")

        V_dir = self.record_path / "Tracker" / folder_name / "Venue_UniquePops"
        V_dir.mkdir(exist_ok=True, parents=True)

        # Save out persons per location
        timestamps = self.location_counters["Timestamp"]
        delta_ts = self.location_counters["delta_t"]
        for sex in self.contact_sexes:
            with pd.ExcelWriter(
                V_dir / f"Venues_{sex}_Counts_BydT{mpi_rankname}.xlsx", mode="w"
            ) as writer:
                for loc in self.location_counters["loc"].keys():
                    df = pd.DataFrame()
                    df["t"] = timestamps
                    df["dt"] = delta_ts
                    NVenues = len(self.location_counters["loc"][loc].keys())

                    loc_j = 0
                    for loc_i in range(NVenues):
                        if (
                            np.sum(self.location_counters["loc"][loc][loc_i]["unisex"])
                            == 0
                        ):
                            continue
                        df[loc_j] = self.location_counters["loc"][loc][loc_i][sex]
                        loc_j += 1

                        if loc_j > 600:
                            break

                    df.to_excel(writer, sheet_name=f"{loc}")

        timestamps = self.location_counters_day["Timestamp"]
        for sex in self.contact_sexes:
            with pd.ExcelWriter(
                V_dir / f"Venues_{sex}_Counts_ByDate{mpi_rankname}.xlsx", mode="w"
            ) as writer:
                for loc in self.location_counters_day["loc"].keys():
                    df = pd.DataFrame()
                    df["t"] = timestamps

                    NVenues = len(self.location_counters_day["loc"][loc].keys())
                    loc_j = 0
                    for loc_i in range(NVenues):
                        if (
                            np.sum(
                                self.location_counters_day["loc"][loc][loc_i]["unisex"]
                            )
                            == 0
                        ):
                            continue
                        df[loc_j] = self.location_counters_day["loc"][loc][loc_i][sex]
                        loc_j += 1

                        if loc_j > 600:
                            break
                    df.to_excel(writer, sheet_name=f"{loc}")

        # Save contacts per location
        Av_dir = self.record_path / "Tracker" / folder_name / "Venue_AvContacts"
        Av_dir.mkdir(exist_ok=True, parents=True)
        with pd.ExcelWriter(
            Av_dir / f"Average_contacts{mpi_rankname}.xlsx", mode="w"
        ) as writer:
            for rbt in self.average_contacts.keys():
                df = self.average_contacts[rbt]
                df.to_excel(writer, sheet_name=f"{rbt}")

        # Save out cumulative time
        CT_dir = self.record_path / "Tracker" / folder_name / "Venue_CumTime"
        CT_dir.mkdir(exist_ok=True, parents=True)
        df = pd.DataFrame.from_dict(self.location_cum_time, orient="index").T
        with pd.ExcelWriter(CT_dir / f"CumTime{mpi_rankname}.xlsx", mode="w") as writer:
            df.to_excel(writer)

        return 1

    def Save_CM_JSON(self, dir, folder, filename, jsonfile):
        """
        Save yaml file for any given json dict.
        Note saves dummy yaml in junk folder then resaves removing quotation marks

        Parameters
        ----------
            dir:
                string, the directory to save
            folder:
                string, raw or merged folder name
            filename:
                string, the filename
            jsonfile:
                dict, save to be saved out

        Returns
        -------
            None

        """
        junk_dir = self.record_path / "Tracker" / folder / "junk"
        junk_dir.mkdir(exist_ok=True, parents=True)

        dir.mkdir(exist_ok=True, parents=True)
        with open(junk_dir / filename, "w") as f:
            yaml.dump(
                jsonfile,
                f,
                allow_unicode=True,
                default_flow_style=False,
                default_style=None,
                sort_keys=False,
            )
        with open(junk_dir / filename, "r") as f, open(dir / filename, "w") as fo:
            for line in f:
                fo.write(line.replace('"', "").replace("'", ""))
        return 1

    def tracker_Simulation_Params(self):
        """
        Get JSON output for Simulation parameters

        Parameters
        ----------
            None

        Returns
        -------
            jsonfile:
                json of simulation parameters. total days, weekend/day names.

        """
        jsonfile = {}
        jsonfile["MPI_size"] = mpi_size
        jsonfile["MPI_rank"] = mpi_rank
        jsonfile["total_days"] = self.timer.total_days
        jsonfile["Weekend_Names"] = self.MatrixString(
            np.array(self.timer.day_types["weekend"])
        )
        jsonfile["Weekday_Names"] = self.MatrixString(
            np.array(self.timer.day_types["weekday"])
        )

        jsonfile["NVenues"] = {}
        for locations in self.location_counters_day["loc"].keys():
            jsonfile["NVenues"][locations] = len(
                self.location_counters_day["loc"][locations]
            )
        jsonfile["NPeople"] = len(self.world.people)
        jsonfile["binTypes"] = self.MatrixString(np.array(list(self.CM.keys())))
        jsonfile["sexes"] = self.MatrixString(np.array(self.contact_sexes))
        return jsonfile

    def tracker_IMJSON(self):
        """
        Get JSON output for the interaction matrix inputs to the contact tracker model

        Parameters
        ----------
            None

        Returns
        -------
            jsonfile:
                json of interaction matrices information

        """
        jsonfile = {}
        for local in self.IM.keys():
            jsonfile[local] = {}
            for item in self.IM[local].keys():
                if item in ["contacts", "contacts_err", "proportion_physical"]:
                    append = self.MatrixString(np.array(self.IM[local][item]))
                elif item in ["bins"]:
                    append = self.MatrixString(
                        np.array(self.IM[local][item]), dtypeString="int"
                    )
                elif item in ["characteristic_time", "type"]:
                    append = self.IM[local][item]
                jsonfile[local][item] = append
        return jsonfile

    def tracker_CMJSON(self, binType, CM, CM_err, NormType="U"):
        """
        Get final JUNE simulated contact matrix.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            CM:
                dict, dictionary of all matrices of type. eg self.CM
            CM_err:
                dict, dictionary of all matrices of type. eg self.CM_err

        Returns
        -------
            jsonfile:
                json of interaction matrices information

        """

        jsonfile = {}
        if binType == "Interaction":
            for local in CM[binType].keys():
                if NormType == "U":
                    ratio = 1
                elif NormType == "P":
                    ratio = self.AttendanceRatio(binType, local, "unisex")

                jsonfile[local] = {}

                c_time = self.IM[local]["characteristic_time"]
                I_bintype = self.IM[local]["type"]
                bins = self.IM[local]["bins"]
                p_physical = np.array(self.IM[local]["proportion_physical"])

                jsonfile[local]["proportion_physical"] = self.MatrixString(p_physical)
                jsonfile[local]["characteristic_time"] = c_time
                jsonfile[local]["type"] = I_bintype
                if I_bintype == "Age":
                    jsonfile[local]["bins"] = self.MatrixString(
                        np.array(bins), dtypeString="int"
                    )
                elif I_bintype == "Discrete":
                    jsonfile[local]["bins"] = self.MatrixString(
                        np.array(bins), dtypeString="float"
                    )

                cm = CM[binType][local]
                cm_err = CM_err[binType][local]
                cm = self.UNtoPNConversion(cm, ratio)
                cm_err = self.UNtoPNConversion(cm_err, ratio)

                jsonfile[local]["contacts"] = self.MatrixString(np.array(cm))
                jsonfile[local]["contacts_err"] = self.MatrixString(np.array(cm_err))
        else:

            def expand_proportional(self, PM, bins_I, bins_I_Type, bins_target):
                if bins_I_Type != "Age":
                    ACBins = any(
                        x in ["students", "teachers", "adults", "children"]
                        for x in bins_I
                    )
                    if ACBins:
                        bins_I = np.array([0, AgeAdult, 100])
                    else:
                        return PM
                expand_bins = self.age_bins["syoa"]
                Pmatrix = np.zeros(
                    (len(expand_bins) - 1, len(expand_bins) - 1), dtype=float
                )
                if PM.shape == (1, 1):
                    bins_I = np.array([0, 100])
                for bin_xi in range(len(bins_I) - 1):
                    for bin_yi in range(len(bins_I) - 1):
                        Win_Xi = (bins_I[bin_xi], bins_I[bin_xi + 1])
                        Win_Yi = (bins_I[bin_yi], bins_I[bin_yi + 1])
                        Pmatrix[Win_Xi[0] : Win_Xi[1], Win_Yi[0] : Win_Yi[1]] = PM[
                            bin_xi, bin_yi
                        ]
                Pmatrix = self.contract_matrix(Pmatrix, bins_target, method=np.mean)
                return Pmatrix

            locallists = list(CM[binType].keys())
            locallists.sort()
            for local in locallists:
                local = str(local)

                if NormType == "U":
                    ratio = 1
                elif NormType == "P":
                    ratio = self.AttendanceRatio(binType, local, "unisex")

                jsonfile[local] = {}

                if "shelter" in local:
                    local_c = "shelter"
                else:
                    local_c = local

                if local == "global":
                    c_time = 24
                    p_physical = np.array([[0.12]])
                else:
                    c_time = self.IM[local_c]["characteristic_time"]
                    p_physical = expand_proportional(
                        self,
                        np.array(self.IM[local_c]["proportion_physical"]),
                        self.IM[local_c]["bins"],
                        self.IM[local_c]["type"],
                        self.age_bins[binType],
                    )

                bins = self.MatrixString(
                    np.array(self.age_bins[binType]), dtypeString="int"
                )
                p_physical = self.MatrixString(p_physical)

                jsonfile[local]["proportion_physical"] = p_physical
                jsonfile[local]["characteristic_time"] = c_time
                jsonfile[local]["type"] = "Age"
                jsonfile[local]["bins"] = bins

                jsonfile[local]["sex"] = {}
                for sex in self.contact_sexes:
                    cm = CM[binType][local][sex]
                    cm_err = CM_err[binType][local][sex]
                    cm = self.UNtoPNConversion(cm, ratio)
                    cm_err = self.UNtoPNConversion(cm_err, ratio)

                    jsonfile[local]["sex"][sex] = {}
                    jsonfile[local]["sex"][sex]["contacts"] = self.MatrixString(
                        np.array(cm)
                    )
                    jsonfile[local]["sex"][sex]["contacts_err"] = self.MatrixString(
                        np.array(cm_err)
                    )
        return jsonfile

    def MatrixString(self, matrix, dtypeString="float"):
        """
        Take square matrix array into a string for clarity of printing

        Parameters
        ----------
            matrix:
                np.array matrix
            dtypeString:
                str, 'int' or 'float'

        Returns
        -------
            string:
                one line string respresentation of matrix

        """
        string = "["
        if len(matrix.shape) == 1:
            for i in range(matrix.shape[0]):
                if isinstance(matrix[i], str):
                    string += matrix[i]
                else:
                    if np.isnan(matrix[i]) or np.isinf(matrix[i]):
                        matrix[i] = 0

                    if dtypeString == "float":
                        string += "{:.2e}".format(matrix[i])  # "%.4f" % matrix[i]
                    if dtypeString == "int":
                        string += "%.0f" % matrix[i]

                if i < matrix.shape[0] - 1:
                    string += ","

        if len(matrix.shape) == 2:
            for i in range(matrix.shape[0]):
                string += "["
                for j in range(matrix.shape[1]):
                    if np.isnan(matrix[i, j]) or np.isinf(matrix[i, j]):
                        matrix[i, j] = 0

                    if dtypeString == "float":
                        string += "{:.2e}".format(matrix[i, j])  # "%.4f" % matrix[i, j]
                    if dtypeString == "int":
                        string += "%.0f" % matrix[i, j]

                    if j < matrix.shape[1] - 1:
                        string += ","
                string += "]"
                if i < matrix.shape[0] - 1:
                    string += ","
        string += "]"
        return string

    ##############################################################
    # Print out tracker results ##################################
    ##############################################################

    def PolicyText(
        self, Type, contacts, contacts_err, proportional_physical, characteristic_time
    ):
        """
        Clear print out of key results from contact matrix tracker for a given location.

        Parameters
        ----------
            Type:
                string bin type, syoa etc
            contacts:
                np.array contact matrix
            contacts_err:
                np.array contact matrix errors
            proportional_physical:
                np.array proportion of physical contact matrix
            characteristic_time:
                np.float The characteristic time at location in hours
        Returns
        -------
            None

        """
        print("  %s:" % Type)
        print("    contacts: %s" % self.MatrixString(contacts))
        print("    contacts_err: %s" % self.MatrixString(contacts_err))
        print("    proportion_physical: %s" % self.MatrixString(proportional_physical))
        print("    characteristic_time: %.2f" % characteristic_time)
        return 1

    def PrintOutResults(self, WhichLocals=[], sex="unisex", binType="Interaction"):
        """
        Clear printout of results from contact tracker. Loop over all locations for contact matrix of sex and binType

        Parameters
        ----------
            WhichLocals:
                list of location names to print results for
            sex:
                Sex contact matrix
            binType:
                Name of bin type syoa, AC etc


        Returns
        -------
            None
        """
        if len(WhichLocals) == 0:
            WhichLocals = self.CM[binType].keys()

        def printoutfunction(which):
            for local in WhichLocals:
                contact, contact_err = self.CMPlots_GetCM(
                    binType, local, sex=sex, which=which
                )
                if local in self.IM.keys():

                    (
                        characteristic_time,
                        proportion_physical,
                    ) = self.get_characteristic_time(local)
                    proportional_physical = np.array(proportion_physical)
                    characteristic_time = characteristic_time * 24
                else:
                    proportional_physical = np.array(0)
                    characteristic_time = 0

                self.PolicyText(
                    local,
                    contact,
                    contact_err,
                    proportional_physical,
                    characteristic_time,
                )
                print("")
                im, im_err = self.IMPlots_GetIM(local)
                print(
                    "    Ratio of contacts and feed in values: %s"
                    % self.MatrixString(contact / np.array(im))
                )
                print("")

        print("Results from UNCM")
        printoutfunction(which="UNCM")
        print("")

        print("Results from UNCM_R")
        printoutfunction(which="UNCM_R")
        print("")

        print("Results from UNCM_V")
        printoutfunction(which="UNCM_V")
        print("")
        return 1


import numpy as np
import yaml
import pandas as pd
from pathlib import Path
import glob

from june.tracker.tracker import Tracker
from june.tracker.tracker_plots import PlotClass

from june.mpi_setup import mpi_comm, mpi_size, mpi_rank
import logging

logger = logging.getLogger("tracker merger")
mpi_logger = logging.getLogger("mpi")

if mpi_rank > 0:
    logger.propagate = False

#######################################################
# Plotting functions ##################################
#######################################################


class MergerClass:
    """
    Class to merge trackers results from multiple MPI runs

    Parameters
    ----------
        record_path:
            location of results directory

    Returns
    -------
    """

    class Timer:
        def __init__(
            self,
        ):
            self.total_days = 1

    def __init__(self, record_path=Path(""), NRanksTest=None):

        self.record_path = record_path
        self.timer = self.Timer()

        if (self.record_path / "Tracker" / "raw_data_output").exists():
            self.MPI = True
        else:
            self.MPI = False

        if not self.MPI:
            pass
        else:
            self.raw_data_path = self.record_path / "Tracker" / "raw_data_output"
            self.merged_data_path = self.record_path / "Tracker" / "merged_data_output"
            self.merged_data_path.mkdir(exist_ok=True, parents=True)

            if NRanksTest is None:
                self.NRanks = len(glob.glob(str(self.raw_data_path / "*.yaml")))
            else:
                self.NRanks = NRanksTest

            with open(self.raw_data_path / "tracker_Simulation_Params_r0_.yaml") as f:
                Params = yaml.load(f, Loader=yaml.FullLoader)

            self.group_type_names = {}
            self.group_type_names[0] = list(Params["NVenues"].keys()) + [
                "care_home_visits",
                "household_visits",
                "global",
            ]
            self.group_type_names["all"] = list(Params["NVenues"].keys()) + ["global"]
            self.binTypes = list(Params["binTypes"])
            self.contact_sexes = list(Params["sexes"])
            self.timer.total_days = int(Params["total_days"])

            Params["MPI_rank"] = "Combined"
            Params["Weekday_Names"] = self.MatrixString(
                matrix=np.array(Params["Weekday_Names"])
            )
            Params["Weekend_Names"] = self.MatrixString(
                matrix=np.array(Params["Weekend_Names"])
            )
            Params["binTypes"] = self.MatrixString(matrix=np.array(Params["binTypes"]))
            Params["sexes"] = self.MatrixString(matrix=np.array(Params["sexes"]))

            for rank in range(1, self.NRanks):
                with open(
                    self.raw_data_path / f"tracker_Simulation_Params_r{rank}_.yaml"
                ) as f:
                    Params_rank = yaml.load(f, Loader=yaml.FullLoader)

                self.group_type_names[rank] = list(Params_rank["NVenues"].keys()) + [
                    "care_home_visits",
                    "household_visits",
                    "global",
                ]

                group_names_update = list(
                    set(self.group_type_names["all"] + self.group_type_names[rank])
                )
                self.group_type_names["all"] = group_names_update

                venues = list(
                    set(Params_rank["NVenues"].keys())
                    & set(self.group_type_names[rank])
                )

                for v in venues:
                    if (
                        v in Params["NVenues"].keys()
                        and v in Params_rank["NVenues"].keys()
                    ):
                        Params["NVenues"][v] += Params_rank["NVenues"][v]
                    elif (
                        v not in Params["NVenues"].keys()
                        and v in Params_rank["NVenues"].keys()
                    ):
                        Params["NVenues"][v] = Params_rank["NVenues"][v]
                    else:
                        continue

                Params["NPeople"] += Params_rank["NPeople"]
            self.Save_CM_JSON(
                dir=self.merged_data_path,
                folder="merged_data_output",
                filename="tracker_Simulation_Params.yaml",
                jsonfile=Params,
            )

        logger.info(
            f"Rank {mpi_rank} -- Initial params loaded -- have following group types { self.group_type_names['all'] }"
        )

    ###########################################################################################
    # Import the useful functions from other Tracker modules ##################################
    ###########################################################################################

    def CM_Norm(self, cm, cm_err, pop_tots, contact_type="global", Which="UNCM"):
        return Tracker.CM_Norm(self, cm, cm_err, pop_tots, contact_type, Which)

    def get_characteristic_time(self, location):
        return Tracker.get_characteristic_time(self, location)

    def PolicyText(
        self, Type, contacts, contacts_err, proportional_physical, characteristic_time
    ):
        return Tracker.PolicyText(
            self,
            Type,
            contacts,
            contacts_err,
            proportional_physical,
            characteristic_time,
        )

    def MatrixString(self, matrix, dtypeString="float"):
        return Tracker.MatrixString(self, matrix, dtypeString)

    def pluralize_r(self, loc):
        return Tracker.pluralize_r(self, loc)

    def pluralize(self, loc):
        return Tracker.pluralize(self, loc)

    def initialize_CM_Normalizations(self):
        return Tracker.initialize_CM_Normalizations(self)

    def initialize_CM_All_Normalizations(self):
        return Tracker.initialize_CM_All_Normalizations(self)

    def normalize_1D_CM(self):
        return Tracker.normalize_1D_CM(self)

    def normalize_All_CM(self):
        return Tracker.normalize_All_CM(self)

    def PrintOutResults(self):
        return Tracker.PrintOutResults(self)

    def Save_CM_JSON(self, dir, folder, filename, jsonfile):
        return Tracker.Save_CM_JSON(self, dir, folder, filename, jsonfile)

    def tracker_CMJSON(self, binType, CM, CM_err, NormType):
        return Tracker.tracker_CMJSON(self, binType, CM, CM_err, NormType)

    def contract_matrix(self, CM, bins, method=np.sum):
        return Tracker.contract_matrix(self, CM, bins, method)

    def Calculate_CM_Metrics(
        self, bin_type, contact_type, CM, CM_err, ratio, sex="unisex"
    ):
        return Tracker.Calculate_CM_Metrics(
            self, bin_type, contact_type, CM, CM_err, ratio, sex
        )

    def Population_Metrics(self, pop_by_bin, pop_bins):
        return Tracker.Population_Metrics(self, pop_by_bin, pop_bins)

    def Expectation_Assortativeness(self, NPCDM, pop_bins):
        return Tracker.Expectation_Assortativeness(self, NPCDM, pop_bins)

    def Calc_NPCDM(self, cm, pop_by_bin, pop_width):
        return Tracker.Calc_NPCDM(self, cm, pop_by_bin, pop_width)

    def Calc_QIndex(self, cm):
        return Tracker.Calc_QIndex(self, cm)

    def Canberra_distance(self, x, y):
        return Tracker.Canberra_distance(self, x, y)

    def AttendanceRatio(self, bin_type, contact_type, sex):
        return Tracker.AttendanceRatio(self, bin_type, contact_type, sex)

    def UNtoPNConversion(self, cm, ratio):
        return Tracker.UNtoPNConversion(self, cm, ratio)

    def CMPlots_GetCM(self, bin_type, contact_type, sex="unisex", which="UNCM"):
        """
        Get cm out of dictionary.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            sex:
                Sex contact matrix
            which:
                str, which matrix type to collect "CM", "UNCM", "UNCM_R", "CMV", "UNCM_V"

        Returns
        -------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
        """
        if bin_type != "Interaction":
            if which == "CM":
                cm = self.CM[bin_type][contact_type][sex]
                cm_err = self.CM_err[bin_type][contact_type][sex]
            elif which == "UNCM":
                cm = self.UNCM[bin_type][contact_type][sex]
                cm_err = self.UNCM_err[bin_type][contact_type][sex]
            elif which == "UNCM_R":
                cm = self.UNCM_R[bin_type][contact_type][sex]
                cm_err = self.UNCM_R_err[bin_type][contact_type][sex]

            elif which == "CMV":
                cm = self.CMV[bin_type][contact_type][sex]
                cm_err = self.CMV_err[bin_type][contact_type][sex]
            elif which == "UNCM_V":
                cm = self.UNCM_V[bin_type][contact_type][sex]
                cm_err = self.UNCM_V_err[bin_type][contact_type][sex]

        else:
            if which == "CM":
                cm = self.CM[bin_type][contact_type]
                cm_err = self.CM_err[bin_type][contact_type]
            elif which == "UNCM":
                cm = self.UNCM[bin_type][contact_type]
                cm_err = self.UNCM_err[bin_type][contact_type]
            elif which == "UNCM_R":
                cm = self.UNCM_R[bin_type][contact_type]
                cm_err = self.UNCM_R_err[bin_type][contact_type]

            elif which == "CMV":
                cm = self.CMV[bin_type][contact_type]
                cm_err = self.CMV_err[bin_type][contact_type]
            elif which == "UNCM_V":
                cm = self.UNCM_V[bin_type][contact_type]
                cm_err = self.UNCM_V_err[bin_type][contact_type]
        return np.array(cm), np.array(cm_err)

    def IMPlots_GetIM(self, contact_type):
        return Tracker.IMPlots_GetIM(self, contact_type)

    #####################################################
    # Individual Merge ##################################
    #####################################################

    def Travel_Distance(self):
        travel_distance = {}
        for rank in range(0, self.NRanks):
            filename = (
                self.raw_data_path
                / "Venue_TravelDist"
                / f"Distance_traveled_r{rank}_.xlsx"
            )
            for loc in self.group_type_names[rank]:
                if loc in [
                    "global",
                    "shelter_inter",
                    "shelter_intra",
                    "care_home_visits",
                    "household_visits",
                ]:
                    continue
                df = pd.read_excel(filename, sheet_name=loc, index_col=0)
                if loc not in travel_distance.keys():
                    travel_distance[loc] = df
                else:
                    travel_distance[loc].iloc[:, 1:] += df.iloc[:, 1:]
        Save_dir = self.merged_data_path / "Venue_TravelDist"
        Save_dir.mkdir(exist_ok=True, parents=True)
        with pd.ExcelWriter(Save_dir / f"Distance_traveled.xlsx", mode="w") as writer:
            for local in travel_distance.keys():
                travel_distance[local].to_excel(writer, sheet_name=f"{local}")
        return 1

    def CumPersonCounts(self):
        self.location_cum_pop = {}
        for rbt in self.binTypes:
            self.location_cum_pop[rbt] = {}
            for rank in range(0, self.NRanks):
                filename = (
                    self.raw_data_path
                    / "Venue_TotalDemographics"
                    / f"CumPersonCounts_{rbt}_r{rank}_.xlsx"
                )
                for loc in self.group_type_names[rank]:
                    if loc in ["care_home_visits", "household_visits"]:
                        continue

                    loc = self.pluralize_r(loc)

                    if loc == "global" and rbt == "Interaction":
                        continue

                    df = pd.read_excel(filename, sheet_name=loc, index_col=0)

                    if loc not in self.location_cum_pop[rbt].keys():
                        self.location_cum_pop[rbt][loc] = df
                    else:
                        self.location_cum_pop[rbt][loc] += df

            Save_dir = self.merged_data_path / "Venue_TotalDemographics"
            Save_dir.mkdir(exist_ok=True, parents=True)
            with pd.ExcelWriter(
                Save_dir / f"CumPersonCounts_{rbt}.xlsx", mode="w"
            ) as writer:
                for local in self.location_cum_pop[rbt].keys():
                    df = pd.DataFrame(self.location_cum_pop[rbt][local])
                    df.to_excel(writer, sheet_name=f"{local}")
        return 1

    def VenueUniquePops(self):
        np.random.seed(1234)
        location_counters = {}
        for sex in self.contact_sexes:
            location_counters[sex] = {}
            for plural_loc in self.group_type_names["all"]:
                if plural_loc in ["global", "care_home_visits", "household_visits"]:
                    continue
                loc = self.pluralize_r(plural_loc)
                NVenues_so_far = 0
                for rank in range(0, self.NRanks):

                    if plural_loc not in self.group_type_names[rank]:
                        continue

                    filename = (
                        self.raw_data_path
                        / "Venue_UniquePops"
                        / f"Venues_{sex}_Counts_ByDate_r{rank}_.xlsx"
                    )

                    df = pd.read_excel(filename, sheet_name=plural_loc, index_col=0)

                    NVenues_rank_loc = df.shape[1] - 1
                    if NVenues_rank_loc == 0:
                        # No venues available
                        location_counters[sex][plural_loc] = pd.DataFrame(
                            {"t": df["t"]}
                        )
                        continue
                    Pick = int(600 / self.NRanks)
                    if NVenues_rank_loc > Pick:
                        pass
                    else:
                        Pick = NVenues_rank_loc

                    rands = np.random.choice(
                        np.arange(1, NVenues_rank_loc + 1, 1), size=Pick, replace=False
                    )
                    if plural_loc not in location_counters[sex].keys():
                        location_counters[sex][plural_loc] = pd.DataFrame(
                            {"t": df["t"]}
                        )
                        location_counters[sex][plural_loc][
                            np.arange(NVenues_so_far, NVenues_so_far + Pick, 1)
                        ] = df.iloc[:, [0] + rands].values
                    else:
                        location_counters[sex][plural_loc][
                            np.arange(NVenues_so_far, NVenues_so_far + Pick, 1)
                        ] = df.iloc[:, rands].values

                    NVenues_so_far += Pick

            Save_dir = self.merged_data_path / "Venue_UniquePops"
            Save_dir.mkdir(exist_ok=True, parents=True)
            with pd.ExcelWriter(
                Save_dir / f"Venues_{sex}_Counts_ByDate.xlsx", mode="w"
            ) as writer:
                for local in location_counters[sex].keys():
                    df = pd.DataFrame(location_counters[sex][local])
                    df.to_excel(writer, sheet_name=f"{local}")

        np.random.seed(1234)
        location_counters = {}
        for sex in self.contact_sexes:
            location_counters[sex] = {}
            for plural_loc in self.group_type_names["all"]:
                if plural_loc in ["global", "care_home_visits", "household_visits"]:
                    continue
                loc = self.pluralize_r(plural_loc)

                NVenues_so_far = 0
                for rank in range(0, self.NRanks):

                    if plural_loc not in self.group_type_names[rank]:
                        continue

                    filename = (
                        self.raw_data_path
                        / "Venue_UniquePops"
                        / f"Venues_{sex}_Counts_BydT_r{rank}_.xlsx"
                    )

                    df = pd.read_excel(filename, sheet_name=plural_loc, index_col=0)

                    NVenues_rank_loc = df.shape[1] - 1
                    if NVenues_rank_loc == 0:
                        location_counters[sex][plural_loc] = pd.DataFrame(
                            {"t": df["t"]}
                        )
                        continue
                    Pick = int(600 / self.NRanks)
                    if NVenues_rank_loc > Pick:
                        pass
                    else:
                        Pick = NVenues_rank_loc

                    rands = np.random.choice(
                        np.arange(1, NVenues_rank_loc + 1, 1), size=Pick, replace=False
                    )
                    if plural_loc not in location_counters[sex].keys():
                        location_counters[sex][plural_loc] = pd.DataFrame(
                            {"t": df["t"], "dt": df["dt"]}
                        )
                        location_counters[sex][plural_loc][
                            np.arange(NVenues_so_far, NVenues_so_far + Pick, 1)
                        ] = df.iloc[:, [0] + rands]
                    else:
                        location_counters[sex][plural_loc][
                            np.arange(NVenues_so_far, NVenues_so_far + Pick, 1)
                        ] = df.iloc[:, rands].values

                    NVenues_so_far += Pick

            Save_dir = self.merged_data_path / "Venue_UniquePops"
            Save_dir.mkdir(exist_ok=True, parents=True)
            with pd.ExcelWriter(
                Save_dir / f"Venues_{sex}_Counts_BydT.xlsx", mode="w"
            ) as writer:
                for local in location_counters[sex].keys():
                    df = pd.DataFrame(location_counters[sex][local])
                    df.to_excel(writer, sheet_name=f"{local}")
        return 1

    def VenuePersonCounts(self):
        self.age_profiles = {}
        self.rank_age_profiles = {}
        for rbt in self.binTypes:
            if rbt == "Interaction":
                continue

            self.rank_age_profiles[rbt] = {}

            self.age_profiles[rbt] = {}
            for rank in range(0, self.NRanks):

                filename = (
                    self.raw_data_path
                    / "Venue_Demographics"
                    / f"PersonCounts_{rbt}_r{rank}_.xlsx"
                )
                for loc in self.group_type_names[rank]:
                    if loc in ["care_home_visits", "household_visits"]:

                        continue

                    loc = self.pluralize_r(loc)
                    if loc not in self.rank_age_profiles[rbt].keys():
                        self.rank_age_profiles[rbt][loc] = {}

                    df = pd.read_excel(filename, sheet_name=loc, index_col=0)

                    self.rank_age_profiles[rbt][loc][rank] = df.copy()["unisex"].iloc[
                        :-1
                    ]
                    if "all" not in self.rank_age_profiles[rbt][loc].keys():
                        self.rank_age_profiles[rbt][loc]["all"] = df["unisex"].iloc[:-1]
                    else:
                        self.rank_age_profiles[rbt][loc]["all"] += (
                            df["unisex"].iloc[:-1].values
                        )

                    if loc not in self.age_profiles[rbt].keys():
                        self.age_profiles[rbt][loc] = df
                    else:
                        self.age_profiles[rbt][loc] += df.values

            Save_dir = self.merged_data_path / "Venue_Demographics"
            Save_dir.mkdir(exist_ok=True, parents=True)
            with pd.ExcelWriter(
                Save_dir / f"PersonCounts_{rbt}.xlsx", mode="w"
            ) as writer:
                for local in self.age_profiles[rbt].keys():
                    df = pd.DataFrame(self.age_profiles[rbt][local])
                    df.to_excel(writer, sheet_name=f"{local}")

        # Remove the total row.
        for rbt in self.age_profiles.keys():
            for loc in self.age_profiles[rbt].keys():
                self.age_profiles[rbt][loc] = self.age_profiles[rbt][loc].iloc[:-1, :]
        return 1

    def AvContacts(self):
        AvContacts = {}
        for rbt in self.binTypes:
            if rbt == "Interaction":
                continue
            for rank in range(0, self.NRanks):
                filename = (
                    self.raw_data_path
                    / "Venue_AvContacts"
                    / f"Average_contacts_r{rank}_.xlsx"
                )
                df = pd.read_excel(filename, sheet_name=rbt, index_col=0)

                if rank == 0:
                    dat = {df.columns[0]: df.iloc[0]}
                    nbins = len(self.rank_age_profiles[rbt]["global"]["all"])
                    for col in self.group_type_names["all"]:
                        col = self.pluralize_r(col)
                        if "visit" in col:
                            col += "s"
                        dat[col] = np.zeros(nbins)

                    AvContacts[rbt] = pd.DataFrame(dat)

                for col in df.columns:
                    if self.pluralize(col) not in self.group_type_names[rank]:
                        continue
                    col_age = self.pluralize_r(col)
                    if col_age == "care_home_visit":
                        col_age = "care_home"
                    if col_age == "household_visit":
                        col_age = "household"

                    # factor = (self.rank_age_profiles[rbt][col_age][rank].values/self.rank_age_profiles[rbt][col_age]["all"].values)
                    factor = (
                        self.rank_age_profiles[rbt]["global"][rank].values
                        / self.rank_age_profiles[rbt]["global"]["all"].values
                    )
                    AvContacts[rbt][col] += (df[col] * factor).values

        Save_dir = self.merged_data_path / "Venue_AvContacts"
        Save_dir.mkdir(exist_ok=True, parents=True)
        with pd.ExcelWriter(Save_dir / f"Average_contacts.xlsx", mode="w") as writer:
            for rbt in self.binTypes:
                if rbt == "Interaction":
                    continue
                df = pd.DataFrame(AvContacts[rbt]).replace(np.nan, 0)

                df.to_excel(writer, sheet_name=f"{rbt}")
        return 1

    def LoadIMatrices(self):
        with open(self.merged_data_path / "CM_yamls" / f"tracker_IM.yaml") as f:
            self.IM = yaml.load(f, Loader=yaml.FullLoader)
        return 1

    def LoadContactMatrices(self):
        self.age_bins = {}

        for rank in range(0, self.NRanks):
            with open(
                self.raw_data_path / "CM_yamls" / f"tracker_CM_r{rank}_.yaml"
            ) as f:
                self.CM_rank = yaml.load(f, Loader=yaml.FullLoader)

            if rank == 0:
                # Create copies of the contact_matrices to be filled in.
                # Error Matrix
                self.CM = {
                    bin_type: {
                        loc: {
                            sex: np.array(
                                self.CM_rank[bin_type][loc]["sex"][sex]["contacts"]
                            )
                            * self.timer.total_days
                            for sex in self.CM_rank[bin_type][loc]["sex"].keys()
                        }
                        for loc in self.CM_rank[bin_type].keys()
                    }
                    for bin_type in self.CM_rank.keys()
                    if bin_type != "Interaction"
                }
                self.CM["Interaction"] = {
                    loc: np.array(self.CM_rank["Interaction"][loc]["contacts"])
                    * self.timer.total_days
                    for loc in self.CM_rank["Interaction"].keys()
                }

                for rbt in self.binTypes:
                    if rbt == "Interaction" or rbt in self.age_bins.keys():
                        continue
                    loc = list(self.CM_rank[rbt].keys())[0]
                    self.age_bins[rbt] = self.CM_rank[rbt][loc]["bins"]

            else:
                for bin_type in self.binTypes:
                    for loc_plural in self.group_type_names["all"]:
                        loc = self.pluralize_r(loc_plural)
                        NEW = False
                        if loc_plural not in self.group_type_names[rank]:
                            continue
                        if loc_plural in ["care_home_visits", "household_visits"]:
                            continue

                        if loc not in self.CM[bin_type].keys():
                            NEW = True

                        if bin_type != "Interaction":
                            if NEW:
                                self.CM[bin_type][loc] = {}

                            for sex in self.contact_sexes:
                                if NEW:
                                    self.CM[bin_type][loc][sex] = (
                                        np.array(
                                            self.CM_rank[bin_type][loc]["sex"][sex][
                                                "contacts"
                                            ]
                                        )
                                        * self.timer.total_days
                                    )
                                else:
                                    self.CM[bin_type][loc][sex] += (
                                        np.array(
                                            self.CM_rank[bin_type][loc]["sex"][sex][
                                                "contacts"
                                            ]
                                        )
                                        * self.timer.total_days
                                    )

                        else:
                            if loc in [
                                "global",
                                "care_home_visits",
                                "household_visits",
                            ]:
                                continue
                            if NEW:
                                self.CM[bin_type][loc] = (
                                    np.array(self.CM_rank[bin_type][loc]["contacts"])
                                    * self.timer.total_days
                                )
                            else:
                                self.CM[bin_type][loc] += (
                                    np.array(self.CM_rank[bin_type][loc]["contacts"])
                                    * self.timer.total_days
                                )
        logger.info(f"Rank {mpi_rank} -- Load CMs Done")

        for rank in range(0, self.NRanks):
            with open(
                self.raw_data_path / "CM_yamls" / f"tracker_CMV_r{rank}_.yaml"
            ) as f:
                self.CMV_rank = yaml.load(f, Loader=yaml.FullLoader)
                # [bin_type][contact_type]["sex"][sex]["contacts"]

            if rank == 0:
                # Create copies of the contact_matrices to be filled in.
                # Error Matrix
                self.CMV = {
                    bin_type: {
                        loc: {
                            sex: np.array(
                                self.CMV_rank[bin_type][loc]["sex"][sex]["contacts"]
                            )
                            * self.timer.total_days
                            for sex in self.CMV_rank[bin_type][loc]["sex"].keys()
                        }
                        for loc in self.CMV_rank[bin_type].keys()
                    }
                    for bin_type in self.CMV_rank.keys()
                    if bin_type != "Interaction"
                }
                self.CMV["Interaction"] = {
                    loc: np.array(self.CMV_rank["Interaction"][loc]["contacts"])
                    * self.timer.total_days
                    for loc in self.CMV_rank["Interaction"].keys()
                }

                for rbt in self.binTypes:
                    if rbt == "Interaction" or rbt in self.age_bins.keys():
                        continue
                    loc = list(self.CMV_rank[rbt].keys())[0]
                    self.age_bins[rbt] = self.CMV_rank[rbt][loc]["bins"]

            else:
                for bin_type in self.binTypes:
                    for loc_plural in self.group_type_names["all"]:
                        loc = self.pluralize_r(loc_plural)
                        NEW = False
                        if loc_plural not in self.group_type_names[rank]:
                            continue
                        if loc_plural in [
                            "global",
                            "care_home_visits",
                            "household_visits",
                        ]:
                            continue

                        if loc not in self.CMV[bin_type].keys():
                            NEW = True

                        if bin_type != "Interaction":
                            if NEW:
                                self.CMV[bin_type][loc] = {}

                            for sex in self.contact_sexes:
                                if NEW:
                                    self.CMV[bin_type][loc][sex] = (
                                        np.array(
                                            self.CMV_rank[bin_type][loc]["sex"][sex][
                                                "contacts"
                                            ]
                                        )
                                        * self.timer.total_days
                                    )
                                else:
                                    self.CMV[bin_type][loc][sex] += (
                                        np.array(
                                            self.CMV_rank[bin_type][loc]["sex"][sex][
                                                "contacts"
                                            ]
                                        )
                                        * self.timer.total_days
                                    )

                        else:
                            if loc in [
                                "global",
                                "care_home_visits",
                                "household_visits",
                            ]:
                                continue
                            if NEW:
                                self.CMV[bin_type][loc] = (
                                    np.array(self.CMV_rank[bin_type][loc]["contacts"])
                                    * self.timer.total_days
                                )
                            else:
                                self.CMV[bin_type][loc] += (
                                    np.array(self.CMV_rank[bin_type][loc]["contacts"])
                                    * self.timer.total_days
                                )

        logger.info(f"Rank {mpi_rank} -- Load CMVs Done")
        return 1

    def LoadCumtimes(self):
        self.location_cum_time = {}
        for rank in range(0, self.NRanks):
            filename = self.raw_data_path / "Venue_CumTime" / f"CumTime_r{rank}_.xlsx"

            df = pd.read_excel(filename, index_col=0)

            for plural_col in self.group_type_names["all"]:
                if plural_col in ["care_home_visits", "household_visits"]:
                    continue
                col = self.pluralize_r(plural_col)
                if col not in df.columns:
                    continue

                if col not in self.location_cum_time.keys():
                    self.location_cum_time[col] = df[col].values[0]
                else:
                    self.location_cum_time[col] += df[col].values[0]

        Save_dir = self.merged_data_path / "Venue_CumTime"
        Save_dir.mkdir(exist_ok=True, parents=True)
        df = pd.DataFrame.from_dict(self.location_cum_time, orient="index").T
        with pd.ExcelWriter(Save_dir / f"CumTime.xlsx", mode="w") as writer:
            df.to_excel(writer)
        return 1

    def SaveOutCM(self):
        folder_name = self.merged_data_path
        mpi_rankname = ""

        def SaveMatrix(CM, CM_err, Mtype, NormType="U"):
            jsonfile = {}
            for binType in list(CM.keys()):

                if NormType == "U":
                    pass
                elif NormType == "P":
                    Mtype = "P" + Mtype[1:]

                jsonfile[binType] = self.tracker_CMJSON(
                    binType=binType, CM=CM, CM_err=CM_err, NormType=NormType
                )
            # Save out the normalized UNCM
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / folder_name / "CM_yamls",
                folder=folder_name,
                filename=f"tracker_{Mtype}{mpi_rankname}.yaml",
                jsonfile=jsonfile,
            )

        def SaveMatrixMetrics(CM, CM_err, Mtype, NormType="U"):
            # Save out metric calculations
            jsonfile = {}
            for binType in list(CM.keys()):
                jsonfile[binType] = {}
                for loc in list(CM[binType].keys()):

                    if NormType == "U":
                        ratio = 1
                    elif NormType == "P":
                        ratio = self.AttendanceRatio(binType, loc, "unisex")
                        Mtype = "P" + Mtype[1:]

                    jsonfile[binType][loc] = self.Calculate_CM_Metrics(
                        bin_type=binType,
                        contact_type=loc,
                        CM=CM,
                        CM_err=CM_err,
                        ratio=ratio,
                        sex="unisex",
                    )
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / folder_name / "CM_Metrics",
                folder=folder_name,
                filename=f"tracker_Metrics_{Mtype}{mpi_rankname}.yaml",
                jsonfile=jsonfile,
            )

        def SaveMatrixCanberra(CM, CM_err, Mtype, NormType="U"):
            jsonfile = {}
            for loc in list(CM["Interaction"].keys()):

                if NormType == "U":
                    ratio = 1
                elif NormType == "P":
                    ratio = self.AttendanceRatio("Interaction", loc, "unisex")
                    Mtype = "P" + Mtype[1:]

                cm = CM["Interaction"][loc]
                cm = self.UNtoPNConversion(cm, ratio)

                A = np.array(cm, dtype=float)
                B = np.array(self.IM[loc]["contacts"], dtype=float)
                Dc = self.Canberra_distance(A, B)[0]
                jsonfile[loc] = {"Dc": f"{Dc}"}
            self.Save_CM_JSON(
                dir=self.record_path / "Tracker" / folder_name / "CM_Metrics",
                folder=folder_name,
                filename=f"tracker_CanberraDist_{Mtype}{mpi_rankname}.yaml",
                jsonfile=jsonfile,
            )

        # Saving Contacts tracker results ##################################
        SaveMatrix(CM=self.CM, CM_err=self.CM, Mtype="CM")
        SaveMatrix(CM=self.CMV, CM_err=self.CMV_err, Mtype="CMV")

        SaveMatrix(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM")
        SaveMatrix(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R")
        SaveMatrix(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V")

        SaveMatrix(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM", NormType="P")
        SaveMatrix(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R", NormType="P")
        SaveMatrix(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V", NormType="P")

        SaveMatrixMetrics(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM")
        SaveMatrixMetrics(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R")
        SaveMatrixMetrics(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V")

        SaveMatrixMetrics(
            CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM", NormType="P"
        )
        SaveMatrixMetrics(
            CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R", NormType="P"
        )
        SaveMatrixMetrics(
            CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V", NormType="P"
        )

        SaveMatrixCanberra(CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM")
        SaveMatrixCanberra(CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R")
        SaveMatrixCanberra(CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V")

        SaveMatrixCanberra(
            CM=self.UNCM, CM_err=self.UNCM_err, Mtype="UNCM", NormType="P"
        )
        SaveMatrixCanberra(
            CM=self.UNCM_R, CM_err=self.UNCM_R_err, Mtype="UNCM_R", NormType="P"
        )
        SaveMatrixCanberra(
            CM=self.UNCM_V, CM_err=self.UNCM_V_err, Mtype="UNCM_V", NormType="P"
        )
        return 1

    #################################################
    # Master Merge ##################################
    #################################################

    def Merge(self):
        logger.info(f"Rank {mpi_rank} -- Begin Merging from {self.NRanks+1} ranks")
        if self.MPI:
            self.Travel_Distance()
            logger.info(f"Rank {mpi_rank} -- Distance sheet done")

            self.LoadCumtimes()
            logger.info(f"Rank {mpi_rank} -- Cumulative time done")
            self.CumPersonCounts()
            logger.info(f"Rank {mpi_rank} -- Person counts done")
            self.VenueUniquePops()
            logger.info(f"Rank {mpi_rank} -- Unique Venue pops done")
            self.VenuePersonCounts()
            logger.info(f"Rank {mpi_rank} -- Total Venue pops done")
            self.AvContacts()
            logger.info(f"Rank {mpi_rank} -- Average contacts done")

            self.LoadIMatrices()
            self.LoadContactMatrices()
            logger.info(f"Rank {mpi_rank} -- Load IM and CMs done")

            self.initialize_CM_Normalizations()
            self.normalize_1D_CM()

            self.initialize_CM_All_Normalizations()
            self.normalize_All_CM()

            logger.info(f"Rank {mpi_rank} -- normalized CMs done")

            self.SaveOutCM()
            logger.info(f"Rank {mpi_rank} -- Saved CM done")

        else:
            logger.info(f"Rank {mpi_rank} -- Skip run was on 1 core")
        logger.info(f"Rank {mpi_rank} -- Merging done")
        self.PrintOutResults()


from .tracker_plots_formatting import fig_initialize, set_size, dpi

import numpy as np
import yaml
import pandas as pd

from pathlib import Path
from june import paths

from june.world import World

import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib.dates import DateFormatter
import matplotlib.dates as mdates
import datetime
import logging

from june.tracker.tracker import Tracker

from june.mpi_setup import mpi_comm, mpi_size, mpi_rank
from june.paths import data_path, configs_path

fig_initialize(setsize=True)

logger = logging.getLogger("tracker plotter")
mpi_logger = logging.getLogger("mpi")

if mpi_rank > 0:
    logger.propagate = False

default_BBC_Pandemic_loc = data_path / "BBC_Pandemic"

DaysOfWeek_Names = [
    "Sunday",
    "Monday",
    "Tuesday",
    "Wednesday",
    "Thursday",
    "Friday",
    "Saturday",
]

cmap_A = "RdYlBu_r"
cmap_B = "seismic"


#######################################################
# Plotting functions ##################################
#######################################################


class PlotClass:
    """
    Class to plot everything tracker related

    Parameters
    ----------
    record_path:
        path for results directory

    Tracker_Contact_Type:
        NONE, Not used

    Normalization_Type:
        string, "U" for venue normalized or "P" for population normalized

    Following parameters can be preloaded data from another plot class. If None data automatically loaded.
        Params,
        IM,
        CM,
        NCM,
        NCM_R,
        CMV,
        NCMV,
        average_contacts,
        location_counters,
        location_counters_day,
        location_cum_pop,
        age_profiles,
        travel_distance,


    Returns
    -------
        The tracker plotting class

    """

    def __init__(
        self,
        record_path=Path(""),
        Tracker_Contact_Type=None,
        Params=None,
        IM=None,
        CM=None,
        NCM=None,
        NCM_R=None,
        CMV=None,
        NCM_V=None,
        average_contacts=None,
        location_counters=None,
        location_counters_day=None,
        location_cum_pop=None,
        age_profiles=None,
        travel_distance=None,
        Normalization_Type="U",
    ):

        if Tracker_Contact_Type is None:
            pass
        else:
            print("Tracker_Contact_Type argument no longer required")

        self.Normalization_Type = Normalization_Type

        self.record_path = record_path

        # Only plot fully merged data (Only applies to MPI runs, auto saved to merge if single core)
        folder_name = "merged_data_output"

        logger.info(f"Rank {mpi_rank} -- Begin loading")

        if Params is None:
            with open(
                self.record_path / folder_name / "tracker_Simulation_Params.yaml"
            ) as f:
                self.Params = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.Params = Params

        if IM is None:
            with open(
                self.record_path / folder_name / "CM_yamls" / "tracker_IM.yaml"
            ) as f:
                self.IM = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.IM = IM

        if CM is None:
            with open(
                self.record_path / folder_name / "CM_yamls" / f"tracker_CM.yaml"
            ) as f:
                self.CM = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.CM = CM

        if NCM is None:
            with open(
                self.record_path
                / folder_name
                / "CM_yamls"
                / f"tracker_{self.Normalization_Type}NCM.yaml"
            ) as f:
                self.NCM = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.NCM = NCM

        if NCM_R is None:
            with open(
                self.record_path
                / folder_name
                / "CM_yamls"
                / f"tracker_{self.Normalization_Type}NCM_R.yaml"
            ) as f:
                self.NCM_R = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.NCM_R = NCM_R

        if CMV is None:
            with open(
                self.record_path / folder_name / "CM_yamls" / f"tracker_CMV.yaml"
            ) as f:
                self.CMV = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.CMV = CMV

        if NCM_V is None:
            with open(
                self.record_path
                / folder_name
                / "CM_yamls"
                / f"tracker_{self.Normalization_Type}NCM_V.yaml"
            ) as f:
                self.NCM_V = yaml.load(f, Loader=yaml.FullLoader)
        else:
            self.NCM_V = NCM_V

        # Get Parameters of simulation
        self.total_days = self.Params["total_days"]
        self.day_types = {
            "weekend": self.Params["Weekend_Names"],
            "weekday": self.Params["Weekday_Names"],
        }
        self.NVenues = self.Params["NVenues"]
        # Get all the bin types
        self.relevant_bin_types = list(self.CM.keys())
        # Get all location names
        self.group_type_names = list(self.CM["syoa"].keys())
        # Get all CM options
        self.CM_Keys = list(self.CM["syoa"][self.group_type_names[0]].keys())
        # Get all contact sexes
        self.contact_sexes = list(
            self.CM["syoa"][self.group_type_names[0]]["sex"].keys()
        )

        self.age_bins = {}
        for rbt in self.relevant_bin_types:
            if rbt == "Interaction":
                continue
            self.age_bins[rbt] = np.array(
                self.CM[rbt][self.group_type_names[0]]["bins"]
            )

        if average_contacts is None:
            self.average_contacts = {}
            for rbt in self.relevant_bin_types:
                if rbt == "Interaction":
                    continue
                self.average_contacts[rbt] = pd.read_excel(
                    self.record_path
                    / folder_name
                    / "Venue_AvContacts"
                    / "Average_contacts.xlsx",
                    sheet_name=rbt,
                    index_col=0,
                )
        else:
            self.average_contacts = average_contacts

        if location_counters is None:
            self.location_counters = {"loc": {}}
            for loc in self.group_type_names:
                if loc in ["global", "shelter_inter", "shelter_intra"]:
                    continue
                self.location_counters["loc"][loc] = {}
                self.location_counters["Timestamp"] = None
                self.location_counters["dt"] = None

                for sex in self.contact_sexes:
                    filename = f"Venues_{sex}_Counts_BydT.xlsx"
                    sheet_name = Tracker.pluralize(self, loc)
                    df = pd.read_excel(
                        self.record_path / folder_name / "Venue_UniquePops" / filename,
                        sheet_name=sheet_name,
                        index_col=0,
                    )
                    self.location_counters["loc"][loc][sex] = df.iloc[:, 2:]
                    if self.location_counters["Timestamp"] is None:
                        self.location_counters["Timestamp"] = df["t"]
                        self.location_counters["delta_t"] = df["dt"]
        else:
            self.location_counters = location_counters

        if location_counters_day is None:
            self.location_counters_day = {"loc": {}}
            for loc in self.group_type_names:
                if loc in ["global", "shelter_inter", "shelter_intra"]:
                    continue
                self.location_counters_day["loc"][loc] = {}
                self.location_counters_day["Timestamp"] = None

                for sex in self.contact_sexes:
                    filename = f"Venues_{sex}_Counts_ByDate.xlsx"
                    sheet_name = Tracker.pluralize(self, loc)
                    df = pd.read_excel(
                        self.record_path / folder_name / "Venue_UniquePops" / filename,
                        sheet_name=sheet_name,
                        index_col=0,
                    )
                    self.location_counters_day["loc"][loc][sex] = df.iloc[:, 0:]
                    if self.location_counters_day["Timestamp"] is None:
                        self.location_counters_day["Timestamp"] = df["t"]
        else:
            self.location_counters_day = location_counters_day

        if location_cum_pop is None:
            self.location_cum_pop = {}
            for rbt in self.relevant_bin_types:
                self.location_cum_pop[rbt] = {}
                filename = (
                    self.record_path
                    / folder_name
                    / "Venue_TotalDemographics"
                    / f"CumPersonCounts_{rbt}.xlsx"
                )
                for loc in self.group_type_names:
                    self.location_cum_pop[rbt][loc] = {}
                    if rbt == "Interaction" and loc in [
                        "global",
                        "shelter_inter",
                        "shelter_intra",
                    ]:
                        continue
                    df = pd.read_excel(filename, sheet_name=loc, index_col=0)
                    self.location_cum_pop[rbt][loc] = df
        else:
            self.location_cum_pop = location_cum_pop

        if age_profiles is None:
            self.age_profiles = {}
            for rbt in self.relevant_bin_types:
                if rbt == "Interaction":
                    continue
                self.age_profiles[rbt] = {}
                filename = (
                    self.record_path
                    / folder_name
                    / "Venue_Demographics"
                    / f"PersonCounts_{rbt}.xlsx"
                )
                for loc in self.group_type_names:
                    self.age_profiles[rbt][loc] = {}

                    df = pd.read_excel(filename, sheet_name=loc, index_col=0)
                    self.age_profiles[rbt][loc] = df.iloc[:-1, :]

        else:
            self.age_profiles = age_profiles

        if travel_distance is None:
            filename = (
                self.record_path
                / folder_name
                / "Venue_TravelDist"
                / "Distance_traveled.xlsx"
            )
            self.travel_distance = {}
            for loc in self.group_type_names:
                if loc in ["global", "shelter_inter", "shelter_intra"]:
                    continue
                sheet_name = Tracker.pluralize(self, loc)
                df = pd.read_excel(filename, sheet_name=sheet_name, index_col=0)
                self.travel_distance[loc] = df
        else:
            self.travel_distance = travel_distance

        logger.info(f"Rank {mpi_rank} -- Data loaded")

    #####################################################
    # Useful functions ##################################
    #####################################################

    def Calculate_CM_Metrics(self, bin_type, contact_type, CM, CM_err, sex="unisex"):
        return Tracker.Calculate_CM_Metrics(
            self, bin_type, contact_type, CM, CM_err, sex
        )

    def Population_Metrics(self, pop_by_bin, pop_bins):
        return Tracker.Population_Metrics(self, pop_by_bin, pop_bins)

    def Expectation_Assortativeness(self, NPCDM, pop_bins):
        return Tracker.Expectation_Assortativeness(self, NPCDM, pop_bins)

    def Calc_NPCDM(self, cm, pop_by_bin, pop_width):
        return Tracker.Calc_NPCDM(self, cm, pop_by_bin, pop_width)

    def Calc_QIndex(self, cm):
        return Tracker.Calc_QIndex(self, cm)

    def Canberra_distance(self, x, y):
        return Tracker.Canberra_distance(self, x, y)

    #############################################
    # Grab CM  ##################################
    #############################################

    def CMPlots_GetCM(self, bin_type, contact_type, sex="unisex", which="NCM"):
        """
        Get cm out of dictionary.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            sex:
                Sex contact matrix
            which:
                str, which matrix type to collect "CM", "NCM", "NCM_R", "CMV", "NCM_V"

        Returns
        -------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
        """
        if bin_type != "Interaction":
            if which == "CM":
                cm = self.CM[bin_type][contact_type]["sex"][sex]["contacts"]
                cm_err = self.CM[bin_type][contact_type]["sex"][sex]["contacts_err"]
            elif which == "NCM":
                cm = self.NCM[bin_type][contact_type]["sex"][sex]["contacts"]
                cm_err = self.NCM[bin_type][contact_type]["sex"][sex]["contacts_err"]
            elif which == "NCM_R":
                cm = self.NCM_R[bin_type][contact_type]["sex"][sex]["contacts"]
                cm_err = self.NCM_R[bin_type][contact_type]["sex"][sex]["contacts_err"]

            elif which == "CMV":
                cm = self.CMV[bin_type][contact_type]["sex"][sex]["contacts"]
                cm_err = self.CMV[bin_type][contact_type]["sex"][sex]["contacts_err"]
            elif which == "NCM_V":
                cm = self.NCM_V[bin_type][contact_type]["sex"][sex]["contacts"]
                cm_err = self.NCM_V[bin_type][contact_type]["sex"][sex]["contacts_err"]

        else:
            if which == "CM":
                cm = self.CM[bin_type][contact_type]["contacts"]
                cm_err = self.CM[bin_type][contact_type]["contacts_err"]
            elif which == "NCM":
                cm = self.NCM[bin_type][contact_type]["contacts"]
                cm_err = self.NCM[bin_type][contact_type]["contacts_err"]
            elif which == "NCM_R":
                cm = self.NCM_R[bin_type][contact_type]["contacts"]
                cm_err = self.NCM_R[bin_type][contact_type]["contacts_err"]

            elif which == "CMV":
                cm = self.CMV[bin_type][contact_type]["contacts"]
                cm_err = self.CMV[bin_type][contact_type]["contacts_err"]
            elif which == "NCM_V":
                cm = self.NCM_V[bin_type][contact_type]["contacts"]
                cm_err = self.NCM_V[bin_type][contact_type]["contacts_err"]

        return np.array(cm), np.array(cm_err)

    def IMPlots_GetIM(self, contact_type):
        return Tracker.IMPlots_GetIM(self, contact_type)

    def get_characteristic_time(self, location):
        return Tracker.get_characteristic_time(self, location)

    #####################################################
    # General Plotting ##################################
    #####################################################

    def Get_SAMECMAP_Norm(self, dim, which="NCM", override=None):
        """
        If same colour map required this produces standardized colourmaps for different size matrices.

        Parameters
        ----------
            dim:
                int, the dimension (length) of square matrix
            which:
                string, the contact matrix type
            override:
                string, Log, Lin, SymLog or SymLin. Override if SAMECMAP was False. (Applies to certain plots)

        Returns
        -------
            Norm:
                matplotlib.colors.Norm object

        """
        if which in ["CM", "NCM", "NCM_R"]:
            if self.Normalization_Type == "U":
                SAMElinvmin = {"small_dim": 0, "large_dim": 0}
                SAMElogvmin = {"small_dim": 1e-1, "large_dim": 1e-2}

                SAMElinvmax = {"small_dim": 2.5e1, "large_dim": 4e0}
                SAMElogvmax = {"small_dim": 2.5e1, "large_dim": 4e0}

                SAMEsymlogvmax = {"small_dim": 3e0, "large_dim": 3e0}
                SAMEsymlinvmax = {"small_dim": 1e0, "large_dim": 0.5e0}

            elif self.Normalization_Type == "P":
                SAMElinvmin = {"small_dim": 0, "large_dim": 0}
                SAMElogvmin = {"small_dim": 1e-1, "large_dim": 1e-4}

                SAMElinvmax = {"small_dim": 2.5e1, "large_dim": 1e0}
                SAMElogvmax = {"small_dim": 2.5e1, "large_dim": 1e0}

                SAMEsymlogvmax = {"small_dim": 3e0, "large_dim": 1e0}
                SAMEsymlinvmax = {"small_dim": 1e0, "large_dim": 1e0}

        elif which in ["CMV", "NCM_V"]:
            if self.Normalization_Type == "U":
                SAMElinvmin = {"small_dim": 0, "large_dim": 0}
                SAMElogvmin = {"small_dim": 1, "large_dim": 1e-2}

                SAMElinvmax = {"small_dim": 1e2, "large_dim": 1e1}
                SAMElogvmax = {"small_dim": 1e2, "large_dim": 1e1}

                SAMEsymlogvmax = {"small_dim": 1e2, "large_dim": 1e1}
                SAMEsymlinvmax = {"small_dim": 1e2, "large_dim": 1e1}
            elif self.Normalization_Type == "P":
                SAMElinvmin = {"small_dim": 0, "large_dim": 0}
                SAMElogvmin = {"small_dim": 1e-1, "large_dim": 1e-4}

                SAMElinvmax = {"small_dim": 1e2, "large_dim": 1e1}
                SAMElogvmax = {"small_dim": 1e2, "large_dim": 1e1}

                SAMEsymlogvmax = {"small_dim": 1e2, "large_dim": 1e1}
                SAMEsymlinvmax = {"small_dim": 1e2, "large_dim": 1e1}

        if dim < 5:
            kind = "small_dim"
        else:
            kind = "large_dim"

        if override is None:
            if self.SameCMAP == "Log":
                return colors.LogNorm(vmin=SAMElogvmin[kind], vmax=SAMElogvmax[kind])
            elif self.SameCMAP == "Lin":
                return colors.Normalize(vmin=SAMElinvmin[kind], vmax=SAMElinvmax[kind])
        elif override == "SymLog":
            return colors.SymLogNorm(
                linthresh=1e-1, vmin=-SAMEsymlogvmax[kind], vmax=SAMEsymlogvmax[kind]
            )
        elif override == "SymLin":
            return colors.Normalize(
                vmin=-SAMEsymlinvmax[kind], vmax=SAMEsymlinvmax[kind]
            )
        elif override == "Log":
            return colors.LogNorm(vmin=SAMElogvmin[kind], vmax=SAMElogvmax[kind])
        elif override == "Lin":
            return colors.Normalize(vmin=SAMElinvmin[kind], vmax=SAMElinvmax[kind])
        return None

    def AnnotateCM(self, cm, cm_err, ax, thresh=1e10, annotate=True):
        """
        Function to annotate the CM with text. Including error catching for Nonetype errors.

        Parameters
        ----------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            ax:
                matplotlib axes
            thresh:
                threshold value for CM text change colour


        Returns
        -------
            ax
        """
        size = mpl.rcParams["font.size"]
        if cm.shape[0] <= 2:
            size -= 3
        if cm.shape[0] >= 3:
            size -= 4

        if annotate == "Small":
            size -= 2

        for i in range(cm.shape[0]):
            for j in range(cm.shape[1]):
                fmt = ".2f"
                if cm[i, j] == 1e-16:
                    cm[i, j] = 0
                if cm[i, j] > 1e8:
                    cm[i, j] = np.inf

                if cm_err is not None:
                    if np.isnan(cm_err[i, j]):
                        cm_err[i, j] = 0

                    if cm_err[i, j] + cm[i, j] == 0:
                        fmt = ".0f"

                    if fmt == ".0f":
                        text = r"$0 \pm 0$"
                    else:
                        text = (
                            r"$%s \pm $" % (format(cm[i, j], fmt))
                            + "\n\t"
                            + "$%s$" % (format(cm_err[i, j], fmt))
                        )

                else:
                    text = r"$%s$" % (format(cm[i, j], fmt))

                if thresh == 1e8:
                    ax.text(
                        j, i, text, ha="center", va="center", color="black", size=size
                    )
                else:
                    ax.text(
                        j,
                        i,
                        text,
                        ha="center",
                        va="center",
                        color="white" if abs(cm[i, j] - 1) > thresh else "black",
                        size=size,
                    )
        return ax

    def PlotCM(
        self,
        cm,
        cm_err,
        labels,
        ax,
        thresh=1e10,
        thumb=False,
        annotate=True,
        **plt_kwargs,
    ):
        """
        Function to imshow plot the CM.

        Parameters
        ----------
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            labels:
                list of string bins labels (or none type)
            ax:
                matplotlib axes
            thresh:
                threshold value for CM text change colour
            thumb:
                bool, make thumbnail style plots. e.g. no axis labels
            **plt_kwargs:
                plot keyword arguments

        Returns
        -------
            im:
                reference to plot object
        """

        if cm is None:
            pass
        else:
            cm = cm.T

        if cm_err is None:
            pass
        else:
            cm_err = cm_err.T

        if labels is not None:
            if "kids" in labels and "young_adults" in labels:
                labels = ["K", "Y", "A", "O"]
            elif len(labels) == 2 and "students" in labels:
                if labels[0] == "students" and labels[1] == "teachers":
                    labels = ["  S  ", "  T  "]
                elif labels[1] == "students" and labels[0] == "teachers":
                    labels = ["  S  ", "  T  "]
                    # labels = ["Stude", "Teach"]
                    cm = cm.T
                    if cm_err is None:
                        pass
                    else:
                        cm_err = cm_err.T
            elif "workers" in labels and len(labels) == 1:
                labels = ["W"]
            elif "inter" in labels:
                labels = [r" H$_1$ ", r" H$_2$ "]

        # im = ax.matshow(cm, **plt_kwargs)
        Interpolation = "None"
        im = ax.imshow(cm, **plt_kwargs, interpolation=Interpolation)
        ax.xaxis.tick_bottom()

        if annotate == "Small" and len(labels) >= 3:
            size = mpl.rcParams["xtick.labelsize"] - 4
        elif annotate == "Small":
            size = mpl.rcParams["xtick.labelsize"] - 2
        else:
            size = mpl.rcParams["xtick.labelsize"]

        if labels is not None:
            if len(labels) == 1:
                ax.set_xticks(np.arange(len(cm)))
                ax.set_xticklabels(labels, rotation=0, size=size)
                ax.set_yticks(np.arange(len(cm)))
                ax.set_yticklabels(labels, rotation=0, size=size)
            elif len(labels) < 10:
                ax.set_xticks(np.arange(len(cm)))
                ax.set_xticklabels(labels, rotation=45)
                ax.set_yticks(np.arange(len(cm)))
                ax.set_yticklabels(labels)

            elif len(labels) >= 10 and len(labels) <= 25:
                ax.set_xticks(np.arange(len(cm)))
                ax.set_xticklabels(labels, rotation=90, size=size)
                ax.set_yticks(np.arange(len(cm)))
                ax.set_yticklabels(labels, size=size)
            elif len(labels) < 25:
                ax.set_xticks(np.arange(len(cm)))
                ax.set_xticklabels(labels, rotation=90, size=size)
                ax.set_yticks(np.arange(len(cm)))
                ax.set_yticklabels(labels, size=size)
        else:
            Nticks = 5
            ticks = np.arange(0, len(cm), int((len(cm) + 1) / (Nticks - 1)))
            ax.set_xticks(ticks)
            ax.set_xticklabels(ticks)
            ax.set_yticks(ticks)
            ax.set_yticklabels(ticks)

        # Loop over data dimensions and create text annotations.
        if cm.shape[0] * cm.shape[1] < 26 and annotate:
            self.AnnotateCM(cm, cm_err, ax, thresh=thresh, annotate=annotate)
        if not thumb:
            ax.set_xlabel("age group")
            ax.set_ylabel("contact age group")
        else:
            # ax.axes.xaxis.set_visible(False)
            # ax.axes.yaxis.set_visible(False)
            ax.set_xlabel("")
            ax.set_ylabel("")
            pass
        return im

    def CMPlots_GetLabels(self, bins):
        """
        Create list of labels for the bins in the CM plots

        Parameters
        ----------
            bins:
                np.array bin edges

        Returns
        -------
            labels:
                list of strings for bin labels or none type
        """
        if len(bins) < 25:
            return [f"{low}-{high-1}" for low, high in zip(bins[:-1], bins[1:])]
        else:
            return None

    def MaxAgeBinIndex(self, bins, MaxAgeBin=60):
        """
        Get index for truncation of bins upto max age MaxAgeBin
            self.group_type_names

        Parameters
        ----------
            bins:
                Age bins
            MaxAgeBin:
                The maximum age at which to truncate the bins

        Returns
        -------
            Index
        """
        Array = [index for index in range(len(bins)) if bins[index] >= MaxAgeBin]
        if len(Array) != 0:
            return min(Array)
        else:
            return None

    def CMPlots_UsefulCM(self, bin_type, cm, cm_err=None, labels=None, MaxAgeBin=100):
        """
        Truncate the CM for the plots to drop age bins of the data with no people.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            labels:
                list of strings for bin labels or none type

        Returns
        -------
            Truncated values of;
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            labels:
                list of strings for bin labels or none type
        """
        if bin_type == "Paper":
            MaxAgeBin = np.inf
        index = self.MaxAgeBinIndex(self.age_bins[bin_type], MaxAgeBin=MaxAgeBin)
        cm = cm[:index, :index]
        if cm_err is not None:
            cm_err = cm_err[:index, :index]
        if labels is not None:
            labels = labels[:index]
        return cm, cm_err, labels

    def IMPlots_GetLabels(self, contact_type):
        """
        Create list of labels for the bins in the input IM plots. More nuisanced as subgroups not always age bins.

        Parameters
        ----------
            contact_type:
                Location of contacts


        Returns
        -------
            labels:
                list of strings for bin labels or none type
        """

        bintype = self.IM[contact_type]["type"]
        bins = np.array(self.IM[contact_type]["bins"])

        if len(bins) < 25 and bintype == "Age":
            labels = [
                f"{int(low)}-{int(high-1)}" for low, high in zip(bins[:-1], bins[1:])
            ]
        elif len(bins) < 25 and bintype == "Discrete":
            labels = bins
        else:
            labels = None
        return labels

    def IMPlots_UsefulCM(self, contact_type, cm, cm_err=None, labels=None):
        """
        Truncate the CM for the plots to drop age bins of the data with no people.

        Parameters
        ----------
            contact_type:
                Location of contacts
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            labels:
                list of strings for bin labels or none type

        Returns
        -------
            Truncated values of;
            cm:
                np.array contact matrix
            cm_err:
                np.array contact matrix errors
            labels:
                list of strings for bin labels or none type
        """
        bintype = self.IM[contact_type]["type"]
        bins = np.array(self.IM[contact_type]["bins"])

        if bintype == "Discrete":
            return cm, cm_err, labels

        index = self.MaxAgeBinIndex(np.array(bins))
        cm = cm[:index, :index]
        if cm_err is not None:
            cm_err = cm_err[:index, :index]
        if labels is not None:
            labels = labels[:index]
        return cm, cm_err, labels

    #############################################
    # Plotting ##################################
    #############################################

    def plot_contact_matrix_INOUT(
        self,
        bin_type,
        contact_type,
        sex="unisex",
        which="NCM_R",
        plot_BBC_Sheet=False,
        MaxAgeBin=100,
    ):
        """
        Function to plot input contact matrix vs output for bin_type, contact_type and sex.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            sex:
                Sex contact matrix
            which:
                str, which matrix type to collect "NCM", "NCM_R", "CM_T"

        Returns
        -------
            (ax1,ax2):
                matplotlib axes objects (Linear and Log)
        """

        IM, IM_err = self.IMPlots_GetIM(contact_type)
        labels_IM = self.IMPlots_GetLabels(contact_type)
        IM, IM_err, labels_IM = self.IMPlots_UsefulCM(
            contact_type, IM, cm_err=IM_err, labels=labels_IM
        )

        if len(np.nonzero(IM)[0]) != 0 and len(np.nonzero(IM)[1]) != 0:
            IM_Min = np.nanmin(IM[np.nonzero(IM)])
        else:
            IM_Min = 1e-1
        if np.isfinite(IM).sum() != 0:
            IM_Max = IM[np.isfinite(IM)].max()
        else:
            IM_Max = 1

        if np.isnan(IM_Min):
            IM_Min = 1e-1
        if np.isnan(IM_Max) or IM_Max == 0:
            IM_Max = 1

        IM = np.nan_to_num(IM, posinf=IM_Max, neginf=0, nan=0)

        labels = self.CMPlots_GetLabels(self.age_bins[bin_type])
        cm, cm_err = self.CMPlots_GetCM(bin_type, contact_type, sex=sex, which=which)
        cm, cm_err, labels = self.CMPlots_UsefulCM(
            bin_type, cm, cm_err, labels, MaxAgeBin
        )

        if len(np.nonzero(cm)[0]) != 0 and len(np.nonzero(cm)[1]) != 0:
            cm_Min = np.nanmin(cm[np.nonzero(cm)])
        else:
            cm_Min = 1e-1
        if np.isfinite(cm).sum() != 0:
            cm_Max = cm[np.isfinite(cm)].max()
        else:
            cm_Max = 1

        if np.isnan(cm_Min):
            cm_Min = 1e-1
        if np.isnan(cm_Max) or cm_Max == 0:
            cm_Max = 1

        cm = np.nan_to_num(cm, posinf=cm_Max, neginf=0, nan=0)

        if not self.SameCMAP:
            norm1 = colors.Normalize(vmin=0, vmax=IM_Max)
            norm2 = colors.Normalize(vmin=0, vmax=cm_Max)
        else:
            norm1 = self.Get_SAMECMAP_Norm(IM.shape[0], which=which)
            norm2 = self.Get_SAMECMAP_Norm(cm.shape[0], which=which)

        if not plot_BBC_Sheet:
            # plt.rcParams["figure.figsize"] = (15, 5)
            f, (ax1, ax2) = plt.subplots(1, 2)
            f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
            f.patch.set_facecolor("white")

            im1 = self.PlotCM(
                IM + 1e-16,
                IM_err,
                labels_IM,
                ax1,
                origin="lower",
                cmap=cmap_A,
                norm=norm1,
            )
            im2 = self.PlotCM(
                cm + 1e-16, cm_err, labels, ax2, origin="lower", cmap=cmap_A, norm=norm2
            )

            f.colorbar(im1, ax=ax1, extend="both")
            f.colorbar(im2, ax=ax2, extend="both")

            # ax1.set_title(f"IM")
            # ax2.set_title(f"{which}")
            # f.suptitle(f"{bin_type} binned contacts in {contact_type}")
            # plt.tight_layout()
            return (ax1, ax2)
        else:
            df = pd.read_excel(
                default_BBC_Pandemic_loc
                / "BBC reciprocal matrices by type and context.xls",
                sheet_name=plot_BBC_Sheet,
            )
            bbc_cm = df.iloc[:, 1:].values.T
            bbc_labels = df.iloc[:, 0].values

            bbc_Max = np.nanmax(bbc_cm)
            bbc_Min = np.nanmin(bbc_cm)

            # Put into same contact units

            CT = self.get_characteristic_time(contact_type)[0]
            cm /= CT

            cm_Max = max(bbc_Max, cm_Max)

            if contact_type in "household":
                norm2 = colors.LogNorm(vmin=1e-2, vmax=2)
            elif contact_type in "school":
                norm2 = colors.LogNorm(vmin=1e-3, vmax=1e1)
            elif contact_type in "company":
                norm2 = colors.LogNorm(vmin=1e-2, vmax=1)

            # plt.rcParams["figure.figsize"] = (15, 5)
            f, (ax1, ax2, ax3) = plt.subplots(1, 3)
            f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
            f.patch.set_facecolor("white")

            im1 = self.PlotCM(
                IM + 1e-16,
                IM_err,
                labels_IM,
                ax1,
                origin="lower",
                cmap=cmap_A,
                norm=norm1,
                annotate=True,
                thumb=True,
            )
            im2 = self.PlotCM(
                (cm + 1e-16),
                cm_err,
                labels,
                ax2,
                origin="lower",
                cmap=cmap_A,
                norm=norm2,
                annotate="Small",
                thumb=True,
            )

            im3 = self.PlotCM(
                bbc_cm,
                None,
                bbc_labels,
                ax3,
                origin="lower",
                cmap=cmap_A,
                norm=norm2,
                annotate="Small",
                thumb=True,
            )

            cm = np.nan_to_num(cm, nan=0.0)
            bbc_cm = np.nan_to_num(bbc_cm, nan=0.0)

            print(contact_type)
            pop_by_bin = np.array(self.age_profiles[bin_type][contact_type][sex])
            pop_bins = np.array(self.age_bins[bin_type])
            pop_width = np.diff(pop_bins)

            pop_density = pop_by_bin / (np.nansum(pop_by_bin) * pop_width)

            pop_by_bin_true = np.array(self.age_profiles["syoa"][contact_type][sex])
            pop_bins_true = np.array(self.age_bins["syoa"])
            mean, var = self.Population_Metrics(pop_by_bin_true, pop_bins_true)

            Q = self.Calc_QIndex(cm)
            NPCDM = self.Calc_NPCDM(cm, pop_density, pop_width)
            I_sq = self.Expectation_Assortativeness(NPCDM, pop_bins)
            I_sq_s = I_sq / var**2
            print("JUNE", {"Q": f"{Q}", "I_sq": f"{I_sq}", "I_sq_s": f"{I_sq_s}"})

            Q = self.Calc_QIndex(bbc_cm)
            NPCDM = self.Calc_NPCDM(bbc_cm, pop_density, pop_width)
            I_sq = self.Expectation_Assortativeness(NPCDM, pop_bins)
            I_sq_s = I_sq / var**2
            print("BBC", {"Q": f"{Q}", "I_sq": f"{I_sq}", "I_sq_s": f"{I_sq_s}"})
            print({"Camberra": self.Canberra_distance(cm, bbc_cm)[0]})
            print("")

            f.colorbar(im1, ax=ax1, extend="both", format="%g")
            f.colorbar(im2, ax=ax2, extend="both", format="%g")
            f.colorbar(im3, ax=ax3, extend="both", format="%g")

            # ax1.set_title(f"IM")
            # ax2.set_title(f"{which}")
            # ax3.set_title(f"BBC ({plot_BBC_Sheet})")
            # f.suptitle(f"{bin_type} binned contacts in {contact_type}")
            plt.tight_layout()
            return (ax1, ax2, ax3)

    def plot_interaction_matrix(self, contact_type):
        """
        Function to plot interaction matrix for contact_type

        Parameters
        ----------
            contact_type:
                Location of contacts

        Returns
        -------
            ax1:
                matplotlib axes object
        """
        which = "NCM"
        IM, IM_err = self.IMPlots_GetIM(contact_type)
        labels_IM = self.IMPlots_GetLabels(contact_type)
        IM, IM_err, labels_IM = self.IMPlots_UsefulCM(
            contact_type, IM, cm_err=IM_err, labels=labels_IM
        )

        if len(np.nonzero(IM)[0]) != 0 and len(np.nonzero(IM)[1]) != 0:
            IM_Min = np.nanmin(IM[np.nonzero(IM)])
        else:
            IM_Min = 1e-1
        if np.isfinite(IM).sum() != 0:
            IM_Max = IM[np.isfinite(IM)].max()
        else:
            IM_Max = 1

        if np.isnan(IM_Min):
            IM_Min = 1e-1
        if np.isnan(IM_Max) or IM_Max == 0:
            IM_Max = 1

        IM = np.nan_to_num(IM, posinf=IM_Max, neginf=0, nan=0)

        labels_CM = labels_IM
        if contact_type in self.CM["Interaction"].keys():
            cm, cm_err = self.CMPlots_GetCM("Interaction", contact_type, which=which)
            cm, cm_err, _ = self.IMPlots_UsefulCM(
                contact_type, cm, cm_err=cm_err, labels=labels_CM
            )
        else:  # The venue wasn't tracked
            cm = np.zeros_like(IM)
            cm_err = np.zeros_like(cm)

        if len(np.nonzero(cm)[0]) != 0 and len(np.nonzero(cm)[1]) != 0:
            cm_Min = np.nanmin(cm[np.nonzero(cm)])
        else:
            cm_Min = 1e-1
        if np.isfinite(cm).sum() != 0:
            cm_Max = cm[np.isfinite(cm)].max()
        else:
            cm_Max = 1

        if np.isnan(cm_Min):
            cm_Min = 1e-1
        if np.isnan(cm_Max) or cm_Max == 0:
            cm_Max = 1
        cm = np.nan_to_num(cm, posinf=cm_Max, neginf=0, nan=0)

        vMax = max(cm_Max, IM_Max)
        vMin = 1e-2

        # plt.rcParams["figure.figsize"] = (15, 5)
        f, (ax1, ax2, ax3) = plt.subplots(1, 3)
        f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
        f.patch.set_facecolor("white")

        if not self.SameCMAP:
            norm1 = colors.Normalize(vmin=vMin, vmax=vMax)
            norm2 = colors.Normalize(vmin=vMin, vmax=vMax)
        else:

            norm1 = self.Get_SAMECMAP_Norm(IM.shape[0], which=which)
            norm2 = self.Get_SAMECMAP_Norm(cm.shape[0], which=which)

        im1 = self.PlotCM(
            IM + 1e-16,
            IM_err,
            labels_IM,
            ax1,
            origin="lower",
            cmap=cmap_A,
            norm=norm1,
            annotate=True,
            thumb=True,
        )
        im2 = self.PlotCM(
            cm + 1e-16,
            cm_err,
            labels_CM,
            ax2,
            origin="lower",
            cmap=cmap_A,
            norm=norm2,
            annotate="Small",
            thumb=True,
        )

        ratio = cm / IM
        ratio = np.nan_to_num(ratio)
        ratio_values = ratio[np.nonzero(ratio) and ratio < 1e3]
        if len(ratio_values) != 0:
            ratio_max = np.nanmax(ratio_values)
            ratio_min = np.nanmin(ratio_values)
            diff_max = np.max([abs(ratio_max - 1), abs(ratio_min - 1)])
            if diff_max < 0.5:
                diff_max = 0.5
        else:
            diff_max = 0.5
        if IM_err is None:
            IM_err = np.zeros_like(IM)
        ratio_errors = ratio * np.sqrt((cm_err / cm) ** 2 + (IM_err / IM) ** 2)

        norm = colors.Normalize(vmin=1 - diff_max, vmax=1 + diff_max)
        norm = colors.Normalize(vmin=1 - 1, vmax=1 + 1)
        im3 = self.PlotCM(
            ratio,
            ratio_errors,
            labels_CM,
            ax3,
            thresh=diff_max / 3,
            origin="lower",
            cmap=cmap_B,
            norm=norm,
            annotate="Small",
            thumb=True,
        )
        f.colorbar(im1, ax=ax1, extend="both")
        f.colorbar(im2, ax=ax2, extend="both")
        f.colorbar(im3, ax=ax3, extend="both")
        ax1.set_title("IM")
        ax2.set_title("NCM")
        ax3.set_title("NCM / IM")

        # f.suptitle(f"Survey interaction binned contacts in {contact_type}")
        plt.tight_layout()
        return ax1

    def plot_interaction_matrix_thumb(self, log, contact_type):
        """
        Function to plot interaction matrix for contact_type

        Parameters
        ----------
            log:

            contact_type:
                Location of contacts

        Returns
        -------
            ax1:
                matplotlib axes object
        """
        which = "NCM"
        IM, IM_err = self.IMPlots_GetIM(contact_type)
        labels_IM = self.IMPlots_GetLabels(contact_type)
        IM, IM_err, labels_IM = self.IMPlots_UsefulCM(
            contact_type, IM, cm_err=IM_err, labels=labels_IM
        )

        if len(np.nonzero(IM)[0]) != 0 and len(np.nonzero(IM)[1]) != 0:
            IM_Min = np.nanmin(IM[np.nonzero(IM)])
        else:
            IM_Min = 1e-1
        if np.isfinite(IM).sum() != 0:
            IM_Max = IM[np.isfinite(IM)].max()
        else:
            IM_Max = 1

        if np.isnan(IM_Min):
            IM_Min = 1e-1
        if np.isnan(IM_Max) or IM_Max == 0:
            IM_Max = 1

        IM = np.nan_to_num(IM, posinf=IM_Max, neginf=0, nan=0)

        labels_CM = labels_IM

        f, ax1 = plt.subplots(1, 1)
        f.set_size_inches(set_size(subplots=(1, 1), fraction=0.5))
        f.patch.set_facecolor("white")

        if not self.SameCMAP:
            normlin = colors.Normalize(vmin=0, vmax=IM_Max)
            normlog = colors.LogNorm(vmin=IM_Max, vmax=IM_Max)
        else:
            normlin = self.Get_SAMECMAP_Norm(IM.shape[0], which=which)
            normlog = self.Get_SAMECMAP_Norm(IM.shape[0], which=which)

        if not log:
            im1 = self.PlotCM(
                IM + 1e-16,
                IM_err,
                labels_IM,
                ax1,
                origin="lower",
                cmap=cmap_A,
                norm=normlin,
                thumb=True,
            )
        else:
            im1 = self.PlotCM(
                IM + 1e-16,
                IM_err,
                labels_IM,
                ax1,
                origin="lower",
                cmap=cmap_A,
                norm=normlog,
                thumb=True,
            )

        # f.suptitle(f"Survey interaction binned contacts in {contact_type}")
        # plt.tight_layout()
        return f, ax1, im1

    def plot_contact_matrix(
        self, bin_type, contact_type, sex="unisex", which="NCM", MaxAgeBin=100
    ):
        """
        Function to plot contact matrix for bin_type, contact_type and sex.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            sex:
                Sex contact matrix
            which:
                str, which matrix type to collect "NCM", "NCM_R", "CM_T"

        Returns
        -------
            (ax1,ax2):
                matplotlib axes objects (Linear and Log)
        """
        cm, cm_err = self.CMPlots_GetCM(bin_type, contact_type, sex=sex, which=which)
        if bin_type == "Interaction":
            labels = self.IMPlots_GetLabels(contact_type)
        else:
            labels = self.CMPlots_GetLabels(self.age_bins[bin_type])
            cm, cm_err, labels = self.CMPlots_UsefulCM(
                bin_type, cm, cm_err, labels, MaxAgeBin
            )

        if len(np.nonzero(cm)[0]) != 0 and len(np.nonzero(cm)[1]) != 0:
            cm_Min = np.nanmin(cm[np.nonzero(cm)])
        else:
            cm_Min = 1e-1
        if np.isfinite(cm).sum() != 0:
            cm_Max = cm[np.isfinite(cm)].max()
        else:
            cm_Max = 1

        if np.isnan(cm_Min):
            cm_Min = 1e-1
        if np.isnan(cm_Max) or cm_Max == 0:
            cm_Max = 1

        cm = np.nan_to_num(cm, posinf=cm_Max, neginf=0, nan=0)

        if not self.SameCMAP or which == "CM_T":
            normlin = colors.Normalize(vmin=0, vmax=cm_Max)
            normlog = colors.LogNorm(vmin=cm_Min, vmax=cm_Max)
        else:
            normlin = self.Get_SAMECMAP_Norm(cm.shape[0], which=which, override="Lin")
            normlog = self.Get_SAMECMAP_Norm(cm.shape[0], which=which, override="Log")

        # plt.rcParams["figure.figsize"] = (15, 5)
        f, (ax1, ax2) = plt.subplots(1, 2)
        f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
        f.patch.set_facecolor("white")

        im1 = self.PlotCM(
            cm + 1e-16, cm_err, labels, ax1, origin="lower", cmap=cmap_A, norm=normlin
        )
        im2 = self.PlotCM(
            cm + 1e-16, cm_err, labels, ax2, origin="lower", cmap=cmap_A, norm=normlog
        )

        f.colorbar(im1, ax=ax1, extend="both")
        f.colorbar(im2, ax=ax2, extend="both")

        # ax1.set_title("Linear Scale")
        # ax2.set_title("Log Scale")
        # f.suptitle(f"{bin_type} binned contacts in {contact_type} for {sex}")
        # plt.tight_layout()
        return (ax1, ax2)

    def plot_contact_matrix_thumb(
        self, log, bin_type, contact_type, sex="unisex", which="NCM", MaxAgeBin=100
    ):
        """
        Function to plot contact matrix for bin_type, contact_type and sex.

        Parameters
        ----------
            log:
                bool, shold be log plot?
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            sex:
                Sex contact matrix
            which:
                str, which matrix type to collect "NCM", "NCM_R", "CM_T"

        Returns
        -------
            (ax1,ax2):
                matplotlib axes objects (Linear and Log)
        """

        cm, cm_err = self.CMPlots_GetCM(bin_type, contact_type, sex=sex, which=which)
        if bin_type == "Interaction":
            labels = self.IMPlots_GetLabels(contact_type)
        else:
            labels = self.CMPlots_GetLabels(self.age_bins[bin_type])
            cm, cm_err, labels = self.CMPlots_UsefulCM(
                bin_type, cm, cm_err, labels, MaxAgeBin
            )

        if len(np.nonzero(cm)[0]) != 0 and len(np.nonzero(cm)[1]) != 0:
            cm_Min = np.nanmin(cm[np.nonzero(cm)])
        else:
            cm_Min = 1e-1
        if np.isfinite(cm).sum() != 0:
            cm_Max = cm[np.isfinite(cm)].max()
        else:
            cm_Max = 1

        if np.isnan(cm_Min):
            cm_Min = 1e-1
        if np.isnan(cm_Max) or cm_Max == 0:
            cm_Max = 1

        cm = np.nan_to_num(cm, posinf=cm_Max, neginf=0, nan=0)

        f, ax1 = plt.subplots(1, 1)
        f.set_size_inches(set_size(subplots=(1, 1), fraction=0.5))
        f.patch.set_facecolor("white")

        if not self.SameCMAP or which == "CM_T":
            normlin = colors.Normalize(vmin=0, vmax=cm_Max)
            normlog = colors.LogNorm(vmin=cm_Min, vmax=cm_Max)
        else:
            normlin = self.Get_SAMECMAP_Norm(cm.shape[0], which=which, override="Lin")
            normlog = self.Get_SAMECMAP_Norm(cm.shape[0], which=which, override="Log")

        if not log:
            im1 = self.PlotCM(
                cm + 1e-16,
                cm_err,
                labels,
                ax1,
                origin="lower",
                cmap=cmap_A,
                norm=normlin,
                thumb=True,
            )
        else:
            im1 = self.PlotCM(
                cm + 1e-16,
                cm_err,
                labels,
                ax1,
                origin="lower",
                cmap=cmap_A,
                norm=normlog,
                thumb=True,
            )

        # cax1 = f.add_axes([ax1.get_position().x1+0.01,ax1.get_position().y0,0.02,ax1.get_position().height])
        # plt.tight_layout()
        return f, ax1, im1

    def plot_comparesexes_contact_matrix(
        self, bin_type, contact_type, which="NCM", MaxAgeBin=100
    ):
        """
        Function to plot difference in contact matrices between men and women for bin_type, contact_type.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_type:
                Location of contacts
            which:
                str, which matrix type to collect "NCM", "NCM_R", "CM_T"

        Returns
        -------
            (ax1,ax2):
                matplotlib axes objects (Linear and Log)
        """
        # plt.rcParams["figure.figsize"] = (15, 5)
        f, (ax1, ax2) = plt.subplots(1, 2)
        f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
        f.patch.set_facecolor("white")

        labels = self.CMPlots_GetLabels(self.age_bins[bin_type])

        cm_M, _ = self.CMPlots_GetCM(bin_type, contact_type, "male", which)
        cm_F, _ = self.CMPlots_GetCM(bin_type, contact_type, "female", which)
        cm = cm_M - cm_F

        cm, cm_err, labels = self.CMPlots_UsefulCM(
            bin_type, cm, None, labels, MaxAgeBin
        )

        cm_Min = -1e-1
        cm_Max = 1e-1

        if not self.SameCMAP:
            normlin = colors.Normalize(vmin=cm_Max, vmax=cm_Max)
            normlog = colors.SymLogNorm(linthresh=1, vmin=cm_Min, vmax=cm_Max)
        else:
            normlin = self.Get_SAMECMAP_Norm(
                cm.shape[0], which=which, override="SymLin"
            )
            normlog = self.Get_SAMECMAP_Norm(
                cm.shape[0], which=which, override="SymLog"
            )

        cm = np.nan_to_num(cm, posinf=cm_Max, neginf=0, nan=0)

        im1 = self.PlotCM(
            cm + 1e-16, cm_err, labels, ax1, origin="lower", cmap=cmap_A, norm=normlin
        )
        im2 = self.PlotCM(
            cm + 1e-16, cm_err, labels, ax2, origin="lower", cmap=cmap_B, norm=normlog
        )

        f.colorbar(im1, ax=ax1, extend="both", label="$M - F$")
        f.colorbar(im2, ax=ax2, extend="both", label="$M - F$")

        # ax1.set_title("Linear Scale")
        # ax2.set_title("Log Scale")
        # f.suptitle(f"Male - female {bin_type} binned contacts in {contact_type}")
        # plt.tight_layout()
        return (ax1, ax2)

    def plot_stacked_contacts(self, bin_type, contact_types=None):
        """
        Plot average contacts per day in each location.

        Parameters
        ----------
            binType:
                Name of bin type syoa, AC etc
            contact_types:
                List of the contact_type locations (or none to grab all of them)

        Returns
        -------
            ax:
                matplotlib axes object

        """
        # plt.rcParams["figure.figsize"] = (10, 5)
        f, ax = plt.subplots()
        f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
        f.patch.set_facecolor("white")

        average_contacts = self.average_contacts[bin_type]
        bins = self.age_bins[bin_type]
        lower = np.zeros(len(bins) - 1)

        mids = 0.5 * (bins[:-1] + bins[1:])
        widths = bins[1:] - bins[:-1]
        plotted = 0

        if contact_types is None:
            contact_types = self.contact_types

        for ii, contact_type in enumerate(contact_types):
            if contact_type in ["shelter_intra", "shelter_inter", "informal_work"]:
                continue
            if contact_type not in average_contacts.columns:
                print(f"No contact_type {contact_type}")
                continue
            if contact_type == "global":
                ax.plot(
                    mids,
                    average_contacts[contact_type],
                    linestyle="-",
                    color="black",
                    label="Total",
                )
                continue

            if plotted > len(plt.rcParams["axes.prop_cycle"].by_key()["color"]) - 1:
                hatch = "//"
            else:
                hatch = None

            heights = average_contacts[contact_type]
            ax.bar(
                mids,
                heights,
                widths,
                bottom=lower,
                hatch=hatch,
                label=contact_type,
                edgecolor="black",
                linewidth=0,
            )
            plotted += 1

            lower = lower + heights

        ax.set_xlim(bins[0], bins[-1])

        ax.legend(bbox_to_anchor=(0.5, 1.02), loc="lower center", ncol=3)
        ax.set_xlabel("Age")
        ax.set_ylabel("average contacts per day")
        f.subplots_adjust(top=0.70)
        # plt.tight_layout()
        return ax

    def plot_population_at_locs_variations(self, locations):
        """
        Plot variations of median values of attendence across all venues of each type

        Parameters
        ----------
            locations:
                list of locations to plot for
        Returns
        -------
            ax:
                matplotlib axes object

        """
        # Get variations between days
        Weekday_Names = self.day_types["weekday"]
        Weekend_Names = self.day_types["weekend"]

        #

        df = pd.DataFrame()
        df = self.location_counters_day["loc"][locations]["unisex"]
        df["t"] = pd.to_datetime(self.location_counters_day["Timestamp"].values)
        df["day"] = [day.day_name() for day in df["t"]]

        means = np.zeros(len(DaysOfWeek_Names))
        stds = np.zeros(len(DaysOfWeek_Names))
        medians = np.zeros(len(DaysOfWeek_Names))
        for day_i in range(len(DaysOfWeek_Names)):
            day = DaysOfWeek_Names[day_i]

            data = df[df["day"] == day][
                df.columns[~df.columns.isin(["t", "day"])]
            ].values.flatten()
            data = data[data > 0]

            if len(data) == 0:
                continue

            means[day_i] = np.nanmean(data)
            stds[day_i] = np.nanstd(data, ddof=1)
            medians[day_i] = np.nanmedian(data)

        # plt.rcParams["figure.figsize"] = (15, 5)
        f, (ax1, ax2) = plt.subplots(1, 2)
        f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
        f.patch.set_facecolor("white")
        ax1.bar(
            np.arange(len(DaysOfWeek_Names)), means, alpha=0.4, color="b", label="mean"
        )
        ax1.bar(
            np.arange(len(DaysOfWeek_Names)),
            medians,
            alpha=0.4,
            color="g",
            label="median",
        )
        # ax1.errorbar(
        #     np.arange(len(DaysOfWeek_Names)),
        #     means,
        #     [stds, stds],
        #     color="black",
        #     label="std errorbar",
        # )
        labels = []
        for i in range(len(DaysOfWeek_Names)):
            labels += [DaysOfWeek_Names[i][:2]]

        ax1.set_xticks(np.arange(len(DaysOfWeek_Names)))
        ax1.set_xticklabels(labels)
        # ax1.set_ylabel("Unique Attendees per day")
        # ax1.set_xlabel("Day of week")
        ax1.set_ylim([0, np.nanmax(means) * 1.4])
        ax1.legend()

        # Get variations between days and time of day
        df = pd.DataFrame()
        df = self.location_counters["loc"][locations]["unisex"]
        df["t"] = pd.to_datetime(
            self.location_counters["Timestamp"].values, format="%d/%m/%y %H:%M:%S"
        )
        df["dt"] = np.array(self.location_counters["delta_t"], dtype=float)
        df["day"] = [day.day_name() for day in df["t"]]

        available_days = np.unique(df["day"].values)
        dts = {}
        times = {}
        timesmid = {}

        for day_i in range(len(DaysOfWeek_Names)):
            day = DaysOfWeek_Names[day_i]
            data = df[df["day"] == day]

            dts[day] = []
            times[day] = [df["t"].iloc[0]]
            timesmid[day] = []

            for i in range(len(data["dt"].values)):

                dts[day].append(df["dt"].values[i])
                timesmid[day].append(
                    times[day][-1] + datetime.timedelta(hours=dts[day][-1]) / 2
                )
                times[day].append(
                    times[day][-1] + datetime.timedelta(hours=dts[day][-1])
                )
                if sum(dts[day]) >= 24:
                    break

            dts[day] = np.array(dts[day])
            times[day] = np.array(times[day])
            timesmid[day] = np.array(timesmid[day])

        medians_days = {}
        means_days = {}
        stds_days = {}

        ymax = -1e3
        ymin = 1e3
        for day_i in range(len(DaysOfWeek_Names)):
            day = DaysOfWeek_Names[day_i]
            if day not in available_days:
                continue
            data = df[df["day"] == day][df.columns[~df.columns.isin(["day"])]]
            total_persons = (
                data[data.columns[~data.columns.isin(["dt", "t"])]].sum(axis=0).values
            )
            total_persons = total_persons[total_persons > 0]

            medians_days[day] = []
            means_days[day] = []
            stds_days[day] = []
            for time_i in range(len(dts[day])):
                data_dt = data[data.columns[~data.columns.isin(["dt", "t"])]].values[
                    time_i
                ]
                data_dt = data_dt[data_dt > 1]
                if len(data_dt) == 0:
                    medians_days[day].append(0)
                    means_days[day].append(0)
                    stds_days[day].append(0)
                else:
                    medians_days[day].append(np.nanmedian(data_dt))
                    means_days[day].append(np.nanmean(data_dt))
                    stds_days[day].append(np.nanstd(data_dt, ddof=1))

            if ymax < np.nanmax(means_days[day]):
                ymax = np.nanmax(means_days[day])
            if ymin > np.nanmin(means_days[day]):
                ymin = np.nanmin(means_days[day])

        xlim = [times[Weekday_Names[0]][0], times[Weekday_Names[0]][-1]]
        for day_i in range(len(DaysOfWeek_Names)):
            day = DaysOfWeek_Names[day_i]
            if day not in available_days:
                continue
            timesmid[day] = np.insert(
                timesmid[day], 0, timesmid[day][-1] - datetime.timedelta(days=1), axis=0
            )
            timesmid[day] = np.insert(
                timesmid[day],
                len(timesmid[day]),
                timesmid[day][1] + datetime.timedelta(days=1),
                axis=0,
            )
            medians_days[day] = np.insert(
                medians_days[day], 0, medians_days[day][-1], axis=0
            )
            medians_days[day] = np.insert(
                medians_days[day], len(medians_days[day]), medians_days[day][1], axis=0
            )
            means_days[day] = np.insert(means_days[day], 0, means_days[day][-1], axis=0)
            means_days[day] = np.insert(
                means_days[day], len(means_days[day]), means_days[day][1], axis=0
            )

        for day_i in range(len(DaysOfWeek_Names)):
            day = DaysOfWeek_Names[day_i]
            if day not in available_days:
                continue
            if day in Weekend_Names:
                linestyle = "--"
            else:
                linestyle = "-"
            # ax2.plot(timesmid[day], medians_days[day], label=DaysOfWeek_Names[day_i], linestyle=linestyle)
            ax2.plot(
                timesmid[day],
                means_days[day],
                label=DaysOfWeek_Names[day_i],
                linestyle=linestyle,
            )

        alphas = [0.1, 0.2]
        ylim = [-abs(ymin * 1.1), abs(ymax * 1.1)]
        for time_i in range(len(dts[Weekday_Names[0]])):
            ax2.fill_between(
                [times[Weekday_Names[0]][time_i], times[Weekday_Names[0]][time_i + 1]],
                ylim[0],
                ylim[1],
                color="g",
                alpha=alphas[time_i % 2],
            )
        ax2.axhline(0, color="grey", linestyle="--")

        # ax2.set_ylabel("Mean Unique Attendees per timeslot")
        # ax2.set_xlabel("Time of day [hour]")
        # Define the date format
        ax2.xaxis.set_major_locator(mdates.HourLocator(byhour=None, interval=4))
        ax2.xaxis.set_major_formatter(DateFormatter("%H"))
        ax2.set_xlim(xlim)
        ax2.set_ylim([0, ylim[1]])
        ax2.legend()
        # plt.tight_layout()
        return (ax1, ax2)

    def plot_AgeProfileRatios(
        self, contact_type="global", bin_type="syoa", sex="unisex"
    ):
        """
        Plot demographic counts for each location and ratio of counts in age bins.

        Parameters
        ----------
            contact_types:
                List of the contact_type locations (or none to grab all of them)
            binType:
                Name of bin type syoa, AC etc
            sex:
                Which sex of population ["male", "female", "unisex"]


        Returns
        -------
            ax:
                matplotlib axes object

        """

        if bin_type != "Interaction":
            pop_tots = self.location_cum_pop[bin_type][contact_type][sex]
            global_age_profile = self.age_profiles[bin_type]["global"][sex]
            Bins = np.array(self.age_bins[bin_type])

            Labels = self.CMPlots_GetLabels(Bins)
            Bincenters = 0.5 * (Bins[1:] + Bins[:-1])
            Bindiffs = np.abs(Bins[1:] - Bins[:-1])
        else:
            Bins = np.array(self.IM[contact_type]["bins"])
            AgeDiscrete = self.IM[contact_type]["type"]
            if AgeDiscrete == "Age":
                pop_tots = self.location_cum_pop[bin_type][contact_type][sex]

                contacts_loc = self.contacts_df[self.contacts_df[contact_type] != 0]
                AgesCount = contacts_loc.groupby([Bins], dropna=False).size()
                AgesCount = AgesCount.reindex(len(Bins), fill_value=0)
                global_age_profile = self.age_profiles[bin_type]["global"][sex]

                Labels = self.CMPlots_GetLabels(Bins)
                Bincenters = 0.5 * (Bins[1:] + Bins[:-1])
                Bindiffs = np.abs(Bins[1:] - Bins[:-1])

            if AgeDiscrete == "Discrete":
                Labels = Bins
                pass

        Height_G = global_age_profile / Bindiffs
        Height_P = pop_tots / Bindiffs

        ws_G = np.zeros((Bins.shape[0] - 1, Bins.shape[0] - 1))
        ws_P = np.zeros((Bins.shape[0] - 1, Bins.shape[0] - 1))
        # Loop over elements
        for i in range(ws_G.shape[0]):
            for j in range(ws_G.shape[1]):
                # Population rescaling
                ws_G[i, j] = Height_G[i] / Height_G[j]
                ws_P[i, j] = Height_P[i] / Height_P[j]

        # plt.rcParams["figure.figsize"] = (15, 5)
        f, (ax1, ax2) = plt.subplots(1, 2)
        f.set_size_inches(set_size(subplots=(1, 1), fraction=1))
        f.patch.set_facecolor("white")

        vmax_G = np.nan
        vmax_P = np.nan
        if np.isfinite(ws_G).sum() != 0:
            vmax_G = ws_G[np.isfinite(ws_G)].max() * 2
        if np.isfinite(ws_P).sum() != 0:
            vmax_P = ws_P[np.isfinite(ws_P)].max() * 2

        vmax = np.nanmax([vmax_G, vmax_P])
        if np.isnan(vmax) or vmax is None:
            vmax = 1e-1

        vmin = 10 ** (-1 * np.log10(vmax))
        # ax1_ins = ax1.inset_axes([0.8, 1.0, 0.2, 0.2])

        norm = colors.LogNorm(vmin=vmin, vmax=vmax)
        im_P = self.PlotCM(
            ws_P, None, Labels, ax1, origin="lower", cmap=cmap_B, norm=norm
        )
        # im_G = self.PlotCM(
        #    ws_G, None, Labels, ax1_ins, origin="lower", cmap=cmap_B, norm=norm
        # )

        f.colorbar(im_P, ax=ax1, label=r"$\dfrac{Age_{y}}{Age_{x}}$", extend="both")
        plt.bar(
            x=Bincenters,
            height=Height_G / sum(Height_G),
            width=Bindiffs,
            tick_label=Labels,
            alpha=0.5,
            color="blue",
            label="Ground truth",
        )
        plt.bar(
            x=Bincenters,
            height=Height_P / sum(Height_P),
            width=Bindiffs,
            tick_label=Labels,
            alpha=0.5,
            color="red",
            label="tracker",
        )
        ax2.set_xlabel("Age")
        ax2.set_ylabel("Normed Population size")
        ax2.set_xlim([Bins[0], Bins[-1]])
        ax2.set_yscale("log")
        plt.xticks(rotation=90)
        # f.suptitle(f"Age profile of {contact_type}")
        plt.legend()
        plt.tight_layout()
        return (ax1, ax2)

    def plot_DistanceTraveled(self, location, day):
        """
        Plot histogram of commuting distances from home

        Parameters
        ----------
            location:
                The venue to look at
            day:
                The day of the week

        Returns
        -------
            ax:
                matplotlib axes object

        """
        plural_locations = Tracker.pluralize(self, location)
        Nlocals = self.NVenues[plural_locations]
        dat = self.travel_distance[location]
        Total = dat.iloc[:, 1].sum()

        # Truncate plot on relvent bins.
        CumSum = np.cumsum(dat.iloc[:, 1].values)
        indexlast = len(CumSum) - np.sum(CumSum == CumSum[-1])
        maxkm = dat.iloc[indexlast, 0] + 3.5 * (dat.iloc[1, 0] - dat.iloc[0, 0])

        # plt.rcParams["figure.figsize"] = (10, 5)
        f, ax = plt.subplots(1, 1)
        f.set_size_inches(set_size(subplots=(1, 2), fraction=1))
        f.patch.set_facecolor("white")
        ax.bar(
            x=dat["bins"],
            height=(100 * dat.iloc[:, 1]) / Total,
            width=(dat["bins"].iloc[1] - dat["bins"].iloc[0]),
            color="b",
            alpha=0.4,
        )
        # ax.set_title(f"{Nlocals} available {location}")
        ax.set_ylabel(r"Frequency [%]")
        ax.set_xlabel(r"Travel distance from shelter [km]")
        ax.set_xlim([0, maxkm])
        return ax

    ###################################################
    # Master plotter ##################################
    ###################################################

    def make_plots(
        self,
        plot_BBC=False,
        plot_thumbprints=False,
        SameCMAP=False,
        plot_INPUTOUTPUT=True,
        plot_AvContactsLocation=True,
        plot_dTLocationPopulation=True,
        plot_InteractionMatrices=True,
        plot_ContactMatrices=True,
        plot_CompareSexMatrices=True,
        plot_AgeBinning=True,
        plot_Distances=True,
        MaxAgeBin=100,
    ):
        """
        Make plots.

        Parameters
        ----------
            plot_BBC:
                bool, if we want to compare to BBC Pandemic data.
            plot_thumbprints:
                bool, To plot thumbnail style plots for plot_ContactMatrices and plot_CompareSexMatrices
            SameCMAP:
                bool, To plot same colour map accross all similar dimension contact matrices
            plot_INPUTOUTPUT:
                bool,
            plot_AvContactsLocation:
                bool, To plot average contacts per location plots
            plot_dTLocationPopulation:
                bool, To plot average people per location at timestamp
            plot_InteractionMatrices:
                bool, To plot interaction matrices
            plot_ContactMatrices:
                bool, To plot contact matrices
            plot_CompareSexMatrices:
                bool, To plot comparison of sexes matrices
            plot_AgeBinning:
                bool, To plot w weight matrix to compare demographics
            plot_Distances:
                bool, To plot the distance traveled from shelter to locations
        Returns
        -------
            None
        """
        CbarMultiplier = 3
        aspect = 40

        logger.info(f"Rank {mpi_rank} -- Begin plotting")
        if self.group_type_names == []:
            return 1

        self.SameCMAP = SameCMAP

        relevant_bin_types = list(self.CM.keys())
        relevant_bin_types_short = ["syoa", "AC"]
        relevant_contact_types = list(self.CM["syoa"].keys())
        IM_contact_types = list(self.CM["Interaction"].keys())

        if self.Normalization_Type == "U":
            NormFolder = "VenueNorm"
        elif self.Normalization_Type == "P":
            NormFolder = "PopNorm"

        CMTypes = ["NCM", "NCM_R", "NCM_V"]

        if plot_INPUTOUTPUT:
            plot_dir_1 = (
                self.record_path / "Graphs" / "Contact_Matrices_INOUT" / NormFolder
            )
            plot_dir_1.mkdir(exist_ok=True, parents=True)
            if "Paper" in relevant_bin_types:
                rbt = "Paper"
            else:
                rbt = "syoa"
            for rct in self.IM.keys():
                if rct not in relevant_contact_types:
                    continue

                which = "NCM_R"
                plot_BBC_Sheet = False

                if (
                    plot_BBC
                    and rct in ["household", "school", "company"]
                    and rbt == "Paper"
                ):
                    if rct == "household":
                        plot_BBC_Sheet = "all_home"
                    if rct == "school":
                        plot_BBC_Sheet = "all_school"
                    if rct == "company":
                        plot_BBC_Sheet = "all_work"
                    which = "NCM_R"

                self.plot_contact_matrix_INOUT(
                    bin_type=rbt,
                    contact_type=rct,
                    sex="unisex",
                    which=which,
                    plot_BBC_Sheet=plot_BBC_Sheet,
                    MaxAgeBin=MaxAgeBin,
                )
                plt.savefig(plot_dir_1 / f"{rct}.pdf", dpi=dpi, bbox_inches="tight")
                plt.close()
        logger.info(f"Rank {mpi_rank} -- Input vs output done")

        if plot_AvContactsLocation:
            plot_dir = self.record_path / "Graphs" / f"Average_Contacts"
            plot_dir.mkdir(exist_ok=True, parents=True)
            for rbt in relevant_bin_types_short:
                stacked_contacts_plot = self.plot_stacked_contacts(
                    bin_type=rbt, contact_types=relevant_contact_types
                )
                stacked_contacts_plot.plot()
                plt.savefig(
                    plot_dir / f"{rbt}_contacts.pdf", dpi=dpi, bbox_inches="tight"
                )
                plt.close()
        logger.info(f"Rank {mpi_rank} -- Av contacts done")

        if plot_dTLocationPopulation:
            plot_dir = self.record_path / "Graphs" / "Location_Pops"
            plot_dir.mkdir(exist_ok=True, parents=True)
            for locations in self.location_counters["loc"].keys():
                self.plot_population_at_locs_variations(locations)
                plt.savefig(
                    plot_dir / f"{locations}_Variations.pdf",
                    dpi=dpi,
                    bbox_inches="tight",
                )
                plt.close()

                # self.plot_population_at_locs(locations)
                # plt.savefig(plot_dir / f"{locations}.pdf", dpi=dpi, bbox_inches="tight")
                # plt.close()
        logger.info(f"Rank {mpi_rank} -- Pop at locations done")

        if plot_InteractionMatrices:
            plot_dir = self.record_path / "Graphs" / "IM" / NormFolder
            plot_dir.mkdir(exist_ok=True, parents=True)
            for rct in self.IM.keys():
                self.plot_interaction_matrix(contact_type=rct)
                plt.savefig(plot_dir / f"{rct}.pdf", dpi=dpi, bbox_inches="tight")
                plt.close()

                if plot_thumbprints:
                    fig, ax1, im1 = self.plot_interaction_matrix_thumb(
                        log=False, contact_type=rct
                    )
                    plt.savefig(
                        plot_dir / f"{rct}_thumbnail.pdf", dpi=100, bbox_inches="tight"
                    )
                    if rct == list(self.IM.keys())[0] and SameCMAP:
                        cbar = fig.colorbar(
                            im1,
                            ax=ax1,
                            extend="both",
                            orientation="horizontal",
                            aspect=aspect,
                            format="%g",
                        )
                        # cbar.ticklabel_format(style='plain')
                        ax1.remove()
                        fig.set_size_inches(
                            fig.get_size_inches()[0] * CbarMultiplier,
                            fig.get_size_inches()[1],
                        )
                        plt.savefig(
                            plot_dir / f"colourbar.pdf", dpi=100, bbox_inches="tight"
                        )
                    plt.close()

                    fig, ax1, im1 = self.plot_interaction_matrix_thumb(
                        log=True, contact_type=rct
                    )
                    plt.savefig(
                        plot_dir / f"{rct}_thumbnail_log.pdf",
                        dpi=100,
                        bbox_inches="tight",
                    )
                    if rct == list(self.IM.keys())[0] and SameCMAP:
                        fig.colorbar(
                            im1,
                            ax=ax1,
                            extend="both",
                            orientation="horizontal",
                            aspect=aspect,
                            format="%g",
                        )
                        ax1.remove()
                        fig.set_size_inches(
                            fig.get_size_inches()[0] * CbarMultiplier,
                            fig.get_size_inches()[1],
                        )
                        plt.savefig(
                            plot_dir / f"colourbar_log.pdf",
                            dpi=100,
                            bbox_inches="tight",
                        )
                    plt.close()
        logger.info(f"Rank {mpi_rank} -- Interaction matrix plots done")

        if plot_ContactMatrices:
            for CMType in CMTypes:
                plot_dir_1 = (
                    self.record_path
                    / "Graphs"
                    / "Contact_Matrices"
                    / NormFolder
                    / CMType
                )
                plot_dir_1.mkdir(exist_ok=True, parents=True)

                for rbt in relevant_bin_types:

                    plot_dir_2 = plot_dir_1 / f"{rbt}"
                    plot_dir_2.mkdir(exist_ok=True, parents=True)

                    if rbt != "Interaction":
                        for sex in self.contact_sexes:
                            plot_dir_3 = plot_dir_2 / f"{sex}"
                            plot_dir_3.mkdir(exist_ok=True, parents=True)

                            for rct in relevant_contact_types:
                                self.plot_contact_matrix(
                                    bin_type=rbt,
                                    contact_type=rct,
                                    sex=sex,
                                    which=CMType,
                                    MaxAgeBin=MaxAgeBin,
                                )
                                plt.savefig(
                                    plot_dir_3 / f"{rct}.pdf",
                                    dpi=dpi,
                                    bbox_inches="tight",
                                )
                                plt.close()

                                if plot_thumbprints:
                                    fig, ax1, im1 = self.plot_contact_matrix_thumb(
                                        log=False,
                                        bin_type=rbt,
                                        contact_type=rct,
                                        sex=sex,
                                        which=CMType,
                                        MaxAgeBin=MaxAgeBin,
                                    )
                                    plt.savefig(
                                        plot_dir_3 / f"{rct}_thumbnail.pdf",
                                        dpi=100,
                                        bbox_inches="tight",
                                    )
                                    if rct == relevant_contact_types[0] and SameCMAP:
                                        fig.colorbar(
                                            im1,
                                            ax=ax1,
                                            extend="both",
                                            orientation="horizontal",
                                            aspect=aspect,
                                            format="%g",
                                        )
                                        ax1.remove()
                                        fig.set_size_inches(
                                            fig.get_size_inches()[0] * CbarMultiplier,
                                            fig.get_size_inches()[1],
                                        )
                                        plt.savefig(
                                            plot_dir_3 / f"colourbar.pdf",
                                            dpi=100,
                                            bbox_inches="tight",
                                        )
                                    plt.close()

                                    fig, ax1, im1 = self.plot_contact_matrix_thumb(
                                        log=True,
                                        bin_type=rbt,
                                        contact_type=rct,
                                        sex=sex,
                                        which=CMType,
                                        MaxAgeBin=MaxAgeBin,
                                    )
                                    plt.savefig(
                                        plot_dir_3 / f"{rct}_thumbnail_log.pdf",
                                        dpi=100,
                                        bbox_inches="tight",
                                    )
                                    if rct == relevant_contact_types[0] and SameCMAP:
                                        fig.colorbar(
                                            im1,
                                            ax=ax1,
                                            extend="both",
                                            orientation="horizontal",
                                            aspect=aspect,
                                            format="%g",
                                        )
                                        ax1.remove()
                                        fig.set_size_inches(
                                            fig.get_size_inches()[0] * CbarMultiplier,
                                            fig.get_size_inches()[1],
                                        )
                                        plt.savefig(
                                            plot_dir_3 / f"colourbar_log.pdf",
                                            dpi=100,
                                            bbox_inches="tight",
                                        )
                                    plt.close()
                    else:
                        for rct in IM_contact_types:
                            sex = "unisex"
                            self.plot_contact_matrix(
                                bin_type=rbt,
                                contact_type=rct,
                                sex=sex,
                                which=CMType,
                                MaxAgeBin=MaxAgeBin,
                            )
                            plt.savefig(
                                plot_dir_2 / f"{rct}.pdf", dpi=150, bbox_inches="tight"
                            )
                            plt.close()

                            if plot_thumbprints:
                                fig, ax1, im1 = self.plot_contact_matrix_thumb(
                                    log=False,
                                    bin_type=rbt,
                                    contact_type=rct,
                                    sex=sex,
                                    which=CMType,
                                    MaxAgeBin=MaxAgeBin,
                                )
                                plt.savefig(
                                    plot_dir_2 / f"{rct}_thumbnail.pdf",
                                    dpi=100,
                                    bbox_inches="tight",
                                )
                                if rct == IM_contact_types[0]:
                                    fig.colorbar(
                                        im1,
                                        ax=ax1,
                                        extend="both",
                                        aspect=aspect,
                                        orientation="horizontal",
                                        format="%g",
                                    )
                                    ax1.remove()
                                    fig.set_size_inches(
                                        fig.get_size_inches()[0] * CbarMultiplier,
                                        fig.get_size_inches()[1],
                                    )
                                    plt.savefig(
                                        plot_dir_2 / f"colourbar.pdf",
                                        dpi=100,
                                        bbox_inches="tight",
                                    )
                                plt.close()

                                fig, ax1, im1 = self.plot_contact_matrix_thumb(
                                    log=True,
                                    bin_type=rbt,
                                    contact_type=rct,
                                    sex=sex,
                                    which=CMType,
                                    MaxAgeBin=MaxAgeBin,
                                )
                                plt.savefig(
                                    plot_dir_2 / f"{rct}_thumbnail_log.pdf",
                                    dpi=100,
                                    bbox_inches="tight",
                                )
                                if rct == IM_contact_types[0]:
                                    fig.colorbar(
                                        im1,
                                        ax=ax1,
                                        extend="both",
                                        aspect=aspect,
                                        orientation="horizontal",
                                        format="%g",
                                    )
                                    ax1.remove()
                                    fig.set_size_inches(
                                        fig.get_size_inches()[0] * CbarMultiplier,
                                        fig.get_size_inches()[1],
                                    )
                                    plt.savefig(
                                        plot_dir_2 / f"colourbar_log.pdf",
                                        dpi=100,
                                        bbox_inches="tight",
                                    )
                                plt.close()
        logger.info(f"Rank {mpi_rank} -- CM plots done")

        if plot_CompareSexMatrices:
            for CMType in CMTypes:
                plot_dir_1 = (
                    self.record_path
                    / "Graphs"
                    / "Contact_Matrices"
                    / NormFolder
                    / CMType
                )
                plot_dir_1.mkdir(exist_ok=True, parents=True)

                for rbt in relevant_bin_types:
                    if rbt == "Interaction":
                        continue

                    plot_dir_2 = plot_dir_1 / f"{rbt}"
                    plot_dir_2.mkdir(exist_ok=True, parents=True)

                    for rct in relevant_contact_types:
                        if (
                            "male" in self.contact_sexes
                            and "female" in self.contact_sexes
                        ):
                            plot_dir_3 = plot_dir_2 / "CompareSexes"
                            plot_dir_3.mkdir(exist_ok=True, parents=True)

                            self.plot_comparesexes_contact_matrix(
                                bin_type=rbt, contact_type=rct, which=CMType
                            )
                            plt.savefig(
                                plot_dir_3 / f"{rct}.pdf", dpi=dpi, bbox_inches="tight"
                            )
                            plt.close()
        logger.info(f"Rank {mpi_rank} -- CM between sexes done")

        if plot_AgeBinning:
            plot_dir = self.record_path / "Graphs" / "Age_Binning"
            plot_dir.mkdir(exist_ok=True, parents=True)
            for rbt in ["syoa", "Paper"]:
                if rbt not in self.age_bins.keys():
                    continue
                for rct in relevant_contact_types:
                    self.plot_AgeProfileRatios(
                        contact_type=rct, bin_type=rbt, sex="unisex"
                    )
                    plt.savefig(
                        plot_dir / f"{rbt}_{rct}.pdf", dpi=dpi, bbox_inches="tight"
                    )
                    plt.close()
        logger.info(f"Rank {mpi_rank} -- Age bin matrix done")

        if plot_Distances:
            plot_dir = self.record_path / "Graphs" / "Distance_Traveled"
            plot_dir.mkdir(exist_ok=True, parents=True)
            for locations in self.location_counters["loc"].keys():
                for day in self.travel_distance.keys():
                    self.plot_DistanceTraveled(locations, day)
                    plt.savefig(
                        plot_dir / f"{locations}.pdf", dpi=dpi, bbox_inches="tight"
                    )
                    plt.close()
                    break
        logger.info(f"Rank {mpi_rank} -- Distance plots done")
        return 1


import numpy as np
from cycler import cycler
import matplotlib as mpl
import matplotlib.font_manager
import matplotlib.pyplot as plt

from june.mpi_setup import mpi_comm, mpi_size, mpi_rank

try:
    plt.style.use(["science", "no-latex", "bright"])
    if mpi_rank == 0:
        print("Using 'science' matplotlib style")
except Exception:
    plt.style.use("default")
    if mpi_rank == 0:
        print("Using default matplotlib style")

dpi = 150


# Some figure initialization
def fig_initialize(setsize=False):

    # Set up tex rendering
    plt.rc("text", usetex=True)
    plt.rc(
        "text.latex",
        preamble=r"\usepackage{amsmath}\usepackage{amsthm}\usepackage{amssymb}\usepackage{amsfonts}",
    )
    # mpl.rcParams["font.family"] = "serif"
    # mpl.rcParams["font.serif"] = "STIX"
    # mpl.rcParams["mathtext.fontset"] = "stix"
    plt.rcParams["axes.facecolor"] = "white"
    plt.rcParams["figure.facecolor"] = "white"
    plt.rcParams["savefig.facecolor"] = "white"
    if setsize:
        mpl.rcParams["font.size"] = 10
        mpl.rcParams["lines.linewidth"] = 1
        mpl.rcParams["axes.labelsize"] = 9
        mpl.rcParams["axes.titlesize"] = 8
        mpl.rcParams["xtick.labelsize"] = 8
        mpl.rcParams["ytick.labelsize"] = 8
        mpl.rcParams["legend.labelspacing"] = 0.5
        plt.rc("legend", **{"fontsize": 8})
        plt.rc("legend", **{"frameon": False})
    # plt.tight_layout()

    # Define a custom cycler
    # custom_cycler = (cycler(color=['steelblue','maroon','midnightblue','r','cadetblue','orange']) + \
    #                 cycler(linestyle=['-','-.','--',':','--','-']))

    # plt.rc('axes',prop_cycle=custom_cycler)
    plt.rc("axes")
    return 1


def set_size(width="paper", fraction=1, subplots=(1, 1)):
    """Set figure dimensions to avoid scaling in LaTeX.

    Credit : https://jwalton.info/Embed-Publication-Matplotlib-Latex/

    Parameters
    ----------
    width: float or string
            Document width in points, or string of predined document type
    fraction: float, optional
            Fraction of the width which you wish the figure to occupy
    subplots: array-like, optional
            The number of rows and columns of subplots.
    Returns
    -------
    fig_dim: tuple
            Dimensions of figure in inches
    """
    if width == "paper":
        width_pt = 392.0
    else:
        width_pt = width

    # Width of figure (in pts)
    fig_width_pt = width_pt * fraction
    # Convert from pt to inches
    inches_per_pt = 1 / 72.27

    # Golden ratio to set aesthetic figure height
    golden_ratio = (5**0.5 - 1) / 2

    # Figure width in inches
    fig_width_in = fig_width_pt * inches_per_pt
    # Figure height in inches
    fig_height_in = fig_width_in * golden_ratio * (subplots[0] / subplots[1])

    return (fig_width_in, fig_height_in)


def legend(fig, axes, x0=1, y0=0.5, direction="v", padpoints=3, **kwargs):

    otrans = axes[0].figure.transFigure
    t = axes[0].legend(
        bbox_to_anchor=(x0, y0), loc="center", bbox_transform=otrans, **kwargs
    )

    plt.tight_layout(pad=0)

    axes[0].figure.canvas.draw()
    plt.tight_layout(pad=0)
    ppar = [0, -padpoints / 72.0] if direction == "v" else [-padpoints / 72.0, 0]
    trans2 = (
        mpl.transforms.ScaledTranslation(ppar[0], ppar[1], fig.dpi_scale_trans)
        + axes[0].figure.transFigure.inverted()
    )
    tbox = t.get_window_extent().transformed(trans2)

    if direction == "v":
        for ax in axes:
            bbox = ax.get_position()
            ax.set_position([bbox.x0, bbox.y0, bbox.width, tbox.y0 - bbox.y0])
    else:
        for ax in axes:
            bbox = ax.get_position()
            ax.set_position([bbox.x0, bbox.y0, tbox.x0 - bbox.x0, bbox.height])


from .tracker import Tracker

# from .interactive_group import InteractiveGroup


import math
import numpy as np

earth_radius = 6371  # km


def haversine_distance(origin, destination):
    """
    Taken from https://gist.github.com/rochacbruno/2883505
    # Author: Wayne Dyck
    """
    lat1, lon1 = origin
    lat2, lon2 = destination

    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = math.sin(dlat / 2) * math.sin(dlat / 2) + math.cos(
        math.radians(lat1)
    ) * math.cos(math.radians(lat2)) * math.sin(dlon / 2) * math.sin(dlon / 2)
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    d = earth_radius * c
    return d


def add_distance_to_lat_lon(latitude, longitude, x, y):
    """
    Given a latitude and a longitude (in degrees), and two distances (x, y) in km, adds those distances
    to lat and lon
    """
    lat2 = latitude + 180 * y / (earth_radius * np.pi)
    lon2 = longitude + 180 * x / (earth_radius * np.pi * np.cos(latitude))
    return lat2, lon2


"""
A few numbaised useful functions for random sampling.
"""
from numba import jit
from random import random
import numpy as np


@jit(nopython=True)
def random_choice_numba(arr, prob):
    """
    Fast implementation of np.random.choice
    """
    return arr[np.searchsorted(np.cumsum(prob), random(), side="right")]


import numpy as np
import pandas as pd
from itertools import chain


def parse_age_probabilities(age_dict: dict, fill_value=0):
    """
    Parses the age probability dictionaries into an array.
    """
    if age_dict is None:
        return [0], [0]
    bins = []
    probabilities = []
    for age_range in age_dict:
        age_range_split = age_range.split("-")
        if len(age_range_split) == 1:
            raise NotImplementedError("Please give age ranges as intervals")
        else:
            bins.append(int(age_range_split[0]))
            bins.append(int(age_range_split[1]))
        probabilities.append(age_dict[age_range])
    sorting_idx = np.argsort(bins[::2])
    bins = list(
        chain.from_iterable([bins[2 * idx], bins[2 * idx + 1]] for idx in sorting_idx)
    )
    probabilities = np.array(probabilities)[sorting_idx]
    probabilities_binned = []
    for prob in probabilities:
        probabilities_binned.append(fill_value)
        probabilities_binned.append(prob)
    probabilities_binned.append(fill_value)
    probabilities_per_age = []
    for age in range(100):
        idx = np.searchsorted(bins, age + 1)  # we do +1 to include the lower boundary
        probabilities_per_age.append(probabilities_binned[idx])
    return probabilities_per_age


def parse_opens(dict: dict, fill_value=0):
    """
    Parses the opening time dictionary into an array
    """
    daytype = list(dict.keys())
    bins = {}
    for day_i in daytype:
        bins[day_i] = []
        hour_range_split = dict[day_i].split("-")
        if len(hour_range_split) == 1:
            raise NotImplementedError("Please give open times as intervals")
        else:
            bins[day_i].append(int(hour_range_split[0]))
            bins[day_i].append(int(hour_range_split[1]))
    return bins


def read_comorbidity_csv(filename: str):
    comorbidity_df = pd.read_csv(filename, index_col=0)
    column_names = [f"0-{comorbidity_df.columns[0]}"]
    for i in range(len(comorbidity_df.columns) - 1):
        column_names.append(
            f"{comorbidity_df.columns[i]}-{comorbidity_df.columns[i+1]}"
        )
    comorbidity_df.columns = column_names
    for column in comorbidity_df.columns:
        no_comorbidity = comorbidity_df[column].loc["no_condition"]
        should_have_comorbidity = 1 - no_comorbidity
        has_comorbidity = np.sum(comorbidity_df[column]) - no_comorbidity
        comorbidity_df[column].iloc[:-1] *= should_have_comorbidity / has_comorbidity

    return comorbidity_df.T


def convert_comorbidities_prevalence_to_dict(prevalence_female, prevalence_male):
    prevalence_reference_population = {}
    for comorbidity in prevalence_female.columns:
        prevalence_reference_population[comorbidity] = {
            "f": prevalence_female[comorbidity].to_dict(),
            "m": prevalence_male[comorbidity].to_dict(),
        }
    return prevalence_reference_population


def parse_prevalence_comorbidities_in_reference_population(
    comorbidity_prevalence_reference_population,
):
    parsed_comorbidity_prevalence = {}
    for comorbidity, values in comorbidity_prevalence_reference_population.items():
        parsed_comorbidity_prevalence[comorbidity] = {
            "f": parse_age_probabilities(values["f"]),
            "m": parse_age_probabilities(values["m"]),
        }
    return parsed_comorbidity_prevalence


import cProfile
from june.mpi_setup import mpi_rank, mpi_comm


# a decorator for profiling
def profile(filename=None, comm=mpi_comm):
    def prof_decorator(f):
        def wrap_f(*args, **kwargs):
            pr = cProfile.Profile()
            pr.enable()
            result = f(*args, **kwargs)
            pr.disable()

            if filename is None:
                pr.print_stats()
            else:
                filename_r = filename + ".{}".format(mpi_rank)
                pr.dump_stats(filename_r)

            return result

        return wrap_f

    return prof_decorator


from typing import Union
import importlib
import datetime


def read_date(date: Union[str, datetime.datetime]) -> datetime.datetime:
    """
    Read date in two possible formats, either string or datetime.date, both
    are translated into datetime.datetime to be used by the simulator

    Parameters
    ----------
    date:
        date to translate into datetime.datetime

    Returns
    -------
        date in datetime format
    """
    if type(date) is str:
        return datetime.datetime.strptime(date, "%Y-%m-%d")
    elif isinstance(date, datetime.date):
        return datetime.datetime.combine(date, datetime.datetime.min.time())
    else:
        raise TypeError("date must be a string or a datetime.date object")


def str_to_class(classname, base_policy_modules=("june.policy",)):
    for module_name in base_policy_modules:
        try:
            module = importlib.import_module(module_name)
            return getattr(module, classname)
        except AttributeError:
            continue
    raise ValueError(f"Cannot find policy {classname} in paths!")


from .parse_probabilities import (
    parse_age_probabilities,
    parse_prevalence_comorbidities_in_reference_population,
    read_comorbidity_csv,
    convert_comorbidities_prevalence_to_dict,
)
from .numba_random import random_choice_numba
from .readers import read_date, str_to_class


#!/usr/bin/env python
import june  # for data

from test_june import run_all_tests

run_all_tests()


from glob import glob
from os import remove

import pytest


@pytest.fixture(autouse=True, scope="session")
def remove_log_files():
    yield
    for file in glob("*.log*"):
        remove(file)


from glob import glob
from pathlib import Path
import pytest

test_path = Path(__file__).parent


def run_all_tests():
    pytest.main(["-x", f"{test_path}"])


from june.groups.leisure.gym import Gyms
from june.groups.leisure.leisure import generate_leisure_for_world
from june.groups.leisure.grocery import Groceries, Grocery
from june.groups.leisure.pub import Pub, Pubs
from june.groups.leisure.cinema import Cinema, Cinemas
from june.groups.university import Universities, University
from june.groups.household import Household, Households
from june.groups.cemetery import Cemeteries
from june.groups.care_home import CareHome, CareHomes
from june.groups.school import School, Schools
from june.groups.company import Companies, Company
from june.groups.hospital import Hospital, Hospitals
import random

import numba as nb
import numpy as np
import pytest
import h5py
from pathlib import Path

from june.interaction import Interaction
from june import paths
from june.geography import Geography, Areas, SuperAreas, Regions, Cities, City
from june.geography.station import CityStation
from june.groups.travel import (
    ModeOfTransport,
    CityTransport,
    CityTransports,
    InterCityTransport,
    InterCityTransports,
)
from june.groups.travel import Travel
from june.demography import Person, Population
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection import (
    Infection,
    Symptoms,
    TrajectoryMakers,
    InfectionSelector,
    InfectionSelectors,
)
from june.epidemiology.infection import transmission as trans
from june.simulator import Simulator
from june.world import generate_world_from_geography, World
import logging

constant_config = (
    paths.configs_path
    / "defaults/epidemiology/infection/transmission/TransmissionConstant.yaml"
)
interaction_config = paths.configs_path / "tests/interaction.yaml"


# disable logging for testing
logging.disable(logging.CRITICAL)


@pytest.fixture(autouse=True, name="test_results", scope="session")
def make_test_output():
    save_path = Path("./test_results")
    save_path.mkdir(exist_ok=True)
    return save_path


@pytest.fixture(autouse=True)
def set_random_seed(seed=999):
    """
    Sets global seeds for testing in numpy, random, and numbaized numpy.
    """

    @nb.njit(cache=True)
    def set_seed_numba(seed):
        random.seed(seed)
        np.random.seed(seed)

    set_seed_numba(seed)
    np.random.seed(seed)
    random.seed(seed)
    return


@pytest.fixture()
def data(pytestconfig):
    return pytestconfig.getoption("data")


@pytest.fixture()
def configs(pytestconfig):
    return pytestconfig.getoption("configs")


@pytest.fixture(name="trajectories", scope="session")
def create_trajectories():
    return TrajectoryMakers.from_file()


@pytest.fixture(name="symptoms", scope="session")
def create_symptoms(symptoms_trajectories):
    return symptoms_trajectories


@pytest.fixture(name="health_index_generator", scope="session")
def make_hi():
    return lambda person, infection_id: [0.4, 0.5, 0.7, 0.74, 0.85, 0.90, 0.95]


@pytest.fixture(name="symptoms_trajectories", scope="session")
def create_symptoms_trajectories():
    return Symptoms(health_index=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])


@pytest.fixture(name="transmission", scope="session")
def create_transmission():
    return trans.TransmissionConstant(probability=0.3)


@pytest.fixture(name="infection", scope="session")
def create_infection(transmission, symptoms):
    return Infection(transmission, symptoms)


@pytest.fixture(name="infection_constant", scope="session")
def create_infection_constant(transmission, symptoms_constant):
    return Infection(transmission, symptoms_constant)


@pytest.fixture(name="interaction", scope="session")
def create_interaction(health_index_generator):
    interaction = Interaction.from_file(config_filename=interaction_config)
    interaction.selector = InfectionSelector(
        transmission_config_path=constant_config,
        health_index_generator=health_index_generator,
    )
    return interaction


@pytest.fixture(name="geography", scope="session")
def make_geography():
    geography = Geography.from_file(
        {"super_area": ["E02002512", "E02001697", "E02001731"]}
    )
    return geography


@pytest.fixture(name="world", scope="session")
def create_world(geography):
    geography.hospitals = Hospitals.for_geography(geography)
    geography.companies = Companies.for_geography(geography)
    geography.schools = Schools.for_geography(geography)
    geography.care_homes = CareHomes.for_geography(geography)
    geography.cemeteries = Cemeteries()
    geography.companies = Companies.for_geography(geography)
    world = generate_world_from_geography(geography, include_households=True)
    return world


@pytest.fixture(name="selector", scope="session")
def make_selector(health_index_generator):
    return InfectionSelector(
        transmission_config_path=constant_config,
        health_index_generator=health_index_generator,
    )


@pytest.fixture(name="selectors", scope="session")
def make_selectors(selector):
    return InfectionSelectors([selector])


@pytest.fixture(name="epidemiology", scope="session")
def make_epidemiology(selectors):
    return Epidemiology(infection_selectors=selectors)


# policy dummy world
@pytest.fixture(name="dummy_world")  # , scope="session")
def make_dummy_world():
    g = Geography.from_file(filter_key={"super_area": ["E02002559"]})
    super_area = g.super_areas.members[0]
    area = g.areas.members[0]
    area.households = []
    company = Company(super_area=super_area, n_workers_max=100, sector="S")
    school = School(
        coordinates=super_area.coordinates,
        n_pupils_max=100,
        age_min=4,
        age_max=10,
        sector="primary",
        area=area,
    )
    household = Household(type="family")
    household.area = super_area.areas[0]
    household2 = Household(type="family")
    worker2 = Person.from_attributes(age=40)
    worker2.area = super_area.areas[0]
    household2.area = super_area.areas[0]
    household2.add(worker2)
    area.households.append(household)
    area.households.append(household2)
    hospital = Hospital(
        n_beds=40, n_icu_beds=5, area=area, coordinates=super_area.coordinates
    )
    super_area.closest_hospitals = [hospital]
    worker = Person.from_attributes(age=40)
    worker.area = super_area.areas[0]
    household.add(worker, subgroup_type=household.SubgroupType.adults)
    worker.sector = "Q"
    company.add(worker)

    pupil = Person.from_attributes(age=6)
    pupil.area = super_area.areas[0]
    household.add(pupil, subgroup_type=household.SubgroupType.kids)
    school.add(pupil)

    student = Person.from_attributes(age=21)
    student.area = super_area.areas[0]
    household.add(student, subgroup_type=household.SubgroupType.adults)
    university = University(
        coordinates=super_area.coordinates, n_students_max=100, area=area
    )
    university.add(student)

    commuter = Person.from_attributes(sex="m", age=30)
    commuter.area = super_area.areas[0]
    commuter.work_super_area = super_area
    commuter.mode_of_transport = ModeOfTransport(description="surf", is_public=True)
    household.add(commuter)

    world = World()
    world.schools = Schools([school])
    world.hospitals = Hospitals([hospital])
    world.households = Households([household, household2])
    world.universities = Universities([])
    world.companies = Companies([company])
    world.universities = Universities([university])
    world.care_homes = CareHomes([CareHome(area=area)])
    world.people = Population([worker, pupil, student, commuter, worker2])
    world.areas = Areas([super_area.areas[0]])
    world.areas[0].people = world.people
    world.super_areas = SuperAreas([super_area])
    world.regions = Regions([super_area.region])
    cinema = Cinema(area=area)
    cinema.coordinates = super_area.coordinates
    cinema.area = area
    world.cinemas = Cinemas([cinema])
    pub = Pub(area=area)
    pub.coordinates = super_area.coordinates
    pub.area = area
    world.pubs = Pubs([pub])
    grocery = Grocery(area=area)
    grocery.coordinates = super_area.coordinates
    grocery.area = area
    world.groceries = Groceries([grocery])
    city = City(name="test", coordinates=[1, 2])
    world.cities = Cities([city])
    city.internal_commuter_ids.add(commuter.id)
    city.city_stations = [CityStation(super_area=world.super_areas[0], city=city)]
    world.stations = city.city_stations
    station = city.city_stations[0]
    super_area.city = city
    # world.super_areas[0].closest_inter_city_station_for_city[city.name] = station
    city_transports = CityTransports([CityTransport(station=station)])
    world.city_transports = city_transports
    inter_city_transports = InterCityTransports([InterCityTransport(station=station)])
    world.inter_city_transports = inter_city_transports
    station.city_transports = city_transports
    station.inter_city_transports = inter_city_transports
    # leisure
    leisure = generate_leisure_for_world(
        world=world,
        list_of_leisure_groups=[
            "pubs",
            "cinemas",
            "groceries",
            "household_visits",
            "care_home_visits",
        ],
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    leisure.distribute_social_venues_to_areas(
        areas=world.areas, super_areas=world.super_areas
    )
    world.cemeteries = Cemeteries()
    return world


@pytest.fixture(name="policy_simulator")
def make_policy_simulator(dummy_world, interaction, epidemiology):
    config_name = paths.configs_path / "tests/test_simulator_simple.yaml"
    travel = Travel()
    sim = Simulator.from_file(
        dummy_world,
        interaction,
        epidemiology=epidemiology,
        config_filename=config_name,
        record=None,
        travel=travel,
        policies=None,
        leisure=None,
    )
    return sim


@pytest.fixture(name="setup_policy_world")
def setup_world(dummy_world, policy_simulator):
    world = dummy_world
    worker = world.people[0]
    pupil = world.people[1]
    student = world.people[2]
    policy_simulator.clear_world()
    return world, pupil, student, worker, policy_simulator


@pytest.fixture(name="full_world_geography", scope="session")
def make_full_world_geography():
    geography = Geography.from_file({"super_area": ["E02001731", "E02002566"]})
    return geography


@pytest.fixture(name="full_world", scope="session")
def create_full_world(full_world_geography, test_results):
    # clean file
    with h5py.File(test_results / "test.hdf5", "w"):
        pass
    geography = full_world_geography
    geography.hospitals = Hospitals.for_geography(geography)
    geography.schools = Schools.for_geography(geography)
    geography.companies = Companies.for_geography(geography)
    geography.care_homes = CareHomes.for_geography(geography)
    geography.universities = Universities.for_geography(geography)
    world = generate_world_from_geography(geography=geography, include_households=True)
    world.pubs = Pubs.for_geography(geography)
    world.cinemas = Cinemas.for_geography(geography)
    world.groceries = Groceries.for_geography(geography)
    world.gyms = Gyms.for_geography(geography)
    leisure = generate_leisure_for_world(
        [
            "pubs",
            "cinemas",
            "groceries",
            "gyms",
            "household_visits",
            "care_home_visits",
        ],
        world,
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    leisure.distribute_social_venues_to_areas(
        areas=world.areas, super_areas=world.super_areas
    )
    travel = Travel()
    travel.initialise_commute(world)
    return world


@pytest.fixture(name="domains_world", scope="session")
def create_domains_world():
    geography = Geography.from_file(
        {"super_area": ["E02001731", "E02001732", "E02002566", "E02002567"]}
    )
    geography.hospitals = Hospitals.for_geography(geography)
    geography.schools = Schools.for_geography(geography)
    geography.companies = Companies.for_geography(geography)
    geography.care_homes = CareHomes.for_geography(geography)
    geography.universities = Universities.for_geography(geography)
    world = generate_world_from_geography(geography=geography, include_households=True)
    world.pubs = Pubs.for_geography(geography)
    world.cinemas = Cinemas.for_geography(geography)
    world.groceries = Groceries.for_geography(geography)
    world.gyms = Gyms.for_geography(geography)
    leisure = generate_leisure_for_world(
        [
            "pubs",
            "cinemas",
            "groceries",
            "gyms",
            "household_visits",
            "care_home_visits",
        ],
        world,
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    leisure.distribute_social_venues_to_areas(
        areas=world.areas, super_areas=world.super_areas
    )
    travel = Travel()
    travel.initialise_commute(world)
    return world


import pytest

from june.domains import Domain
from june.groups import Subgroup, ExternalSubgroup, ExternalGroup
from june.geography.station import InterCityStation

available_groups = [
    "companies",
    "schools",
    "pubs",
    "groceries",
    "cinemas",
    "universities",
]


@pytest.fixture(name="domains", scope="module")
def decomp(domains_world, test_results):
    world = domains_world
    world.to_hdf5(test_results / "test_domains_world.hdf5")
    domains = []
    super_areas_to_domain_dict = {
        domains_world.super_areas[0].id: 0,
        domains_world.super_areas[1].id: 0,
        domains_world.super_areas[2].id: 1,
        domains_world.super_areas[3].id: 1,
    }
    domain1 = Domain.from_hdf5(
        domain_id=0,
        super_areas_to_domain_dict=super_areas_to_domain_dict,
        hdf5_file_path=test_results / "test_domains_world.hdf5",
    )
    domain2 = Domain.from_hdf5(
        domain_id=1,
        super_areas_to_domain_dict=super_areas_to_domain_dict,
        hdf5_file_path=test_results / "test_domains_world.hdf5",
    )
    domains = [domain1, domain2]
    # append everyone everywhere for checks
    for person in world.people:
        for subgroup in person.subgroups.iter():
            if subgroup is not None:
                subgroup.append(person)
    return domains


class TestDomainDecomposition:
    def test__super_area_decomposition(self, domains_world, domains):
        super_areas = [super_area.name for super_area in domains_world.super_areas]
        super_areas_domains = [
            super_area.name for domain in domains for super_area in domain.super_areas
        ]
        assert len(super_areas) == len(super_areas_domains)
        assert set(super_areas) == set(super_areas_domains)

    def test__people_decomposition(self, domains_world, domains):
        all_people = [person.id for person in domains_world.people]
        all_domains_people = [
            person.id for domain in domains for person in domain.people
        ]
        assert len(all_people) == len(all_domains_people)
        assert set(all_people) == set(all_domains_people)

    def test__all_groups_decomposition(self, domains_world, domains):
        for supergroup_name in available_groups:
            world_supergroup = getattr(domains_world, supergroup_name)
            for domain in domains:
                domain_super_area_names = [
                    super_area.name for super_area in domain.super_areas
                ]
                domain_supergroup = getattr(domain, supergroup_name)
                for group in world_supergroup:
                    if group.super_area.name not in domain_super_area_names:
                        assert group.id not in domain_supergroup.member_ids
                    else:
                        assert group.id in domain_supergroup.member_ids

    def test__information_about_away_groups(self, domains_world, domains):
        for domain_id, domain in enumerate(domains):
            domain_super_areas = [sa.name for sa in domain.super_areas]
            for person_domain in domain.people:
                person_world = domains_world.people.get_from_id(person_domain.id)
                # work super area
                if person_world.work_super_area is not None:
                    assert (
                        person_domain.work_super_area.coordinates[0]
                        == person_world.work_super_area.coordinates[0]
                    )
                    assert (
                        person_domain.work_super_area.coordinates[1]
                        == person_world.work_super_area.coordinates[1]
                    )
                    assert (
                        person_domain.work_super_area.id
                        == person_world.work_super_area.id
                    )
                for subgroup, subgroup_domain in zip(
                    person_world.subgroups.iter(), person_domain.subgroups.iter()
                ):
                    if subgroup is None:
                        assert subgroup_domain is None
                    else:
                        if subgroup.group.super_area.name not in domain_super_areas:
                            assert isinstance(subgroup_domain, ExternalSubgroup)
                            assert isinstance(subgroup_domain.group, ExternalGroup)
                            if domain_id == 0:
                                assert subgroup_domain.domain_id == 1
                            else:
                                assert subgroup_domain.domain_id == 0
                            assert subgroup_domain.group.spec == subgroup.group.spec
                            assert subgroup_domain.group_id == subgroup.group.id
                            assert (
                                subgroup_domain.subgroup_type == subgroup.subgroup_type
                            )
                        else:
                            assert isinstance(subgroup_domain, Subgroup)
                            assert subgroup_domain.group.id == subgroup.group.id
                            assert subgroup_domain.group.spec == subgroup.group.spec
                            assert (
                                subgroup_domain.subgroup_type == subgroup.subgroup_type
                            )
                            assert (
                                subgroup_domain.group.super_area.name
                                == subgroup.group.super_area.name
                            )

    def test__hospitals(self, domains_world, domains):
        assert len(domains_world.hospitals) > 0
        for hospital in domains_world.hospitals:
            for domain in domains:
                assert len(domain.hospitals) == len(domains_world.hospitals)
                domain_super_area_ids = [super_area.id for super_area in domain]
                if hospital.super_area.id in domain_super_area_ids:
                    for hospital_domain in domain.hospitals:
                        if hospital.id == hospital_domain.id:
                            assert hospital_domain.external is False
                            assert (
                                hospital.super_area.id == hospital_domain.super_area.id
                            )
                            assert hospital.trust_code == hospital_domain.trust_code
                            assert hospital.region_name == hospital_domain.region_name
                        else:
                            for hospital_domain in domain.hospitals:
                                if hospital_domain.id == domain.hospitals.id:
                                    assert hospital_domain.external
                                    assert (
                                        hospital.region_name
                                        == hospital_domain.region_name
                                    )

    def test__stations_and_cities(self, domains_world, domains):
        assert len(domains_world.cities) > 0
        assert len(domains_world.stations) > 0
        assert len(domains_world.city_transports) > 0
        assert len(domains_world.inter_city_transports) > 0
        for city in domains_world.cities:
            for domain in domains:
                assert len(domain.cities) == len(domains_world.cities)
                domain_super_area_ids = [super_area.id for super_area in domain]
                if city.super_area.id in domain_super_area_ids:
                    for city_domain in domain.cities:
                        if city.id == city_domain.id:
                            assert city_domain.external is False
                            assert city.super_area.id == city_domain.super_area.id
                            assert city.name == city_domain.name
                            assert (
                                city.internal_commuter_ids
                                == city_domain.internal_commuter_ids
                            )
                            break
                else:
                    for city_domain in domain.cities:
                        if city.id == city_domain.id:
                            assert city_domain.external
                            assert (
                                city.internal_commuter_ids
                                == city_domain.internal_commuter_ids
                            )
                            break

        for station in domains_world.stations:
            for domain in domains:
                assert len(domain.stations) == len(domains_world.stations)
                domain_super_area_ids = [super_area.id for super_area in domain]
                if station.super_area.id in domain_super_area_ids:
                    for super_area in domain.super_areas:
                        if (
                            super_area.closest_inter_city_station_for_city[
                                station.city
                            ].id
                            == station.id
                        ):
                            assert (
                                super_area.closest_inter_city_station_for_city[
                                    station.city
                                ].external
                                is False
                            )
                    for station_domain in domain.stations:
                        if station.id == station_domain.id:
                            assert station_domain.external is False
                            assert station.super_area.id == station_domain.super_area.id
                            if isinstance(station, InterCityStation):
                                assert len(station.inter_city_transports) == len(
                                    station_domain.inter_city_transports
                                )
                                for ct1, ct2 in zip(
                                    station.inter_city_transports,
                                    station_domain.inter_city_transports,
                                ):
                                    assert ct2.external is False
                                    assert ct1.id == ct2.id
                            else:
                                assert len(station.city_transports) == len(
                                    station_domain.city_transports
                                )
                                for ct1, ct2 in zip(
                                    station.city_transports,
                                    station_domain.city_transports,
                                ):
                                    assert ct2.external is False
                                    assert ct1.id == ct2.id
                            assert station.commuter_ids == station_domain.commuter_ids
                            break
                else:
                    for super_area in domain.super_areas:
                        if (
                            super_area.closest_inter_city_station_for_city[
                                station.city
                            ].id
                            == station.id
                        ):
                            assert (
                                super_area.closest_inter_city_station_for_city[
                                    station.city
                                ].external
                                is True
                            )
                    for station_domain in domain.stations:
                        if station.id == station_domain.id:
                            assert station_domain.external
                            if isinstance(station, InterCityStation):
                                for ct1, ct2 in zip(
                                    station.inter_city_transports,
                                    station_domain.inter_city_transports,
                                ):
                                    assert ct2.external
                                    assert ct1.id == ct2.id
                            else:
                                for ct1, ct2 in zip(
                                    station.city_transports,
                                    station_domain.city_transports,
                                ):
                                    assert ct2.external
                                    assert ct1.id == ct2.id
                            assert station.commuter_ids == station_domain.commuter_ids
                            break

    def test__residences_to_visit(self, domains_world, domains):
        assert len(domains_world.households) > 0
        for household in domains_world.households:
            for domain in domains:
                domain_super_area_ids = [super_area.id for super_area in domain]
                if household.super_area.id in domain_super_area_ids:
                    household_domain = domain.households.get_from_id(household.id)
                    assert household.id == household_domain.id
                    assert len(household.residences_to_visit) == len(
                        household_domain.residences_to_visit
                    )
                    assert len(household.residences_to_visit) == len(
                        household_domain.residences_to_visit
                    )
                    for rv1_spec in household.residences_to_visit:
                        for r1, r2 in zip(
                            household.residences_to_visit[rv1_spec],
                            household_domain.residences_to_visit[rv1_spec],
                        ):
                            assert r1.id == r2.id
                            assert r1.spec == r2.spec
                            if r1.super_area.id not in domain_super_area_ids:
                                assert r2.external
                            else:
                                assert not r2.external


from june.simulator import Simulator
from june.interaction import Interaction
from june.epidemiology.infection import InfectionSelectors, Immunity
from june.epidemiology.infection_seed import InfectionSeed
from june.epidemiology.epidemiology import Epidemiology
from june.groups.travel import Travel
from june.policy import Policies
from june.records import Record
from june.groups.leisure import generate_leisure_for_config
from june import paths


selector_config = paths.configs_path / "defaults/infection/InfectionConstant.yaml"
test_config = paths.configs_path / "tests/test_simulator.yaml"
interaction_config = paths.configs_path / "tests/interaction.yaml"


def test__full_run(dummy_world, selector, test_results):
    world = dummy_world
    # restore health status of people
    for person in world.people:
        person.infection = None
        person.immunity = Immunity()
        person.dead = False
    travel = Travel()
    leisure = generate_leisure_for_config(
        world=dummy_world, config_filename=test_config
    )
    interaction = Interaction.from_file(config_filename=interaction_config)
    record = Record(record_path=test_results / "results")
    policies = Policies.from_file()
    selectors = InfectionSelectors([selector])
    epidemiology = Epidemiology(infection_selectors=selectors)

    sim = Simulator.from_file(
        world=world,
        interaction=interaction,
        epidemiology=epidemiology,
        config_filename=test_config,
        leisure=leisure,
        travel=travel,
        policies=policies,
        record=record,
    )
    seed = InfectionSeed.from_uniform_cases(
        world=sim.world,
        infection_selector=selector,
        cases_per_capita=0.01,
        date=sim.timer.date_str,
        seed_past_infections=True,
    )
    seed.unleash_virus_per_day(date=sim.timer.date, time=0)
    sim.run()
    for region in world.regions:
        region.policy["local_closed_venues"] = set()
        region.policy["global_closed_venues"] = set()


import h5py
import pytest
import numpy as np

from june.epidemiology.infection.transmission_xnexp import TransmissionXNExp
from june.epidemiology.infection.transmission import TransmissionGamma
from june.epidemiology.infection.symptoms import Symptoms, SymptomTag
from june.epidemiology.infection import Immunity, B117, Covid19
from june.hdf5_savers.infection_savers import (
    save_transmissions_to_hdf5,
    load_transmissions_from_hdf5,
    save_symptoms_to_hdf5,
    load_symptoms_from_hdf5,
    save_infections_to_hdf5,
    load_infections_from_hdf5,
    save_immunities_to_hdf5,
    load_immunities_from_hdf5,
)


@pytest.fixture(name="xnexp_transmissions", scope="module")
def setup_xnexp_trans():
    transmission1 = TransmissionXNExp(
        max_probability=1,
        time_first_infectious=1,
        norm_time=2,
        n=3,
        alpha=4,
        max_symptoms="asymptomatic",
        asymptomatic_infectious_factor=5,
        mild_infectious_factor=6,
    )
    transmission2 = TransmissionXNExp(
        max_probability=7,
        time_first_infectious=8,
        norm_time=9,
        n=10,
        alpha=11,
        max_symptoms="mild",
        asymptomatic_infectious_factor=12,
        mild_infectious_factor=13,
    )
    transmissions = [transmission1, transmission2]
    return transmissions


@pytest.fixture(name="gamma_transmissions", scope="module")
def setup_gamma_trans():
    transmission1 = TransmissionGamma(
        max_infectiousness=1.0,
        shape=2.0,
        rate=3.0,
        shift=-2.0,
        max_symptoms="mild",
        asymptomatic_infectious_factor=0.5,
        mild_infectious_factor=0.7,
    )
    transmission2 = TransmissionGamma(
        max_infectiousness=1.1,
        shape=2.1,
        rate=3.1,
        shift=-2.1,
        max_symptoms="asymptomatic",
        asymptomatic_infectious_factor=0.2,
        mild_infectious_factor=0.2,
    )
    transmissions = [transmission1, transmission2]
    return transmissions


@pytest.fixture(name="symptoms_list", scope="module")
def setup_symptoms():
    health_index = np.linspace(0, 1, 5)
    symptoms1 = Symptoms(health_index=health_index)
    symptoms2 = Symptoms(health_index=health_index)
    symptoms = [symptoms1, symptoms2]
    return symptoms


@pytest.fixture(name="infections", scope="module")
def setup_infections(xnexp_transmissions, symptoms_list):
    infections = []
    for (i, (symptoms, trans)) in enumerate(zip(symptoms_list, xnexp_transmissions)):
        if i % 2:
            infection = Covid19(transmission=trans, symptoms=symptoms, start_time=2)
        else:
            infection = B117(transmission=trans, symptoms=symptoms, start_time=2)
        infections.append(infection)
    return infections


class TestTransmissionSavers:
    def test__save_xnexp(self, xnexp_transmissions, test_results):
        with h5py.File(test_results / "checkpoint_tests.hdf5", "w"):
            pass
        save_transmissions_to_hdf5(
            test_results / "checkpoint_tests.hdf5", xnexp_transmissions, chunk_size=1
        )
        transmissions_recovered = load_transmissions_from_hdf5(
            test_results / "checkpoint_tests.hdf5", chunk_size=1
        )
        assert len(transmissions_recovered) == len(xnexp_transmissions)
        for transmission, transmission_recovered in zip(
            xnexp_transmissions, transmissions_recovered
        ):
            for attribute in [
                "time_first_infectious",
                "norm_time",
                "n",
                "norm",
                "alpha",
                "probability",
            ]:
                assert getattr(transmission, attribute) == getattr(
                    transmission_recovered, attribute
                )

    def test__save_gamma(self, gamma_transmissions, test_results):
        with h5py.File(test_results / "checkpoint_tests.hdf5", "w"):
            pass
        save_transmissions_to_hdf5(
            test_results / "checkpoint_tests.hdf5", gamma_transmissions, chunk_size=1
        )
        transmissions_recovered = load_transmissions_from_hdf5(
            test_results / "checkpoint_tests.hdf5", chunk_size=1
        )
        assert len(transmissions_recovered) == len(gamma_transmissions)
        for transmission, transmission_recovered in zip(
            gamma_transmissions, transmissions_recovered
        ):
            for attribute in ["shape", "shift", "scale", "norm", "probability"]:
                assert getattr(transmission, attribute) == getattr(
                    transmission_recovered, attribute
                )


class TestSymptomSavers:
    def test__save_symptoms(self, symptoms_list, test_results):
        with h5py.File(test_results / "checkpoint_tests.hdf5", "w"):
            pass
        save_symptoms_to_hdf5(
            test_results / "checkpoint_tests.hdf5", symptoms_list, chunk_size=1
        )
        symptoms_recovered = load_symptoms_from_hdf5(
            test_results / "checkpoint_tests.hdf5", chunk_size=1
        )
        assert len(symptoms_recovered) == len(symptoms_list)
        for symptom, symptom_recovered in zip(symptoms_list, symptoms_recovered):
            for attribute_name in [
                "max_tag",
                "tag",
                "max_severity",
                "stage",
                "time_of_symptoms_onset",
            ]:
                assert getattr(symptom, attribute_name) == getattr(
                    symptom_recovered, attribute_name
                )
            trajectory = symptom.trajectory
            trajectory_recovered = symptom_recovered.trajectory
            assert len(trajectory) == len(trajectory_recovered)
            for stage, stage_recovered in zip(trajectory, trajectory_recovered):
                assert isinstance(stage_recovered[1], SymptomTag)
                assert stage[0] == stage_recovered[0]
                assert stage[1] == stage_recovered[1]


class TestInfectionSavers:
    def test__save_infection(self, infections, test_results):
        with h5py.File(test_results / "checkpoint_tests.hdf5", "w"):
            pass
        save_infections_to_hdf5(
            test_results / "checkpoint_tests.hdf5", infections, chunk_size=1
        )
        infections_recovered = load_infections_from_hdf5(
            test_results / "checkpoint_tests.hdf5", chunk_size=1
        )
        assert len(infections_recovered) == len(infections)
        for infection, infection_recovered in zip(infections, infections_recovered):
            assert (
                infection.__class__.__name__ == infection_recovered.__class__.__name__
            )
            for attribute_name in ["start_time"]:
                assert getattr(infection, attribute_name) == getattr(
                    infection_recovered, attribute_name
                )
            symptoms = infection.symptoms
            symptoms_recovered = infection_recovered.symptoms
            for attribute_name in [
                "max_tag",
                "tag",
                "max_severity",
                "stage",
                "time_of_symptoms_onset",
            ]:
                assert getattr(symptoms, attribute_name) == getattr(
                    symptoms_recovered, attribute_name
                )
            trajectory = symptoms.trajectory
            trajectory_recovered = symptoms_recovered.trajectory
            assert len(trajectory) == len(trajectory_recovered)
            for stage, stage_recovered in zip(trajectory, trajectory_recovered):
                assert isinstance(stage_recovered[1], SymptomTag)
                assert stage[0] == stage_recovered[0]
                assert stage[1] == stage_recovered[1]
            transmission = infection.transmission
            transmission_recovered = infection_recovered.transmission
            for attribute in [
                "time_first_infectious",
                "norm_time",
                "n",
                "norm",
                "alpha",
                "probability",
            ]:
                assert getattr(transmission, attribute) == getattr(
                    transmission_recovered, attribute
                )


class TestImmunitySavers:
    def test__save_immunities(self, test_results):
        with h5py.File(test_results / "checkpoint_tests.hdf5", "w"):
            pass
        immunities = []
        for i in range(100):
            susc_dict = {i: i / 10}
            imm = Immunity(susc_dict)
            immunities.append(imm)
        save_immunities_to_hdf5(test_results / "checkpoint_tests.hdf5", immunities)
        immunities_recovered = load_immunities_from_hdf5(
            test_results / "checkpoint_tests.hdf5", chunk_size=2
        )
        assert len(immunities) == len(immunities_recovered)
        for imm, immr in zip(immunities, immunities_recovered):
            assert imm.susceptibility_dict == immr.susceptibility_dict


import numpy as np
import h5py
import pytest
from june.geography import Geography
from june.geography.station import CityStation, InterCityStation
from june.groups.travel import CityTransport, InterCityTransport
from june.hdf5_savers import (
    save_population_to_hdf5,
    save_geography_to_hdf5,
    save_schools_to_hdf5,
    save_hospitals_to_hdf5,
    save_care_homes_to_hdf5,
    save_households_to_hdf5,
    save_companies_to_hdf5,
    save_cities_to_hdf5,
    save_stations_to_hdf5,
    save_universities_to_hdf5,
    save_social_venues_to_hdf5,
    generate_world_from_hdf5,
    save_data_for_domain_decomposition,
    load_data_for_domain_decomposition,
)
from june.hdf5_savers import (
    load_geography_from_hdf5,
    load_care_homes_from_hdf5,
    load_companies_from_hdf5,
    load_households_from_hdf5,
    load_population_from_hdf5,
    load_schools_from_hdf5,
    load_hospitals_from_hdf5,
    load_cities_from_hdf5,
    load_stations_from_hdf5,
    load_universities_from_hdf5,
    load_social_venues_from_hdf5,
)

from pytest import fixture


@pytest.fixture(autouse=True)
def remove_hdf5(test_results):
    with h5py.File(test_results / "test.hdf5", "w"):
        pass


class TestSavePeople:
    def test__save_population(self, full_world, test_results):
        population = full_world.people
        assert len(population) > 0
        save_population_to_hdf5(population, test_results / "test.hdf5", chunk_size=500)
        pop_recovered = load_population_from_hdf5(
            test_results / "test.hdf5", chunk_size=600
        )
        for person, person2 in zip(population, pop_recovered):
            for attribute_name in [
                "id",
                "age",
                "sex",
                "ethnicity",
                "sector",
                "sub_sector",
                "lockdown_status",
            ]:
                attribute = getattr(person, attribute_name)
                attribute2 = getattr(person2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2
            assert (
                person.mode_of_transport.description
                == person2.mode_of_transport.description
            )
            assert (
                person.mode_of_transport.is_public
                == person2.mode_of_transport.is_public
            )


class TestSaveHouses:
    def test__save_households(self, full_world, test_results):
        households = full_world.households
        assert len(households) > 0
        save_households_to_hdf5(households, test_results / "test.hdf5", chunk_size=500)
        households_recovered = load_households_from_hdf5(
            test_results / "test.hdf5", chunk_size=600
        )
        for household, household2 in zip(households, households_recovered):
            for attribute_name in ["id", "max_size", "type", "composition_type"]:
                attribute = getattr(household, attribute_name)
                attribute2 = getattr(household2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2


class TestSaveCompanies:
    def test__save_companies(self, full_world, test_results):
        companies = full_world.companies
        assert len(companies) > 0
        save_companies_to_hdf5(companies, test_results / "test.hdf5", chunk_size=500)
        companies_recovered = load_companies_from_hdf5(
            test_results / "test.hdf5", chunk_size=600
        )
        for company, company2 in zip(companies, companies_recovered):
            for attribute_name in ["id", "n_workers_max", "sector"]:
                attribute = getattr(company, attribute_name)
                attribute2 = getattr(company2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2


class TestSaveHospitals:
    def test__save_hospitals(self, full_world, test_results):
        hospitals = full_world.hospitals
        assert len(hospitals) > 0
        save_hospitals_to_hdf5(hospitals, test_results / "test.hdf5", chunk_size=500)
        hospitals_recovered = load_hospitals_from_hdf5(
            test_results / "test.hdf5", chunk_size=600
        )
        for hospital, hospital2 in zip(hospitals, hospitals_recovered):
            for attribute_name in ["id", "n_beds", "n_icu_beds"]:
                attribute = getattr(hospital, attribute_name)
                attribute2 = getattr(hospital2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2
            assert hospital.coordinates[0] == hospital2.coordinates[0]
            assert hospital.coordinates[1] == hospital2.coordinates[1]
            assert hospital.trust_code == hospital2.trust_code


class TestSaveSchools:
    def test__save_schools(self, full_world, test_results):
        schools = full_world.schools
        assert len(schools) > 0
        save_schools_to_hdf5(schools, test_results / "test.hdf5", chunk_size=500)
        schools_recovered = load_schools_from_hdf5(
            test_results / "test.hdf5", chunk_size=600
        )
        for school, school2 in zip(schools, schools_recovered):
            for attribute_name in [
                "id",
                "n_pupils_max",
                "age_min",
                "age_max",
                "sector",
                "n_classrooms",
                "years",
            ]:
                attribute = getattr(school, attribute_name)
                attribute2 = getattr(school2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2
            assert school.coordinates[0] == school2.coordinates[0]
            assert school.coordinates[1] == school2.coordinates[1]


class TestSaveCarehomes:
    def test__save_carehomes(self, full_world, test_results):
        carehomes = full_world.care_homes
        assert len(carehomes) > 0
        save_care_homes_to_hdf5(carehomes, test_results / "test.hdf5", chunk_size=500)
        carehomes_recovered = load_care_homes_from_hdf5(
            test_results / "test.hdf5", chunk_size=600
        )
        for carehome, carehome2 in zip(carehomes, carehomes_recovered):
            for attribute_name in ["id", "n_residents"]:
                attribute = getattr(carehome, attribute_name)
                attribute2 = getattr(carehome2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2


class TestSaveGeography:
    def test__save_geography(self, full_world, test_results):
        areas = full_world.areas
        super_areas = full_world.super_areas
        regions = full_world.regions
        assert len(areas) > 0
        assert len(super_areas) > 0
        assert len(regions) > 0
        geography = Geography(areas, super_areas, regions)
        save_geography_to_hdf5(geography, test_results / "test.hdf5")
        geography_recovered = load_geography_from_hdf5(test_results / "test.hdf5")
        for area, area2 in zip(areas, geography_recovered.areas):
            for attribute_name in ["id", "name"]:
                attribute = getattr(area, attribute_name)
                attribute2 = getattr(area2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2
            assert area.coordinates[0] == area2.coordinates[0]
            assert area.coordinates[1] == area2.coordinates[1]

        for super_area, super_area2 in zip(
            super_areas, geography_recovered.super_areas
        ):
            for attribute_name in ["id", "name"]:
                attribute = getattr(super_area, attribute_name)
                attribute2 = getattr(super_area2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2
            assert super_area.coordinates[0] == super_area2.coordinates[0]
            assert super_area.coordinates[1] == super_area2.coordinates[1]

        for region, region2 in zip(regions, geography_recovered.regions):
            for attribute_name in ["id", "name"]:
                attribute = getattr(region, attribute_name)
                attribute2 = getattr(region2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2


class TestSaveTravel:
    def test__save_cities(self, full_world, test_results):
        cities = full_world.cities
        assert len(cities) > 0
        save_cities_to_hdf5(cities, test_results / "test.hdf5")
        cities_recovered = load_cities_from_hdf5(test_results / "test.hdf5")
        assert len(cities) == len(cities_recovered)
        for city, city_recovered in zip(cities, cities_recovered):
            assert city.name == city_recovered.name
            for sa1, sa2 in zip(city.super_areas, city_recovered.super_areas):
                assert sa1 == sa2
            assert city.coordinates[0] == city_recovered.coordinates[0]
            assert city.coordinates[1] == city_recovered.coordinates[1]

    def test__save_stations(self, full_world, test_results):
        stations = full_world.stations
        inter_city_transports = full_world.inter_city_transports
        assert len(stations) > 0
        save_stations_to_hdf5(stations, test_results / "test.hdf5")
        (
            stations_recovered,
            inter_city_transports_recovered,
            city_transports_recovered,
        ) = load_stations_from_hdf5(test_results / "test.hdf5")
        assert len(stations) == len(stations_recovered)
        assert len(inter_city_transports) == len(inter_city_transports_recovered)
        for station, station_recovered in zip(stations, stations_recovered):
            assert station.id == station_recovered.id
            assert station.city == station_recovered.city
            if isinstance(station, CityStation):
                assert isinstance(station_recovered, CityStation)
                assert len(station.city_transports) == len(
                    station_recovered.city_transports
                )
                for ct1, ct2 in zip(
                    station.city_transports, station_recovered.city_transports
                ):
                    assert isinstance(ct1, CityTransport)
                    assert isinstance(ct2, CityTransport)
                    assert ct1.id == ct2.id
            else:
                assert isinstance(station, InterCityStation)
                assert isinstance(station_recovered, InterCityStation)
                assert len(station.inter_city_transports) == len(
                    station_recovered.inter_city_transports
                )
                for ict1, ict2 in zip(
                    station.inter_city_transports,
                    station_recovered.inter_city_transports,
                ):
                    assert isinstance(ict1, InterCityTransport)
                    assert isinstance(ict2, InterCityTransport)
                    assert ict1.id == ict2.id


class TestSaveUniversities:
    def test__save_universities(self, full_world, test_results):
        universities = full_world.universities
        assert len(universities) > 0
        save_universities_to_hdf5(universities, test_results / "test.hdf5")
        universities_recovered = load_universities_from_hdf5(test_results / "test.hdf5")
        for uni, uni2 in zip(universities, universities_recovered):
            for attribute_name in ["id", "n_students_max", "n_years"]:
                attribute = getattr(uni, attribute_name)
                attribute2 = getattr(uni2, attribute_name)
                if attribute is None:
                    assert attribute2 is None
                else:
                    assert attribute == attribute2
            assert uni.coordinates[0] == uni2.coordinates[0]
            assert uni.coordinates[1] == uni2.coordinates[1]


class TestSaveLeisure:
    def test__save_social_venues(self, full_world, test_results):
        save_social_venues_to_hdf5(
            social_venues_list=[
                full_world.pubs,
                full_world.groceries,
                full_world.cinemas,
                full_world.gyms,
            ],
            file_path=test_results / "test.hdf5",
        )
        social_venues_dict = load_social_venues_from_hdf5(test_results / "test.hdf5")
        for social_venues_spec, social_venues in social_venues_dict.items():
            for sv1, sv2 in zip(getattr(full_world, social_venues_spec), social_venues):
                assert sv1.coordinates[0] == sv2.coordinates[0]
                assert sv1.coordinates[1] == sv2.coordinates[1]
                assert sv1.id == sv2.id


class TestSaveWorld:
    @fixture(name="full_world_loaded", scope="module")
    def reaload_world(self, full_world, test_results):
        full_world.to_hdf5(test_results / "test.hdf5")
        world2 = generate_world_from_hdf5(test_results / "test.hdf5", chunk_size=500)
        return world2

    def test__save_geography(self, full_world, full_world_loaded):
        assert len(full_world.areas) == len(full_world_loaded.areas)
        for area1, area2 in zip(full_world.areas, full_world_loaded.areas):
            assert area1.id == area2.id
            assert area1.socioeconomic_index == area2.socioeconomic_index
            assert area1.super_area.id == area2.super_area.id
            assert area1.super_area.name == area2.super_area.name
            assert area1.name == area2.name

        assert len(full_world.super_areas) == len(full_world_loaded.super_areas)
        for super_area1, super_area2 in zip(
            full_world.super_areas, full_world_loaded.super_areas
        ):
            assert super_area1.id == super_area2.id
            assert super_area1.name == super_area2.name
            assert len(super_area1.areas) == len(super_area2.areas)
            area1_ids = [area.id for area in super_area1.areas]
            area2_ids = [area.id for area in super_area2.areas]
            assert set(area1_ids) == set(area2_ids)
            sa1_areas = [super_area1.areas[idx] for idx in np.argsort(area1_ids)]
            sa2_areas = [super_area2.areas[idx] for idx in np.argsort(area2_ids)]
            for area1, area2 in zip(sa1_areas, sa2_areas):
                assert area1.id == area2.id
                assert area1.socioeconomic_index == area2.socioeconomic_index
                assert area1.super_area.id == area2.super_area.id
                assert area1.super_area.name == area2.super_area.name
                assert area1.name == area2.name
                assert area1.super_area.region.id == area2.super_area.region.id
                assert area1.super_area.region.name == area2.super_area.region.name

        assert len(full_world.regions) == len(full_world_loaded.regions)
        for region1, region2 in zip(full_world.regions, full_world_loaded.regions):
            assert region1.id == region2.id
            assert region1.name == region2.name
            super_area1_ids = [super_area.id for super_area in region1.super_areas]
            super_area2_ids = [super_area.id for super_area in region2.super_areas]
            assert len(super_area1_ids) == len(super_area2_ids)
            assert set(super_area1_ids) == set(super_area2_ids)
            region1_super_areas = [
                region1.super_areas[idx] for idx in np.argsort(super_area1_ids)
            ]
            region2_super_areas = [
                region2.super_areas[idx] for idx in np.argsort(super_area2_ids)
            ]
            for superarea1, superarea2 in zip(region1_super_areas, region2_super_areas):
                assert superarea1.id == superarea2.id
                assert superarea1.name == superarea2.name

    def test__subgroups(self, full_world, full_world_loaded):
        for person1, person2 in zip(full_world.people, full_world_loaded.people):
            assert person1.area.id == person2.area.id
            assert (person1.area.coordinates == person2.area.coordinates).all()
            for subgroup1, subgroup2 in zip(
                person1.subgroups.iter(), person2.subgroups.iter()
            ):
                if subgroup1 is None:
                    assert subgroup2 is None
                    continue
                assert subgroup1.group.spec == subgroup2.group.spec
                assert subgroup1.group.id == subgroup2.group.id
                assert subgroup1.subgroup_type == subgroup2.subgroup_type

    def test__household_area(self, full_world, full_world_loaded):
        assert len(full_world_loaded.households) == len(full_world_loaded.households)
        for household, household2 in zip(
            full_world.households, full_world_loaded.households
        ):
            if household.area is not None:
                assert household.area.id == household2.area.id
            else:
                assert household2.area is None

    def test__school_area(self, full_world, full_world_loaded):
        assert len(full_world_loaded.schools) == len(full_world.schools)
        for school, school2 in zip(full_world.schools, full_world_loaded.schools):
            if school.area is not None:
                assert school.area.id == school2.area.id
            else:
                assert school2.super_area is None

    def test__work_super_area(self, full_world, full_world_loaded):
        has_super_area = False
        for p1, p2 in zip(full_world.people, full_world_loaded.people):
            if p1.work_super_area is None:
                assert p2.work_super_area is None
            else:
                has_super_area = True
                assert p1.work_super_area.id == p2.work_super_area.id
                assert p1.work_super_area == p1.primary_activity.group.super_area
                assert p2.work_super_area == p2.primary_activity.group.super_area
                assert p1.work_super_area.id == p2.primary_activity.group.super_area.id
                assert (
                    p1.work_super_area.coordinates[0]
                    == p2.work_super_area.coordinates[0]
                )
                assert (
                    p1.work_super_area.coordinates[1]
                    == p2.work_super_area.coordinates[1]
                )
                if p1.work_super_area.city is None:
                    assert p2.work_super_area.city is None
                else:
                    assert p1.work_super_area.city.id == p2.work_super_area.city.id
        assert has_super_area
        has_people = False
        for company1, company2 in zip(
            full_world.companies, full_world_loaded.companies
        ):
            for person1, person2 in zip(company1.people, company2.people):

                has_people = True
                assert person1.work_super_area is not None
                assert person2.work_super_area is not None
                assert person1.work_super_area == company1.super_area
                assert person2.work_super_area == company2.super_area
        assert has_people

    def test__super_area_city(self, full_world, full_world_loaded):
        has_city = False
        for sa1, sa2 in zip(full_world.super_areas, full_world_loaded.super_areas):
            if sa1.city is None:
                assert sa2.city is None
            else:
                has_city = True
                assert sa1.city.id == sa2.city.id
                assert sa1.city.name == sa2.city.name
            for (
                city,
                closest_station,
            ) in sa1.closest_inter_city_station_for_city.items():
                assert city in sa2.closest_inter_city_station_for_city
                assert (
                    sa2.closest_inter_city_station_for_city[city].id
                    == closest_station.id
                )
        assert has_city

    def test__care_home_area(self, full_world, full_world_loaded):
        assert len(full_world_loaded.care_homes) == len(full_world_loaded.care_homes)
        for carehome, carehome2 in zip(
            full_world.care_homes, full_world_loaded.care_homes
        ):
            assert carehome.area.id == carehome2.area.id
            assert carehome.area.name == carehome2.area.name

    def test__company_super_area(self, full_world, full_world_loaded):
        for company1, company2 in zip(
            full_world.companies, full_world_loaded.companies
        ):
            assert company1.super_area.id == company2.super_area.id

    def test__university_super_area(self, full_world, full_world_loaded):
        for uni1, uni2 in zip(full_world.universities, full_world_loaded.universities):
            assert uni1.area.id == uni2.area.id
            assert uni1.super_area.id == uni2.super_area.id
            assert uni1.super_area.name == uni2.super_area.name

    def test__hospital_super_area(self, full_world, full_world_loaded):
        for h1, h2 in zip(full_world.hospitals, full_world_loaded.hospitals):
            assert h1.area.id == h2.area.id
            assert h1.super_area.id == h2.super_area.id
            assert h1.super_area.name == h2.super_area.name
            assert h1.region_name == h2.region_name

    def test__social_venues_super_area(self, full_world, full_world_loaded):
        for spec in ["pubs", "groceries", "cinemas"]:
            social_venues1 = getattr(full_world, spec)
            social_venues2 = getattr(full_world_loaded, spec)
            assert len(social_venues1) == len(social_venues2)
            for v1, v2 in zip(social_venues1, social_venues2):
                assert v1.area.id == v2.area.id
                assert v1.super_area.id == v2.super_area.id
                assert v1.super_area.name == v2.super_area.name

    def test__commute(self, full_world, full_world_loaded):
        assert len(full_world.city_transports) > 0
        assert len(full_world.inter_city_transports) > 0
        assert len(full_world.city_transports) == len(full_world_loaded.city_transports)
        assert len(full_world.inter_city_transports) == len(
            full_world_loaded.inter_city_transports
        )
        for city1, city2 in zip(full_world.cities, full_world_loaded.cities):
            assert city1.name == city2.name
            assert len(city1.internal_commuter_ids) == len(city2.internal_commuter_ids)
            assert city1.internal_commuter_ids == city2.internal_commuter_ids
            assert city1.super_area.id == city2.super_area.id
            assert len(city1.inter_city_stations) == len(city2.inter_city_stations)
            for station1, station2 in zip(
                city1.inter_city_stations, city2.inter_city_stations
            ):
                assert station1.id == station2.id
                assert station1.super_area.id == station2.super_area.id
                assert len(station1.commuter_ids) == len(station2.commuter_ids)
                assert station1.commuter_ids == station2.commuter_ids
            assert len(city1.city_stations) == len(city2.city_stations)
            for station1, station2 in zip(city1.city_stations, city2.city_stations):
                assert station1.id == station2.id
                assert station1.super_area.id == station2.super_area.id
                assert len(station1.commuter_ids) == len(station2.commuter_ids)
                assert station1.commuter_ids == station2.commuter_ids

    def test__household_residents(self, full_world, full_world_loaded):
        for h1, h2 in zip(full_world.households, full_world_loaded.households):
            assert len(h1.residents) == len(h2.residents)
            h1_resident_ids = np.array([p.id for p in h1.residents])
            h2_resident_ids = np.array([p.id for p in h2.residents])
            for p1, p2 in zip(np.sort(h1_resident_ids), np.sort(h2_resident_ids)):
                assert p1 == p2

    def test__closest_hospitals(self, full_world, full_world_loaded):
        for sa1, sa2 in zip(full_world.super_areas, full_world_loaded.super_areas):
            assert len(sa1.closest_hospitals) == len(sa2.closest_hospitals)
            for h1, h2 in zip(sa1.closest_hospitals, sa2.closest_hospitals):
                assert h1.id == h2.id

    def test__socioeconomic_index(self, full_world, full_world_loaded):
        for person1, person2 in zip(full_world.people, full_world_loaded.people):
            assert person1.socioeconomic_index == person2.socioeconomic_index

    def test__social_venues(self, full_world, full_world_loaded):
        for area1, area2 in zip(full_world.areas, full_world_loaded.areas):
            for key in area1.social_venues.keys():
                assert key in area2.social_venues.keys()
                social_venues = area1.social_venues[key]
                social_venues_recovered = area2.social_venues[key]
                social_venues_id = np.sort([sv.id for sv in social_venues])
                social_venues_recovered_id = np.sort(
                    [sv.id for sv in social_venues_recovered]
                )
                assert np.array_equal(social_venues_id, social_venues_recovered_id)
        for h1, h2 in zip(full_world.households, full_world_loaded.households):
            assert h1.id == h2.id
            assert len(h1.residences_to_visit) == len(h2.residences_to_visit)
            for (key1, value1), (key2, value2) in zip(
                h1.residences_to_visit.items(), h2.residences_to_visit.items()
            ):
                assert key1 == key2
                for residence1, residence2 in zip(value1, value2):
                    assert residence1.id == residence2.id
                    assert residence1.spec == residence2.spec


class TestSaveDataDomainDecomposition:
    def test__save_data(self, full_world, test_results):
        save_data_for_domain_decomposition(full_world, test_results / "test.hdf5")
        data_recovered = load_data_for_domain_decomposition(test_results / "test.hdf5")
        n_people_sa = {}
        n_workers_sa = {}
        n_pupils_sa = {}
        for super_area in full_world.super_areas:
            n_people_sa[super_area.name] = len(super_area.people)
            n_workers_sa[super_area.name] = len(super_area.workers)
            n_pupils_sa[super_area.name] = sum(
                len(school.people)
                for area in super_area.areas
                for school in area.schools
            )
        total_commuters = sum(
            [len(station.commuter_ids) for station in full_world.stations]
        )
        total_commuters += sum(
            [len(city.internal_commuter_ids) for city in full_world.cities]
        )
        total_commuters_recovered = 0
        checks = False
        for super_area in n_pupils_sa.keys():
            assert data_recovered[super_area]["n_people"] == n_people_sa[super_area]
            assert data_recovered[super_area]["n_workers"] == n_workers_sa[super_area]
            assert data_recovered[super_area]["n_pupils"] == n_pupils_sa[super_area]
            total_commuters_recovered += data_recovered[super_area]["n_commuters"]
            checks = True
        assert total_commuters_recovered == total_commuters
        assert checks


from june.time import Timer


def test_initial_parameters():
    timer = Timer(initial_day="2020-03-10", total_days=10)
    assert timer.shift == 0
    assert timer.is_weekend is False
    assert timer.day_of_week == "Tuesday"
    assert timer.date_str == "2020-03-10"


def test_time_is_passing():
    timer = Timer(initial_day="2020-03-10", total_days=10)
    assert timer.now == 0
    next(timer)
    assert timer.now == 0.5
    assert timer.previous_date == timer.initial_date
    next(timer)
    assert timer.now == 1.0


def test_time_reset():
    timer = Timer(initial_day="2020-03-10", total_days=10)
    start_time = timer.initial_date
    assert timer.date_str == "2020-03-10"
    next(timer)
    next(timer)
    assert timer.date_str == "2020-03-11"
    next(timer)
    next(timer)
    assert timer.day == 2
    assert timer.date_str == "2020-03-12"
    timer.reset()
    assert timer.day == 0
    assert timer.shift == 0
    assert timer.previous_date == start_time
    assert timer.date_str == "2020-03-10"
    next(timer)
    next(timer)
    next(timer)
    next(timer)
    assert timer.day == 2


def test_weekend_transition():
    timer = Timer(initial_day="2020-03-10", total_days=10)
    for _ in range(0, 8):  # 5 days for 3 time steps per day
        next(timer)
    assert timer.is_weekend is True
    assert timer.activities == ("residence",)
    next(timer)
    assert timer.is_weekend is True
    assert timer.activities == ("residence",)
    next(timer)
    assert timer.is_weekend is False
    assert timer.activities == ("primary_activity", "residence")
    # a second test


from june.geography import Geography
from june.world import generate_world_from_geography
from june.groups import Schools, Hospitals, Companies, Households, Cemeteries, CareHomes


def test__onearea_world(geography):
    geography = Geography.from_file(filter_key={"area": ["E00088544"]})
    world = generate_world_from_geography(geography)
    assert hasattr(world, "households")
    assert isinstance(world.households, Households)
    assert len(world.areas) == 1
    assert len(world.super_areas) == 1
    assert world.super_areas.members[0].name == "E02003616"
    assert len(world.areas.members[0].people) == 362
    assert len(world.households) <= 148


def test__world_has_everything(world):
    assert isinstance(world.schools, Schools)
    assert isinstance(world.cemeteries, Cemeteries)
    assert isinstance(world.companies, Companies)
    assert isinstance(world.households, Households)
    assert isinstance(world.hospitals, Hospitals)
    assert isinstance(world.care_homes, CareHomes)
    assert isinstance(world.companies, Companies)


def test__people_in_world_right_subgroups(world):
    dummy_people = world.people.members[:40]

    for dummy_person in dummy_people:
        for subgroup in dummy_person.subgroups.iter():
            if subgroup is not None:
                assert dummy_person in subgroup.people


import collections
import pytest
import numpy as np

from june.geography import Geography
from june import demography as d
from june.demography import AgeSexGenerator


@pytest.fixture(name="area")
def area_name():
    return


@pytest.fixture(name="geography_demography_test", scope="module")
def create_geography():
    return Geography.from_file(filter_key={"super_area": ["E02004935"]})


def test__age_sex_generator():
    age_counts = [0, 2, 0, 2, 4]
    age_bins = [0, 3]
    female_fractions = [0, 1]
    ethnicity_age_bins = [0, 2, 4]
    ethnicity_groups = ["A1", "B2", "C3"]
    ethnicity_structure = [[2, 0, 0], [0, 0, 2], [0, 4, 0]]
    age_sex_generator = d.demography.AgeSexGenerator(
        age_counts,
        age_bins,
        female_fractions,
        ethnicity_age_bins,
        ethnicity_groups,
        ethnicity_structure,
    )
    assert list(age_sex_generator.age_iterator) == [1, 1, 3, 3, 4, 4, 4, 4]
    assert list(age_sex_generator.sex_iterator) == [
        "m",
        "m",
        "f",
        "f",
        "f",
        "f",
        "f",
        "f",
    ]
    assert list(age_sex_generator.ethnicity_iterator) == [
        "A1",
        "A1",
        "C3",
        "C3",
        "B2",
        "B2",
        "B2",
        "B2",
    ]
    age_sex_generator = d.demography.AgeSexGenerator(
        age_counts,
        age_bins,
        female_fractions,
        ethnicity_age_bins,
        ethnicity_groups,
        ethnicity_structure,
    )
    ages = []
    sexes = []
    ethnicities = []
    for _ in range(0, sum(age_counts)):
        age = age_sex_generator.age()
        sex = age_sex_generator.sex()
        ethnicity = age_sex_generator.ethnicity()
        ages.append(age)
        sexes.append(sex)
        ethnicities.append(ethnicity)

    assert sorted(ages) == [1, 1, 3, 3, 4, 4, 4, 4]
    assert collections.Counter(sexes) == collections.Counter(
        ["m", "m", "f", "f", "f", "f", "f", "f"]
    )
    assert collections.Counter(ethnicities) == collections.Counter(
        ["A1", "A1", "C3", "C3", "B2", "B2", "B2", "B2"]
    )


class TestDemography:
    def test__demography_for_areas(self):
        geography = Geography.from_file({"area": ["E00088544"]})
        area = list(geography.areas)[0]
        demography = d.Demography.for_areas(area_names=[area.name])
        area.populate(demography)
        population = area.people
        assert len(population) == 362
        people_ages_dict = {}
        people_sex_dict = {}
        for person in population:
            if person.age == 0:
                assert person.sex == "f"
            if person.age > 90:
                assert person.sex == "f"
            if person.age == 21:
                assert person.sex == "m"
            if person.age in range(55, 69):
                assert person.ethnicity.startswith("A")
                assert person.ethnicity in ["A1", "A2", "A4"]
            assert person.ethnicity.startswith("D") is False
            assert (
                person.area.socioeconomic_index == 0.59
            )  # checked in new socioecon_index file.
            if person.age not in people_ages_dict:
                people_ages_dict[person.age] = 1
            else:
                people_ages_dict[person.age] += 1
            if person.sex not in people_sex_dict:
                people_sex_dict[person.sex] = 1
            else:
                people_sex_dict[person.sex] += 1
        assert people_ages_dict[0] == 6
        assert people_ages_dict[71] == 3
        assert max(people_ages_dict.keys()) == 90

    def test__demography_for_super_areas(self):
        demography = d.Demography.for_zone(filter_key={"super_area": ["E02004935"]})
        assert len(demography.age_sex_generators) == 26

    def test__demography_for_geography(self, geography_demography_test):
        demography = d.Demography.for_geography(geography_demography_test)
        assert len(demography.age_sex_generators) == 26

    def test__age_sex_generator_from_bins(self):
        men_age_dict = {"0-10": 1000, "11-70": 2000, "71-99": 500}
        women_age_dict = {"0-10": 1500, "11-70": 1000, "71-99": 1000}
        age_sex_gen = AgeSexGenerator.from_age_sex_bins(men_age_dict, women_age_dict)
        people_number = sum(men_age_dict.values()) + sum(women_age_dict.values())
        ages = []
        sexes = []
        for _ in range(people_number):
            ages.append(age_sex_gen.age())
            sexes.append(age_sex_gen.sex())
        ages = np.array(ages)
        sexes = np.array(sexes)
        men_idx = sexes == "m"
        men_ages = ages[men_idx]
        women_ages = ages[~men_idx]
        _, men_counts = np.unique(
            np.digitize(men_ages, [0, 11, 71]), return_counts=True
        )
        _, women_counts = np.unique(
            np.digitize(women_ages, [0, 11, 71]), return_counts=True
        )
        np.testing.assert_allclose(
            men_counts / women_counts, [10 / 15, 20 / 10, 5 / 10], atol=0, rtol=0.05
        )

    def test__comorbidities_for_areas(self):
        geography = Geography.from_file({"area": ["E00088544"]})
        area = list(geography.areas)[0]
        demography = d.Demography.for_areas(area_names=[area.name])
        area.populate(demography)
        comorbidities = []
        for person in area.people:
            if person.comorbidity is not None:
                comorbidities.append(person.comorbidity)
        assert len(np.unique(comorbidities)) > 0


class TestPopulation:
    def test__create_population_from_demography(self, geography_demography_test):
        demography = d.Demography.for_geography(geography_demography_test)
        population = list()
        for area in geography_demography_test.areas:
            area.populate(demography)
            population.extend(area.people)
        assert len(population) == 7602


import pytest
from june import paths
from june.distributors.care_home_distributor import CareHomeDistributor
from june.demography import Person
from june.groups.care_home import CareHome
from june.geography import Area, SuperArea, Areas, SuperAreas


default_config_file = paths.configs_path / "defaults/groups/carehome.yaml"


@pytest.fixture(name="geography", scope="module")
def make_geo():
    super_areas = SuperAreas(
        [SuperArea(name="super_area_1"), SuperArea(name="super_area_2")],
        ball_tree=False,
    )
    areas = Areas(
        [
            Area(super_area=super_areas[0], name="area_1"),
            Area(super_area=super_areas[0], name="area_2"),
            Area(super_area=super_areas[1], name="area_3"),
            Area(super_area=super_areas[1], name="area_4"),
        ],
        ball_tree=False,
    )
    super_areas[0].areas = areas[0:2]
    super_areas[1].areas = areas[2:4]
    # workers/carers
    for _ in range(5):
        carer = Person.from_attributes()
        carer.sector = "Q"
        carer.sub_sector = None
        super_areas[0].workers.append(carer)
    for _ in range(3):
        carer = Person.from_attributes()
        carer.sector = "Q"
        carer.sub_sector = None
        super_areas[1].workers.append(carer)

    # residents
    # super area 1
    for _ in range(10):
        person = Person.from_attributes(age=40, sex="m")
        areas[0].people.append(person)
    for _ in range(10):
        person = Person.from_attributes(age=80, sex="f")
        areas[0].people.append(person)

    for _ in range(30):
        person = Person.from_attributes(age=80, sex="m")
        areas[1].people.append(person)
    for _ in range(20):
        person = Person.from_attributes(age=40, sex="f")
        areas[1].people.append(person)

    # super area 2
    for _ in range(5):
        person = Person.from_attributes(age=40, sex="m")
        areas[2].people.append(person)
    for _ in range(12):
        person = Person.from_attributes(age=80, sex="f")
        areas[2].people.append(person)

    for _ in range(10):
        person = Person.from_attributes(age=80, sex="m")
        areas[3].people.append(person)
    for _ in range(8):
        person = Person.from_attributes(age=40, sex="f")
        areas[3].people.append(person)

    # workers
    super_areas[0].workers = []
    for _ in range(30):
        person = Person.from_attributes(age=28)
        person.sector = "Q"
        person.sub_sector = None
        super_areas[0].workers.append(person)
    super_areas[1].workers = []
    for _ in range(70):
        person = Person.from_attributes(age=28)
        person.sector = "Q"
        person.sub_sector = None
        super_areas[1].workers.append(person)

    areas[0].care_home = CareHome(n_residents=20, n_workers=10, area=areas[0])
    areas[1].care_home = CareHome(n_residents=50, n_workers=20, area=areas[1])
    areas[2].care_home = CareHome(n_residents=17, n_workers=30, area=areas[2])
    areas[3].care_home = CareHome(n_residents=18, n_workers=40, area=areas[3])

    return areas, super_areas


class TestCareHomeDistributor:
    @pytest.fixture(name="carehome_distributor")
    def create_carehome_dist(self):
        communal_men_by_super_area = {
            "super_area_1": {"0-60": 10, "60-100": 30},
            "super_area_2": {"0-50": 5, "50-100": 10},
        }
        communal_women_by_super_area = {
            "super_area_1": {"0-60": 20, "60-100": 10},
            "super_area_2": {"0-50": 8, "50-100": 12},
        }
        carehome_dist = CareHomeDistributor(
            communal_men_by_super_area=communal_men_by_super_area,
            communal_women_by_super_area=communal_women_by_super_area,
        )
        return carehome_dist

    def _count_people_in_carehome(self, area):
        men = []
        women = []
        for person in area.care_home.residents:
            if person.sex == "m":
                men.append(person)
            else:
                women.append(person)
        return men, women

    def test__care_home_residents(self, carehome_distributor, geography):
        areas, super_areas = geography
        carehome_distributor.populate_care_homes_in_super_areas(super_areas=super_areas)

        men, women = self._count_people_in_carehome(areas[0])
        assert len(men) == 10
        assert len(women) == 10
        for man in men:
            assert man.age == 40
        for woman in women:
            assert woman.age == 80

        men, women = self._count_people_in_carehome(areas[1])
        assert len(men) == 30
        assert len(women) == 20
        for man in men:
            assert man.age == 80
        for woman in women:
            assert woman.age == 40

        men, women = self._count_people_in_carehome(areas[2])
        assert len(men) == 5
        assert len(women) == 12
        for man in men:
            assert man.age == 40
        for woman in women:
            assert woman.age == 80

        men, women = self._count_people_in_carehome(areas[3])
        assert len(men) == 10
        assert len(women) == 8
        for man in men:
            assert man.age == 80
        for woman in women:
            assert woman.age == 40

    def test__carehome_workers(self, carehome_distributor, geography):
        areas, super_areas = geography
        carehome_distributor.distribute_workers_to_care_homes(super_areas=super_areas)
        for area in areas:
            assert len(area.care_home.workers) == area.care_home.n_workers
            for worker in area.care_home.workers:
                assert worker.sector == "Q"


import pytest
from june.geography import SuperArea
from june.groups import Company
from june.demography import Person
from june.distributors import CompanyDistributor

# TODO: This test shouldn't use from goegraphy! Create a world that has those characteristics


@pytest.fixture(name="super_area")
def make_super_area():
    super_area = SuperArea()
    for i in range(3):
        super_area.companies.append(Company(sector=i, n_workers_max=i))
        person = Person.from_attributes()
        person.sector = i
        super_area.workers.append(person)
    return super_area


def test__company_distributor(super_area):
    cd = CompanyDistributor()
    cd.distribute_adults_to_companies_in_super_area(super_area)
    for company in super_area.companies:
        assert len(company.people) == 1
        assert list(company.people)[0].sector == company.sector


def test__company_and_work_super_area(full_world):
    has_people = False
    for person in full_world.people:
        if person.work_super_area is not None:
            has_people = True
            assert person.work_super_area == person.primary_activity.group.super_area
    assert has_people


class TestLockdownStatus:
    def test__lockdown_status_random(self, full_world):
        found_worker = False
        found_child = False
        for person in full_world.areas[0].people:
            if person.age > 18:
                worker = person
                found_worker = True
            elif person.age < 18:
                child = person
                found_child = True
            if found_worker and found_child:
                break

        assert worker.lockdown_status is not None
        assert child.lockdown_status is None

    def test__lockdown_status_teacher(self, full_world):
        teacher = full_world.schools[0].teachers.people[0]
        assert teacher.lockdown_status == "key_worker"

    def test__lockdown_status_medic(self, full_world):
        medic = full_world.hospitals[0].people[0]
        assert medic.lockdown_status == "key_worker"

    def test__lockdown_status_care_home(self, full_world):
        care_home_worker = full_world.care_homes[0].people[0]
        assert care_home_worker.lockdown_status == "key_worker"


import pytest

from june.distributors import HospitalDistributor
from june.geography import Geography
from june.groups import Hospital, Hospitals
from june.demography.person import Person


@pytest.fixture(name="young_medic")
def make_medic_young():
    medic = Person.from_attributes(age=18)
    medic.sector = "Q"
    medic.sub_sector = "Hospital"
    return medic


@pytest.fixture(name="old_medic")
def make_medic_old():
    medic = Person.from_attributes(age=40)
    medic.sector = "Q"
    medic.sub_sector = "Hospital"
    return medic


@pytest.fixture(name="geography_hospital")
def make_geography(young_medic, old_medic):
    geography = Geography.from_file({"super_area": ["E02003999", "E02006764"]})
    for _ in range(200):
        geography.super_areas.members[0].add_worker(young_medic)
        geography.super_areas.members[0].areas[0].add(young_medic)
    for _ in range(200):
        geography.super_areas.members[0].add_worker(old_medic)
        geography.super_areas.members[0].areas[0].add(old_medic)
    return geography


@pytest.fixture(name="hospitals")
def make_hospitals(geography_hospital):
    super_area_test = geography_hospital.super_areas.members[0]
    hospitals = [
        Hospital(
            n_beds=40,
            n_icu_beds=5,
            area=super_area_test.areas[0],
            coordinates=super_area_test.coordinates,
        ),
        Hospital(
            n_beds=80,
            n_icu_beds=20,
            area=super_area_test.areas[0],
            coordinates=super_area_test.coordinates,
        ),
    ]
    return Hospitals(hospitals)


def test__distribution_of_medics(geography_hospital, hospitals):
    geography_hospital.hospitals = hospitals
    hospital_distributor = HospitalDistributor(
        hospitals, medic_min_age=25, patients_per_medic=10, healthcare_sector_label="Q"
    )
    hospital_distributor.distribute_medics_to_super_areas(
        geography_hospital.super_areas
    )
    for hospital in hospitals:
        patients = hospital.n_beds + hospital.n_icu_beds
        medics = hospital.subgroups[hospital.SubgroupType.workers].people
        for medic in medics:
            assert medic.age >= hospital_distributor.medic_min_age
        assert len(medics) == patients // hospital_distributor.patients_per_medic


def test__distribution_of_medics_from_world(geography_hospital, hospitals):

    hospital_distributor = HospitalDistributor(
        hospitals, medic_min_age=20, patients_per_medic=10
    )
    hospital_distributor.distribute_medics_from_world(
        geography_hospital.super_areas.members[0].people
    )
    for hospital in hospitals:
        patients = hospital.n_beds + hospital.n_icu_beds
        medics = hospital.subgroups[hospital.SubgroupType.workers].people
        for medic in medics:
            assert medic.age >= hospital_distributor.medic_min_age
        assert len(medics) == patients // hospital_distributor.patients_per_medic


from collections import OrderedDict

import numpy as np
import pytest

from june.demography.person import Person
from june.distributors import HouseholdDistributor


class MockHouseholds:
    def __init__(self):
        self.members = []


class MockWorld:
    def __init__(self):
        self.households = MockHouseholds()


class MockArea:
    def __init__(self, age_min=0, age_max=99, people_per_age=5):
        self.create_dicts(age_min, age_max, people_per_age)
        self.n_people = (age_max - age_min + 1) * people_per_age
        self.world = MockWorld()
        self.households = []

    def create_dicts(self, age_min, age_max, people_per_age):
        self.men_by_age = create_men_by_age_dict(age_min, age_max, people_per_age)
        self.women_by_age = create_women_by_age_dict(age_min, age_max, people_per_age)


def create_men_by_age_dict(age_min=0, age_max=99, people_per_age=5):
    ages = np.arange(age_min, age_max)
    men_by_age = OrderedDict({})
    for age in ages:
        men_by_age[age] = []
        for _ in range(0, people_per_age):
            man = Person.from_attributes(sex=0, age=age)
            men_by_age[age].append(man)
    return men_by_age


def create_women_by_age_dict(age_min=0, age_max=99, people_per_age=5):
    ages = np.arange(age_min, age_max)
    women_by_age = OrderedDict({})
    for age in ages:
        women_by_age[age] = []
        for _ in range(0, people_per_age):
            woman = Person.from_attributes(sex=1, age=age)
            women_by_age[age].append(woman)
    return women_by_age


def create_area(age_min=0, age_max=99, people_per_age=5):
    area = MockArea(age_min, age_max, people_per_age)
    return area


@pytest.fixture(name="household_distributor", scope="module")
def create_household_distributor():
    first_kid_parent_age_differences = {20: 0.5, 21: 0.5}
    second_kid_parent_age_differences = {30: 0.5, 31: 0.5}
    couples_age_differences = {0: 0.5, 1: 0.5}
    hd = HouseholdDistributor(
        first_kid_parent_age_differences,
        second_kid_parent_age_differences,
        couples_age_differences,
    )
    return hd


class TestAuxiliaryFunctions:
    def test__get_closest_person_of_age(self, household_distributor):
        area = create_area(people_per_age=1)
        # check normal use
        age = 35
        man = household_distributor._get_closest_person_of_age(
            area.men_by_age, area.women_by_age, age
        )
        assert man.sex == 0
        assert man.age == 35
        assert 35 not in area.men_by_age.keys()  # check key has been deleted

        age = 0
        kid = household_distributor._get_closest_person_of_age(
            area.women_by_age, area.men_by_age, 0
        )
        assert kid.sex == 1
        assert kid.age == 0

        # assert returns none when can't find someone in the allowed age range
        none_person = household_distributor._get_closest_person_of_age(
            area.men_by_age, area.women_by_age, 45, min_age=20, max_age=25
        )
        assert none_person is None

        for key in range(40, 51):
            del area.men_by_age[key]
        none_person = household_distributor._get_closest_person_of_age(
            area.men_by_age, {}, 45, min_age=40, max_age=50
        )
        assert none_person is None

        # assert return opposite sex if the option is available
        woman = household_distributor._get_closest_person_of_age(
            area.men_by_age, area.women_by_age, 45, min_age=40, max_age=50
        )
        assert woman.sex == 1
        assert woman.age == 45

    def test__get_random_person_in_age_bracket(self, household_distributor):
        area = create_area(people_per_age=1)
        # check normal use
        person_1 = household_distributor._get_random_person_in_age_bracket(
            area.men_by_age, area.women_by_age, min_age=18, max_age=18
        )
        person_2 = household_distributor._get_random_person_in_age_bracket(
            area.men_by_age, area.women_by_age, min_age=18, max_age=18
        )
        assert person_1.age == 18
        assert person_2.age == 18
        assert person_1.sex != person_2.sex

    def test__get_matching_partner_is_correct(self, household_distributor):
        area = create_area(people_per_age=5)
        man = Person.from_attributes(sex=0, age=40)
        woman = household_distributor._get_matching_partner(
            man, area.men_by_age, area.women_by_age
        )
        assert woman.sex == 1
        assert (woman.age == 40) or (woman.age == 41)
        woman = Person.from_attributes(sex=1, age=40)
        man = household_distributor._get_matching_partner(
            woman, area.men_by_age, area.women_by_age
        )
        assert man.sex == 0
        assert (man.age == 40) or (man.age == 41)
        # check option to get under or over 65
        person = Person.from_attributes(sex=1, age=76)
        partner = household_distributor._get_matching_partner(
            person, area.men_by_age, area.women_by_age, under_65=True
        )
        assert partner.age < 65
        assert partner.sex == 0
        partner2 = household_distributor._get_matching_partner(
            person, area.men_by_age, area.women_by_age, over_65=True
        )
        assert partner2.age > 65
        assert partner2.sex == 0
        # check we get same sex if not available
        area.men_by_age = {}
        woman = household_distributor._get_matching_partner(
            woman, area.men_by_age, area.women_by_age
        )
        assert woman.sex == 1
        assert (woman.age == 40) or (woman.age == 41)

    def test__get_matching_parent(self, household_distributor):
        area = create_area()
        kid = Person.from_attributes(age=10)
        parent = household_distributor._get_matching_parent(
            kid, area.men_by_age, area.women_by_age
        )
        assert parent.age == 30 or parent.age == 31
        assert parent.sex == 1

        # check if no adult women available it returns men
        age_min_parent = 18
        age_max_parent = household_distributor.max_age_to_be_parent
        for key in range(age_min_parent, age_max_parent + 1):
            del area.women_by_age[key]
        male_parent = household_distributor._get_matching_parent(
            kid, area.men_by_age, area.women_by_age
        )
        assert male_parent.sex == 0
        assert male_parent.age == 30 or male_parent.age == 31

        # check if no adults available it returns None
        for key in range(age_min_parent, age_max_parent + 1):
            del area.men_by_age[key]
        none_parent = household_distributor._get_matching_parent(
            kid, area.men_by_age, area.women_by_age
        )
        assert none_parent is None

    def test__get_matching_second_kid(self, household_distributor):
        area = create_area()
        parent = Person.from_attributes(age=20)
        kid = household_distributor._get_matching_second_kid(
            parent, area.men_by_age, area.women_by_age
        )
        assert kid.age == 0
        parent = Person.from_attributes(age=35)
        kid = household_distributor._get_matching_second_kid(
            parent, area.men_by_age, area.women_by_age
        )
        assert kid.age == 5 or kid.age == 4
        parent = Person.from_attributes(age=80)
        kid = household_distributor._get_matching_second_kid(
            parent, area.men_by_age, area.women_by_age
        )
        assert kid.age == 17


class TestIndividualHouseholdCompositions:
    def test__fill_all_student_households(self, household_distributor):
        area = create_area(age_min=15, age_max=30, people_per_age=5)  # enough students
        # I put these limits to narrow the age range and make it faster, but
        # they do not reflect the expected age of students
        area.households = household_distributor.fill_all_student_households(
            area.men_by_age,
            area.women_by_age,
            area,
            n_students=20,
            student_houses_number=5,
            composition_type=None,
        )
        assert len(area.households) == 5
        counter = 0
        for household in area.households:
            for person in household.people:
                counter += 1
                assert (
                    household_distributor.student_min_age
                    <= person.age
                    <= household_distributor.student_max_age
                )
        assert counter == 20
        area.households = []
        area.households = household_distributor.fill_all_student_households(
            area.men_by_age,
            area.women_by_age,
            area,
            n_students=11,
            student_houses_number=3,
            composition_type=None,
        )
        assert len(area.households) == 3

    def test__fill_oldpeople_households(self, household_distributor):
        area = create_area(age_min=50, age_max=100, people_per_age=20)
        # I put these limits to narrow the age range and make it faster, but
        # they do not reflect the expected age of old people
        households_with_extrapeople_list = []
        area.households = household_distributor.fill_oldpeople_households(
            area.men_by_age,
            area.women_by_age,
            2,
            10,
            area,
            extra_people_lists=(households_with_extrapeople_list,),
            composition_type=None,
        )
        assert len(households_with_extrapeople_list) == 10
        assert len(area.households) == 10
        for household in area.households:
            assert len(household.people) == 2
            for person in household.people:
                assert person.age >= household_distributor.old_min_age
        households_with_extrapeople_list = []
        area.households += household_distributor.fill_oldpeople_households(
            area.men_by_age,
            area.women_by_age,
            2,
            10,
            area,
            extra_people_lists=(households_with_extrapeople_list,),
            max_household_size=2,
            composition_type=None,
        )
        assert len(area.households) == 20
        assert len(households_with_extrapeople_list) == 0  # no spaces left
        for household in area.households:
            assert len(household.people) == 2
            for person in household.people:
                assert person.age >= household_distributor.old_min_age

    def test__fill_families_households(self, household_distributor):
        area = create_area(people_per_age=20, age_max=65)
        households_with_extrapeople_list = []
        area.households = household_distributor.fill_families_households(
            area.men_by_age,
            area.women_by_age,
            n_households=10,
            kids_per_house=2,
            parents_per_house=2,
            old_per_house=0,
            area=area,
            extra_people_lists=(households_with_extrapeople_list,),
            composition_type=None,
        )
        assert len(households_with_extrapeople_list) == 10
        assert len(area.households) == 10
        for household in area.households:
            assert len(household.people) == 4
            no_of_kids = 0
            no_of_adults = 0
            mother = None
            father = None
            kid_1 = None
            kid_2 = None
            for person in household.people:
                if person.age >= 18:
                    no_of_adults += 1
                    if person.sex == 1:
                        mother = person
                    else:
                        father = person
                else:
                    if kid_1 is None:
                        kid_1 = person
                    else:
                        kid_2 = person
                    no_of_kids += 1
            assert no_of_adults == 2
            assert no_of_kids == 2
            assert father is not None
            assert mother is not None
            if kid_1.age < kid_2.age:
                kid_1, kid_2 = kid_2, kid_1
            assert (mother.age - kid_1.age <= 20) or (mother.age - kid_1.age <= 21)
            assert (mother.age - kid_2.age <= 30) or (mother.age - kid_2.age <= 31)
            assert (father.age - mother.age) in [-1, 0, 1] or (
                father.age - mother.age
            ) in [-1, 0, 1]

    def test__fill_nokids_households(self, household_distributor):
        area = create_area(age_min=18, people_per_age=10, age_max=60)
        households_with_extrapeople_list = []
        area.households = household_distributor.fill_nokids_households(
            area.men_by_age,
            area.women_by_age,
            adults_per_household=2,
            n_households=10,
            area=area,
            extra_people_lists=(households_with_extrapeople_list,),
            composition_type=None,
        )
        assert len(households_with_extrapeople_list) == 10
        assert len(area.households) == 10
        for household in area.households:
            man = None
            woman = None
            oldpeople = 0
            for person in household.people:
                assert (
                    household_distributor.adult_min_age
                    <= person.age
                    <= household_distributor.old_max_age
                )
                if person.age >= household_distributor.old_min_age:
                    oldpeople += 1
                if person.sex == 0:
                    man = person
                else:
                    woman = person
            assert man is not None
            assert woman is not None
            assert oldpeople <= 1

    def test__fill_youngadult_households(self, household_distributor):
        area = create_area(age_min=15, age_max=40, people_per_age=5)
        households_with_extrapeople_list = []
        area.households = household_distributor.fill_youngadult_households(
            area.men_by_age,
            area.women_by_age,
            3,
            20,
            area,
            extra_people_lists=(households_with_extrapeople_list,),
            composition_type=None,
        )
        assert len(households_with_extrapeople_list) == 20
        assert len(area.households) == 20
        for household in area.households:
            for person in household.people:
                assert (
                    household_distributor.adult_min_age
                    <= person.age
                    <= household_distributor.young_adult_max_age
                )

    def test__fill_youngadult_with_parents_households(self, household_distributor):
        area = create_area(age_min=15, age_max=40, people_per_age=5)
        households_with_extrapeople_list = []
        area.households = household_distributor.fill_youngadult_households(
            area.men_by_age,
            area.women_by_age,
            3,
            20,
            area,
            extra_people_lists=(households_with_extrapeople_list,),
            composition_type=None,
        )
        assert len(households_with_extrapeople_list) == 20
        assert len(area.households) == 20
        for household in area.households:
            for person in household.people:
                assert (
                    household_distributor.adult_min_age
                    <= person.age
                    <= household_distributor.young_adult_max_age
                )

    def test__fill_communal_establishments(self, household_distributor):
        area = create_area(people_per_age=5)
        area.households = household_distributor.fill_all_communal_establishments(
            area.men_by_age,
            area.women_by_age,
            n_establishments=5,
            n_people_in_communal=20,
            area=area,
            composition_type=None,
        )
        assert len(area.households) == 5
        for household in area.households:
            assert len(household.people) == 4
        area.households = []
        area.households = household_distributor.fill_all_communal_establishments(
            area.men_by_age,
            area.women_by_age,
            n_establishments=2,
            n_people_in_communal=7,
            area=area,
            composition_type=None,
        )
        assert len(area.households) == 2
        for household in area.households:
            assert len(household.people) in [3, 4]


class TestMultipleHouseholdCompositions:
    def test__area_is_filled_properly_1(self, household_distributor):
        area = create_area(people_per_age=0)
        men_by_age_counts = {5: 4, 50: 4, 75: 3}  # kids  # adults  # old people
        area.men_by_age = OrderedDict({})
        area.women_by_age = OrderedDict({})
        for age in men_by_age_counts.keys():
            area.men_by_age[age] = []
            for _ in range(men_by_age_counts[age]):
                person = Person.from_attributes(age=age)
                area.men_by_age[age].append(person)
        composition_numbers = {"1 0 >=0 1 0": 4, "0 0 0 0 1": 1, "0 0 0 0 2": 1}
        area.households = household_distributor.distribute_people_to_households(
            area.men_by_age, area.women_by_age, area, composition_numbers, 0, 0
        )
        assert len(area.households) == 6
        total_people = 0
        for household in area.households:
            assert len(household.people) <= 6
            kids = 0
            adults = 0
            youngadults = 0
            old = 0
            for person in household.people:
                if 0 <= person.age < household_distributor.adult_min_age:
                    kids += 1
                elif (
                    household_distributor.adult_min_age
                    <= person.age
                    <= household_distributor.young_adult_max_age
                ):
                    youngadults += 1
                elif (
                    household_distributor.adult_min_age
                    <= person.age
                    < household_distributor.old_min_age
                ):
                    adults += 1
                else:
                    old += 1
            assert kids in [0, 1]
            assert adults in [0, 1]
            assert youngadults in range(0, 7)
            assert old in range(0, 3)
            total_people += old + kids + adults + youngadults

        assert total_people == 11

    def test__area_is_filled_properly_2(self, household_distributor):
        area = create_area(people_per_age=0)
        men_by_age_counts = {
            23: 5,  # young adults or students
            50: 4,  # adults
            75: 3,  # old people
        }
        area.men_by_age = OrderedDict({})
        area.women_by_age = OrderedDict({})
        for age in men_by_age_counts.keys():
            area.men_by_age[age] = []
            for _ in range(men_by_age_counts[age]):
                person = Person.from_attributes(age=age)
                area.men_by_age[age].append(person)
        composition_numbers = {
            "0 0 0 2 0": 2,
            "0 >=1 0 0 0": 1,
            "0 0 0 0 2": 1,
            "0 0 0 0 1": 1,
        }
        area.households = household_distributor.distribute_people_to_households(
            area.men_by_age, area.women_by_age, area, composition_numbers, 5, 0
        )
        assert len(area.households) == 5
        total_people = 0
        for household in area.households:
            assert len(household.people) in [1, 2, 5]
            adults = 0
            youngadults = 0
            old = 0
            for person in household.people:
                if (
                    household_distributor.adult_min_age
                    <= person.age
                    <= household_distributor.young_adult_max_age
                ):
                    youngadults += 1
                elif (
                    household_distributor.adult_min_age
                    <= person.age
                    < household_distributor.old_min_age
                ):
                    adults += 1
                else:
                    old += 1
            assert adults in [0, 2]
            assert youngadults in [0, 5]
            assert old in range(0, 3)
            total_people += old + adults + youngadults

        assert total_people == 12

    def test__area_is_filled_properly_3(self, household_distributor):
        area = create_area(people_per_age=0)
        men_by_age_counts = {5: 3, 50: 14}  # kids  # adults
        area.men_by_age = OrderedDict({})
        area.women_by_age = OrderedDict({})
        for age in men_by_age_counts.keys():
            area.men_by_age[age] = []
            for _ in range(men_by_age_counts[age]):
                person = Person.from_attributes(age=age)
                area.men_by_age[age].append(person)
        composition_numbers = {
            "1 0 >=0 2 0": 1,
            ">=2 0 >=0 2 0": 1,
            ">=0 >=0 >=0 >=0 >=0": 2,
        }
        area.households = household_distributor.distribute_people_to_households(
            area.men_by_age, area.women_by_age, area, composition_numbers, 0, 10
        )
        assert len(area.households) == 4
        total_people = 0
        for household in area.households:
            assert len(household.people) in [3, 4, 5]
            kids = 0
            adults = 0
            for person in household.people:
                if 0 <= person.age < household_distributor.adult_min_age:
                    kids += 1
                elif (
                    household_distributor.adult_min_age
                    <= person.age
                    <= household_distributor.adult_max_age
                ):
                    adults += 1
            assert kids in [0, 1, 2]
            assert adults in [2, 5]
            total_people += kids + adults

        assert total_people == 17


# class TestSpecificArea:
#    """
#    Let's carefully check the first output area of the test set.
#    This area has no carehomes so we don't have to account for them.
#    The area E00062207 has this configuration:
#    0 0 0 0 1              15
#    0 0 0 1 0              20
#    0 0 0 0 2              11
#    0 0 0 2 0              24
#    1 0 >=0 2 0            12
#    >=2 0 >=0 2 0           9
#    0 0 >=1 2 0             6
#    1 0 >=0 1 0             5
#    >=2 0 >=0 1 0           3
#    0 0 >=1 1 0             7
#    1 0 >=0 >=1 >=0         0
#    >=2 0 >=0 >=1 >=0       1
#    0 >=1 0 0 0             0
#    0 0 0 0 >=2             0
#    0 0 >=0 >=0 >=0         1
#    >=0 >=0 >=0 >=0 >=0     0
#    Name: E00062207, dtype: int64
#    """
#    @pytest.fixture(name="example_area", scope="module")
#    def make_geo(self):
#        geo = Geography.from_file({"oa": ["E00062207"]})
#        dem = Demography.for_geography(geo)
#        geo.areas[0].populate(dem)
#        return geo.areas[0]
#
#    @pytest.fixture(name="hd_area", scope="module")
#    def populate_example_area(self, example_area):
#        area = example_area
#        household_distributor = HouseholdDistributor.from_file()
#        household_distributor.distribute_people_and_households_to_areas(
#            [area],
#        )
#        return household_distributor
#
#    def test__all_household_have_reasonable_size(
#        self, example_area, hd_area
#    ):
#        sizes_dict = {}
#        for household in example_area.households:
#            size = len(household.people)
#            if size not in sizes_dict:
#                sizes_dict[size] = 0
#            sizes_dict[size] += 1
#
#        assert max(list(sizes_dict.keys())) <= 8
#        assert sizes_dict[2] >= 35
#        assert sizes_dict[1] >= 35
#
#    def test__oldpeople_have_suitable_accomodation(
#        self, example_area,
#    ):
#        """
#        run the test ten times to be sure
#        """
#        area = example_area
#        oldpeople_household_sizes = {}
#        maxsize = 0
#        for household in area.households:
#            has_old_people = False
#            house_size = 0
#            for person in household.people:
#                house_size += 1
#                if person.age >= 65:
#                    has_old_people = True
#            if has_old_people:
#                if house_size not in oldpeople_household_sizes:
#                    oldpeople_household_sizes[house_size] = 0
#                oldpeople_household_sizes[house_size] += 1
#            if house_size > maxsize:
#                maxsize = house_size
#
#        # only the three generation family can have more than 3 people in it
#        big_houses = 0
#        for size in oldpeople_household_sizes.keys():
#            if size > 3:
#                big_houses += 1
#        assert big_houses <= 1
#
#    def test__kids_live_in_families(self, example_area):
#        area = example_area
#        kids_household_sizes = {}
#        for household in area.households:
#            has_kids = False
#            has_adults = False
#            house_size = 0
#            for person in household.people:
#                house_size += 1
#                if person.age <= 17:
#                    has_kids = True
#                else:
#                    has_adults = True
#            if has_kids:
#                assert has_adults
#                if house_size not in kids_household_sizes:
#                    kids_household_sizes[house_size] = 0
#                kids_household_sizes[house_size] += 1
#        # only big family is the multigenerational one
#        for size in kids_household_sizes.keys():
#            assert size <= 8
#
#    def test__most_couples_are_heterosexual(self, example_area):
#        different_sex = 0
#        total = 0
#        for household in example_area.households:
#            if len(household.people) == 2:
#                if household.people[0].sex != household.people[1].sex:
#                    different_sex += 1
#                    total += 1
#                else:
#                    total += 1
#
#        assert different_sex / total > 0.65
#
#    def test__household_size_is_acceptable(self, example_area):
#        for household in example_area.households:
#            size = len(household.people)
#            assert size <= 8


# ##class TestSpecificArea2:
# ##    """
# ##    Let's carefully check the first output area of the test set.
# ##    This area has no carehomes so we don't have to account for them.
# ##    The area E00062386 has this configuration:
# ##    0 0 0 0 1               9
# ##    0 0 0 1 0              11
# ##    0 0 0 0 2              20
# ##    0 0 0 2 0              29
# ##    1 0 >=0 2 0             5
# ##    >=2 0 >=0 2 0          12
# ##    0 0 >=1 2 0            13
# ##    1 0 >=0 1 0             6
# ##    >=2 0 >=0 1 0           2
# ##    0 0 >=1 1 0             0
# ##    1 0 >=0 >=1 >=0         1
# ##    >=2 0 >=0 >=1 >=0       0
# ##    0 >=1 0 0 0             0
# ##    0 0 0 0 >=2             1
# ##    0 0 >=0 >=0 >=0         1
# ##    >=0 >=0 >=0 >=0 >=0     0
# ##    Name: E00062386, dtype: int64
# ##    """
# ##    @pytest.fixture(name="example_area2", scope="module")
# ##    def make_geo(self):
# ##        geo = Geography.from_file({"oa": ["E00062386"]})
# ##        dem = Demography.for_geography(geo)
# ##        geo.areas[0].populate(dem)
# ##        return geo.areas[0]
# ##
# ##    @pytest.fixture(name="hd_area2", scope="module")
# ##    def populate_example_area(self, example_area2):
# ##        area = example_area2
# ##        household_distributor = HouseholdDistributor.from_file()
# ##        household_distributor.distribute_people_and_households_to_areas(
# ##            [area],
# ##        )
# ##        return household_distributor
# ##
# ##    def test__households_of_size1(self, hd_area2, example_area2):
# ##        area = example_area2
# ##        households_one = 0
# ##        for household in area.households:
# ##            if len(household.people) == 1:
# ##                households_one += 1
# ##        assert households_one == 20


import os
from pathlib import Path

import numpy as np
import pytest

from june.world import generate_world_from_geography
from june.geography import Geography
from june.groups.school import Schools
from june.distributors.school_distributor import SchoolDistributor

default_config_filename = (
    Path(os.path.abspath(__file__)).parent.parent.parent.parent
    / "configs/defaults/distributors/school_distributor.yaml"
)

default_mandatory_age_range = (5, 18)


@pytest.fixture(name="geography_school", scope="module")
def create_geography():
    geography = Geography.from_file({"super_area": ["E02004935"]})
    return geography


@pytest.fixture(name="school_world", scope="module")
def make_and_populate_schools(geography_school):
    schools = Schools.for_geography(geography_school)
    school_distributor = SchoolDistributor(schools)
    school_distributor.distribute_kids_to_school(geography_school.areas)
    geography_school.schools = Schools.for_geography(geography_school)
    world = generate_world_from_geography(geography_school, include_households=False)
    return world


def test__years_mapping(school_world):
    for school in school_world.schools:
        for subgroup in school.subgroups:
            if subgroup.subgroup_type != 0:
                for person in subgroup.people:
                    assert person.age == school.years[subgroup.subgroup_type - 1]


def test__all_kids_mandatory_school(school_world):
    """
    Check that all kids in mandatory school ages are assigned a school
    """
    KIDS_LOW = default_mandatory_age_range[0]
    KIDS_UP = default_mandatory_age_range[1]
    lost_kids = 0
    for area in school_world.areas.members:
        for person in area.people:
            if (person.age >= KIDS_LOW) and (person.age <= KIDS_UP):
                if (
                    person.primary_activity is None
                    or person.primary_activity.group.spec != "school"
                    and person.primary_activity.subgroup_type != 0
                ):
                    lost_kids += 1
    assert lost_kids == 0


def test__only_kids_school(school_world):
    ADULTS_LOW = 20
    schooled_adults = 0
    for area in school_world.areas:
        for person in area.people:
            if person.age >= ADULTS_LOW:
                if (
                    person.primary_activity is not None
                    and person.primary_activity.group.spec == "school"
                    and person.primary_activity.subgroup_type != 0
                ):
                    schooled_adults += 1

    assert schooled_adults == 0


def test__n_pupils_counter(school_world):
    schools = school_world.schools
    for school in schools.members:
        n_pupils = np.sum(
            [
                len(grouping.people)
                for grouping in school.subgroups
                if grouping.subgroup_type != 0
            ]
        )
        assert n_pupils == school.n_pupils


def test__age_range_schools(school_world):
    schools = school_world.schools
    n_outside_range = 0
    for school in schools.members:
        for person in school.people:
            if person.primary_activity.subgroup_type != 0:
                if person.age < school.age_min or person.age > school.age_max:
                    n_outside_range += 1
    assert n_outside_range == 0


def test__non_mandatory_dont_go_if_school_full(school_world):
    non_mandatory_added = 0
    mandatory_age_range = default_mandatory_age_range
    schools = school_world.schools
    for school in schools.members:
        if school.n_pupils > school.n_pupils_max:
            ages = np.array(
                [
                    person.age
                    for person in list(
                        sorted(school.students, key=lambda person: person.age)
                    )[int(school.n_pupils_max) :]
                ]
            )
            older_kids_when_full = np.sum(ages > mandatory_age_range[1])
            younger_kids_when_full = np.sum(ages < mandatory_age_range[0])
            if older_kids_when_full > 0 or younger_kids_when_full > 0:
                non_mandatory_added += 1

    assert non_mandatory_added == 0


def test__teacher_distribution(school_world):
    for school in school_world.schools:
        students = len(school.students)
        teachers = len(school.teachers.people)
        ratio = students / teachers
        assert ratio < 40


def test__limit_classroom_sizes(school_world):
    school_distributor = SchoolDistributor(school_world.schools, max_classroom_size=3)
    school_distributor.limit_classroom_sizes()
    for school in school_world.schools:
        for subgroup in school.subgroups:
            if subgroup.subgroup_type != 0:
                assert len(subgroup.people) <= school_distributor.max_classroom_size
                for person in subgroup.people:
                    assert person.age == school.years[subgroup.subgroup_type - 1]
        n_pupils = np.sum(
            [
                len(grouping.people)
                for grouping in school.subgroups
                if grouping.subgroup_type != 0
            ]
        )
        assert n_pupils == school.n_pupils


import numpy as np
from pytest import fixture
from random import random

from june.groups.leisure import SocialVenueDistributor
from june.groups.leisure import SocialVenue, SocialVenues
from june.utils.parse_probabilities import parse_age_probabilities
from june.demography import Person


@fixture(name="social_venues", scope="module")
def make_social_venues():
    ll = []
    for _ in range(10):
        social_venue = SocialVenue()
        social_venue.coordinates = np.array([1, 2])
        ll.append(social_venue)
    ll[-1].coordinates = np.array([10, 10])
    social_venues = SocialVenues(ll)
    return social_venues


@fixture(name="sv_input", scope="module")
def make_input():
    times_per_week = {
        "weekday": {
            "male": {"18-65": 2, "65-100": 1},
            "female": {"18-65": 1, "65-100": 0.5},
        },
        "weekend": {"male": {"18-100": 3}, "female": {"18-100": 3}},
    }
    hours_per_day = {
        "weekday": {
            "male": {"18-65": 3, "65-100": 11},
            "female": {"18-65": 3, "65-100": 11},
        },
        "weekend": {"male": {"18-100": 12}, "female": {"18-100": 12}},
    }
    return times_per_week, hours_per_day


@fixture(name="social_venue_distributor", scope="module")
def make_distributor(social_venues, sv_input):
    times_per_week, hours_per_day = sv_input
    return SocialVenueDistributor(
        social_venues,
        times_per_week=times_per_week,
        hours_per_day=hours_per_day,
        maximum_distance=30,
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )


class TestInput:
    def test__age_dict_parsing(self):
        age_dict = {"40-60": 0.4, "10-20": 0.2}
        probabilities_per_age = parse_age_probabilities(age_dict)
        for idx, prob in enumerate(probabilities_per_age):
            if idx < 10:
                assert prob == 0.0
            elif idx < 20:
                assert prob == 0.2
            elif idx < 40:
                assert prob == 0.0
            elif idx < 60:
                assert prob == 0.4
            else:
                assert prob == 0.0

    def test__read_input_times_a_week(self, social_venue_distributor):
        poisson_parameters = social_venue_distributor.poisson_parameters
        for age in range(0, 18):
            assert poisson_parameters["weekday"]["m"][age] == 0
            assert poisson_parameters["weekday"]["f"][age] == 0
            assert poisson_parameters["weekend"]["m"][age] == 0
            assert poisson_parameters["weekend"]["f"][age] == 0
        for age in range(18, 65):
            assert np.isclose(
                poisson_parameters["weekday"]["m"][age], 2 * 1 / 5 * 24 / 3
            )
            assert np.isclose(
                poisson_parameters["weekday"]["f"][age], 1 * 1 / 5 * 24 / 3
            )
            assert np.isclose(
                poisson_parameters["weekend"]["m"][age], 3 * 1 / 2 * 24 / 12
            )
            assert np.isclose(
                poisson_parameters["weekend"]["f"][age], 3 * 1 / 2 * 24 / 12
            )
        for age in range(65, 100):
            assert np.isclose(
                poisson_parameters["weekday"]["m"][age], 1 * 1 / 5 * 24 / 11
            )
            assert np.isclose(
                poisson_parameters["weekday"]["f"][age], 0.5 * 1 / 5 * 24 / 11
            )
            assert np.isclose(
                poisson_parameters["weekend"]["m"][age], 3 * 1 / 2 * 24 / 12
            )
            assert np.isclose(
                poisson_parameters["weekend"]["f"][age], 3 * 1 / 2 * 24 / 12
            )


class TestProbabilities:
    def get_n_times_a_week(self, person, delta_time, day_type, distrib):
        if day_type == "weekday":
            max_time = 5
        else:
            max_time = 2
        times = []
        for _ in range(100):
            time = 0
            times_this_week = 0
            while time < max_time:
                time += 0.25
                probability = distrib.probability_to_go_to_social_venue(
                    person,
                    delta_time=delta_time,
                    day_type=day_type,
                    working_hours=False,
                )
                if random() < probability:
                    times_this_week += 1
            times.append(times_this_week)
        return np.mean(times)

    def test__decide_person_goes_to_social_venue(
        self, social_venue_distributor, sv_input
    ):
        times_per_week, hours_per_day = sv_input

        # young weekday #
        dt = 3 / 4 / 24  # in days
        person = Person(age=40, sex="m")
        times_per_week_weekday = times_per_week["weekday"]["male"]["18-65"]
        rest = self.get_n_times_a_week(person, dt, "weekday", social_venue_distributor)
        assert np.isclose(rest, times_per_week_weekday, atol=0, rtol=0.2)

        # young weekend #
        dt = 3 / 24
        times_per_week_weekend = times_per_week["weekend"]["male"]["18-100"]
        rest = self.get_n_times_a_week(person, dt, "weekend", social_venue_distributor)
        assert np.isclose(rest, times_per_week_weekend, atol=0, rtol=0.2)

        # retired weekday
        dt = 11 / 4 / 24
        person = Person(age=68, sex="f")
        times_per_week_weekday = times_per_week["weekday"]["female"]["65-100"]
        rest = self.get_n_times_a_week(person, dt, "weekday", social_venue_distributor)
        assert np.isclose(rest, times_per_week_weekday, atol=0, rtol=0.2)

        # retired weekend
        dt = 3 / 24
        times_per_week_weekend = times_per_week["weekend"]["female"]["18-100"]
        rest = self.get_n_times_a_week(person, dt, "weekend", social_venue_distributor)
        assert np.isclose(rest, times_per_week_weekend, atol=0, rtol=0.2)

    class MockArea:
        def __init__(self):
            self.coordinates = np.array([10, 11])


from june.distributors import UniversityDistributor
from june.groups import Universities
from june.geography import Geography
from june.world import generate_world_from_geography

import pytest


@pytest.fixture(name="world")
def create_world():
    geography = Geography.from_file(
        {"super_area": ["E02004314", "E02004315", "E02004313"]}
    )
    world = generate_world_from_geography(geography, include_households=True)
    return world


def test__students_go_to_uni(world):
    universities = Universities.for_areas(world.areas)
    durham = universities[0]
    university_distributor = UniversityDistributor(universities)
    university_distributor.distribute_students_to_universities(
        areas=world.areas, people=world.people
    )
    assert durham.n_students > 6000


import unittest

import numpy as np
import pytest
import yaml

from june.demography import Demography, Population
from june.geography import Geography
from june import paths
from june.distributors import WorkerDistributor, load_workflow_df, load_sex_per_sector

default_config_file = (
    paths.configs_path / "defaults/distributors/worker_distributor.yaml"
)


@pytest.fixture(name="worker_config", scope="module")
def load_config():
    with open(default_config_file) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


@pytest.fixture(name="worker_super_areas", scope="module")
def use_super_areas():
    return ["E02002559", "E02002560", "E02002561", "E02002665"]  # E00064524


@pytest.fixture(name="worker_geography", scope="module")
def create_geography(worker_super_areas):
    return Geography.from_file(filter_key={"super_area": worker_super_areas})


@pytest.fixture(name="worker_demography", scope="module")
def create_demography(worker_geography):
    return Demography.for_geography(worker_geography)


@pytest.fixture(name="worker_population", scope="module")
def create_population(worker_geography, worker_demography):
    population = list()
    for area in worker_geography.areas:
        area.populate(worker_demography)
        population.extend(area.people)
    distributor = WorkerDistributor.for_geography(worker_geography)
    distributor.distribute(
        areas=worker_geography.areas,
        super_areas=worker_geography.super_areas,
        population=population,
    )
    return population


def test__load_workflow_df(worker_super_areas):
    wf_df = load_workflow_df(area_names=worker_super_areas)
    assert wf_df["n_man"].sum() == len(worker_super_areas)
    assert wf_df["n_woman"].sum() == len(worker_super_areas)


def test__load_sex_per_sector(worker_super_areas):
    sector_by_sex_df = load_sex_per_sector(area_names=worker_super_areas)
    m_columns = [col for col in sector_by_sex_df.columns.values if "m " in col]
    f_columns = [col for col in sector_by_sex_df.columns.values if "f " in col]
    m_sum = sector_by_sex_df.loc[:, m_columns].sum(axis="columns").values
    f_sum = sector_by_sex_df.loc[:, f_columns].sum(axis="columns").values
    m_unic_sum = np.sum(np.unique(m_sum))
    f_unic_sum = np.sum(np.unique(f_sum))
    assert m_unic_sum == len(worker_super_areas)
    assert f_unic_sum == len(worker_super_areas)


class TestInitialization:
    def test__distributor_from_file(self, worker_super_areas: list):
        WorkerDistributor.from_file(area_names=worker_super_areas)

    def test__distributor_from_geography(
        self, worker_geography: Geography, worker_population: Population
    ):
        WorkerDistributor.for_geography(worker_geography)


class TestDistribution:
    def test__workers_stay_in_geography(
        self,
        worker_config: dict,
        worker_super_areas: list,
        worker_population: Population,
    ):
        case = unittest.TestCase()
        work_super_area_name = np.array(
            [
                person.work_super_area.name
                for person in worker_population
                if (
                    worker_config["age_range"][0]
                    <= person.age
                    <= worker_config["age_range"][1]
                )
                and person.work_super_area is not None
            ]
        )
        work_super_area_name = list(np.unique(work_super_area_name))
        case.assertCountEqual(work_super_area_name, worker_super_areas)

    def test__workers_that_stay_home(
        self, worker_config: dict, worker_population: Population
    ):
        nr_working_from_home = len(
            [
                person.work_super_area
                for person in worker_population
                if (
                    worker_config["age_range"][0]
                    <= person.age
                    <= worker_config["age_range"][1]
                )
                and person.work_super_area is None
            ]
        )
        assert 0.050 < nr_working_from_home / len(worker_population) < 0.070

    def test__worker_nr_in_sector_larger_than_its_sub(
        self, worker_config: dict, worker_population: Population
    ):
        occupations = np.array(
            [
                [person.sex, person.sector, person.sub_sector]
                for person in worker_population
                if person.sector in list(worker_config["sub_sector_ratio"].keys())
            ]
        ).T
        p_sex = occupations[0]
        p_sectors = occupations[1][p_sex == "m"]
        p_sub_sectors = occupations[2][p_sex == "m"]
        for sector in list(worker_config["sub_sector_ratio"].keys()):
            idx = np.where(p_sectors == sector)[0]
            sector_worker_nr = len(idx)
            p_sub_sector = p_sub_sectors[idx]
            sub_sector_worker_nr = len(p_sub_sector[p_sub_sector is not None])
            assert sector_worker_nr > sub_sector_worker_nr

    def test__worker_super_area(self, worker_population, worker_geography):
        has_workers = False
        for super_area in worker_geography.super_areas:
            for worker in super_area.workers:
                has_workers = True
                assert worker.work_super_area == super_area
        assert has_workers


from june import paths
from june.epidemiology.infection.health_index.data_to_rates import (
    read_comorbidity_csv,
    convert_comorbidities_prevalence_to_dict,
)
import pytest


def test__parse_comorbidity_prevalence():
    male_filename = paths.data_path / "input/demography/uk_male_comorbidities.csv"
    female_filename = paths.data_path / "input/demography/uk_female_comorbidities.csv"
    prevalence_female = read_comorbidity_csv(female_filename)
    prevalence_male = read_comorbidity_csv(male_filename)
    for value in prevalence_female.sum(axis=1):
        assert value == pytest.approx(1.0)
    for value in prevalence_male.sum(axis=1):
        assert value == pytest.approx(1.0)

    prevalence_dict = convert_comorbidities_prevalence_to_dict(
        prevalence_female, prevalence_male
    )
    assert prevalence_dict["sickle_cell"]["m"]["0-4"] == pytest.approx(
        3.92152e-05, rel=0.2
    )
    assert prevalence_dict["tuberculosis"]["f"]["4-9"] == pytest.approx(
        5.99818e-05, rel=0.2
    )
    assert prevalence_dict["tuberculosis"]["f"]["4-9"] == pytest.approx(
        5.99818e-05, rel=0.2
    )


import numpy as np
import pytest
from june.epidemiology.infection.health_index.health_index import (
    HealthIndexGenerator,
    index_to_maximum_symptoms_tag,
)
from june.demography import Person, Population
from june.epidemiology.infection import Covid19, ImmunitySetter


@pytest.fixture(name="health_index", scope="module")
def make_hi():
    return HealthIndexGenerator.from_file()


class TestHealthIndex:
    def test__probabilities_positive_sum_to_one(self, health_index):
        for population in ("gp", "ch"):
            for sex in ("m", "f"):
                for age in np.arange(100):
                    if population == "ch" and age < 50:
                        continue
                    probs = health_index.probabilities[population][sex][age]
                    assert all(probs >= 0)
                    assert sum(probs) == pytest.approx(1, rel=1.0e-2)

    def test__physiological_age(self):
        health_index = HealthIndexGenerator.from_file(
            m_exp_baseline=80,
            f_exp_baseline=90,
            m_exp=60,
            f_exp=80,
            cutoff_age=30,
        )
        assert health_index.use_physiological_age
        assert health_index.physiological_age(39, "m") == 45
        assert health_index.physiological_age(60, "f") == 66


class TestMultipliers:
    @pytest.mark.parametrize("multiplier", [1.5, 0.5])
    def test__apply_large_multiplier(self, multiplier):
        health_index = HealthIndexGenerator.from_file()
        probabilities = np.array([1.0 / 8] * 8)
        modified_probabilities = health_index.apply_effective_multiplier(
            probabilities=probabilities, effective_multiplier=multiplier
        )
        assert modified_probabilities[0] == (1 - 6.0 / 8.0 * multiplier) / 2.0
        assert modified_probabilities[1] == (1 - 6.0 / 8.0 * multiplier) / 2.0
        for i in range(2, 8):
            assert modified_probabilities[i] == 1.0 / 8.0 * multiplier

    def test__comorbidities_effect(self):
        comorbidity_multipliers = {"guapo": 0.8, "feo": 1.2, "no_condition": 1.0}
        dummy = Person.from_attributes(sex="f", age=60)
        dummy.infection = Covid19(None, None)
        feo = Person.from_attributes(sex="f", age=60, comorbidity="feo")
        feo.infection = Covid19(None, None)
        guapo = Person.from_attributes(sex="f", age=60, comorbidity="guapo")
        guapo.infection = Covid19(None, None)

        population = Population([])
        population.add(dummy)
        population.add(feo)
        population.add(guapo)

        prevalence_reference_population = {
            "feo": {
                "f": {"0-10": 0.2, "10-100": 0.4},
                "m": {"0-10": 0.6, "10-100": 0.5},
            },
            "guapo": {
                "f": {"0-10": 0.1, "10-100": 0.1},
                "m": {"0-10": 0.05, "10-100": 0.2},
            },
            "no_condition": {
                "f": {"0-10": 0.7, "10-100": 0.5},
                "m": {"0-10": 0.35, "10-100": 0.3},
            },
        }
        multiplier_setter = ImmunitySetter(
            multiplier_by_comorbidity=comorbidity_multipliers,
            comorbidity_prevalence_reference_population=prevalence_reference_population,
        )
        multiplier_setter.set_multipliers(population)

        health_index = HealthIndexGenerator.from_file()
        health_index.max_mild_symptom_tag = {
            value: key for key, value in index_to_maximum_symptoms_tag.items()
        }["severe"]
        dummy_health = health_index(dummy, dummy.infection.infection_id())
        feo_health = health_index(feo, feo.infection.infection_id())
        guapo_health = health_index(guapo, guapo.infection.infection_id())

        mean_multiplier_uk = multiplier_setter.get_multiplier_from_reference_prevalence(
            dummy.age, dummy.sex
        )

        dummy_probabilities = np.diff(dummy_health, prepend=0.0, append=1.0)
        feo_probabilities = np.diff(feo_health, prepend=0.0, append=1.0)
        guapo_probabilities = np.diff(guapo_health, prepend=0.0, append=1.0)

        np.testing.assert_allclose(
            feo_probabilities[:2].sum(),
            1
            - comorbidity_multipliers["feo"]
            / mean_multiplier_uk
            * dummy_probabilities[2:].sum()
            / comorbidity_multipliers["no_condition"]
            * mean_multiplier_uk,
        )
        np.testing.assert_allclose(
            feo_probabilities[2:].sum(),
            comorbidity_multipliers["feo"]
            / mean_multiplier_uk
            * dummy_probabilities[2:].sum()
            / comorbidity_multipliers["no_condition"]
            * mean_multiplier_uk,
        )

        np.testing.assert_allclose(
            guapo_probabilities[:2].sum(),
            1
            - comorbidity_multipliers["guapo"]
            / mean_multiplier_uk
            * dummy_probabilities[2:].sum()
            / comorbidity_multipliers["no_condition"]
            * mean_multiplier_uk,
        )
        np.testing.assert_allclose(
            guapo_probabilities[2:].sum(),
            comorbidity_multipliers["guapo"]
            / mean_multiplier_uk
            * dummy_probabilities[2:].sum()
            / comorbidity_multipliers["no_condition"]
            * mean_multiplier_uk,
        )
        np.testing.assert_allclose(
            guapo_probabilities[:2].sum(),
            1
            - comorbidity_multipliers["guapo"]
            / mean_multiplier_uk
            * dummy_probabilities[2:].sum()
            / comorbidity_multipliers["no_condition"]
            * mean_multiplier_uk,
        )
        np.testing.assert_allclose(
            guapo_probabilities[2:].sum(),
            comorbidity_multipliers["guapo"]
            / mean_multiplier_uk
            * dummy_probabilities[2:].sum()
            / comorbidity_multipliers["no_condition"]
            * mean_multiplier_uk,
        )


from june.epidemiology.infection import Immunity


class TestImmunity:
    def test_immunity(self):
        susceptibility_dict = {1: 0.3}
        immunity = Immunity(susceptibility_dict)
        assert immunity.susceptibility_dict[1] == 0.3
        immunity.add_immunity([123])
        assert immunity.is_immune(123) is True
        assert immunity.susceptibility_dict[123] == 0.0


import pytest
import numpy as np

from june.epidemiology.infection import Covid19, B117, ImmunitySetter
from june.demography import Person, Population
from june.geography import Area, Areas, SuperArea, SuperAreas, Region, Regions
from june.groups import Household, Households
from june.records import Record, RecordReader
from june import World


@pytest.fixture(name="susceptibility_dict")
def make_susc():
    return {
        Covid19.infection_id(): {"0-13": 0.5, "13-100": 1.0},
        B117.infection_id(): {"20-40": 0.25},
    }


class TestSusceptibilitySetter:
    def test__susceptibility_parser(self, susceptibility_dict):
        susc_setter = ImmunitySetter(susceptibility_dict)
        susceptibilities_parsed = susc_setter.susceptibility_dict
        c19_id = Covid19.infection_id()
        b117_id = B117.infection_id()
        for i in range(0, 100):
            if i < 13:
                assert susceptibilities_parsed[c19_id][i] == 0.5
            else:
                assert susceptibilities_parsed[c19_id][i] == 1.0
            if i < 20:
                assert susceptibilities_parsed[b117_id][i] == 1.0
            elif i < 40:
                assert susceptibilities_parsed[b117_id][i] == 0.25
            else:
                assert susceptibilities_parsed[b117_id][i] == 1.0

    def test__susceptiblity_setter_avg(self, susceptibility_dict):
        population = Population([])
        for i in range(105):
            population.add(Person.from_attributes(age=i))

        susceptibility_setter = ImmunitySetter(susceptibility_dict)
        susceptibility_setter.set_susceptibilities(population)
        c19_id = Covid19.infection_id()
        b117_id = B117.infection_id()

        for person in population:
            if person.age < 13:
                assert person.immunity.get_susceptibility(c19_id) == 0.5
            else:
                assert person.immunity.get_susceptibility(c19_id) == 1.0
            if person.age < 20:
                assert person.immunity.get_susceptibility(b117_id) == 1.0
            elif person.age < 40:
                assert person.immunity.get_susceptibility(b117_id) == 0.25
            else:
                assert person.immunity.get_susceptibility(b117_id) == 1.0

    def test__susceptiblity_setter_individual(self, susceptibility_dict):
        population = Population([])
        for i in range(105):
            for j in range(10):
                population.add(Person.from_attributes(age=i))

        susceptibility_setter = ImmunitySetter(
            susceptibility_dict, susceptibility_mode="individual"
        )
        susceptibility_setter.set_susceptibilities(population)
        c19_id = Covid19.infection_id()
        b117_id = B117.infection_id()
        immune_c19_13 = 0
        immune_b117_13 = 0
        immune_40 = 0
        for person in population:
            if person.age < 13:
                if person.immunity.get_susceptibility(c19_id) == 0.0:
                    immune_c19_13 += 1
                if person.immunity.get_susceptibility(b117_id) == 0.0:
                    immune_b117_13 += 1
            if person.age < 20:
                assert person.immunity.get_susceptibility(b117_id) == 1.0
            elif person.age < 40:
                if person.immunity.get_susceptibility(b117_id) == 0.0:
                    immune_40 += 1
            else:
                assert person.immunity.get_susceptibility(b117_id) == 1.0
        aged_13 = len([person for person in population if person.age < 13])
        aged_40 = len([person for person in population if 20 <= person.age < 40])
        assert np.isclose(immune_c19_13 / aged_13, 0.5, rtol=1e-1)
        assert immune_b117_13 == 0
        assert np.isclose(immune_40 / aged_40, 0.75, rtol=1e-1)


@pytest.fixture(name="multiplier_dict")
def make_multiplier():
    return {Covid19.infection_id(): 1.0, B117.infection_id(): 1.5}


class TestMultiplierSetter:
    def test__multiplier_variants_setter(self, multiplier_dict):
        population = Population([])
        for i in range(105):
            population.add(Person.from_attributes(age=i))

        multiplier_setter = ImmunitySetter(multiplier_dict=multiplier_dict)
        multiplier_setter.set_multipliers(population)
        c19_id = Covid19.infection_id()
        b117_id = B117.infection_id()

        for person in population:
            assert person.immunity.get_effective_multiplier(c19_id) == 1.0
            assert person.immunity.get_effective_multiplier(b117_id) == 1.5

    def test__mean_multiplier_reference(
        self,
    ):
        prevalence_reference_population = {
            "feo": {
                "f": {"0-10": 0.2, "10-100": 0.4},
                "m": {"0-10": 0.6, "10-100": 0.5},
            },
            "guapo": {
                "f": {"0-10": 0.1, "10-100": 0.1},
                "m": {"0-10": 0.05, "10-100": 0.2},
            },
            "no_condition": {
                "f": {"0-10": 0.7, "10-100": 0.5},
                "m": {"0-10": 0.35, "10-100": 0.3},
            },
        }
        comorbidity_multipliers = {"guapo": 0.8, "feo": 1.2, "no_condition": 1.0}
        multiplier_setter = ImmunitySetter(
            multiplier_by_comorbidity=comorbidity_multipliers,
            comorbidity_prevalence_reference_population=prevalence_reference_population,
        )
        dummy = Person.from_attributes(sex="f", age=40)
        mean_multiplier_uk = (
            prevalence_reference_population["feo"]["f"]["10-100"]
            * comorbidity_multipliers["feo"]
            + prevalence_reference_population["guapo"]["f"]["10-100"]
            * comorbidity_multipliers["guapo"]
            + prevalence_reference_population["no_condition"]["f"]["10-100"]
            * comorbidity_multipliers["no_condition"]
        )
        assert (
            multiplier_setter.get_multiplier_from_reference_prevalence(
                dummy.age, dummy.sex
            )
            == mean_multiplier_uk
        )

    def test__interaction_changes_multiplier(
        self,
    ):
        c19_id = Covid19.infection_id()
        b117_id = B117.infection_id()
        comorbidity_multipliers = {"guapo": 0.8, "feo": 1.2, "no_condition": 1.0}
        population = Population([])
        for comorbidity in comorbidity_multipliers.keys():
            population.add(Person.from_attributes(age=40, comorbidity=comorbidity))
        for person in population:
            assert person.immunity.get_effective_multiplier(c19_id) == 1.0
            assert person.immunity.get_effective_multiplier(b117_id) == 1.0
        comorbidity_prevalence_reference_population = {
            "guapo": {"f": {"0-100": 0.0}, "m": {"0-100": 0.0}},
            "feo": {"f": {"0-100": 0.0}, "m": {"0-100": 0.0}},
            "no_condition": {"m": {"0-100": 1.0}, "f": {"0-100": 1.0}},
        }

        multiplier_setter = ImmunitySetter(
            multiplier_by_comorbidity=comorbidity_multipliers,
            comorbidity_prevalence_reference_population=comorbidity_prevalence_reference_population,
        )
        multiplier_setter.set_multipliers(population)
        assert population[0].immunity.effective_multiplier_dict[c19_id] == 0.8
        assert population[0].immunity.effective_multiplier_dict[b117_id] == 1.3

        assert population[1].immunity.effective_multiplier_dict[c19_id] == 1.2
        assert population[1].immunity.effective_multiplier_dict[b117_id] == 1.7

        assert population[2].immunity.effective_multiplier_dict[c19_id] == 1.0
        assert population[2].immunity.effective_multiplier_dict[b117_id] == 1.5


class TestVaccinationSetter:
    @pytest.fixture(name="vaccination_dict")
    def make_vacc(self):
        return {
            "pfizer": {
                "percentage_vaccinated": {"0-50": 0.7, "50-100": 1.0},
                "infections": {
                    Covid19.infection_id(): {
                        "sterilisation_efficacy": {"0-100": 0.5},
                        "symptomatic_efficacy": {"0-100": 0.5},
                    }
                },
            },
            "sputnik": {
                "percentage_vaccinated": {"0-30": 0.3, "30-100": 0.0},
                "infections": {
                    B117.infection_id(): {
                        "sterilisation_efficacy": {"0-100": 0.8},
                        "symptomatic_efficacy": {"0-100": 0.8},
                    }
                },
            },
        }

    def test__vaccination_parser(self, vaccination_dict):
        susc_setter = ImmunitySetter(vaccination_dict=vaccination_dict)
        vp = susc_setter.vaccination_dict
        for age in range(0, 100):
            # pfizer
            if age < 50:
                assert vp["pfizer"]["percentage_vaccinated"][age] == 0.7
            else:
                assert vp["pfizer"]["percentage_vaccinated"][age] == 1.0
            assert (
                vp["pfizer"]["infections"][Covid19.infection_id()][
                    "sterilisation_efficacy"
                ][age]
                == 0.5
            )
            assert (
                vp["pfizer"]["infections"][Covid19.infection_id()][
                    "symptomatic_efficacy"
                ][age]
                == 0.5
            )

            # sputnik
            if age < 30:
                assert vp["sputnik"]["percentage_vaccinated"][age] == 0.3
            else:
                assert vp["sputnik"]["percentage_vaccinated"][age] == 0.0
            assert (
                vp["sputnik"]["infections"][B117.infection_id()][
                    "sterilisation_efficacy"
                ][age]
                == 0.8
            )
            assert (
                vp["sputnik"]["infections"][B117.infection_id()][
                    "symptomatic_efficacy"
                ][age]
                == 0.8
            )

    def test__set_pre_vaccinations(self, vaccination_dict):
        population = Population([])
        for i in range(100):
            for _ in range(200):
                population.add(Person.from_attributes(age=i))
        immunity = ImmunitySetter(vaccination_dict=vaccination_dict)
        immunity.set_vaccinations(population)
        under50_pfizer = 0
        under50_pfizer_not = 0
        over50_pfizer = 0
        over50_pfizer_not = 0
        under30_sputnik = 0
        under30_sputnik_not = 0
        b117id = B117.infection_id()
        c19id = Covid19.infection_id()
        for person in population:
            if person.age < 30:
                if b117id in person.immunity.susceptibility_dict:
                    assert np.isclose(person.immunity.get_susceptibility(b117id), 0.2)
                    under30_sputnik += 1
                if b117id in person.immunity.effective_multiplier_dict:
                    assert (
                        pytest.approx(person.immunity.get_effective_multiplier(b117id))
                        == 0.2
                    )
                    under30_sputnik_not += 1
            if person.age > 30:
                if b117id in person.immunity.susceptibility_dict:
                    assert False
                if b117id in person.immunity.effective_multiplier_dict:
                    assert False

            if person.age < 50:
                if c19id in person.immunity.susceptibility_dict:
                    assert person.immunity.get_susceptibility(c19id) == 0.5
                    under50_pfizer += 1
                if c19id in person.immunity.effective_multiplier_dict:
                    under50_pfizer_not += 1
                    assert person.immunity.get_effective_multiplier(c19id) == 0.5
            else:
                if c19id in person.immunity.susceptibility_dict:
                    assert person.immunity.get_susceptibility(c19id) == 0.5
                    over50_pfizer += 1
                if c19id in person.immunity.effective_multiplier_dict:
                    over50_pfizer_not += 1
                    assert person.immunity.get_effective_multiplier(c19id) == 0.5

        under30 = len([person for person in population if person.age < 30])
        under50 = len([person for person in population if person.age < 50])
        over50 = len([person for person in population if person.age >= 50])
        assert np.isclose(under30_sputnik / under30, 0.3, rtol=1e-1)
        assert np.isclose(under30_sputnik_not / under30, 0.3, rtol=1e-1)
        assert np.isclose(under50_pfizer / under50, 0.7, rtol=1e-1)
        assert np.isclose(under50_pfizer_not / under50, 0.7, rtol=1e-1)
        assert np.isclose(over50_pfizer / over50, 1, rtol=1e-1)
        assert np.isclose(over50_pfizer_not / over50, 1, rtol=1e-1)

    def test__set_save_vaccine_type_record(self):
        vaccination_dict = {
            "sputnik": {
                "percentage_vaccinated": {"0-30": 1.0, "30-100": 0.0},
                "infections": {
                    B117.infection_id(): {
                        "sterilisation_efficacy": {"0-100": 0.8},
                        "symptomatic_efficacy": {"0-100": 0.8},
                    }
                },
            }
        }
        record = Record(record_path="results/", record_static_data=True)
        population = Population([])
        world = World()
        world.people = population

        for i in range(100):
            for _ in range(200):
                population.add(Person.from_attributes(age=i))
        immunity = ImmunitySetter(vaccination_dict=vaccination_dict, record=record)
        immunity.set_vaccinations(population)
        record.static_data(world=world)

        record_reader = RecordReader()
        people_df = record_reader.table_to_df("population")
        for id, row in people_df.iterrows():
            if row["age"] < 30:
                assert row["vaccine_type"] == "sputnik"
            else:
                assert row["vaccine_type"] == "none"


class TestPreviousInfectionSetter:
    @pytest.fixture(name="world")
    def create_world(self):
        households = []
        london_area = Area()
        ne_area = Area()
        areas = Areas(areas=[london_area, ne_area], ball_tree=False)
        london_super_area = SuperArea(areas=[london_area])
        ne_super_area = SuperArea(areas=[ne_area])
        london_area.super_area = london_super_area
        ne_area.super_area = ne_super_area
        super_areas = SuperAreas(
            super_areas=[london_super_area, ne_super_area], ball_tree=False
        )
        london = Region(name="London", super_areas=[london_super_area])
        ne = Region(name="North East", super_areas=[ne_super_area])
        london_super_area.region = london
        ne_super_area.region = ne
        regions = Regions(regions=[london, ne])
        # geography = Geography(areas=areas, super_areas=super_areas, regions=regions)
        world = World()
        world.areas = areas
        world.super_areas = super_areas
        world.regions = regions
        people = [Person.from_attributes(age=i % 100) for i in range(4000)]
        world.people = Population(people)
        for i in range(1000):
            if i % 2 == 0:
                area = london_area
            else:
                area = ne_area
            h = Household(area=area)
            area.households.append(h)
            for j in range(i * 4, 4 * (i + 1)):
                h.add(people[j])
                area.add(people[j])
            households.append(h)
        world.households = Households(households)
        return world

    @pytest.fixture(name="previous_infections_dict_uniform")
    def make_prev_inf_dict_uniform(self):
        dd = {
            "distribution_method": "uniform",
            "infections": {
                Covid19.infection_id(): {
                    "sterilisation_efficacy": 0.7,
                    "symptomatic_efficacy": 0.6,
                },
                B117.infection_id(): {
                    "sterilisation_efficacy": 0.2,
                    "symptomatic_efficacy": 0.3,
                },
            },
            "ratios": {
                "London": {"0-40": 0.5, "40-100": 0.2},
                "North East": {"0-80": 0.3, "80-100": 0.8},
            },
        }
        return dd

    def test__setting_prev_infections_uniform(
        self, world, previous_infections_dict_uniform
    ):
        previous_infections_dict = previous_infections_dict_uniform
        immunity = ImmunitySetter(previous_infections_dict=previous_infections_dict)
        immunity.set_previous_infections_uniform(world.people)
        vaccinated = {"London": {1: 0, 2: 0}, "North East": {1: 0, 2: 0}}
        population = world.people
        for person in population:
            c19_susc = person.immunity.get_susceptibility(Covid19.infection_id())
            b117_susc = person.immunity.get_susceptibility(B117.infection_id())
            if c19_susc < 1.0:
                assert np.isclose(c19_susc, 0.3)
                if person.region.name == "London":
                    if person.age < 40:
                        vaccinated[person.region.name][1] += 1
                    else:
                        vaccinated[person.region.name][2] += 1
                else:
                    if person.age < 80:
                        vaccinated[person.region.name][1] += 1
                    else:
                        vaccinated[person.region.name][2] += 1
            if b117_susc < 1.0:
                assert b117_susc == 0.8

        people_london1 = len(
            [
                person
                for person in population
                if person.region.name == "London"
                if person.age < 40
            ]
        )
        people_london2 = len(
            [
                person
                for person in population
                if person.region.name == "London"
                if person.age >= 40
            ]
        )
        people_ne1 = len(
            [
                person
                for person in population
                if person.region.name == "North East"
                if person.age < 80
            ]
        )
        people_ne2 = len(
            [
                person
                for person in population
                if person.region.name == "North East"
                if person.age >= 80
            ]
        )
        assert np.isclose(vaccinated["London"][1] / people_london1, 0.5, rtol=0.1)
        assert np.isclose(vaccinated["London"][2] / people_london2, 0.2, rtol=0.1)
        assert np.isclose(vaccinated["North East"][1] / people_ne1, 0.3, rtol=0.1)
        assert np.isclose(vaccinated["North East"][2] / people_ne2, 0.8, rtol=0.1)

    @pytest.fixture(name="previous_infections_dict_clustered")
    def make_prev_inf_dict_clustered(self):
        dd = {
            "distribution_method": "clustered",
            "infections": {
                Covid19.infection_id(): {
                    "sterilisation_efficacy": 0.7,
                    "symptomatic_efficacy": 0.6,
                },
                B117.infection_id(): {
                    "sterilisation_efficacy": 0.2,
                    "symptomatic_efficacy": 0.3,
                },
            },
            "ratios": {
                "London": {"0-40": 0.5, "40-100": 0.0},
                "North East": {"0-80": 0.0, "80-100": 0.0},
            },
        }
        return dd

    def test__setting_prev_infections_clustered(
        self, world, previous_infections_dict_clustered
    ):
        previous_infections_dict = previous_infections_dict_clustered
        immunity = ImmunitySetter(previous_infections_dict=previous_infections_dict)
        immunity.set_previous_infections_clustered(world)
        # Test again for correct age.
        vaxed_london = len(
            [
                p
                for p in world.people
                if p.region.name == "London"
                and np.isclose(
                    p.immunity.get_susceptibility(Covid19.infection_id()), 0.3
                )
            ]
        )
        total_london = len([p for p in world.people if p.region.name == "London"])
        assert np.isclose(vaxed_london / total_london, 0.5 * 0.4)
        n_prev_infected_per_household = []
        for household in world.households:
            n = 0
            for person in household.residents:
                if np.isclose(
                    person.immunity.get_susceptibility(Covid19.infection_id()), 0.3
                ):
                    n += 1
            if n > 0:
                n_prev_infected_per_household.append(n)
        assert np.isclose(np.mean(n_prev_infected_per_household), 4, rtol=0.1)


import pytest
import statistics
import numpy as np
from pathlib import Path

from june import paths
import june.epidemiology.infection.symptoms
from june.demography import Person
from june.epidemiology.infection.infection_selector import (
    default_transmission_config_path,
)
from june.epidemiology.infection import (
    Infection,
    InfectionSelector,
    Covid19,
    B117,
    InfectionSelectors,
    transmission,
    SymptomTag,
)

path_pwd = Path(__file__)
dir_pwd = path_pwd.parent
constant_config = (
    dir_pwd.parent.parent.parent
    / "configs/defaults/epidemiology/infection/InfectionConstant.yaml"
)


class MockInfection(Infection):
    pass


class MockHealthIndexGenerator:
    def __init__(self, desired_symptoms):
        self.index = {"asymptomatic": -1, "mild": 0, "severe": 1}[desired_symptoms]

    def __call__(self, person, infection_id):
        hi = np.ones(3)
        if self.index >= 0:
            hi[self.index] = 0
        return hi


def make_selector(
    desired_symptoms, transmission_config_path=default_transmission_config_path
):
    health_index_generator = MockHealthIndexGenerator(desired_symptoms)
    selector = InfectionSelector(
        health_index_generator=health_index_generator,
        transmission_config_path=transmission_config_path,
    )
    return selector


def infect_person(
    person,
    max_symptom_tag="mild",
    transmission_config_path=default_transmission_config_path,
):
    selector = make_selector(
        max_symptom_tag, transmission_config_path=transmission_config_path
    )
    infection = selector._make_infection(person, 0.0)
    if max_symptom_tag == "asymptomatic":
        assert infection.max_tag == SymptomTag.asymptomatic
    elif max_symptom_tag == "mild":
        assert infection.max_tag == SymptomTag.mild
    elif max_symptom_tag == "severe":
        assert infection.max_tag == SymptomTag.severe
    return infection, selector


class TestInfection:
    def test__infect_person__gives_them_symptoms_and_transmission(self):
        selector = InfectionSelector(
            health_index_generator=MockHealthIndexGenerator("severe")
        )
        victim = Person.from_attributes(sex="f", age=26)
        selector.infect_person_at_time(person=victim, time=0.2)

        assert victim.infection.start_time == 0.2
        assert isinstance(
            victim.infection.symptoms, june.epidemiology.infection.symptoms.Symptoms
        )
        assert isinstance(victim.infection.transmission, transmission.TransmissionGamma)

    def test__update_to_time__calls_transmission_symptoms_methods(
        self, transmission, symptoms
    ):
        infection = Infection(
            start_time=0.1, transmission=transmission, symptoms=symptoms
        )

        infection.update_symptoms_and_transmission(time=20.0)
        assert infection.infection_probability == transmission.probability


class TestInfectionSelector:
    def test__defaults_when_no_filename_is_given(self):
        selector = InfectionSelector()
        assert selector.transmission_type == "gamma"

    def test__constant_filename(self):
        selector = InfectionSelector(
            transmission_config_path=paths.configs_path
            / "defaults/epidemiology/infection/transmission/TransmissionConstant.yaml"
        )
        assert selector.transmission_type == "constant"

    def test__position_max_infectivity(self):
        dummy = Person.from_attributes(sex="f", age=26)
        infection, _ = infect_person(person=dummy, max_symptom_tag="severe")
        true_max_t = infection.transmission.time_at_maximum_infectivity
        infectivity = []
        time_steps = np.linspace(0.0, 30.0, 500)
        for time in time_steps:
            infection.transmission.update_infection_probability(
                time_from_infection=time
            )
            infectivity.append(infection.transmission.probability)
        max_t = time_steps[np.argmax(np.array(infectivity))]
        assert max_t == pytest.approx(true_max_t, rel=0.01)

    def test__avg_peak_value(self):
        dummy = Person.from_attributes(sex="f", age=26)
        infection, selector = infect_person(
            person=dummy,
            max_symptom_tag="severe",
            transmission_config_path=paths.configs_path
            / "tests/transmission/test_transmission_constant.yaml",
        )
        avg_gamma = transmission.TransmissionGamma.from_file(
            config_path=paths.configs_path
            / "tests/transmission/test_transmission_constant.yaml"
        )
        avg_gamma.update_infection_probability(avg_gamma.time_at_maximum_infectivity)
        true_avg_peak_infectivity = avg_gamma.probability
        peak_infectivity = []
        for i in range(100):
            infection = selector._make_infection(time=0.1, person=dummy)
            max_t = infection.transmission.time_at_maximum_infectivity
            infection.transmission.update_infection_probability(
                time_from_infection=max_t
            )
            peak_infectivity.append(infection.transmission.probability)
        assert np.mean(peak_infectivity) == pytest.approx(
            true_avg_peak_infectivity, rel=0.05
        )

    def test__lognormal_in_maxprob(self):
        health_index_generator = MockHealthIndexGenerator("severe")
        selector = InfectionSelector(
            transmission_config_path=paths.configs_path
            / "tests/transmission/test_transmission_lognormal.yaml",
            health_index_generator=health_index_generator,
        )
        avg_gamma = transmission.TransmissionGamma.from_file(
            config_path=paths.configs_path
            / "tests/transmission/test_transmission_constant.yaml"
        )

        avg_gamma.update_infection_probability(avg_gamma.time_at_maximum_infectivity)
        true_avg_peak_infectivity = avg_gamma.probability
        dummy = Person.from_attributes(sex="f", age=26)
        norms, maxprobs = [], []
        for i in range(1_000):
            infection = selector._make_infection(time=0.1, person=dummy)
            norms.append(infection.transmission.norm)
            max_t = infection.transmission.time_at_maximum_infectivity
            infection.transmission.update_infection_probability(
                time_from_infection=max_t
            )
            maxprobs.append(infection.transmission.probability)

        np.testing.assert_allclose(statistics.mean(norms), 1.13, rtol=0.05)
        np.testing.assert_allclose(statistics.median(norms), 1.00, rtol=0.05)
        np.testing.assert_allclose(
            statistics.median(maxprobs) / true_avg_peak_infectivity, 1.0, rtol=0.1
        )

    def test__infectivity_for_asymptomatic_carriers(self):
        avg_gamma = transmission.TransmissionGamma.from_file(
            config_path=paths.configs_path
            / "tests/transmission/test_transmission_constant.yaml"
        )
        avg_gamma.update_infection_probability(avg_gamma.time_at_maximum_infectivity)
        true_avg_peak_infectivity = avg_gamma.probability

        dummy = Person.from_attributes(sex="f", age=26)
        infection, selector = infect_person(
            person=dummy,
            max_symptom_tag="asymptomatic",
            transmission_config_path=paths.configs_path
            / "tests/transmission/test_transmission_symptoms.yaml",
        )
        max_t = infection.transmission.time_at_maximum_infectivity
        infection.update_symptoms_and_transmission(max_t)
        max_prob = infection.transmission.probability
        np.testing.assert_allclose(max_prob / true_avg_peak_infectivity, 0.3, atol=0.1)

    def test__infectivity_for_mild_carriers(self):
        avg_gamma = transmission.TransmissionGamma.from_file(
            config_path=paths.configs_path
            / "tests/transmission/test_transmission_constant.yaml"
        )
        avg_gamma.update_infection_probability(avg_gamma.time_at_maximum_infectivity)
        true_avg_peak_infectivity = avg_gamma.probability
        dummy = Person.from_attributes(sex="f", age=26)
        infection, selector = infect_person(
            person=dummy,
            max_symptom_tag="mild",
            transmission_config_path=paths.configs_path
            / "tests/transmission/test_transmission_symptoms.yaml",
        )
        max_t = infection.transmission.time_at_maximum_infectivity
        infection.update_symptoms_and_transmission(max_t)
        max_prob = infection.transmission.probability
        np.testing.assert_allclose(max_prob / true_avg_peak_infectivity, 0.48, atol=0.1)


class TestMultipleVirus:
    def test__infection_id_generation(self):
        infection1 = Covid19(transmission=None, symptoms=None)
        infection11 = Covid19(transmission=None, symptoms=None)
        infection2 = MockInfection(transmission=None, symptoms=None)
        infection22 = MockInfection(transmission=None, symptoms=None)
        assert type(infection1.infection_id()) == int
        assert infection1.infection_id() > 0
        assert infection1.infection_id() == infection11.infection_id()
        assert infection2.infection_id() == infection22.infection_id()
        assert infection1.infection_id() != infection2.infection_id()

    def test__multiple_virus(self):
        health_index_generator = MockHealthIndexGenerator("asymptomatic")
        selector1 = InfectionSelector(
            health_index_generator=health_index_generator,
            transmission_config_path=default_transmission_config_path,
        )
        p = Person.from_attributes(sex="f", age=26)
        infection = selector1._make_infection(person=p, time=0)
        assert isinstance(infection, Covid19)
        selector2 = InfectionSelector(
            infection_class=MockInfection,
            health_index_generator=health_index_generator,
            transmission_config_path=default_transmission_config_path,
        )
        infection = selector2._make_infection(person=p, time=0)
        assert isinstance(infection, MockInfection)
        infection_selectors = InfectionSelectors([selector1, selector2])
        assert set(infection_selectors.infection_id_to_selector.values()) == set(
            [selector1, selector2]
        )

    def test__immunity_multiple_virus(self):
        selector = InfectionSelector.from_file()
        person = Person.from_attributes()
        selector.infect_person_at_time(person, 0.0)
        assert person.immunity.is_immune(Covid19.infection_id())
        assert person.immunity.is_immune(B117.infection_id())
        assert person.infected


import pytest

from june import paths
from june.epidemiology.infection import SymptomTag, InfectionSelector
from june.demography.person import Person
from june.epidemiology.infection.trajectory_maker import (
    Stage,
    CompletionTime,
    ConstantCompletionTime,
    ExponentialCompletionTime,
    TrajectoryMaker,
    TrajectoryMakers,
    BetaCompletionTime,
)

health_index = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]


@pytest.fixture(name="constant_completion_dict")
def make_completion_time_dict():
    return {"type": "constant", "value": 1.0}


@pytest.fixture(name="stage_dict")
def make_stage_dict(constant_completion_dict):
    return {"completion_time": constant_completion_dict, "symptom_tag": "healthy"}


@pytest.fixture(name="trajectory_dict")
def make_trajectory_dict(stage_dict):
    return {"symptom_tag": "healthy", "stages": [stage_dict]}


class TestCompletionTime:
    def test_constant_completion_time(self):
        completion_time = ConstantCompletionTime(value=1.0)
        assert completion_time() == 1.0

    def test_exponential_completion_time(self):
        completion_time = ExponentialCompletionTime(loc=1.0, scale=1.0)
        assert completion_time() >= 1.0

    def test_beta_completion_time(self):
        completion_time = BetaCompletionTime(1.0, 1.0)
        assert 0.0 <= completion_time() <= 1.0


class TestParse:
    def test_symptoms_tag_for_string(self):
        assert SymptomTag.from_string("healthy") == SymptomTag.healthy
        with pytest.raises(AssertionError):
            SymptomTag.from_string("nonsense")

    def test_parse_completion_time(self, constant_completion_dict):
        constant = CompletionTime.from_dict(constant_completion_dict)
        assert isinstance(constant, ConstantCompletionTime)

        exponential = CompletionTime.from_dict(
            {"type": "exponential", "loc": 1.0, "scale": 2.0}
        )
        assert isinstance(exponential, ExponentialCompletionTime)
        assert exponential.kwargs["loc"] == 1.0
        assert exponential.kwargs["scale"] == 2.0

    def test_parse_stage(self, stage_dict):
        stage = Stage.from_dict(stage_dict)

        assert isinstance(stage.completion_time, ConstantCompletionTime)
        assert stage.symptoms_tag == SymptomTag.healthy
        assert stage.completion_time.value == 1.0

    def test_parse_trajectory(self, trajectory_dict):
        trajectory = TrajectoryMaker.from_dict(trajectory_dict)
        assert trajectory.most_severe_symptoms == SymptomTag.healthy

        (stage,) = trajectory.stages
        assert stage.completion_time.value == 1.0

    def test_parse_trajectory_maker(self, trajectory_dict):
        trajectory_maker = TrajectoryMakers.from_list([trajectory_dict])
        assert (
            trajectory_maker.trajectories[SymptomTag.healthy]
            .stages[0]
            .completion_time.value
            == 1.0
        )


class TestTrajectoryMaker:
    def test__make__trajectories(self, trajectories):
        assert len(trajectories.trajectories) == 8
        mild_trajectory = trajectories.trajectories[SymptomTag.mild]
        infected = mild_trajectory.stages[0]
        assert infected.symptoms_tag == SymptomTag.exposed
        assert infected.completion_time.args[0] == 2.29
        assert infected.completion_time.args[1] == 19.05
        assert infected.completion_time.kwargs["scale"] == 39.8
        assert infected.completion_time.kwargs["loc"] == 0.39

        recovered = mild_trajectory.stages[-1]
        assert recovered.symptoms_tag == SymptomTag.recovered
        assert recovered.completion_time.value == 0.0

    def test_most_severe_symptoms(self, trajectories):
        for symptom_tag, trajectory in trajectories.trajectories.items():
            assert symptom_tag == trajectory.most_severe_symptoms


class TestSymptoms:
    def test__construct__trajectory__from__maxseverity(self, symptoms_trajectories):
        symptoms_trajectories.max_severity = 0.9
        symptoms_trajectories.trajectory = (
            symptoms_trajectories._make_symptom_trajectory(health_index)
        )
        symptoms_trajectories.time_of_symptoms_onset = (
            symptoms_trajectories._compute_time_from_infection_to_symptoms()
        )
        assert symptoms_trajectories.trajectory == [
            (0.0, SymptomTag.exposed),
            (pytest.approx(3.4, rel=0.5), SymptomTag.mild),
            (pytest.approx(6.8, rel=0.5), SymptomTag.hospitalised),
            (pytest.approx(6.8, rel=0.5), SymptomTag.intensive_care),
            (pytest.approx(20, rel=0.5), SymptomTag.dead_icu),
        ]
        assert (
            symptoms_trajectories.time_of_symptoms_onset
            == symptoms_trajectories.trajectory[1][0]
        )
        assert symptoms_trajectories.time_of_symptoms_onset > 0
        symptoms_trajectories.max_severity = 0.45
        symptoms_trajectories.trajectory = (
            symptoms_trajectories._make_symptom_trajectory(health_index)
        )
        symptoms_trajectories.time_of_symptoms_onset = (
            symptoms_trajectories._compute_time_from_infection_to_symptoms()
        )
        assert symptoms_trajectories.trajectory == [
            (0.0, SymptomTag.exposed),
            (pytest.approx(10, rel=0.5), SymptomTag.mild),
            (pytest.approx(13, rel=0.5), SymptomTag.hospitalised),
            (pytest.approx(15, rel=0.5), SymptomTag.intensive_care),
            (pytest.approx(20, rel=0.5), SymptomTag.hospitalised),
            (pytest.approx(34, rel=0.5), SymptomTag.mild),
            (pytest.approx(40, rel=0.5), SymptomTag.recovered),
        ]
        assert (
            symptoms_trajectories.time_of_symptoms_onset
            == symptoms_trajectories.trajectory[1][0]
        )
        assert symptoms_trajectories.time_of_symptoms_onset > 0
        symptoms_trajectories.max_severity = 0.05
        symptoms_trajectories.trajectory = (
            symptoms_trajectories._make_symptom_trajectory(health_index)
        )
        symptoms_trajectories.time_of_symptoms_onset = (
            symptoms_trajectories._compute_time_from_infection_to_symptoms()
        )
        assert symptoms_trajectories.time_of_symptoms_onset is None

    def test__symptoms_progression(self, health_index_generator):
        selector = InfectionSelector(
            health_index_generator=health_index_generator,
            transmission_config_path=paths.configs_path
            / "defaults/epidemiology/infection/transmission/TransmissionConstant.yaml",
        )
        dummy = Person(sex="f", age=65)
        health_index = selector.health_index_generator(dummy, 0)
        fixed_severity = 0.72
        infection = selector._make_infection(person=dummy, time=0.1)
        infection.symptoms.max_severity = fixed_severity
        infection.symptoms.trajectory = infection.symptoms._make_symptom_trajectory(
            health_index
        )
        max_tag = infection.symptoms.max_tag
        assert max_tag == SymptomTag.hospitalised
        assert infection.symptoms.trajectory == [
            (0.0, SymptomTag.exposed),
            (pytest.approx(5, 2.5), SymptomTag.mild),
            (pytest.approx(5, rel=5), SymptomTag.hospitalised),
            (pytest.approx(13, rel=5), SymptomTag.mild),
            (pytest.approx(30, rel=5), SymptomTag.recovered),
        ]
        hospitalised_time = infection.symptoms.trajectory[2][0]

        infection.update_symptoms_and_transmission(float(1.0))
        assert infection.symptoms.tag == SymptomTag.exposed
        infection.update_symptoms_and_transmission(float(1.0))
        assert infection.symptoms.tag == SymptomTag.exposed
        infection.update_symptoms_and_transmission(float(6.0))
        assert infection.symptoms.tag == SymptomTag.mild
        infection.update_symptoms_and_transmission(hospitalised_time + 8.0)
        assert infection.symptoms.tag == SymptomTag.hospitalised
        infection.update_symptoms_and_transmission(float(40.0))
        assert infection.symptoms.tag == SymptomTag.mild
        infection.update_symptoms_and_transmission(float(50.0))
        assert infection.symptoms.tag == SymptomTag.recovered


from june.epidemiology.infection import transmission as trans
import scipy.stats
import numpy as np
import os
import pytest

directory = os.path.dirname(os.path.realpath(__file__))


class TestTransmission:
    def test__update_probability_at_time(self):

        transmission = trans.TransmissionConstant(probability=0.3)

        assert transmission.probability == 0.3


class TestTransmissionGamma:
    def test__update_probability_at_time(self):
        max_infectiousness = 4.0
        shift = 3.0
        shape = 3.0
        rate = 2.0
        transmission = trans.TransmissionGamma(
            max_infectiousness=max_infectiousness, shape=shape, rate=rate, shift=shift
        )
        transmission.update_infection_probability((shape - 1) / rate + shift)
        avg_gamma = trans.TransmissionGamma(
            max_infectiousness=1.0, shape=shape, rate=rate, shift=shift
        )
        avg_gamma.update_infection_probability(avg_gamma.time_at_maximum_infectivity)
        true_avg_peak_infectivity = avg_gamma.probability

        assert transmission.probability / true_avg_peak_infectivity == pytest.approx(
            max_infectiousness, rel=0.01
        )

    @pytest.mark.parametrize("x", [0.0, 1, 3, 5])
    @pytest.mark.parametrize("a", [1, 3, 5])
    @pytest.mark.parametrize("loc", [0, -3, 3])
    @pytest.mark.parametrize("scale", [1, 3, 5])
    def test__gamma_pdf_implementation(self, x, a, loc, scale):
        scipy_gamma = scipy.stats.gamma(a=a, loc=loc, scale=scale)
        assert trans.gamma_pdf(x, a=a, loc=loc, scale=scale) == pytest.approx(
            scipy_gamma.pdf(x), rel=0.001
        )

    def test__gamma_pdf_vectorized(
        self,
    ):
        x = np.linspace(0.0, 10.0, 100)
        a = 1.0
        loc = 1.0
        scale = 1.0
        scipy_gamma = scipy.stats.gamma(a=a, loc=loc, scale=scale)
        np.testing.assert_allclose(
            trans.gamma_pdf_vectorized(x, a=a, loc=loc, scale=scale), scipy_gamma.pdf(x)
        )


import pandas as pd
import numpy as np
from pytest import fixture

from june.epidemiology.infection_seed import CasesDistributor


class TestCasesDistributor:
    @fixture(name="super_area_region")
    def make_super_area_region(self):
        data = [["a1", "East of England"], ["a2", "East of England"]]
        ret = pd.DataFrame(data=data, columns=["super_area", "region"])
        return ret

    @fixture(name="residents_by_super_area")
    def make_area_super_area_region(self):
        data = [["a1", 100], ["a2", 200]]
        ret = pd.DataFrame(data=data, columns=["super_area", "n_residents"])
        return ret

    @fixture(name="cases_per_region_per_day")
    def make_cases_per_region_per_day(self):
        index = ["2020-03-01", "2020-03-02"]
        ret = pd.DataFrame(index=index)
        ret["East of England"] = [600, 1200]
        return ret

    def test__from_regional_cases(
        self, super_area_region, residents_by_super_area, cases_per_region_per_day
    ):

        cd = CasesDistributor.from_regional_cases(
            cases_per_day_region=cases_per_region_per_day,
            super_area_to_region=super_area_region,
            residents_per_super_area=residents_by_super_area,
        )
        cases_per_super_area = cd.cases_per_super_area
        assert np.allclose(
            cases_per_super_area.loc[:, "a1"].values,
            np.array([200, 400], dtype=np.float64),
            rtol=0.25,
        )
        assert np.allclose(
            cases_per_super_area.loc[:, "a2"].values,
            np.array([400, 800], dtype=np.float64),
            rtol=0.25,
        )

    def test__from_national_cases(
        self, super_area_region, residents_by_super_area, cases_per_region_per_day
    ):
        index = ["2020-03-01", "2020-03-02"]
        cases_per_day = pd.DataFrame(index=index)
        cases_per_day["N_cases"] = [600, 1200]

        cd = CasesDistributor.from_national_cases(
            cases_per_day=cases_per_day,
            super_area_to_region=super_area_region,
            residents_per_super_area=residents_by_super_area,
        )
        cases_per_super_area = cd.cases_per_super_area
        assert np.allclose(
            cases_per_super_area.loc[:, "a1"].values,
            np.array([200, 400], dtype=np.float64),
            rtol=0.25,
        )
        assert np.allclose(
            cases_per_super_area.loc[:, "a2"].values,
            np.array([400, 800], dtype=np.float64),
            rtol=0.25,
        )


import pytest
import numpy as np
import pandas as pd
from pathlib import Path

from june.epidemiology.infection_seed.clustered_infection_seed import (
    ClusteredInfectionSeed,
)
from june.demography import Person, Population
from june.geography import Area, SuperArea, Areas, SuperAreas, Region, Regions
from june.world import World
from june.groups import Household, Households


@pytest.fixture(name="world")
def create_world():
    households = []
    area = Area()
    areas = Areas(areas=[area], ball_tree=False)
    super_area = SuperArea(areas=[area])
    super_areas = SuperAreas(super_areas=[super_area], ball_tree=False)
    region = Region(name="London", super_areas=[super_area])
    regions = Regions(regions=[region])
    # geography = Geography(areas=areas, super_areas=super_areas, regions=regions)
    world = World()
    world.areas = areas
    world.super_areas = super_areas
    world.regions = regions
    people = [Person.from_attributes(age=i % 100) for i in range(5000)]
    area.people = people
    for i in range(1000):
        h = Household(area=area)
        area.households.append(h)
        for j in range(i * 5, 5 * (i + 1)):
            h.add(people[j])
        households.append(h)
    world.households = Households(households)
    world.people = Population(people)
    return world


@pytest.fixture(name="cases")
def make_cases():
    """
    Seed one day, 50% of those aged 0-50, and 20% aged 50-100 in London
    """
    ret = pd.read_csv(Path(__file__).parent / "cases_per_region.csv", index_col=[0, 1])
    return ret


def test__world(world):
    assert len(world.people) == 5000
    assert len(world.households) == 1000
    for h in world.households:
        assert len(h.residents) == 5
    for person in world.people:
        assert person.residence is not None


@pytest.fixture(name="cis")
def create_seed(world, selector, cases):
    cis = ClusteredInfectionSeed(
        world=world,
        infection_selector=selector,
        daily_cases_per_capita_per_age_per_region=cases,
        seed_past_infections=True,
    )
    return cis


class TestInfectOneHousehold:
    def test__get_people_to_infect(self, cis, world):
        people = world.people
        cases = cis.daily_cases_per_capita_per_age_per_region.loc[
            "2021-06-26", "London"
        ]
        total_to_infect = cis.get_total_people_to_infect(
            people=people, cases_per_capita_per_age=cases
        )
        total_50 = len([person for person in world.people if person.age < 50])
        total_100 = len([person for person in world.people if person.age >= 50])
        expected = total_50 * 0.5 + total_100 * 0.2
        assert np.isclose(total_to_infect, expected)

    def test__get_household_score(self, cis):
        household = Household()
        age_distribution = pd.Series(
            index=[0, 1, 2, 3, 4], data=[0.1, 0.2, 0, 0.1, 0.3]
        )
        for i in range(3):
            person = Person.from_attributes(age=i)
            household.add(person)
        assert np.isclose(
            cis.get_household_score(
                household=household, age_distribution=age_distribution
            ),
            0.3 / np.sqrt(3),
            rtol=1e-2,
        )
        household.add(Person.from_attributes(age=4))
        assert np.isclose(
            cis.get_household_score(
                household=household, age_distribution=age_distribution
            ),
            0.6 / np.sqrt(4),
            rtol=1e-2,
        )

    def test__infect_super_area(self, cis, world):
        date = "2021-06-26"
        super_area = world.super_areas[0]
        time = 0
        cases_per_capita_per_age = cis.daily_cases_per_capita_per_age_per_region.loc[
            date
        ]
        cis.infect_super_area(
            super_area=super_area,
            time=time,
            cases_per_capita_per_age=cases_per_capita_per_age,
        )
        infected_50 = len(
            [person for person in world.people if person.age < 50 and person.infected]
        )
        total_50 = len([person for person in world.people if person.age < 50])
        assert np.isclose(infected_50 / total_50, 0.5, rtol=0.10)

        infected_100 = len(
            [person for person in world.people if person.age >= 50 and person.infected]
        )
        total_100 = len([person for person in world.people if person.age >= 50])
        assert np.isclose(infected_100 / total_100, 0.2, rtol=0.10)

        # test household clustering
        n_infected_per_household = []
        for household in world.households:
            n = 0
            for person in household.residents:
                if person.infected:
                    n += 1
            if n > 0:
                n_infected_per_household.append(n)
        assert np.isclose(np.mean(n_infected_per_household), 5, rtol=0.1)


import pytest
import numpy as np
import pandas as pd
from pathlib import Path

from june.epidemiology.infection_seed.exact_num_infection_seed import (
    ExactNumClusteredInfectionSeed,
    ExactNumInfectionSeed,
)
from june.demography import Person, Population
from june.geography import Area, SuperArea, Areas, SuperAreas, Region, Regions
from june.world import World
from june.groups import Household, Households, Cemeteries


@pytest.fixture(name="world")
def create_world():
    area_1 = Area(name="area_1", super_area=None, coordinates=None)
    area_2 = Area(name="area_2", super_area=None, coordinates=None)
    super_area_1 = SuperArea("super_1", areas=[area_1], coordinates=(1.0, 1.0))
    super_area_2 = SuperArea("Durham", areas=[area_2], coordinates=(1.0, 2.0))
    region1 = Region(name="London", super_areas=[super_area_1])
    region2 = Region(name="North East", super_areas=[super_area_2])
    super_area_1.region = region1
    super_area_2.region = region2
    area_1.super_area = super_area_1
    area_2.super_area = super_area_2

    households = []
    people = [Person.from_attributes(age=i % 100) for i in range(10000)]
    area_1.people = people[: int(len(people) / 2)]
    for i in range(1000):
        h = Household(area=area_1)
        area_1.households.append(h)
        for j in range(i * 5, 5 * (i + 1)):
            h.add(people[j])
            people[j].area = area_1
        households.append(h)
    area_2.people = people[int(len(people) / 2) :]
    for i in range(1000, 2000):
        h = Household(area=area_2)
        area_2.households.append(h)
        for j in range(i * 5, 5 * (i + 1)):
            h.add(people[j])
            people[j].area = area_2
        households.append(h)

    world = World()
    world.households = Households(households)
    world.people = Population(people)
    world.areas = Areas(areas=[area_1, area_2], ball_tree=False)
    world.super_areas = SuperAreas([super_area_1, super_area_2])
    world.regions = Regions([region1, region2])
    world.cemeteries = Cemeteries()
    return world


def test__world(world):
    assert len(world.people) == 10000
    assert len(world.households) == 2000
    for h in world.households:
        assert len(h.residents) == 5
    for person in world.people:
        assert person.residence is not None


@pytest.fixture(name="cases")
def make_cases():
    """
    Seed two days, 40 of those aged 0-50, and 15 aged 50-100 in London per day
    """
    ret = pd.read_csv(
        Path(__file__).parent / "exact_num_cases_per_region.csv", index_col=[0, 1]
    )
    return ret


@pytest.fixture(name="cis")
def create_seed(world, selector, cases):
    cis = ExactNumClusteredInfectionSeed(
        world=world,
        infection_selector=selector,
        daily_cases_per_capita_per_age_per_region=cases,
        seed_past_infections=True,
    )
    return cis


class TestExactNumInfectOneHousehold:
    def test__get_household_score(self, cis):

        household = Household()
        age_distribution = pd.Series(index=["0-50", "50-100"], data=[0.1, 0.3])
        for i in [20, 50]:
            person = Person.from_attributes(age=i)
            household.add(person)
        assert np.isclose(
            cis.get_household_score(
                household=household, age_distribution=age_distribution
            ),
            0.4 / np.sqrt(2),
            rtol=1e-2,
        )
        household.add(Person.from_attributes(age=40))
        assert np.isclose(
            cis.get_household_score(
                household=household, age_distribution=age_distribution
            ),
            0.5 / np.sqrt(3),
            rtol=1e-2,
        )

    def test__infect_super_area(self, cis, world):
        date = "2020-03-01"
        super_area = world.super_areas[1]
        time = 0
        cases_per_capita_per_age = cis.daily_cases_per_capita_per_age_per_region.loc[
            date, "Durham"
        ]
        cis.infect_super_area(
            super_area=super_area,
            time=time,
            cases_per_capita_per_age=cases_per_capita_per_age,
        )

        infected_Durham50 = len(
            [
                person
                for person in world.people
                if person.age < 50
                and person.infected
                and person.super_area.name == "Durham"
            ]
        )
        assert infected_Durham50 == cases_per_capita_per_age["0-50"]

        infected_Durham100 = len(
            [
                person
                for person in world.people
                if person.age >= 50
                and person.infected
                and person.super_area.name == "Durham"
            ]
        )
        assert infected_Durham100 == cases_per_capita_per_age["50-100"]

        infected_London50 = len(
            [
                person
                for person in world.people
                if person.age < 50
                and person.infected
                and person.region.name == "London"
            ]
        )
        assert infected_London50 == 0

        infected_London100 = len(
            [
                person
                for person in world.people
                if person.age >= 50
                and person.infected
                and person.region.name == "London"
            ]
        )
        assert infected_London100 == 0

        # test household clustering
        n_infected_per_household = []
        for household in world.households:
            n = 0
            for person in household.residents:
                if person.infected:
                    n += 1
            if n > 0:
                n_infected_per_household.append(n)
        assert np.isclose(np.mean(n_infected_per_household), 5, rtol=0.1)


import pytest
import numpy as np
import pandas as pd
from pathlib import Path

from june.epidemiology.infection_seed.exact_num_infection_seed import (
    ExactNumClusteredInfectionSeed,
    ExactNumInfectionSeed,
)
from june.demography import Person, Population
from june.geography import Area, SuperArea, Areas, SuperAreas, Region, Regions
from june.world import World
from june.groups import Household, Households, Cemeteries


@pytest.fixture(name="world")
def create_world():
    area_1 = Area(name="area_1", super_area=None, coordinates=None)
    area_2 = Area(name="area_2", super_area=None, coordinates=None)
    super_area_1 = SuperArea("super_1", areas=[area_1], coordinates=(1.0, 1.0))
    super_area_2 = SuperArea("Durham", areas=[area_2], coordinates=(1.0, 2.0))
    region1 = Region(name="London", super_areas=[super_area_1])
    region2 = Region(name="North East", super_areas=[super_area_2])
    super_area_1.region = region1
    super_area_2.region = region2
    area_1.super_area = super_area_1
    area_2.super_area = super_area_2

    households = []
    people = [Person.from_attributes(age=i % 100) for i in range(10000)]
    area_1.people = people[: int(len(people) / 2)]
    for i in range(1000):
        h = Household(area=area_1)
        area_1.households.append(h)
        for j in range(i * 5, 5 * (i + 1)):
            h.add(people[j])
            people[j].area = area_1
        households.append(h)
    area_2.people = people[int(len(people) / 2) :]
    for i in range(1000, 2000):
        h = Household(area=area_2)
        area_2.households.append(h)
        for j in range(i * 5, 5 * (i + 1)):
            h.add(people[j])
            people[j].area = area_2
        households.append(h)

    world = World()
    world.households = Households(households)
    world.people = Population(people)
    world.areas = Areas(areas=[area_1, area_2], ball_tree=False)
    world.super_areas = SuperAreas([super_area_1, super_area_2])
    world.regions = Regions([region1, region2])
    world.cemeteries = Cemeteries()
    return world


def test__world(world):
    assert len(world.people) == 10000
    assert len(world.households) == 2000
    for h in world.households:
        assert len(h.residents) == 5
    for person in world.people:
        assert person.residence is not None


@pytest.fixture(name="cases")
def make_cases():
    """
    Seed two days, 40 of those aged 0-50, and 15 aged 50-100 in London per day
    """
    ret = pd.read_csv(
        Path(__file__).parent / "exact_num_cases_world.csv", index_col=[0, 1]
    )
    return ret


@pytest.fixture(name="cis")
def create_seed(world, selector, cases):
    cis = ExactNumInfectionSeed(
        world=world,
        infection_selector=selector,
        daily_cases_per_capita_per_age_per_region=cases,
        seed_past_infections=True,
    )
    return cis


class TestExactNumInfectSeed:
    def test__infect_super_areas(self, cis, world):
        date = "2020-03-01"
        time = 0
        cis.infect_super_areas(
            time=time,
            date=date,
            cases_per_capita_per_age_per_region=cis.daily_cases_per_capita_per_age_per_region.loc[
                date
            ],
        )
        infected_50 = len(
            [person for person in world.people if person.age < 50 and person.infected]
        )
        assert (
            infected_50
            == cis.daily_cases_per_capita_per_age_per_region.loc[date, "0-50"].sum()
        )

        infected_100 = len(
            [person for person in world.people if person.age >= 50 and person.infected]
        )
        assert (
            infected_100
            == cis.daily_cases_per_capita_per_age_per_region.loc[date, "50-100"].sum()
        )


import pandas as pd
import pytest
import numpy as np
from june.geography import SuperArea, SuperAreas, Area, Region, Regions
from june.demography import Person, Population
from june import World
from june.epidemiology.infection_seed import InfectionSeed
from june.epidemiology.infection import Immunity
from pathlib import Path
from june.time import Timer
from june.groups import Cemeteries, Household

path_pwd = Path(__file__)
dir_pwd = path_pwd.parent
constant_config = (
    dir_pwd.parent.parent.parent
    / "configs/defaults/transmission/TransmissionConstant.yaml"
)


@pytest.fixture(name="world", scope="module")
def create_world():
    household = Household()
    people = [
        Person.from_attributes(age=np.random.randint(0, 100), sex="f")
        for i in range(10000)
    ]
    for person in people:
        household.add(person)
    world = World()
    world.people = Population(people)
    area_1 = Area(name="area_1", super_area=None, coordinates=None)
    area_1.people = people[: int(len(people) / 2)]
    area_2 = Area(name="area_2", super_area=None, coordinates=None)
    area_2.people = people[int(len(people) / 2) :]
    super_area_1 = SuperArea("super_1", areas=[area_1], coordinates=(1.0, 1.0))
    super_area_2 = SuperArea("super_2", areas=[area_2], coordinates=(1.0, 2.0))
    region1 = Region(name="London", super_areas=[super_area_1])
    region2 = Region(name="North East", super_areas=[super_area_2])
    super_area_1.region = region1
    super_area_2.region = region2
    area_1.super_area = super_area_1
    area_2.super_area = super_area_2
    world.super_areas = SuperAreas([super_area_1, super_area_2])
    world.regions = Regions([region1, region2])
    world.cemeteries = Cemeteries()
    return world


def clean_world(world):
    for person in world.people:
        person.infection = None
        person.immunity = Immunity()


def test__simplest_seed(world, selector):
    clean_world(world)
    date = "2020-03-01"
    seed = InfectionSeed.from_uniform_cases(
        world=world,
        infection_selector=selector,
        cases_per_capita=0.1,
        date=date,
        seed_past_infections=False,
    )
    seed.unleash_virus_per_day(date=pd.to_datetime(date), time=0.0, record=None)
    n_people = len(world.people)
    infected_people = len([person for person in world.people if person.infected])
    assert np.isclose(infected_people, 0.1 * n_people, rtol=1e-1)


def test__seed_strength(world, selector):
    clean_world(world)
    date = "2020-03-01"
    seed = InfectionSeed.from_uniform_cases(
        world=world,
        infection_selector=selector,
        cases_per_capita=0.05,
        date=date,
        seed_strength=10,
        seed_past_infections=False,
    )
    seed.unleash_virus_per_day(date=pd.to_datetime(date), time=0.0, record=None)
    n_people = len(world.people)
    infected_people = len([person for person in world.people if person.infected])
    assert np.isclose(infected_people, 0.5 * n_people, rtol=1e-1)


def test__infection_per_day(world, selector):
    clean_world(world)
    cases_per_region_df = pd.DataFrame(
        {
            "date": ["2020-04-20", "2020-04-21"],
            "London": [0.2, 0.1],
            "North East": [0.3, 0.4],
        }
    )
    cases_per_region_df.set_index("date", inplace=True)
    cases_per_region_df.index = pd.to_datetime(cases_per_region_df.index)
    seed = InfectionSeed.from_global_age_profile(
        world=world,
        infection_selector=selector,
        daily_cases_per_region=cases_per_region_df,
        seed_past_infections=False,
    )
    assert seed.min_date.strftime("%Y-%m-%d") == "2020-04-20"
    assert seed.max_date.strftime("%Y-%m-%d") == "2020-04-21"
    timer = Timer(initial_day="2020-04-20", total_days=7)
    seed.unleash_virus_per_day(timer.date, time=0)
    n_sa1 = len(world.super_areas[0].people)
    n_sa2 = len(world.super_areas[0].people)
    next(timer)
    assert np.isclose(
        len([person for person in world.super_areas[0].people if person.infected])
        / n_sa1,
        0.2,
        rtol=1e-1,
    )
    assert np.isclose(
        len([person for person in world.super_areas[1].people if person.infected])
        / n_sa2,
        0.3,
        rtol=1e-1,
    )

    seed.unleash_virus_per_day(timer.date, time=0)
    next(timer)
    assert np.isclose(
        len([person for person in world.super_areas[0].people if person.infected])
        / n_sa1,
        0.2,
        rtol=1e-1,
    )
    assert np.isclose(
        len([person for person in world.super_areas[1].people if person.infected])
        / n_sa2,
        0.3,
        rtol=1e-1,
    )

    seed.unleash_virus_per_day(timer.date, time=0)
    next(timer)
    assert np.isclose(
        len([person for person in world.super_areas[0].people if person.infected])
        / n_sa1,
        0.2 + 0.1,
        rtol=1e-1,
    )
    assert np.isclose(
        len([person for person in world.super_areas[1].people if person.infected])
        / n_sa2,
        0.3 + 0.4,
        rtol=1e-1,
    )

    seed.unleash_virus_per_day(timer.date, time=0)
    next(timer)
    assert np.isclose(
        len([person for person in world.super_areas[0].people if person.infected])
        / n_sa1,
        0.2 + 0.1,
        rtol=1e-1,
    )
    assert np.isclose(
        len([person for person in world.super_areas[1].people if person.infected])
        / n_sa2,
        0.3 + 0.4,
        rtol=1e-1,
    )

    seed.unleash_virus_per_day(timer.date, time=0)
    next(timer)
    assert np.isclose(
        len([person for person in world.super_areas[0].people if person.infected])
        / n_sa1,
        0.2 + 0.1,
        rtol=1e-1,
    )
    assert np.isclose(
        len([person for person in world.super_areas[1].people if person.infected])
        / n_sa2,
        0.3 + 0.4,
        rtol=1e-1,
    )


def test__age_profile(world, selector):
    clean_world(world)
    cases_per_region_df = pd.DataFrame(
        {
            "date": ["2020-04-20", "2020-04-21"],
            "London": [0.2, 0.1],
            "North East": [0.3, 0.4],
        }
    )
    cases_per_region_df.set_index("date", inplace=True)
    cases_per_region_df.index = pd.to_datetime(cases_per_region_df.index)
    seed = InfectionSeed.from_global_age_profile(
        world=world,
        infection_selector=selector,
        daily_cases_per_region=cases_per_region_df,
        age_profile={"0-9": 0.0, "10-39": 1.0, "40-100": 0.0},
        seed_past_infections=False,
    )
    seed.unleash_virus_per_day(pd.to_datetime("2020-04-20"), time=0)
    should_not_infected = [
        person
        for person in world.people
        if person.infected and (person.age < 10 or person.age >= 40)
    ]

    assert len(should_not_infected) == 0
    should_infected = len(
        [
            person
            for person in world.people
            if person.infected and (person.age >= 10 and person.age < 40)
        ]
    )
    target = (39 - 10) / 100 * 0.25
    assert np.isclose(should_infected / len(world.people), target, rtol=2e-1)


def test__ignore_previously_infected(world, selector):
    clean_world(world)
    for person in world.people[::2]:
        person.immunity.add_immunity([selector.infection_class.infection_id()])

    date = "2020-03-01"
    seed = InfectionSeed.from_uniform_cases(
        world=world,
        infection_selector=selector,
        cases_per_capita=0.1,
        date=date,
        seed_past_infections=False,
    )
    seed.unleash_virus_per_day(date=pd.to_datetime(date), time=0.0, record=None)
    n_people = len(world.people)
    infected_people = len([person for person in world.people if person.infected])
    immune_people = len(
        [
            person
            for person in world.people
            if person.immunity.is_immune(selector.infection_class.infection_id())
        ]
    )
    assert np.isclose(infected_people, 0.1 * n_people, rtol=1e-1)
    assert np.isclose(immune_people, (0.1 + 0.5) * n_people, rtol=1e-1)


def test__seed_past_days(world, selector):
    clean_world(world)
    cases_per_region_df = pd.DataFrame(
        {
            "date": ["2019-02-01", "2020-03-31", "2020-04-01"],
            "London": [0.5, 0.1, 0.2],
            "North East": [0.3, 0.2, 0.3],
        }
    )
    cases_per_region_df.set_index("date", inplace=True)
    cases_per_region_df.index = pd.to_datetime(cases_per_region_df.index)
    seed = InfectionSeed.from_global_age_profile(
        world=world,
        infection_selector=selector,
        daily_cases_per_region=cases_per_region_df,
        seed_past_infections=True,
    )
    timer = Timer(initial_day="2020-04-01", total_days=7)
    seed.unleash_virus_per_day(timer.date, time=0)
    recovered = 0
    infected_1 = 0
    infected_2 = 0
    for person in world.people:
        if person.infected:
            if person.infection.start_time == -1:
                infected_1 += 1
            elif person.infection.start_time == 0:
                infected_2 += 1
            else:
                assert False
        else:
            if person.immunity.is_immune(selector.infection_class.infection_id()):
                recovered += 1
    n_people = len(world.people)
    expected_recovered = (0.5 * 0.5 + 0.5 * 0.3) * n_people
    expected_inf1 = (0.5 * 0.1 + 0.5 * 0.2) * n_people
    expected_inf2 = (0.5 * 0.2 + 0.5 * 0.3) * n_people
    assert np.isclose(infected_1, expected_inf1, rtol=1e-1)
    assert np.isclose(infected_2, expected_inf2, rtol=1e-1)
    assert np.isclose(recovered, expected_recovered, rtol=1e-1)


def test__account_secondary_infections(world, selector):
    clean_world(world)
    cases_per_region_df = pd.DataFrame(
        {
            "date": ["2020-02-01", "2020-02-02", "2020-02-03"],
            "London": [0.5, 0.5, 0.1],
            "North East": [0.3, 0.2, 0.3],
        }
    )
    cases_per_region_df.set_index("date", inplace=True)
    cases_per_region_df.index = pd.to_datetime(cases_per_region_df.index)
    seed = InfectionSeed.from_global_age_profile(
        world=world,
        infection_selector=selector,
        daily_cases_per_region=cases_per_region_df,
        seed_past_infections=False,
        account_secondary_infections=True,
    )
    timer = Timer(initial_day="2020-02-02", total_days=7)
    seed.unleash_virus_per_day(timer.date, time=0)
    infected = 0
    london = world.regions.get_from_name("London")
    for person in london.people:
        if person.infected:
            infected += 1
    n_people_london = len(london.people)
    assert np.isclose(infected, 0.5 * n_people_london, rtol=0.1)

    timer = Timer(initial_day="2020-02-03", total_days=7)
    to_infect = int(0.25 * n_people_london)
    for person in london.people:
        if not person.infected:
            selector.infect_person_at_time(person=person, time=timer.now)
            to_infect -= 1
        if to_infect <= 0:
            break
    infected = 0
    for person in london.people:
        if person.infected:
            infected += 1
    assert np.isclose(infected, 0.75 * n_people_london, rtol=0.1)
    # No more people should be infected
    seed.unleash_virus_per_day(timer.date, time=0)
    infected = 0
    for person in london.people:
        if person.infected:
            infected += 1
    assert np.isclose(infected, 0.75 * n_people_london, rtol=0.1)


import yaml
import pytest
import pandas as pd
import numpy as np
from june.epidemiology.infection_seed import Observed2Cases
from june.demography import Person
from june.epidemiology.infection.trajectory_maker import TrajectoryMaker
from june import paths


@pytest.fixture(name="oc")
def get_oc(health_index_generator):
    area_super_region_df = pd.DataFrame(
        {
            "area": ["beautiful"],
            "super_area": ["marvellous"],
            "region": ["magnificient"],
        }
    )
    area_super_region_df.set_index("area", inplace=True)
    age_per_area_dict = {str(i): 0 for i in range(101)}
    age_per_area_dict["50"] = 1
    age_per_area_df = pd.DataFrame(age_per_area_dict, index=["beautiful"])
    female_fraction_per_area_dict = {str(i): 0 for i in range(0, 101, 5)}
    female_fraction_per_area_dict["50"] = 1.0
    female_fraction_per_area_df = pd.DataFrame(
        female_fraction_per_area_dict, index=["beautiful"]
    )
    with open(
        paths.configs_path
        / "defaults/epidemiology/infection/symptoms/trajectories.yaml"
    ) as f:
        trajectories = yaml.safe_load(f)["trajectories"]
    symptoms_trajectories = [
        TrajectoryMaker.from_dict(trajectory) for trajectory in trajectories
    ]
    return Observed2Cases(
        age_per_area_df=age_per_area_df,
        female_fraction_per_area_df=female_fraction_per_area_df,
        area_super_region_df=area_super_region_df,
        health_index_generator=health_index_generator,
        symptoms_trajectories=symptoms_trajectories,
    )


@pytest.fixture(name="oc_multiple_super_areas")
def get_oc_multiple_super_areas(health_index_generator):
    area_super_region_df = pd.DataFrame(
        {
            "area": ["area_1", "area_2", "area_3"],
            "super_area": ["super_1", "super_2", "super_3"],
            "region": ["magnificient", "magnificient", "magnificient"],
        }
    )
    area_super_region_df.set_index("area", inplace=True)
    age_per_area_dict = {str(i): [0, 0, 0] for i in range(101)}
    age_per_area_dict["50"] = [1, 1, 1]
    age_per_area_df = pd.DataFrame(
        age_per_area_dict, index=["area_1", "area_2", "area_3"]
    )
    female_fraction_per_area_dict = {str(i): [0, 0, 0] for i in range(0, 101, 5)}
    female_fraction_per_area_dict["50"] = [1, 1, 1]
    female_fraction_per_area_df = pd.DataFrame(
        female_fraction_per_area_dict, index=["area_1", "area_2", "area_3"]
    )
    return Observed2Cases(
        age_per_area_df=age_per_area_df,
        female_fraction_per_area_df=female_fraction_per_area_df,
        area_super_region_df=area_super_region_df,
        health_index_generator=health_index_generator,
        regional_infections_per_hundred_thousand=2.0e8,
    )


def test__generate_demography_dfs_by_region(oc):
    assert oc.females_per_age_region_df.loc["magnificient"]["50"] == 1
    assert oc.females_per_age_region_df.loc["magnificient"].sum() == 1
    assert oc.males_per_age_region_df.loc["magnificient"].sum() == 0


def test__avg_rates_by_age_and_sex(oc):
    rates_dict = oc.get_symptoms_rates_per_age_sex()
    assert list(rates_dict.keys()) == ["m", "f"]
    np.testing.assert_equal(np.array(list(rates_dict["f"].keys())), np.arange(100))
    np.testing.assert_equal(np.array(list(rates_dict["m"].keys())), np.arange(100))
    assert rates_dict["f"][0].shape[0] == 8
    avg_death_rate = oc.weight_rates_by_age_sex_per_region(
        rates_dict, symptoms_tags=("dead_home", "dead_hospital", "dead_icu")
    )
    np.testing.assert_equal(
        avg_death_rate["magnificient"],
        np.array(
            np.diff(
                oc.health_index_generator(Person(age=50, sex="f"), infection_id=0),
                prepend=0.0,
                append=1.0,
            )
        )[[5, 6, 7]],
    )


def test__expected_cases(oc):
    n_cases = oc.get_latent_cases_from_observed(
        n_observed=20, avg_rates=[0.2, 0.1, 0.1]
    )
    assert n_cases == 20 / 0.4


def test__latent_cases_per_region(oc):
    n_observed_df = pd.DataFrame(
        {"date": ["2020-04-20", "2020-04-21"], "magnificient": [100, 200]}
    )
    n_observed_df.set_index("date", inplace=True)
    n_observed_df.index = pd.to_datetime(n_observed_df.index)
    n_expected_true_df = pd.DataFrame(
        {"date": ["2020-04-10", "2020-04-11"], "magnificient": [100 / 0.4, 200 / 0.4]}
    )
    n_expected_true_df.set_index("date", inplace=True)
    n_expected_true_df.index = pd.to_datetime(n_expected_true_df.index)
    avg_death_rate = {"magnificient": [0.2, 0.1, 0.1]}
    n_expected_df = oc.get_latent_cases_per_region(n_observed_df, 10, avg_death_rate)
    pd.testing.assert_frame_equal(n_expected_df, n_expected_true_df)


def test__filter_trajectories(oc):
    hospitalised_trajectories = oc.filter_symptoms_trajectories(
        oc.symptoms_trajectories, symptoms_to_keep=["hospitalised"]
    )
    for trajectory in hospitalised_trajectories:
        symptom_tags = [stage.symptoms_tag.name for stage in trajectory.stages]
        assert "hospitalised" in symptom_tags

    dead_trajectories = oc.filter_symptoms_trajectories(
        oc.symptoms_trajectories, symptoms_to_keep=["dead_hospital", "dead_icu"]
    )
    for trajectory in dead_trajectories:
        symptom_tags = [stage.symptoms_tag.name for stage in trajectory.stages]
        assert "dead" in symptom_tags[-1]


def test__median_completion_time(oc):
    assert oc.get_median_completion_time(oc.symptoms_trajectories[0].stages[1]) == 14


def test__get_time_it_takes_to_symptoms(oc):
    asymptomatic_trajectories = oc.filter_symptoms_trajectories(
        oc.symptoms_trajectories, symptoms_to_keep=["asymptomatic"]
    )
    assert (
        2.0
        < oc.get_time_it_takes_to_symptoms(asymptomatic_trajectories, ["asymptomatic"])[
            0
        ]
        < 5.0
    )

    hospitalised_trajectories = oc.filter_symptoms_trajectories(
        oc.symptoms_trajectories, symptoms_to_keep=["hospitalised", "intensive_care"]
    )
    assert len(hospitalised_trajectories) == 4
    times_to_hospital = oc.get_time_it_takes_to_symptoms(
        hospitalised_trajectories, ["hospitalised", "intensive_care"]
    )
    for time in times_to_hospital:
        assert 1.0 < time < 16.0


def test__get_weighted_time_to_symptoms(oc):
    rates_dict = oc.get_symptoms_rates_per_age_sex()
    avg_rates = oc.weight_rates_by_age_sex_per_region(
        rates_dict,
        symptoms_tags=["hospitalised", "intensive_care", "dead_hospital", "dead_icu"],
    )
    avg_rates = avg_rates["magnificient"]
    hospitalised_trajectories = oc.filter_symptoms_trajectories(
        oc.symptoms_trajectories,
        symptoms_to_keep=[
            "hospitalised",
            "intensive_care",
            "dead_hospital",
            "dead_icu",
        ],
    )

    avg_time_to_hospital = oc.get_weighted_time_to_symptoms(
        hospitalised_trajectories, avg_rates, ["hospitalised", "intensive_care"]
    )
    assert 1.0 < avg_time_to_hospital < 13.0


def test__cases_from_observation_per_super_area(oc_multiple_super_areas):
    n_observed_df = pd.DataFrame(
        {"date": ["2020-04-20", "2020-04-21"], "magnificient": [100, 200]}
    )
    n_observed_df.set_index("date", inplace=True)
    n_observed_df.index = pd.to_datetime(n_observed_df.index)
    n_expected_true_df = pd.DataFrame(
        {
            "date": ["2020-04-10", "2020-04-11"],
            "super_1": [round(100 / 3 / 0.4), round(200 / 3 / 0.4)],
            "super_2": [round(100 / 3 / 0.4), round(200 / 3 / 0.4)],
            "super_3": [round(100 / 3 / 0.4), round(200 / 3 / 0.4)],
        }
    )
    n_expected_true_df.set_index("date", inplace=True)
    n_expected_true_df.index = pd.to_datetime(n_expected_true_df.index)
    avg_death_rate = {"magnificient": [0.2, 0.1, 0.1]}
    n_expected_per_region_df = oc_multiple_super_areas.get_latent_cases_per_region(
        n_observed_df, 10, avg_death_rate
    )
    super_area_weights = oc_multiple_super_areas.get_super_area_population_weights()

    assert (
        super_area_weights.groupby("region").sum()["weights"].loc["magnificient"] == 1.0
    )
    assert super_area_weights.loc["super_1"]["weights"] == pytest.approx(0.33, rel=0.05)
    assert super_area_weights.loc["super_2"]["weights"] == pytest.approx(0.33, rel=0.05)
    assert super_area_weights.loc["super_3"]["weights"] == pytest.approx(0.33, rel=0.05)
    n_expected_per_super_area_df = (
        oc_multiple_super_areas.convert_regional_cases_to_super_area(
            n_expected_per_region_df, starting_date="2020-04-10"
        )
    )

    pd.testing.assert_series_equal(
        n_expected_per_super_area_df.sum(axis=1),
        n_expected_per_region_df["magnificient"].astype(int),
        check_names=False,
    )


import datetime
import pytest
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.vaccines.vaccines import Vaccine
from june.epidemiology.vaccines.vaccination_campaign import (
    VaccinationCampaign,
    VaccinationCampaigns,
)
from june.demography import Person, Population
from june.world import World
from june.epidemiology.infection.infection import Delta, Omicron

delta_id = Delta.infection_id()
omicron_id = Omicron.infection_id()


@pytest.fixture(name="dates_values")
def make_dates_and_values():
    return {
        datetime.datetime(2022, 1, 1): 0.1,
        datetime.datetime(2022, 1, 2): (0.3 - 0.1) / 5 + 0.1,
        datetime.datetime(2022, 1, 6): 0.3,
        datetime.datetime(2022, 1, 7): 0.3,
        datetime.datetime(2022, 1, 8): 0.3,
        datetime.datetime(2022, 1, 9): (0.15 - 0.3) / 10 + 0.3,
        datetime.datetime(2022, 1, 19): 0.15,
        datetime.datetime(2022, 3, 2): (0.9 - 0.15) / 5 + 0.15,
        datetime.datetime(2022, 3, 6): 0.9,
        datetime.datetime(2022, 3, 7): 0.9,
        datetime.datetime(2022, 3, 8): 0.9,
        datetime.datetime(2022, 3, 9): (0.45 - 0.9) / 10 + 0.9,
        datetime.datetime(2022, 3, 19): 0.45,
        datetime.datetime(2022, 8, 1): 0.45,
    }


@pytest.fixture(name="vaccine")
def make_vaccine():
    effectiveness = [
        {"Delta": {"0-100": 0.3}, "Omicron": {"0-100": 0.3}},
        {"Delta": {"0-100": 0.9}, "Omicron": {"0-100": 0.9}},
    ]
    return Vaccine(
        "Pfizer",
        days_administered_to_effective=[5, 5, 5],
        days_effective_to_waning=[2, 2, 2],
        days_waning=[10, 10, 10],
        sterilisation_efficacies=effectiveness,
        symptomatic_efficacies=effectiveness,
        waning_factor=0.5,
    )


@pytest.fixture(name="vaccination_campaigns")
def make_campaigns(vaccine):
    days_to_next_dose = [0, 59]
    dose_numbers = [0, 1]
    vc = VaccinationCampaign(
        vaccine=vaccine,
        days_to_next_dose=[
            days_to_next_dose[dose_number] for dose_number in dose_numbers
        ],
        dose_numbers=dose_numbers,
        start_time="2022-01-01",
        end_time="2022-01-11",
        group_by="age",
        group_type="0-100",
        group_coverage=1.0,
    )
    return VaccinationCampaigns([vc])


@pytest.fixture(name="vaccine_epidemiology")
def make_epidemiology(selectors, vaccination_campaigns):
    return Epidemiology(
        infection_selectors=selectors, vaccination_campaigns=vaccination_campaigns
    )


@pytest.fixture(name="world")
def make_world():
    world = World()
    person = Person.from_attributes(age=30)
    person.immunity.susceptibility_dict = {delta_id: 0.9, omicron_id: 0.9}
    person.immunity.effective_multiplier_dict = {delta_id: 0.9, omicron_id: 0.9}
    population = Population([person])
    world.people = population
    return world


class TestEpi:
    def test__update_health(self, world, vaccine_epidemiology, dates_values):
        person = world.people[0]

        vc = vaccine_epidemiology.vaccination_campaigns.vaccination_campaigns[0]
        start_date = datetime.datetime(2022, 1, 1)
        vc.vaccinate(person, date=start_date)
        n_days = 500
        for days in range(n_days):
            date = start_date + datetime.timedelta(days)
            vaccine_epidemiology.update_health_status(
                world=world, time=0.0, duration=4, date=date, vaccinate=True
            )
            if date in dates_values:
                assert person.immunity.susceptibility_dict[delta_id] == pytest.approx(
                    1.0 - dates_values[date]
                )


import datetime
import pytest
import numpy as np

from june.demography import Person, Population

from june.epidemiology.vaccines.vaccines import Vaccine, VaccineTrajectory
from june.epidemiology.vaccines.vaccination_campaign import (
    VaccinationCampaign,
    VaccinationCampaigns,
)
from june.epidemiology.infection.infection import Delta, Omicron
from june.records import Record, RecordReader

delta_id = Delta.infection_id()
omicron_id = Omicron.infection_id()


@pytest.fixture(name="effectiveness")
def make_effectiveness():
    return [
        {"Delta": {"0-50": 0.6, "50-100": 0.7}},
        {"Delta": {"0-50": 0.8, "50-100": 0.9}},
        {"Delta": {"0-50": 0.9, "50-100": 0.99}},
    ]


@pytest.fixture(name="vaccine")
def make_vaccine(effectiveness):
    return Vaccine(
        "Pfizer",
        days_administered_to_effective=[0, 10, 5],
        days_effective_to_waning=[2, 2, 2],
        days_waning=[5, 5, 5],
        sterilisation_efficacies=effectiveness,
        symptomatic_efficacies=effectiveness,
        waning_factor=1.0,
    )


@pytest.fixture(name="fast_population")
def make_fast_population():
    people = []
    for age in range(100):
        for _ in range(10):
            person = Person.from_attributes(age=age)
            people.append(person)
    return Population(people)


def make_campaign(
    vaccine, group_by, group_type, dose_numbers=[0, 1], last_dose_type=None
):
    days_to_next_dose = [0, 20, 20, 20]
    return VaccinationCampaign(
        vaccine=vaccine,
        days_to_next_dose=[
            days_to_next_dose[dose_number] for dose_number in dose_numbers
        ],
        dose_numbers=dose_numbers,
        start_time="2022-01-01",
        end_time="2022-01-11",
        group_by=group_by,
        group_type=group_type,
        group_coverage=1.0,
        last_dose_type=last_dose_type,
    )


class TestWhenWho:
    def test__is_active(self, vaccine):
        campaign = make_campaign(vaccine=vaccine, group_by="age", group_type="50-100")
        assert campaign.is_active(datetime.datetime(2022, 1, 2)) is True
        assert campaign.is_active(datetime.datetime(2022, 1, 12)) is False

    def test__is_target_group(self, vaccine):
        young_person = Person(age=5, sex="f")
        old_person = Person(age=51, sex="m")
        campaign = make_campaign(vaccine=vaccine, group_by="age", group_type="50-100")
        assert campaign.is_target_group(person=young_person) is False
        assert campaign.is_target_group(person=old_person) is True

        campaign = make_campaign(vaccine=vaccine, group_by="sex", group_type="m")
        assert campaign.is_target_group(person=young_person) is False
        assert campaign.is_target_group(person=old_person) is True

    def test__should_be_vaccinated(self, vaccine):
        person = Person(age=5, sex="f")
        campaign = make_campaign(
            vaccine=vaccine, group_by="age", group_type="50-100", dose_numbers=[0, 1]
        )
        person.vaccinated = None
        assert campaign.has_right_dosage(person=person) is True
        person.vaccinated = 1
        assert campaign.has_right_dosage(person=person) is False
        person.vaccinated = 0
        assert campaign.has_right_dosage(person=person) is False

    def test__should_be_vaccinated_booster(self, vaccine):
        person = Person(age=5, sex="f")
        campaign = make_campaign(
            vaccine=vaccine, group_by="age", group_type="50-100", dose_numbers=[2]
        )
        person.vaccinated = None
        assert campaign.has_right_dosage(person=person) is False

        person.vaccinated = 1
        assert campaign.has_right_dosage(person=person) is True
        person.vaccinated = 0
        assert campaign.has_right_dosage(person=person) is False

    def test__should_be_vaccinated_last_dose(self, vaccine):
        person = Person(age=5, sex="f")
        campaign = make_campaign(
            vaccine=vaccine,
            group_by="age",
            group_type="50-100",
            dose_numbers=[2],
            last_dose_type=["Pfizer"],
        )
        person.vaccinated = 1
        person.vaccine_type = None
        assert campaign.has_right_dosage(person=person) is False
        person.vaccine_type = "Pfizer"
        assert campaign.has_right_dosage(person=person) is True
        person.vaccine_type = "Other"
        assert campaign.has_right_dosage(person=person) is False


class TestCampaign:
    def test__daily_prob(self, vaccine):
        campaign = make_campaign(vaccine=vaccine, group_by="age", group_type="0-100")
        campaign.group_coverage = 0.3
        total_days = 10
        assert campaign.daily_vaccination_probability(days_passed=5) == 0.3 * (
            1.0 / (total_days - 5 * 0.3)
        )

    def test__vaccinate(self, vaccine):
        person = Person.from_attributes(age=5, sex="f")
        person.immunity.susceptibility_dict = {delta_id: 0.9, omicron_id: 0.9}
        person.immunity.effective_multiplier_dict = {delta_id: 0.9, omicron_id: 0.9}

        date = datetime.datetime(2022, 1, 1)
        campaign = make_campaign(vaccine=vaccine, group_by="age", group_type="0-100")
        campaign.vaccinate(person, date=date)
        assert isinstance(person.vaccine_trajectory, VaccineTrajectory)
        assert person.vaccine_trajectory.doses[0].date_administered == date
        assert person.id in campaign.vaccinated_ids


class TestCampaigns:
    def test__apply(self, fast_population, effectiveness):
        pfizer = Vaccine(
            "Pfizer",
            days_administered_to_effective=[0, 10, 5],
            days_effective_to_waning=[2, 2, 2],
            days_waning=[5, 5, 5],
            sterilisation_efficacies=effectiveness,
            symptomatic_efficacies=effectiveness,
            waning_factor=1.0,
        )

        az = Vaccine(
            "AstraZeneca",
            days_administered_to_effective=[0, 10, 5],
            days_effective_to_waning=[2, 2, 2],
            days_waning=[5, 5, 5],
            sterilisation_efficacies=effectiveness,
            symptomatic_efficacies=effectiveness,
            waning_factor=1.0,
        )
        pfizer_campaign = VaccinationCampaign(
            vaccine=pfizer,
            days_to_next_dose=[0, 10],
            dose_numbers=[0, 1],
            start_time="2022-01-01",
            end_time="2022-01-11",
            group_by="age",
            group_type="0-50",
            group_coverage=0.6,
        )
        az_campaign = VaccinationCampaign(
            vaccine=az,
            days_to_next_dose=[0, 10],
            dose_numbers=[0, 1],
            start_time="2022-01-01",
            end_time="2022-01-11",
            group_by="age",
            group_type="0-50",
            group_coverage=0.1,
        )
        campaigns = VaccinationCampaigns([pfizer_campaign, az_campaign])
        start_date = datetime.datetime(2021, 12, 31)
        n_days = 11
        for days in range(n_days):
            date = start_date + datetime.timedelta(days=days)
            for person in fast_population:
                campaigns.apply(person=person, date=date)

        n_pfizer, n_az = 0, 0
        for person in fast_population:
            if person.vaccine_type == "Pfizer":
                n_pfizer += 1
            elif person.vaccine_type == "AstraZeneca":
                n_az += 1
        assert 0.6 * 0.5 * len(fast_population) == pytest.approx(n_pfizer, rel=0.1)
        assert 0.1 * 0.5 * len(fast_population) == pytest.approx(n_az, rel=0.15)


@pytest.fixture(name="population")
def make_population():
    people = []
    for age in range(100):
        for _ in range(100):
            person = Person.from_attributes(age=age)
            people.append(person)
    return Population(people)


@pytest.fixture(name="vax_campaigns")
def make_campaigns():
    ster_effectiveness = [
        {"Delta": {"0-100": 0.3}, "Omicron": {"0-100": 0.2}},
        {"Delta": {"0-100": 0.7}, "Omicron": {"0-100": 0.2}},
        {"Delta": {"0-100": 0.9}, "Omicron": {"0-100": 0.8}},
    ]
    sympto_effectiveness = [
        {"Delta": {"0-100": 0.3}, "Omicron": {"0-100": 0.5}},
        {"Delta": {"0-100": 0.7}, "Omicron": {"0-100": 0.2}},
        {"Delta": {"0-100": 0.7}, "Omicron": {"0-100": 0.1}},
    ]

    vaccine = Vaccine(
        "Test",
        days_administered_to_effective=[1, 2, 10],
        days_effective_to_waning=[0, 0, 0],
        days_waning=[0, 0, 0],
        sterilisation_efficacies=ster_effectiveness,
        symptomatic_efficacies=sympto_effectiveness,
        waning_factor=1.0,
    )
    return VaccinationCampaigns(
        [
            VaccinationCampaign(
                vaccine=vaccine,
                days_to_next_dose=[0, 9, 16],
                start_time="2021-03-01",
                end_time="2021-03-05",
                group_by="age",
                group_type="20-40",
                group_coverage=0.6,
                dose_numbers=[0, 1, 2],
            )
        ]
    )


class TestVaccinationInitialization:
    def test__to_finished(self, vax_campaigns):
        assert (
            vax_campaigns.vaccination_campaigns[0].days_from_administered_to_finished
            == 38
        )

    def test__vaccination_from_the_past(self, population, vax_campaigns):
        date = datetime.datetime(2021, 4, 30)
        vax_campaigns.apply_past_campaigns(people=population, date=date)
        n_vaccinated = 0
        for person in population:
            if (person.age < 20) or (person.age >= 40):
                assert person.vaccinated is None
            else:
                if person.vaccinated is not None:
                    n_vaccinated += 1
                    assert np.isclose(
                        person.immunity.susceptibility_dict[delta_id], 0.1
                    )
                    assert np.isclose(
                        person.immunity.susceptibility_dict[omicron_id], 0.2
                    )
                    assert np.isclose(
                        person.immunity.effective_multiplier_dict[delta_id], 0.3
                    )
                    assert np.isclose(
                        person.immunity.effective_multiplier_dict[omicron_id], 0.9
                    )
        assert np.isclose(n_vaccinated, 60 * 20, atol=0, rtol=0.1)

    def test__record_saving(self, fast_population, vax_campaigns):
        record = Record(record_path="results")
        dates = vax_campaigns.collect_all_dates_in_past(
            current_date=datetime.datetime(2021, 5, 1)
        )
        assert len(set(dates)) == len(dates)
        vax_campaigns.apply_past_campaigns(
            people=fast_population, date=datetime.datetime(2021, 5, 1), record=record
        )
        n_vaccinated = 0
        for person in fast_population:
            if person.vaccinated is not None:
                n_vaccinated += 1
        read = RecordReader(results_path="results")
        vaccines_df = read.table_to_df("vaccines", "vaccinated_ids")
        first_dose_df = vaccines_df[vaccines_df["dose_numbers"] == 0]
        second_dose_df = vaccines_df[vaccines_df["dose_numbers"] == 1]
        third_dose_df = vaccines_df[vaccines_df["dose_numbers"] == 2]
        assert len(first_dose_df) == n_vaccinated
        assert len(third_dose_df) == n_vaccinated
        assert len(first_dose_df) == len(second_dose_df)


import pytest
import datetime
from june.epidemiology.vaccines import Vaccine
from june.epidemiology.vaccines.vaccines import Efficacy, Dose, VaccineTrajectory

from june import Person
from june.epidemiology.infection.infection import Delta, Omicron

delta_id = Delta.infection_id()
omicron_id = Omicron.infection_id()


@pytest.fixture(name="efficacy")
def make_efficacy():
    return Efficacy(
        infection={delta_id: 0.9, omicron_id: 0.2},
        symptoms={delta_id: 0.4, omicron_id: 0.1},
        waning_factor=0.5,
    )


@pytest.fixture(name="prior_efficacy")
def make_prior_efficacy():
    return Efficacy(
        infection={delta_id: 0.1, omicron_id: 0.1},
        symptoms={delta_id: 0.1, omicron_id: 0.1},
        waning_factor=0.5,
    )


@pytest.fixture(name="dose")
def make_dose(efficacy, prior_efficacy):
    return Dose(
        number=0,
        days_administered_to_effective=5,
        days_effective_to_waning=2,
        days_waning=10,
        efficacy=efficacy,
        prior_efficacy=prior_efficacy,
        date_administered=datetime.datetime(2022, 1, 1),
    )


@pytest.fixture(name="dates_values")
def make_dates_and_values():
    return {
        datetime.datetime(2022, 1, 1): 0.1,
        datetime.datetime(2022, 1, 2): (0.3 - 0.1) / 5 + 0.1,
        datetime.datetime(2022, 1, 6): 0.3,
        datetime.datetime(2022, 1, 7): 0.3,
        datetime.datetime(2022, 1, 8): 0.3,
        datetime.datetime(2022, 1, 9): (0.15 - 0.3) / 10 + 0.3,
        datetime.datetime(2022, 1, 19): 0.15,
        datetime.datetime(2022, 3, 2): (0.9 - 0.15) / 5 + 0.15,
        datetime.datetime(2022, 3, 6): 0.9,
        datetime.datetime(2022, 3, 7): 0.9,
        datetime.datetime(2022, 3, 8): 0.9,
        datetime.datetime(2022, 3, 9): (0.45 - 0.9) / 10 + 0.9,
        datetime.datetime(2022, 3, 19): 0.45,
        datetime.datetime(2022, 8, 1): 0.45,
    }


class TestEfficacy:
    def test_waning(self, efficacy):
        assert efficacy(protection_type="infection", infection_id=delta_id) == 0.9
        assert efficacy(protection_type="symptoms", infection_id=omicron_id) == 0.1


class TestDose:
    def test_dates(self, dose):
        assert dose.date_effective == datetime.datetime(2022, 1, 6)
        assert dose.date_waning == datetime.datetime(2022, 1, 8)
        assert dose.date_finished == datetime.datetime(2022, 1, 18)

    def test_time_evolution(self, dose):
        dates = [
            datetime.datetime(2022, 1, 1),
            datetime.datetime(2022, 1, 2),
            datetime.datetime(2022, 1, 9),
            datetime.datetime(2022, 1, 20),
        ]
        values = [0.1, (0.9 - 0.1) / 5 + 0.1, (0.45 - 0.9) / 10 + 0.9, 0.45]
        for date, value in zip(dates, values):
            assert (
                dose.get_efficacy(
                    date=date, infection_id=delta_id, protection_type="infection"
                )
                == value
            )


def get_trajectory_initial_efficacy(prior_efficacy):
    prior_efficacy = Efficacy(
        infection={delta_id: prior_efficacy, omicron_id: prior_efficacy},
        symptoms={delta_id: prior_efficacy, omicron_id: prior_efficacy},
        waning_factor=1.0,
    )
    first_dose_efficacy = Efficacy(
        infection={delta_id: 0.3, omicron_id: 0.3},
        symptoms={delta_id: 0.3, omicron_id: 0.3},
        waning_factor=0.5,
    )
    first_dose = Dose(
        number=0,
        days_administered_to_effective=5,
        days_effective_to_waning=2,
        days_waning=10,
        efficacy=first_dose_efficacy,
        prior_efficacy=prior_efficacy,
        date_administered=datetime.datetime(2022, 1, 1),
    )
    second_dose_efficacy = Efficacy(
        infection={delta_id: 0.9, omicron_id: 0.9},
        symptoms={delta_id: 0.9, omicron_id: 0.9},
        waning_factor=0.5,
    )
    first_dose_efficacy_waned = Efficacy(
        infection={delta_id: 0.15, omicron_id: 0.15},
        symptoms={delta_id: 0.15, omicron_id: 0.15},
        waning_factor=1.0,
    )

    second_dose = Dose(
        number=1,
        days_administered_to_effective=5,
        days_effective_to_waning=2,
        days_waning=10,
        efficacy=second_dose_efficacy,
        prior_efficacy=first_dose_efficacy_waned,
        date_administered=datetime.datetime(2022, 3, 1),
    )
    return VaccineTrajectory(
        doses=[first_dose, second_dose],
        name="holi",
        infection_ids=[delta_id, omicron_id],
    )


@pytest.fixture(name="trajectory")
def make_trajectory():
    return get_trajectory_initial_efficacy(0.1)


class TestVaccineTrajectory:
    def test_dose_index(self, trajectory):
        assert trajectory.get_dose_index(date=datetime.datetime(2022, 1, 1)) == 0
        assert trajectory.get_dose_index(date=datetime.datetime(2022, 1, 8)) == 0
        assert trajectory.get_dose_index(date=datetime.datetime(2022, 3, 1)) == 1
        assert trajectory.get_dose_index(date=datetime.datetime(2022, 3, 20)) == 1

    def test_dose_number(self, trajectory):
        assert trajectory.get_dose_number(date=datetime.datetime(2022, 1, 1)) == 0
        assert trajectory.get_dose_number(date=datetime.datetime(2022, 3, 20)) == 1

    def test_is_finished(self, trajectory):
        assert trajectory.is_finished(date=datetime.datetime(2022, 1, 1)) is False
        assert trajectory.is_finished(date=datetime.datetime(2022, 3, 19)) is True

    def test_time_evolution(self, trajectory, dates_values):
        n_days = 500
        for days in range(n_days):
            date = trajectory.first_dose_date + datetime.timedelta(days=days)
            trajectory.update_trajectory_stage(date=date)
            if date in dates_values:
                efficacy = trajectory.get_efficacy(
                    date=date, infection_id=delta_id, protection_type="infection"
                )
                assert dates_values[date] == pytest.approx(efficacy)

    def test__update_vaccine_effect(self, trajectory):
        person = Person.from_attributes(age=5, sex="f")
        person.immunity.susceptibility_dict = {delta_id: 0.9, omicron_id: 0.9}
        person.immunity.effective_multiplier_dict = {delta_id: 0.9, omicron_id: 0.9}
        date = datetime.datetime(2022, 1, 1)
        n_days = 200
        for days in range(n_days):
            date = trajectory.first_dose_date + datetime.timedelta(days=days)
            trajectory.update_vaccine_effect(person=person, date=date)
        assert person.immunity.susceptibility_dict[delta_id] == pytest.approx(0.55)
        assert person.immunity.effective_multiplier_dict[delta_id] == pytest.approx(
            0.55
        )

    def test__update_vaccine_effect_high_initial_immunity(
        self,
    ):
        trajectory = get_trajectory_initial_efficacy(0.9)
        person = Person.from_attributes(age=5, sex="f")
        person.immunity.susceptibility_dict = {delta_id: 0.1, omicron_id: 0.1}
        person.immunity.effective_multiplier_dict = {delta_id: 0.1, omicron_id: 0.1}
        date = datetime.datetime(2022, 1, 1)
        n_days = 200
        for days in range(n_days):
            date = trajectory.first_dose_date + datetime.timedelta(days=days)
            trajectory.update_vaccine_effect(person=person, date=date)

        assert person.immunity.susceptibility_dict[delta_id] == pytest.approx(0.1)
        assert person.immunity.effective_multiplier_dict[delta_id] == pytest.approx(0.1)


@pytest.fixture(name="vaccine")
def make_vaccine():
    effectiveness = [
        {"Delta": {"0-50": 0.6, "50-100": 0.7}},
        {"Delta": {"0-50": 0.8, "50-100": 0.9}},
        {"Delta": {"0-50": 0.9, "50-100": 0.99}},
    ]
    return Vaccine(
        "Pfizer",
        days_administered_to_effective=[0, 10, 5],
        days_effective_to_waning=[2, 2, 2],
        days_waning=[5, 5, 5],
        sterilisation_efficacies=effectiveness,
        symptomatic_efficacies=effectiveness,
        waning_factor=0.5,
    )


class TestVaccine:
    def test__infection_ids(self, vaccine):
        assert set(vaccine.infection_ids) == set([delta_id])

    def test__vt_generation(self, vaccine):
        young_person = Person.from_attributes(age=20)
        old_person = Person.from_attributes(age=70)
        vts = []
        for person in [young_person, old_person]:
            person.immunity.susceptibility_dict = {delta_id: 1.0, omicron_id: 1.0}
            person.immunity.effective_multiplier_dict = {delta_id: 1.0, omicron_id: 1.0}
            vts.append(
                vaccine.generate_trajectory(
                    person=person,
                    dose_numbers=[2],
                    date=datetime.datetime(2022, 1, 1),
                    days_to_next_dose=[0],
                )
            )
        for vt in vts:
            assert vt.doses[0].date_administered == datetime.datetime(2022, 1, 1)
            assert len(vt.doses) == 1
            assert vt.doses[0].number == 2
        assert vts[0].doses[0].efficacy.symptoms[delta_id] == 0.9
        assert vts[1].doses[0].efficacy.symptoms[delta_id] == 0.99

    def test__vt_generation_time_evolution(self, dates_values):
        effectiveness = [
            {"Delta": {"0-100": 0.3}, "Omicron": {"0-100": 0.3}},
            {"Delta": {"0-100": 0.9}, "Omicron": {"0-100": 0.9}},
        ]
        vaccine = Vaccine(
            "Pfizer",
            days_administered_to_effective=[5, 5, 5],
            days_effective_to_waning=[2, 2, 2],
            days_waning=[10, 10, 10],
            sterilisation_efficacies=effectiveness,
            symptomatic_efficacies=effectiveness,
            waning_factor=0.5,
        )

        person = Person.from_attributes(age=20)
        person.immunity.susceptibility_dict = {delta_id: 0.9, omicron_id: 0.9}
        person.immunity.effective_multiplier_dict = {delta_id: 0.9, omicron_id: 0.9}
        trajectory = vaccine.generate_trajectory(
            person=person,
            days_to_next_dose=[0, 59],
            dose_numbers=[0, 1],
            date=datetime.datetime(2022, 1, 1),
        )
        assert trajectory.doses[0].date_administered == datetime.datetime(2022, 1, 1)
        assert trajectory.doses[1].date_administered == datetime.datetime(2022, 3, 1)
        assert len(trajectory.doses) == 2

        n_days = 500
        for days in range(n_days):
            date = trajectory.first_dose_date + datetime.timedelta(days=days)
            trajectory.update_trajectory_stage(date=date)
            if date in dates_values:
                efficacy = trajectory.get_efficacy(
                    date=date, infection_id=delta_id, protection_type="infection"
                )
                assert dates_values[date] == pytest.approx(efficacy)


import numpy as np
import pytest

from june.event import DomesticCare
from june.world import World
from june.groups import Household, Households
from june.geography import Area, Areas, SuperArea, SuperAreas, Region
from june.demography import Population, Person


class TestDomesticCare:
    @pytest.fixture(name="world")
    def make_world(self):
        world = World()
        region = Region()
        super_areas = []
        areas = []
        households = []
        people = []
        for _ in range(10):
            _areas = []
            for _ in range(5):
                area = Area()
                area.households = []
                areas.append(area)
                _areas.append(area)
                for _ in range(5):
                    household = Household(type="old", area=area)
                    p1 = Person.from_attributes(age=80)
                    p2 = Person.from_attributes(age=75)
                    household.add(p1)
                    people.append(p1)
                    people.append(p2)
                    household.add(p2)
                    households.append(household)
                    area.households.append(household)
                for _ in range(10):
                    household = Household(type="random", area=area)
                    p1 = Person.from_attributes(age=50)
                    p2 = Person.from_attributes(age=30)
                    household.add(p1)
                    household.add(p2)
                    people.append(p1)
                    people.append(p2)
                    area.households.append(household)
                    households.append(household)
            super_area = SuperArea(areas=_areas, region=region)
            for area in _areas:
                area.super_area = super_area
            super_areas.append(super_area)
        world.areas = Areas(areas, ball_tree=False)
        world.super_areas = SuperAreas(super_areas, ball_tree=False)
        world.households = Households(households)
        world.people = Population(people)
        for person in world.people:
            person.busy = False
            person.subgroups.leisure = None
        for household in world.households:
            household.clear()
        return world

    @pytest.fixture(name="needs_care_probabilities")
    def make_probs(self):
        needs_care_probabilities = {"0-65": 0, "65-100": 0.3}
        return needs_care_probabilities

    @pytest.fixture(name="domestic_care")
    def make_domestic_care_event(self, needs_care_probabilities, world):
        domestic_care = DomesticCare(
            start_time="1900-01-01",
            end_time="2999-01-01",
            needs_care_probabilities=needs_care_probabilities,
        )
        domestic_care.initialise(world=world)
        return domestic_care

    def test__care_probs_read(self, domestic_care):
        assert domestic_care.needs_care_probabilities[75] == 0.3
        assert domestic_care.needs_care_probabilities[45] == 0.0

    def test__household_linking(
        self, domestic_care, world
    ):  # domestic needed for fixture
        has_at_least_one = False
        n_linked = 0
        total = 0
        probability_care = 1 - (0.7 * 0.7)
        for household in world.households:
            if household.type == "old":
                assert household.household_to_care is None
                total += 1
            else:
                if household.household_to_care:
                    n_linked += 1
                    assert household.household_to_care.type == "old"
                    has_at_least_one = True
        assert has_at_least_one
        assert np.isclose(n_linked / total, probability_care, rtol=0.1)

    def test__send_carers_during_leisure(self, domestic_care, world):
        # leisure only go weekdays leisure
        domestic_care.apply(world=world, activities=["leisure"], day_type="weekday")
        for household in world.households:
            if household.household_to_care:
                has_active = False
                for person in household.residents:
                    if person.leisure is not None:
                        has_active = True
                        assert person.busy
                        assert person in person.leisure
                        assert person in household.household_to_care.people
                        assert person.leisure.group.spec == "household"
                        assert person.leisure.group.type == "old"
                        break
                assert has_active

    def test__carers_dont_go_weekends(self, domestic_care, world):
        # leisure only go weekdays leisure
        domestic_care.apply(world=world, activities=["leisure"], day_type="weekend")
        for household in world.households:
            if household.household_to_care:
                for person in household.residents:
                    assert person.leisure is None
                    assert not person.busy

    def test__carers_dont_go_outside_leisure(self, domestic_care, world):
        # leisure only go weekdays leisure
        domestic_care.apply(
            world=world, activities=["primary_activity"], day_type="weekday"
        )
        for household in world.households:
            if household.household_to_care:
                for person in household.residents:
                    assert person.leisure is None
                    assert not person.busy

    def test__residents_stay_home(self, domestic_care, world):
        domestic_care.apply(world=world, activities=["leisure"], day_type="weekday")
        active = False
        for household in world.households:
            if household.household_to_care:
                household_to_care = household.household_to_care
                for person in household_to_care.residents:
                    active = True
                    assert person in person.residence.people
                    assert person.busy
        assert active

    def test__care_beta(self, domestic_care, world):
        domestic_care.apply(world=world, activities=["leisure"], day_type="weekday")
        for household in world.households:
            if household.household_to_care:
                household_to_care = household.household_to_care
                assert household_to_care.receiving_care
                int_house = household.get_interactive_group(None)
                assert (
                    int_house.get_processed_beta(
                        {"household": 1, "household_visits": 2, "care_visits": 3}, {}
                    )
                    == 1.0
                )


import datetime
from june.event import Event


def test__event_dates():
    event = Event(start_time="2020-01-05", end_time="2020-12-05")
    assert event.start_time.strftime("%Y-%m-%d") == "2020-01-05"
    assert event.end_time.strftime("%Y-%m-%d") == "2020-12-05"
    assert event.is_active(datetime.datetime.strptime("2020-03-05", "%Y-%m-%d"))
    assert not event.is_active(datetime.datetime.strptime("2030-03-05", "%Y-%m-%d"))


import pytest
import numpy as np

from june.event import IncidenceSetter
from june.world import World
from june.geography import Area, SuperArea, Areas, SuperAreas, Region, Regions
from june.demography import Person, Population

incidence_per_region = {"London": 0.1, "North East": 0.01}


class TestIncidenceSetter:
    @pytest.fixture(name="world")
    def setup(self):
        world = World()

        london_area = Area()
        london_super_area = SuperArea(areas=[london_area])
        london = Region(name="London", super_areas=[london_super_area])

        ne_area = Area()
        ne_super_area = SuperArea(areas=[ne_area])
        ne = Region(name="North East", super_areas=[ne_super_area])
        people = []

        for i in range(1000):
            person = Person.from_attributes()
            london_area.add(person)
            people.append(person)

        for i in range(1000):
            person = Person.from_attributes()
            ne_area.add(person)
            people.append(person)

        world.areas = Areas([ne_area, london_area], ball_tree=False)
        world.super_areas = SuperAreas(
            [ne_super_area, london_super_area], ball_tree=False
        )
        world.regions = Regions([london, ne])
        world.people = Population(people)
        return world

    def test__removing_infections(self, world, policy_simulator):
        selector = policy_simulator.epidemiology.infection_selectors[0]
        # infect everyone
        for person in world.people:
            selector.infect_person_at_time(person, 0.0)

        setter = IncidenceSetter(
            start_time="2020-03-01",
            end_time="2020-03-02",
            incidence_per_region=incidence_per_region,
        )
        setter.apply(world, policy_simulator)
        london = world.regions.get_from_name("London")
        ne = world.regions.get_from_name("North East")

        # infected london
        infected = 0
        for person in london.people:
            if person.infected:
                infected += 1
        assert np.isclose(infected, 0.1 * len(london.people), rtol=1e-2, atol=0)

        # infected north east
        infected = 0
        for person in ne.people:
            if person.infected:
                infected += 1
        assert np.isclose(infected, 0.01 * len(ne.people), rtol=1e-2, atol=0)

    def test__adding_infections(self, world, policy_simulator):
        selector = policy_simulator.epidemiology.infection_selectors[0]
        selector.infect_person_at_time(
            world.regions.get_from_name("London").people[0], 0.0
        )
        selector.infect_person_at_time(
            world.regions.get_from_name("North East").people[0], 0.0
        )
        setter = IncidenceSetter(
            start_time="2020-03-01",
            end_time="2020-03-02",
            incidence_per_region=incidence_per_region,
        )
        setter.apply(world, policy_simulator)

        london = world.regions.get_from_name("London")
        ne = world.regions.get_from_name("North East")

        # infected london
        infected = 0
        for person in london.people:
            if person.infected:
                infected += 1
        assert np.isclose(infected, 0.1 * len(london.people), rtol=1e-2, atol=0)

        # infected north east
        infected = 0
        for person in ne.people:
            if person.infected:
                infected += 1
        assert np.isclose(infected, 0.01 * len(ne.people), rtol=1e-2, atol=0)


import pytest
import numpy as np

from june.demography import Person
from june.epidemiology.infection import (
    B117,
    Covid19,
    InfectionSelector,
    InfectionSelectors,
)
from june.epidemiology.epidemiology import Epidemiology
from june.event import Mutation


class MockRegion:
    def __init__(self, name):
        self.name = name


class MockArea:
    def __init__(self, super_area):
        self.super_area = super_area


class MockSuperArea:
    def __init__(self, region):
        self.region = region


class MockSimulator:
    def __init__(self, epidemiology):
        self.epidemiology = epidemiology


class MockWorld:
    def __init__(self, people):
        self.people = people


@pytest.fixture(name="c19_selector")
def covid19_selector(health_index_generator):
    return InfectionSelector(
        health_index_generator=health_index_generator, infection_class=Covid19
    )


@pytest.fixture(name="c20_selector")
def covid20_selector(health_index_generator):
    return InfectionSelector(
        infection_class=B117, health_index_generator=health_index_generator
    )


class TestMutations:
    @pytest.fixture(name="people")
    def create_pop(self, c19_selector):
        people = []
        london = MockRegion("London")
        london_sa = MockSuperArea(region=london)
        ne = MockRegion("North East")
        ne_sa = MockSuperArea(region=ne)
        for i in range(0, 1000):
            person = Person.from_attributes()
            if i % 2 == 0:
                person.area = MockArea(super_area=london_sa)
            else:
                person.area = MockArea(super_area=ne_sa)
            people.append(person)
        for person in people:
            c19_selector.infect_person_at_time(person, 0)
        return people

    def test_mutation(self, people, c19_selector, c20_selector):
        infection_selectors = InfectionSelectors([c19_selector, c20_selector])
        epidemiology = Epidemiology(infection_selectors=infection_selectors)
        simulator = MockSimulator(epidemiology)
        world = MockWorld(people=people)
        mutation = Mutation(
            start_time="2020-11-01",
            end_time="2020-11-02",
            mutation_id=B117.infection_id(),
            regional_probabilities={"London": 0.5, "North East": 0.01},
        )
        mutation.initialise()
        mutation.apply(world=world, simulator=simulator)
        c19_london = 0
        c19_ne = 0
        c20_london = 0
        c20_ne = 0
        for person in world.people:
            if person.infection.infection_id() == Covid19.infection_id():
                assert person.infection.__class__.__name__ == "Covid19"
                if person.region.name == "London":
                    c19_london += 1
                else:
                    assert person.region.name == "North East"
                    c19_ne += 1
            else:
                assert person.infection.infection_id() == B117.infection_id()
                assert person.infection.__class__.__name__ == "B117"
                if person.region.name == "London":
                    c20_london += 1
                else:
                    assert person.region.name == "North East"
                    c20_ne += 1
        assert np.isclose(c20_london / (c20_london + c19_london), 0.5, rtol=1e-1)
        assert np.isclose(c20_ne / (c20_ne + c19_ne), 0.01, atol=0.01)


from pathlib import Path

from june.geography import City, Cities, SuperArea, SuperAreas

city_test_file = Path(__file__).parent / "cities.csv"


class TestCity:
    def test__city_setup(self):
        city = City(name="Durham", super_areas=["A1", "A2"])
        assert city.name == "Durham"
        assert city.super_areas == ["A1", "A2"]

    def test__city_setup_from_file(self):
        city = City.from_file(name="Durham", city_super_areas_filename=city_test_file)
        assert list(city.super_areas) == ["a1", "a2"]
        city = City.from_file(
            name="Newcastle", city_super_areas_filename=city_test_file
        )
        assert list(city.super_areas) == ["b1"]
        city = City.from_file(name="Leeds", city_super_areas_filename=city_test_file)
        assert list(city.super_areas) == ["c1", "c2", "c3"]

    def test__cities_for_super_areas(self):
        super_areas = SuperAreas(
            [
                SuperArea(name="c1", coordinates=[1, 2]),
                SuperArea(name="c2", coordinates=[3, 4]),
            ]
        )
        cities = Cities.for_super_areas(
            super_areas, city_super_areas_filename=city_test_file
        )
        assert cities[0].name == "Leeds"
        assert super_areas[0].city == cities[0]
        assert super_areas[1].city == cities[0]
        assert list(cities[0].super_areas) == ["c1", "c2"]


import pytest
import pandas as pd
import numpy.testing as npt

from june.geography import geography as g


@pytest.fixture()
def geography_example():
    return g.Geography.from_file(filter_key={"super_area": ["E02000140"]})


def test__create_geographical_hierarchy():
    hierarchy_df = pd.DataFrame(
        {
            "area": ["area_1", "area_2", "area_3", "area_4"],
            "super_area": [
                "super_area_1",
                "super_area_1",
                "super_area_1",
                "super_area_2",
            ],
            "region": ["region_1", "region_1", "region_1", "region_2"],
        }
    )
    area_coordinates_df = pd.DataFrame(
        {
            "area": ["area_1", "area_2", "area_3", "area_4"],
            "longitude": [0.0, 1.0, 2.0, 3.0],
            "latitude": [0.0, 1.0, 2.0, 3.0],
        }
    )
    area_coordinates_df.set_index("area", inplace=True)
    super_area_coordinates_df = pd.DataFrame(
        {
            "super_area": ["super_area_1", "super_area_2"],
            "longitude": [0.0, 1.0],
            "latitude": [0.0, 1.0],
        }
    )
    super_area_coordinates_df.set_index("super_area", inplace=True)
    area_socioeconomic_indices_df = pd.Series(
        index=area_coordinates_df.index, data=[0.01, 0.02, 0.75, 0.90]
    )
    # area_socioeconomic_indices_df.set_index("area", inplace=True)
    areas, super_areas, regions = g.Geography.create_geographical_units(
        hierarchy=hierarchy_df,
        area_coordinates=area_coordinates_df,
        super_area_coordinates=super_area_coordinates_df,
        area_socioeconomic_indices=area_socioeconomic_indices_df,
    )

    assert len(areas) == 4
    assert len(super_areas) == 2
    assert len(regions) == 2

    assert regions[0].super_areas[0].name == super_areas[0].name
    assert regions[1].super_areas[0].name == super_areas[1].name

    assert super_areas[0].region == regions[0]
    assert super_areas[1].region == regions[1]

    assert super_areas[0].areas == [areas[0], areas[1], areas[2]]
    assert super_areas[1].areas == [areas[3]]

    assert areas[0].socioeconomic_index == 0.01
    assert areas[1].socioeconomic_index == 0.02
    assert areas[2].socioeconomic_index == 0.75
    assert (
        areas[3].socioeconomic_index == 0.90
    )  # this one is important, it's a single-area region.


def test__nr_of_members_in_units(geography_example):
    assert len(geography_example.areas) == 26
    assert len(geography_example.super_areas) == 1


def test__area_attributes(geography_example):
    area = geography_example.areas.get_from_name("E00003598")
    assert area.name == "E00003598"
    npt.assert_almost_equal(
        area.coordinates, [51.395954503652504, 0.10846483370388499], decimal=3
    )
    assert area.super_area.name == "E02000140"
    assert area.socioeconomic_index == 0.12


def test__super_area_attributes(geography_example):
    super_area = geography_example.super_areas.get_from_name("E02000140")
    assert super_area.name == "E02000140"
    npt.assert_almost_equal(
        super_area.coordinates, [51.40340615262757, 0.10741193961090514], decimal=3
    )
    assert "E00003595" in [area.name for area in super_area.areas]


def test__create_single_area():
    geography = g.Geography.from_file(filter_key={"area": ["E00120481"]})
    assert len(geography.areas) == 1


def test__geography_no_socioeconomic_index():
    geog_no_sei = g.Geography.from_file(
        filter_key={"area": ["E00003598", "E00120481"]},
        area_socioeconomic_index_filename=None,
    )
    for area in geog_no_sei.areas:
        assert area.socioeconomic_index is None


def test_create_ball_tree_for_super_areas():
    geo = g.Geography.from_file(filter_key={"super_area": ["E02004935", "E02000140"]})
    super_area = geo.super_areas.get_closest_super_areas(
        coordinates=[54.770512, -1.594221]
    )[0]
    assert super_area.name == "E02004935"
    assert (
        len(
            geo.super_areas.get_closest_super_areas(
                coordinates=[54.770512, -1.594221], k=2
            )
        )
        == 2
    )
    assert (
        len(geo.areas.get_closest_areas(coordinates=[54.770512, -1.594221], k=10)) == 10
    )


from pathlib import Path
import numpy as np

from june.geography import SuperArea, SuperAreas, Station, Stations, City
from june.geography.station import CityStation, InterCityStation

super_stations_test_file = Path(__file__).parent / "stations.csv"


class TestStations:
    def test__stations_setup(self):
        station = Station(city="Barcelona", super_area=SuperArea(name="b1"))
        assert station.city == "Barcelona"
        assert station.super_area.name == "b1"

    def test__stations_for_city_center(self):
        super_areas = SuperAreas(
            [
                SuperArea(name="b1", coordinates=[0, 0]),
                SuperArea(name="b2", coordinates=[1, 0]),
                SuperArea(name="b3", coordinates=[0, 1]),
                SuperArea(name="b4", coordinates=[-1, 0]),
                SuperArea(name="b5", coordinates=[0, -1]),
            ],
            ball_tree=True,
        )
        city = City(name="Barcelona", coordinates=[0, 0], super_area=super_areas[0])
        city_stations = Stations.from_city_center(
            city=city,
            number_of_stations=4,
            distance_to_city_center=500,
            super_areas=super_areas,
            type="city_station",
        )
        assert len(city_stations) == 4
        for st in city_stations:
            assert isinstance(st, CityStation)
        station_super_areas = []
        for station in city_stations:
            station_super_areas.append(station.super_area.name)
            assert station.city == "Barcelona"
            assert station.super_area.name in ["b1", "b2", "b3", "b4", "b5"]
        assert len(np.unique(station_super_areas)) == 4
        city_stations._construct_ball_tree()
        station = city_stations.get_closest_station([0.1, 0])
        assert station.coordinates[0] == 1
        assert station.coordinates[1] == 0
        station = city_stations.get_closest_station([-50, -10])
        assert station.coordinates[0] == -1
        assert station.coordinates[1] == 0
        inter_city_stations = Stations.from_city_center(
            city=city,
            number_of_stations=4,
            distance_to_city_center=500,
            super_areas=super_areas,
            type="inter_city_station",
        )
        for st in inter_city_stations:
            assert isinstance(st, InterCityStation)


import pytest
from june import paths
from june.groups.care_home import CareHome

default_config_file = paths.configs_path / "defaults/groups/carehome.yaml"


class TestCareHome:
    @pytest.fixture(name="carehome")
    def create_carehome(self):
        return CareHome(n_residents=30, n_workers=8, area="asd")

    def test__carehome_grouptype(self, carehome):
        assert carehome.SubgroupType.workers == 0
        assert carehome.SubgroupType.residents == 1
        assert carehome.SubgroupType.visitors == 2
        assert carehome.n_residents == 30
        assert carehome.area == "asd"
        assert carehome.n_workers == 8


# import pytest
#
# from june import World
# from june.world import generate_world_from_geography
# from june.demography.geography import Geography, Area
# from june.demography import Person, Demography
# from june.distributors import WorkerDistributor
# from june.commute import CommuteGenerator
# from june.groups import (
#    CommuteCity,
#    CommuteCities,
#    CommuteCityDistributor,
#    CommuteHub,
#    CommuteHubs,
#    CommuteHubDistributor,
#    CommuteUnit,
#    CommuteUnits,
#    CommuteUnitDistributor,
#    CommuteCityUnit,
#    CommuteCityUnits,
#    CommuteCityUnitDistributor,
# )
#
#
# @pytest.fixture(name="super_area_commute", scope="module")
# def super_area_name():
#    return "E02002559"
#
#
# @pytest.fixture(name="geography_commute", scope="module")
# def create_geography(super_area_companies):
#    return Geography.from_file(filter_key={"super_area": [super_area_commute]})
#
#
# @pytest.fixture(name="person")
# def create_person():
#    return Person(sex="m", age=44)
#
#
# class TestCommuteCity:
#    @pytest.fixture(name="city")
#    def create_city(self, super_area_commute):
#        return CommuteCity(
#            city="Manchester",
#            metro_msoas=super_area_commute,
#            metro_centroid=[-2, 52.0],
#        )
#
#    def test__city_grouptype(self, city):
#        assert len(city.people) == 0
#        assert len(city.commutehubs) == 0
#        assert len(city.commute_internal) == 0
#        assert len(city.commutecityunits) == 0
#
#
# class TestCommuteHub:
#    @pytest.fixture(name="hub")
#    def create_hub(self):
#        return CommuteHub(city="Manchester", lat_lon=[-2, 52.0],)
#
#    def test__hub_grouptype(self, hub):
#        assert len(hub.commute_through) == 0
#        assert len(hub.commuteunits) == 0
#
#
# class TestCommuteUnit:
#    @pytest.fixture(name="unit")
#    def create_hub(self):
#        return CommuteUnit(city="Manchester", commutehub_id=0, is_peak=False,)
#
#    def test__unit_grouptype(self, unit):
#        assert len(unit.people) == 0
#        assert unit.max_passengers != 0
#
#
# class TestCommuteCityUnit:
#    @pytest.fixture(name="unit")
#    def create_hub(self):
#        return CommuteCityUnit(city="Manchester", is_peak=False,)
#
#    def test__unit_grouptype(self, unit):
#        assert len(unit.people) == 0
#        assert unit.max_passengers != 0
#
#
# class TestNewcastle:
#    @pytest.fixture(name="super_area_commute_nc")
#    def super_area_name_nc(self):
#        return ["E02001731", "E02001729"]
#
#    @pytest.fixture(name="geography_commute_nc")
#    def create_geography_nc(self, super_area_commute_nc):
#        geography = Geography.from_file({"super_area": super_area_commute_nc})
#        return geography
#
#    @pytest.fixture(name="world_nc")
#    def create_world_nc(self, geography_commute_nc):
#        world = generate_world_from_geography(
#            geography_commute_nc, include_commute=False, include_households=False
#        )
#        worker_distr = WorkerDistributor.for_geography(geography_commute_nc)
#        worker_distr.distribute(geography_commute_nc.areas, geography_commute_nc.super_areas, world.people)
#        commute_generator = CommuteGenerator.from_file()
#
#        for area in world.areas:
#            commute_gen = commute_generator.regional_gen_from_msoarea(area.name)
#            for person in area.people:
#                person.mode_of_transport = commute_gen.weighted_random_choice()
#
#        return world
#
#    @pytest.fixture(name="commutecities_nc")
#    def create_cities_with_people(self, world_nc):
#        commutecities = CommuteCities.for_super_areas(world_nc.super_areas)
#        commutecity_distributor = CommuteCityDistributor(
#            commutecities.members, world_nc.super_areas.members
#        )
#        commutecity_distributor.distribute_people()
#
#        return commutecities
#
#    def test__commutecities(self, commutecities_nc):
#        assert len(commutecities_nc.members) == 11
#        assert (len(commutecities_nc.members[7].people)) > 0
#
#    @pytest.fixture(name="commutehubs_nc")
#    def create_commutehubs_with_people(self, commutecities_nc):
#        commutehubs = CommuteHubs(commutecities_nc.members)
#        commutehubs.from_file()
#        commutehubs.init_hubs()
#        commutehub_distributor = CommuteHubDistributor(commutecities_nc.members)
#        commutehub_distributor.from_file()
#        commutehub_distributor.distribute_people()
#
#        return commutehubs
#
#    def test__commutehubs(self, commutecities_nc, commutehubs_nc):
#        assert len(commutecities_nc.members[7].commutehubs) == 4
#        # assert len(commutecities_nc.members[7].commute_internal) > 0
#
#    @pytest.fixture(name="commuteunits_nc")
#    def create_commute_units_with_people(self, commutecities_nc, commutehubs_nc):
#        commuteunits = CommuteUnits(commutehubs_nc.members)
#        commuteunits.init_units()
#        commuteunit_distributor = CommuteUnitDistributor(commutehubs_nc.members)
#        commuteunit_distributor.distribute_people()
#
#        return commuteunits
#
#    def test__commuteunits(self, commuteunits_nc):
#        assert len(commuteunits_nc.members[0].people) == 0
#
#    @pytest.fixture(name="commutecityunits_nc")
#    def create_commute_city_units_with_people(self, commutecities_nc, commutehubs_nc):
#        commutecityunits = CommuteCityUnits(commutecities_nc.members)
#        commutecityunits.init_units()
#        commutecityunit_distributor = CommuteCityUnitDistributor(
#            commutecities_nc.members
#        )
#        commutecityunit_distributor.distribute_people()
#
#        return commutecityunits
#
#    def test__commutecityunits(self, commutecityunits_nc):
#        assert len(commutecityunits_nc.members) > 0
#        # assert len(commutecityunits_nc.members[0].people) > 0


import os
from pathlib import Path

import pytest
import numpy as np
from collections import defaultdict

from june.geography import Geography
from june.demography import Person
from june.groups.company import Company, Companies


default_data_path = (
    Path(os.path.abspath(__file__)).parent.parent.parent.parent
    / "data/processed/census_data/company_data/"
)


@pytest.fixture(name="super_area_companies", scope="module")
def create_geography():
    g = Geography.from_file(filter_key={"super_area": ["E02002559"]})
    return g.super_areas.members[0]


@pytest.fixture(name="person")
def create_person():
    return Person.from_attributes(sex="m", age=44)


class TestCompany:
    @pytest.fixture(name="company")
    def create_company(self, super_area_companies):
        return Company(super_area=super_area_companies, n_workers_max=115, sector="Q")

    def test__company_grouptype(self, company):
        assert company.SubgroupType.workers == 0

    def test__empty_company(self, company):
        assert len(company.people) == 0

    def test__filling_company(self, person, company):
        company.add(person)
        assert list(company.people)[0] == person

    def test__person_is_employed(self, person, company):
        company.add(person)
        assert (
            person.primary_activity == company.subgroups[company.SubgroupType.workers]
        )


@pytest.fixture(name="companies_example")
def create_companies(super_area_companies):
    companies = Companies.for_super_areas([super_area_companies])
    return companies


def test__company_sizes(companies_example):
    assert len(companies_example) == 450
    sizes_dict = defaultdict(int)
    bins = [0, 10, 20, 50, 100, 250, 500, 1000, 1500]
    for company in companies_example:
        size = company.n_workers_max
        idx = np.searchsorted(bins, size) - 1
        sizes_dict[idx] += 1
    assert np.isclose(sizes_dict[0], 400, atol=10)
    assert np.isclose(sizes_dict[1], 30, atol=10)
    assert np.isclose(sizes_dict[2], 10, atol=10)
    assert np.isclose(sizes_dict[3], 0, atol=5)
    assert np.isclose(sizes_dict[4], 5, atol=6)


def test__company_ids(companies_example, super_area_companies):
    for company_id, company in companies_example.members_by_id.items():
        assert company.id == company_id
    for company in companies_example:
        assert company.super_area == super_area_companies


from june.demography.person import Person
from june.groups.care_home import CareHome
from june.groups.household import Household


class TestGroup:
    def test_group_types(self):
        group = Household()
        group.add(Person.from_attributes(), group.SubgroupType.adults)
        assert group[group.SubgroupType.adults].size == 1

    def test_ids(self):
        household_1 = Household()
        household_2 = Household()
        care_home_1 = CareHome(None, None, None)
        care_home_2 = CareHome(None, None, None)

        assert household_2.id == household_1.id + 1
        assert household_1.name == f"Household_{household_1.id:05d}"

        assert care_home_2.id == care_home_1.id + 1
        assert care_home_1.name == f"CareHome_{care_home_1.id:05d}"


from pathlib import Path
import pytest
import pandas as pd
from june.geography import Geography

from june.groups import Hospitals
from june.epidemiology.infection import InfectionSelector
from june.paths import data_path


path_pwd = Path(__file__)
dir_pwd = path_pwd.parent


@pytest.fixture(name="hospitals", scope="module")
def create_hospitals():
    return Hospitals.from_file(filename=data_path / "input/hospitals/trusts.csv")


@pytest.fixture(name="hospitals_df", scope="module")
def create_hospitals_df():
    return pd.read_csv(data_path / "input/hospitals/trusts.csv")


def test__total_number_hospitals_is_correct(hospitals, hospitals_df):
    assert len(hospitals.members) == len(hospitals_df)


@pytest.mark.parametrize("index", [2, 3])
def test__given_hospital_finds_itself_as_closest(hospitals, hospitals_df, index):

    closest_idx = hospitals.get_closest_hospitals_idx(
        hospitals_df[["latitude", "longitude"]].iloc[index].values, k=10
    )

    closest_hospital_idx = closest_idx[0]
    assert hospitals.members[closest_hospital_idx] == hospitals.members[index]


@pytest.fixture(name="selector", scope="module")
def create_selector():
    selector = InfectionSelector.from_file()
    selector.recovery_rate = 0.05
    selector.transmission_probability = 0.7
    return selector


class MockArea:
    def __init__(self, coordinates):
        self.coordinates = coordinates


def test__initialize_hospitals_from_geography():
    geography = Geography.from_file({"super_area": ["E02003282", "E02005560"]})
    hospitals = Hospitals.for_geography(geography)
    assert len(hospitals.members) == 2
    assert hospitals.members[1].super_area.name == "E02005560"
    assert hospitals.members[0].super_area.name == "E02003282"
    assert hospitals.members[1].n_beds + hospitals.members[1].n_icu_beds == 468 + 41
    assert hospitals.members[0].n_beds + hospitals.members[0].n_icu_beds == 2115 + 296
    assert hospitals.members[0].trust_code == "RAJ"


from june.groups import Household, Households
from june.demography import Person


def test__households_adding():
    household = Household()
    household2 = Household()
    household3 = Household()
    households1 = Households([household])
    households2 = Households([household2, household3])
    households3 = households1 + households2
    assert households3.members == [household, household2, household3]


def test__household_mates():

    house = Household()
    person1 = Person.from_attributes()
    house.add(person1, subgroup_type=house.SubgroupType.kids)
    assert house.residents[0] == person1
    person2 = Person.from_attributes()
    person3 = Person.from_attributes()
    house.add(person2)
    house.add(person3)
    assert person1 in person1.housemates
    assert person2 in person1.housemates
    assert person3 in person1.housemates


def test__being_visited_flag():
    house = Household()
    person = Person.from_attributes()
    assert not house.being_visited
    house.add(person, activity="leisure")
    assert house.being_visited
    house.being_visited = False
    house.add(person)
    assert not house.being_visited


from june.groups import Supergroup
from june.groups import Group
from june.demography import Person
from enum import IntEnum
import pytest
from june import paths
import itertools


from june.groups.group import make_subgroups

interaction_config = (
    paths.configs_path / "tests/groups/make_subgroups_test_interaction.yaml"
)


class MockGroup(Group):
    def __init__(self):
        super().__init__()


class MockSupergroup(Supergroup):
    venue_class = MockGroup

    def __init__(self, groups):
        super().__init__(groups)


@pytest.fixture(name="super_group_default", scope="module")
def make_supergroup_default():
    MockSupergroup.get_interaction(interaction_config)
    groups_list = [MockSupergroup.venue_class() for _ in range(10)]
    super_group_default = MockSupergroup(groups_list)
    return super_group_default


@pytest.fixture(name="super_group", scope="module")
def make_supergroup():
    MockSupergroup.get_interaction(interaction_config)
    groups_list = [MockSupergroup.venue_class() for _ in range(10)]
    super_group = MockSupergroup(groups_list)
    return super_group


def test__make_subgroups_defualt(super_group_default):
    assert super_group_default[0].subgroup_type == "Age"
    assert super_group_default[0].subgroup_bins == [0, 18, 60, 100]
    assert super_group_default[0].subgroup_labels == ["A", "B", "C"]


def test__make_subgroups(super_group):
    assert super_group[0].subgroup_type == "Age"
    assert super_group[0].subgroup_bins == [0, 18, 60, 100]
    assert super_group[0].subgroup_labels == ["A", "B", "C"]


def test_excel_cols():
    assert list(itertools.islice(make_subgroups.SubgroupParams().excel_cols(), 10)) == [
        "A",
        "B",
        "C",
        "D",
        "E",
        "F",
        "G",
        "H",
        "I",
        "J",
    ]


import pytest

from june.groups.travel import mode_of_transport as c
from june import paths

test_data_filename = paths.data_path / "census_data/commute.csv"


class TestModeOfTransport:
    def test__setup_with_a_description__check_index(self):
        mode_of_transport = c.ModeOfTransport(description="hello")

        assert mode_of_transport.description == "hello"

        index = mode_of_transport.index(headers=["hi", "hello", "whatsup"])

        assert index == 1

        with pytest.raises(AssertionError):
            mode_of_transport.index(headers=["hi", "hel2lo"])

    def test__equality_override(self):
        mode_of_transport = c.ModeOfTransport(description="hello")

        assert mode_of_transport == "hello"

    def test__load_from_file__uses_correct_values_from_configs(self):
        modes_of_transport = c.ModeOfTransport.load_from_file()
        assert len(modes_of_transport) == 11  # used to be 12 with unemployment
        assert "Work mainly at or from home" in modes_of_transport
        assert c.ModeOfTransport.load_from_file()[0] is modes_of_transport[0]

    def test__is_public(self):
        c.ModeOfTransport.load_from_file()
        bus = c.ModeOfTransport.with_description("Bus, minibus or coach")
        assert bus.is_public is True
        assert bus.is_private is False

        car = c.ModeOfTransport.with_description("Driving a car or van")
        assert car.is_public is False
        assert car.is_private is True


class TestRegionalGenerator:
    def test__total__sum_of_people_using_all_transports(self):
        weighted_modes = [(2, c.ModeOfTransport("car"))]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.total == 2

        weighted_modes = [
            (2, c.ModeOfTransport("car")),
            (4, c.ModeOfTransport("bus")),
            (1, c.ModeOfTransport("magic_carpet")),
        ]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.total == 7

    def test__modes__list_of_all_transports(self):
        weighted_modes = [(2, c.ModeOfTransport("car"))]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.modes == ["car"]

        weighted_modes = [
            (2, c.ModeOfTransport("car")),
            (4, c.ModeOfTransport("bus")),
            (1, c.ModeOfTransport("magic_carpet")),
        ]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.modes == ["car", "bus", "magic_carpet"]

    def test__weights__lists_people_per_transport_divided_by_total(self):
        weighted_modes = [(2, c.ModeOfTransport("car"))]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert (regional_gen.weights == [1]).all()

        weighted_modes = [
            (2, c.ModeOfTransport("car")),
            (4, c.ModeOfTransport("bus")),
            (1, c.ModeOfTransport("magic_carpet")),
        ]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert (regional_gen.weights == [2 / 7, 4 / 7, 1 / 7]).all()

    def test__weighted_choice__chooses_random_value_from_the_modes(self):
        weighted_modes = [(2, c.ModeOfTransport("car"))]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.weighted_random_choice() == "car"

        weighted_modes = [
            (2, c.ModeOfTransport("car")),
            (4, c.ModeOfTransport("bus")),
            (1, c.ModeOfTransport("magic_carpet")),
        ]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.weighted_random_choice() == "car" or "bus" or "magic_carpet"

    def test__weighted_choice__cant_choose_transports_with_0_people(self):
        weighted_modes = [
            (0, c.ModeOfTransport("car")),
            (4, c.ModeOfTransport("bus")),
            (0, c.ModeOfTransport("magic_carpet")),
        ]

        regional_gen = c.RegionalGenerator(
            area="test_area", weighted_modes=weighted_modes
        )

        assert regional_gen.weighted_random_choice() == "bus"
        assert regional_gen.weighted_random_choice() == "bus"
        assert regional_gen.weighted_random_choice() == "bus"


import pytest

from june.geography import Geography
from june.demography import Person
from june.groups import School, Schools
from recordclass import dataobject


class Activities(dataobject):
    residence: None
    primary_activity: None
    medical_facility: None
    commute: None
    rail_travel: None
    leisure: None

    def iter(self):
        return [getattr(self, activity) for activity in self.__fields__]


@pytest.fixture(name="geo_schools", scope="module")
def area_name():
    geography = Geography.from_file(filter_key={"super_area": ["E02004935"]})
    return geography


class TestSchool:
    @pytest.fixture(name="school")
    def create_school(self):
        return School(coordinates=(1.0, 1.0), n_pupils_max=467, age_min=6, age_max=8)

    def test__school_grouptype(self, school):
        assert school.SubgroupType.teachers == 0
        assert school.SubgroupType.students == 1

    def test__empty_school(self, school):
        assert len(school.teachers.people) == 0
        for subgroup in school.subgroups[1:]:
            assert len(subgroup.people) == 0

    def test__filling_school(self, school):
        person = Person(
            sex="f", age=7, subgroups=Activities(None, None, None, None, None, None)
        )

        school.add(person)
        assert bool(school.subgroups[2].people) is True


class TestSchools:
    def test__creating_schools_from_file(self, geo_schools):
        Schools.from_file(areas=geo_schools.areas)

    def test_creating_schools_for_areas(self, geo_schools):
        Schools.for_areas(geo_schools.areas)

    @pytest.fixture(name="schools", scope="module")
    def test__creating_schools_for_geography(self, geo_schools):
        return Schools.for_geography(geo_schools)

    def test__school_nr_for_geography(self, schools):
        assert len(schools) == 4

    def test__school_is_closest_to_itself(self, schools):
        school = schools.members[0]
        age = int(0.5 * (school.age_min + school.age_max))
        closest_school = schools.get_closest_schools(age, school.coordinates, 1)
        closest_school = schools.members[
            schools.school_agegroup_to_global_indices.get(age)[closest_school[0]]
        ]
        assert closest_school == school


from june.groups import Supergroup
from june.groups import Group
from june.demography import Person
from enum import IntEnum
import pytest


class MockSupergroup(Supergroup):
    def __init__(self, groups):
        super().__init__(groups)


class MockGroup(Group):
    class SubgroupType(IntEnum):
        A = 0
        B = 1

    def __init__(self):
        super().__init__()


@pytest.fixture(name="super_group", scope="module")
def make_supergroup():
    groups_list = [MockGroup() for _ in range(10)]
    super_group = MockSupergroup(groups_list)
    return super_group


def test__create_supergroup(super_group):
    assert len(super_group) == 10
    assert super_group.group_type == "MockSupergroup"
    return super_group


import pytest
import numpy as np

from june.geography import Geography
from june.groups.travel import Travel, ModeOfTransport
from june import World
from june.demography import Person, Population


@pytest.fixture(name="geo", scope="module")
def make_sa():
    return Geography.from_file({"super_area": ["E02001731", "E02005123"]})


@pytest.fixture(name="travel_world", scope="module")
def make_commuting_network(geo):
    world = World()
    world.areas = geo.areas
    world.super_areas = geo.super_areas
    people = []
    for i in range(1200):
        person = Person.from_attributes()
        person.mode_of_transport = ModeOfTransport(is_public=True, description="asd")
        person.work_super_area = world.super_areas.members_by_name["E02001731"]
        world.super_areas[0].workers.append(person)
        if i % 4 == 0:
            # these people commute internally
            person.area = world.super_areas.members_by_name["E02001731"].areas[0]
        else:
            # these people come from abroad
            person.area = world.super_areas.members_by_name["E02005123"].areas[0]
        people.append(person)
    world.people = Population(people)
    travel = Travel()
    travel.initialise_commute(world, maximum_number_commuters_per_city_station=150)
    return world, travel


class TestCommute:
    def test__generate_commuting_network(self, travel_world):
        world, travel = travel_world
        assert len(world.cities) == 1
        city = world.cities[0]
        assert city.name == "Newcastle upon Tyne"
        assert city.super_areas[0] == "E02001731"
        assert len(city.city_stations) == 2
        assert len(city.inter_city_stations) == 4
        for super_area in world.super_areas:
            if super_area.name == "E02001731":
                assert super_area.city.name == "Newcastle upon Tyne"
            else:
                assert super_area.city is None

    def test__assign_commuters_to_stations(self, travel_world):
        world, travel = travel_world
        city = world.cities[0]
        n_external_commuters = 0
        n_internal_commuters = len(city.internal_commuter_ids)
        for station in city.inter_city_stations:
            n_external_commuters += len(station.commuter_ids)
        assert n_internal_commuters == 300
        assert n_external_commuters == 900

    def test__get_travel_subgroup(self, travel_world):
        world, travel = travel_world
        # get internal commuter
        worker = world.people[0]
        subgroup = travel.get_commute_subgroup(worker)
        assert subgroup.group.spec == "city_transport"
        # extenral
        worker = world.people[1]
        subgroup = travel.get_commute_subgroup(worker)
        assert subgroup.group.spec == "inter_city_transport"

    def test__number_of_commuters(self, travel_world):
        world, travel = travel_world
        public_transports = 0
        for person in world.people:
            if (
                person.mode_of_transport is not None
                and person.mode_of_transport.is_public
            ):
                if (
                    person.work_super_area.city is not None
                    and person.work_super_area.city.has_stations
                ):
                    public_transports += 1
        commuters = 0
        for city in world.cities:
            commuters += len(city.internal_commuter_ids)
            for station in city.inter_city_stations:
                commuters += len(station.commuter_ids)
        assert public_transports == commuters

    def test__all_commuters_get_commute(self, travel_world):
        world, travel = travel_world
        assigned_commuters = 0
        for person in world.people:
            subgroup = travel.get_commute_subgroup(person)
            if subgroup is not None:
                assigned_commuters += 1
        commuters = 0
        for city in world.cities:
            commuters += len(city.internal_commuter_ids)
            for station in city.inter_city_stations:
                commuters += len(station.commuter_ids)
        assert commuters > 0
        assert commuters == assigned_commuters

    def test__number_of_transports(self, travel_world):
        world, travel = travel_world
        newcastle = world.cities.get_by_name("Newcastle upon Tyne")
        seats_per_passenger = 2.28
        seats_per_train = 50

        n_city_transports = sum(
            [len(station.city_transports) for station in newcastle.city_stations]
        )
        assert n_city_transports > 0
        n_city_commuters = len(newcastle.internal_commuter_ids)
        assert n_city_commuters > 0
        assert (
            np.ceil(n_city_commuters * seats_per_passenger / seats_per_train)
            == n_city_transports
        )
        n_inter_city_transports = sum(
            len(station.inter_city_transports)
            for station in newcastle.inter_city_stations
        )
        assert n_inter_city_transports > 0
        n_inter_city_commuters = sum(
            len(station.commuter_ids) for station in newcastle.inter_city_stations
        )
        assert n_inter_city_commuters > 0
        assert (
            np.ceil(n_inter_city_commuters * seats_per_passenger / seats_per_train)
            == n_inter_city_transports
        )


# import pytest
#
# from june import World
# from june.geography import Geography, Area
# from june.demography import Person, Demography
# from june.distributors import WorkerDistributor
# from june.commute import CommuteGenerator
# from june.groups import (
#    CommuteCity,
#    CommuteCities,
#    CommuteCityDistributor,
#    CommuteHub,
#    CommuteHubs,
#    CommuteHubDistributor,
#    CommuteUnit,
#    CommuteUnits,
#    CommuteUnitDistributor,
#    CommuteCityUnit,
#    CommuteCityUnits,
#    CommuteCityUnitDistributor,
# )
# from june.groups import (
#    TravelCity,
#    TravelCities,
#    TravelCityDistributor,
#    TravelUnit,
#    TravelUnits,
#    TravelUnitDistributor,
# )
# from june.world import generate_world_from_geography
#
#
# class TestTravel:
#    @pytest.fixture(name="super_area_commute_nc")
#    def super_area_name_nc(self):
#        # return ['E02001731', 'E02001729', 'E02001688', 'E02001689', 'E02001736',
#        #        'E02001720', 'E02001724', 'E02001730', 'E02006841', 'E02001691',
#        #        'E02001713', 'E02001712', 'E02001694', 'E02006842', 'E02001723',
#        #        'E02001715', 'E02001710', 'E02001692', 'E02001734', 'E02001709']
#        return ["E02001731", "E02001729"]
#
#    @pytest.fixture(name="geography_commute_nc")
#    def create_geography_nc(self, super_area_commute_nc):
#        geography = Geography.from_file({"super_area": super_area_commute_nc})
#        return geography
#
#    @pytest.fixture(name="world_nc")
#    def create_world_nc(self, geography_commute_nc):
#        world = generate_world_from_geography(
#            geography_commute_nc, include_households=False, include_commute=False
#        )
#
#        return world
#
#    @pytest.fixture(name="commutecities_nc")
#    def create_commute_setup(self, world_nc):
#        commutecities = CommuteCities.for_super_areas(world_nc.super_areas)
#        assert len(commutecities.members) == 11
#
#        return commutecities
#
#    def test__travel_all(self, world_nc, commutecities_nc):
#        travelcities = TravelCities(commutecities_nc)
#        travelcities.init_cities()
#        assert len(travelcities.members) == 11
#
#        travelcity_distributor = TravelCityDistributor(
#            travelcities.members, world_nc.super_areas.members
#        )
#        travelcity_distributor.distribute_msoas()
#
#        travelunits = TravelUnits()
#        travelunit_distributor = TravelUnitDistributor(
#            travelcities.members, travelunits.members
#        )
#        travelunit_distributor.from_file()
#        travelunit_distributor.distribute_people_out()
#        assert len(travelunits.members) > 0
#
#        people = 0
#        for i in travelunits.members:
#            no_pass = i.no_passengers
#            people += no_pass
#
#        arrive = 0
#        for city in travelcities.members:
#            arrive += len(city.arrived)
#
#        assert people == arrive
#
#        travelunit_distributor.distribute_people_back()
#        assert len(travelunits.members) > 0
#
#        people = 0
#        for i in travelunits.members:
#            no_pass = i.no_passengers
#            people += no_pass
#
#        assert people == arrive


import numpy as np
from june.groups import University, Universities
from june.geography import Area, Areas, SuperArea, SuperAreas


def test__university_init():
    university = University(coordinates=np.array([1, 2]), n_students_max=500)
    assert (university.coordinates == np.array([1, 2])).all()
    assert university.n_students_max == 500


def test__university_for_super_areas():
    super_area = SuperArea(name="durham", areas=None, coordinates=[54.768, -1.571868])
    area = Area(
        name="durham_central", super_area=super_area, coordinates=super_area.coordinates
    )
    areas = Areas([area])
    super_area.areas = areas
    SuperAreas([super_area])
    unis = Universities.for_areas(areas)
    durham_uni = unis[0]
    assert durham_uni.n_students_max == 19025


from datetime import datetime
import numpy as np
from pytest import fixture
from june.groups import Household

from june.groups.leisure import (
    Leisure,
    generate_leisure_for_world,
    Pub,
    Pubs,
    Cinemas,
    Cinema,
    PubDistributor,
    CinemaDistributor,
)
from june.demography import Person


class MockArea:
    def __init__(self):
        pass


@fixture(name="leisure")
def make_leisure():
    pubs = Pubs([Pub()], make_tree=False)
    pub_distributor = PubDistributor(
        pubs,
        times_per_week={
            "weekday": {"male": {"18-50": 0.5}, "female": {"10-40": 0.3}},
            "weekend": {"male": {"18-50": 0.7}, "female": {"18-50": 0.4}},
        },
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    pubs[0].coordinates = [1, 2]
    cinemas = Cinemas([Cinema()], make_tree=False)
    cinemas[0].coordinates = [1, 2]
    cinema_distributor = CinemaDistributor(
        cinemas,
        times_per_week={
            "weekday": {"male": {"10-40": 0.1}, "female": {"10-40": 0.2}},
            "weekend": {"male": {"18-50": 0.4}, "female": {"18-50": 0.5}},
        },
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
        drags_household_probability=1.0,
    )
    leisure = Leisure(
        leisure_distributors={"pub": pub_distributor, "cinema": cinema_distributor}
    )
    return leisure


def _get_times_pub_cinema(leisure, person, day_type):
    if day_type == "weekend":
        delta_time = 0.125  # in reality is 0.5 but make it smaller for stats
        n_days = 8  # in reality is 2
    else:
        delta_time = 1 / 8
        n_days = 5
    leisure.generate_leisure_probabilities_for_timestep(
        delta_time,
        working_hours=False,
        date=datetime.strptime("2020-03-01", "%Y-%m-%d"),
    )
    times_goes_pub = []
    times_goes_cinema = []
    for _ in range(0, 5000):
        goes_pub = 0
        goes_cinema = 0
        for _ in range(n_days):  # one week
            subgroup = leisure.get_subgroup_for_person_and_housemates(person)
            if subgroup is None:
                continue
            if subgroup.group.spec == "pub":
                goes_pub += 1
            elif subgroup.group.spec == "cinema":
                goes_cinema += 1
            else:
                raise ValueError
        times_goes_pub.append(goes_pub)
        times_goes_cinema.append(goes_cinema)
    times_pub_a_week = np.mean(times_goes_pub)
    times_cinema_a_week = np.mean(times_goes_cinema)
    return times_pub_a_week, times_cinema_a_week


def test__probability_of_leisure(leisure):
    household = Household(type="student")
    male = Person.from_attributes(sex="m", age=26)
    male.area = MockArea()
    household.add(male)
    female = Person.from_attributes(sex="f", age=26)
    female.area = MockArea()
    household.add(female)
    male.area.social_venues = {
        "cinema": [leisure.leisure_distributors["cinema"].social_venues[0]],
        "pub": [leisure.leisure_distributors["pub"].social_venues[0]],
    }
    female.area.social_venues = {
        "cinema": [leisure.leisure_distributors["cinema"].social_venues[0]],
        "pub": [leisure.leisure_distributors["pub"].social_venues[0]],
    }
    # weekday male
    times_pub_a_week, times_cinema_a_week = _get_times_pub_cinema(
        person=male, leisure=leisure, day_type="weekday"
    )
    assert np.isclose(times_pub_a_week, 0.43, rtol=0.1)
    assert np.isclose(times_cinema_a_week, 0.23, rtol=0.1)
    # weekday female
    times_pub_a_week, times_cinema_a_week = _get_times_pub_cinema(
        person=female, leisure=leisure, day_type="weekday"
    )
    assert np.isclose(times_pub_a_week, 0.23, rtol=0.1)
    assert np.isclose(times_cinema_a_week, 0.3, rtol=0.1)
    # weekend male
    times_pub_a_week, times_cinema_a_week = _get_times_pub_cinema(
        person=male, leisure=leisure, day_type="weekend"
    )
    assert np.isclose(times_pub_a_week, 0.7, rtol=0.1)
    assert np.isclose(times_cinema_a_week, 0.4, rtol=0.1)
    # weekend female
    times_pub_a_week, times_cinema_a_week = _get_times_pub_cinema(
        person=female, leisure=leisure, day_type="weekend"
    )
    assert np.isclose(times_pub_a_week, 0.4, rtol=0.1)
    assert np.isclose(times_cinema_a_week, 0.5, rtol=0.1)


def test__person_drags_household(leisure):
    person1 = Person.from_attributes(sex="m", age=26)
    person2 = Person.from_attributes(sex="f", age=26)
    person3 = Person.from_attributes(sex="m", age=27)
    household = Household()
    household.add(person1)
    household.add(person2)
    household.add(person3)
    person2.busy = False
    person3.busy = False
    social_venue = leisure.leisure_distributors["cinema"].social_venues[0]
    social_venue.add(person1)
    leisure.leisure_distributors["cinema"].send_household_with_person_if_necessary(
        person1, None
    )
    for person in [person1, person2, person3]:
        assert person.subgroups.leisure == social_venue.subgroups[0]


def test__generate_leisure_from_world(dummy_world):
    world = dummy_world
    person = Person.from_attributes(sex="m", age=27)
    household = Household()
    household.area = world.areas[0]
    household.add(person)
    person.area = world.areas[0]
    leisure = generate_leisure_for_world(
        list_of_leisure_groups=["pubs", "cinemas", "groceries"],
        world=world,
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    leisure.generate_leisure_probabilities_for_timestep(
        0.1, False, datetime.strptime("2020-03-01", "%Y-%m-%d")
    )
    n_pubs = 0
    n_cinemas = 0
    n_groceries = 0
    for _ in range(0, 1000):
        subgroup = leisure.get_subgroup_for_person_and_housemates(person)
        if subgroup is not None:
            if subgroup.group.spec == "pub":
                n_pubs += 1
            elif subgroup.group.spec == "cinema":
                n_cinemas += 1
            elif subgroup.group.spec == "grocery":
                n_groceries += 1
    assert 0 <= n_pubs
    assert 0 <= n_cinemas
    assert 0 <= n_groceries


import pytest

from june.geography import Geography
from june.groups.leisure import Pubs


@pytest.fixture(name="geography")
def make_geography():
    geography = Geography.from_file({"super_area": ["E02005103"]})
    return geography


class TestPubs:
    def test__create_pubs_in_geography(self, geography):
        pubs = Pubs.for_geography(geography)
        assert len(pubs) == 179
        return pubs


import numpy as np
import pytest

from june.groups.leisure.residence_visits import ResidenceVisitsDistributor
from june.demography import Person
from june.groups import Household, CareHome, Pub, Company
from june.geography import SuperArea, SuperAreas, Area


@pytest.fixture(name="rv_distributor", scope="module")
def make_rvd():
    residence_visits_distributor = ResidenceVisitsDistributor(
        residence_type_probabilities={"household": 0.7, "care_home": 0.3},
        times_per_week={
            "weekday": {"male": {"0-100": 1}, "female": {"0-100": 1}},
            "weekend": {"male": {"0-100": 1}, "female": {"0-100": 1}},
        },
        hours_per_day={
            "weekday": {"male": {"0-100": 3}, "female": {"0-100": 3}},
            "weekend": {"male": {"0-100": 3}, "female": {"0-100": 3}},
        },
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    return residence_visits_distributor


@pytest.fixture(name="super_areas", scope="module")
def make_super_areas(rv_distributor):
    n_super_areas = 10
    n_areas_per_super_area = 5
    n_households_per_area = 10
    super_areas = []
    areas = []
    person = Person.from_attributes()
    for i in range(n_super_areas):
        areas_super_area = []
        for j in range(n_areas_per_super_area):
            area = Area(coordinates=[i, j])
            for _ in range(n_households_per_area):
                household = Household(type="family")
                household.add(person)
                area.households.append(household)
                household = Household(type="communal")
                household.add(person)
                area.households.append(household)
            area.care_home = CareHome(area=area)
            areas.append(area)
            areas_super_area.append(area)
        super_area = SuperArea(areas=areas_super_area, coordinates=[i, i])
        super_areas.append(super_area)
    super_areas = SuperAreas(super_areas)
    rv_distributor.link_households_to_households(super_areas)
    return super_areas


class TestResidenceVisitsDistributor:
    def test__get_residence_to_visit(self, rv_distributor):
        person = Person.from_attributes(age=8)
        household = Household(type="family")
        household2 = Household()
        household.add(person)
        household.residences_to_visit["household"] = (household2,)
        rv_distributor.get_leisure_group(person) == household2
        rv_distributor.get_leisure_subgroup(person) == household2[
            household2.SubgroupType.kids
        ]

    def test__no_visits_during_working_hours(self, rv_distributor):
        for sex in ["m", "f"]:
            for age in range(0, 100):
                for day_type in ["weekday", "weekend"]:
                    poisson_parameter = rv_distributor.get_poisson_parameter(
                        age=age, sex=sex, day_type=day_type, working_hours=True
                    )
                    assert poisson_parameter == 0
                    poisson_parameter = rv_distributor.get_poisson_parameter(
                        age=age, sex=sex, day_type=day_type, working_hours=False
                    )
                    assert poisson_parameter > 0


class TestHouseholdVisits:
    def test__household_linking(self, super_areas):
        has_visits = False
        for super_area in super_areas:
            for area in super_area.areas:
                for household in area.households:
                    has_visits = True
                    to_visit = [
                        residence
                        for residence in household.residences_to_visit["household"]
                    ]
                    assert len(to_visit) in range(2, 5)
        assert has_visits

    def test__visitors_stay_home_when_visited(self, rv_distributor):
        visitor = Person.from_attributes(age=20)
        resident1 = Person.from_attributes()
        resident2 = Person.from_attributes()
        household_visitor = Household(type="family")
        household_visitor.add(visitor)
        household_residents = Household(type="family")
        household_residents.add(resident1)
        household_residents.add(resident2)
        household_visitor.residences_to_visit["household"] = (household_residents,)
        # resident 1 is at the pub, he can go bakc home
        pub = Pub()
        pub.add(resident1)
        assert resident1.leisure == pub[0]
        # resident 2 is at the company, he can't go back home
        company = Company()
        company.add(resident2)
        assert resident2.primary_activity == company[0]
        subgroup = rv_distributor.get_leisure_subgroup(visitor)
        assert (
            subgroup
            == household_residents[household_residents.SubgroupType.young_adults]
        )
        assert resident1.leisure == resident1.residence
        assert resident2.leisure is None


class TestCareHomeVisits:
    def test__every_resident_has_one_relative(self, super_areas, rv_distributor):
        rv_distributor.link_households_to_care_homes(super_areas)
        has_visits = False
        for super_area in super_areas:
            for area in super_area.areas:
                for household in area.households:
                    if household.type in [
                        "student",
                        "young_adults",
                        "old",
                        "other",
                        "communal",
                    ]:
                        assert "care_home" not in household.residences_to_visit
                    elif household.type in ["family", "ya_parents", "nokids"]:
                        has_visits = True
                        assert (
                            "care_home" not in household.residences_to_visit
                            or len(household.residences_to_visit["care_home"]) <= 2
                        )
                        if "care_home" in household.residences_to_visit:
                            # for now we only allow household -> care_home
                            for link in household.residences_to_visit["care_home"]:
                                assert link.spec == "care_home"
                    else:
                        raise ValueError
        assert has_visits

    def test__type_probabilities(self, rv_distributor):
        visitor = Person.from_attributes(age=20)
        household = Household(type="family")
        household2 = Household(type="family")
        care_home = CareHome()
        household.add(visitor)
        household.residences_to_visit = {
            "household": (household2,),
            "care_home": (care_home,),
        }
        gets_household = 0
        gets_care_home = 0
        for _ in range(500):
            tovisit = rv_distributor.get_leisure_group(visitor)
            assert tovisit in [household2, care_home]
            if tovisit == household2:
                gets_household += 1
            else:
                gets_care_home += 1
        total = gets_care_home + gets_household
        assert np.isclose(gets_household / total, 0.7, rtol=0.1)
        assert np.isclose(gets_care_home / total, 0.3, rtol=0.1)


import numpy as np
from june.groups.leisure import SocialVenues
from june.geography import Geography


def test__social_venue_from_coordinates():
    super_areas = ["E02004935", "E02004940"]
    geo = Geography.from_file({"super_area": super_areas})
    coordinate_list = np.array([[51.752179, -0.334667], [51.741485, -0.336645]])
    social_venues = SocialVenues.from_coordinates(
        coordinate_list, super_areas=geo.super_areas
    )
    social_venues.add_to_areas(geo.areas)
    assert len(social_venues) == 2
    assert social_venues[0].super_area == geo.super_areas[0]
    assert social_venues[1].super_area == geo.super_areas[1]


def test__get_closest_venues():
    coordinate_list = np.array([[51.752179, -0.334667], [51.741485, -0.336645]])

    social_venues = SocialVenues.from_coordinates(coordinate_list, super_areas=None)
    social_venues.make_tree()
    venue = social_venues.get_closest_venues([50, 0])[0]
    assert venue == social_venues[1]

    venues_in_radius = social_venues.get_venues_in_radius([51.7, -0.33], 10)
    assert venues_in_radius[0] == social_venues[1]
    assert venues_in_radius[1] == social_venues[0]


from june.interaction import Interaction
from june.groups import School
from june.demography import Person
from june import paths
from june.geography import Geography
from june.world import generate_world_from_geography
from june.epidemiology.infection_seed import InfectionSeed
from june.policy import Policies
from june.simulator import Simulator

import pytest
import numpy as np
import pandas as pd
import pathlib


test_config = paths.configs_path / "tests/interaction.yaml"
default_sector_beta_filename = (
    paths.configs_path / "defaults/interaction/sector_beta.yaml"
)


class TestInteractionFunctions:
    def test__contact_matrices_from_default(self):
        interaction = Interaction.from_file(config_filename=test_config)
        np.testing.assert_allclose(
            interaction.contact_matrices["pub"],
            np.array([[3 * (1 + 0.12) * 24 / 3]]),
            rtol=0.05,
        )

    def test__create_infector_tensor(self):
        infectors_per_infection_per_subgroup = {
            1: {2: {"ids": [1, 2, 3], "trans_probs": [0.1, 0.2, 0.3]}},
            2: {
                0: {"ids": [4, 5], "trans_probs": [0.4, 0.5]},
                1: {"ids": [6], "trans_probs": [0.8]},
            },
        }
        subgroup_sizes = [3, 5, 7]
        interaction = Interaction.from_file(config_filename=test_config)
        contact_matrix = np.array([[1, 0, 1], [1, 1, 1], [0, 1, 0]])
        infector_tensor = interaction.create_infector_tensor(
            infectors_per_infection_per_subgroup, subgroup_sizes, contact_matrix, 1, 1
        )
        expected = np.array([[0, 0, 0.6 / 7], [0, 0, 0.6 / 7], [0.0, 0.0, 0.0]])
        assert np.allclose(infector_tensor[1], expected)
        expected = np.array([[0.9 / 2, 0, 0], [0.9 / 3, 0.8 / 4, 0], [0, 0.8 / 5, 0]])
        assert np.allclose(infector_tensor[2], expected)

    def test__gets_infected(self):
        interaction = Interaction.from_file(config_filename=test_config)
        probs = np.array([0.1, 0.2, 0.3])
        possible_infections = [1, 2, 3]
        infections = []
        misses = 0
        n = 10000
        for _ in range(n):
            infection = interaction._gets_infected(probs, possible_infections)
            if infection is None:
                misses += 1
                continue
            infections.append(infection)
        infections = np.array(infections)
        misses_exp = np.exp(-0.6)
        assert np.isclose(misses, misses_exp * n, rtol=0.1)
        assert np.isclose(
            len(infections[infections == 1]), 0.1 / 0.6 * n * (1 - misses_exp), rtol=0.1
        )
        assert np.isclose(
            len(infections[infections == 2]), 0.2 / 0.6 * n * (1 - misses_exp), rtol=0.1
        )
        assert np.isclose(
            len(infections[infections == 3]), 0.3 / 0.6 * n * (1 - misses_exp), rtol=0.1
        )

    def test__blame_subgroup(self):
        interaction = Interaction.from_file(config_filename=test_config)
        probs = np.array([20, 30, 100])
        blames = []
        n = 10000
        for _ in range(n):
            blame = interaction._blame_subgroup(probs)
            blames.append(blame)
        blames = np.array(blames)
        assert np.isclose(len(blames[blames == 0]), 20 / 150 * n, rtol=0.1)
        assert np.isclose(len(blames[blames == 1]), 30 / 150 * n, rtol=0.1)
        assert np.isclose(len(blames[blames == 2]), 100 / 150 * n, rtol=0.1)

    def test__blame_individuals(self):
        interaction = Interaction.from_file(config_filename=test_config)
        infectors_per_infection_per_subgroup = {
            "a": {
                0: {"ids": [1, 2, 3], "trans_probs": [0.1, 0.2, 0.3]},
                1: {"ids": [4], "trans_probs": [0.4]},
            },
            "b": {2: {"ids": [5, 6, 7], "trans_probs": [0.1, 0.2, 0.3]}},
        }
        infection_ids = ["a", "b"]
        to_blame_subgroups = [0, 2]
        blames1 = []
        blames2 = []
        n = 1000
        for i in range(n):
            to_blame_ids = interaction._blame_individuals(
                to_blame_subgroups, infection_ids, infectors_per_infection_per_subgroup
            )
            blames1.append(to_blame_ids[0])
            blames2.append(to_blame_ids[1])
        blames1 = np.array(blames1)
        blames2 = np.array(blames2)
        assert np.isclose(len(blames1[blames1 == 1]), 0.1 / 0.6 * n, rtol=0.1)
        assert np.isclose(len(blames1[blames1 == 2]), 0.2 / 0.6 * n, rtol=0.1)
        assert np.isclose(len(blames1[blames1 == 3]), 0.3 / 0.6 * n, rtol=0.1)
        assert np.isclose(len(blames2[blames2 == 5]), 0.1 / 0.6 * n, rtol=0.1)
        assert np.isclose(len(blames2[blames2 == 6]), 0.2 / 0.6 * n, rtol=0.1)
        assert np.isclose(len(blames2[blames2 == 7]), 0.3 / 0.6 * n, rtol=0.1)


def days_to_infection(interaction, susceptible_person, group, people, n_students):
    delta_time = 1 / 24
    days_to_infection = 0
    while not susceptible_person.infected and days_to_infection < 100:
        for person in people[:n_students]:
            group.subgroups[1].append(person)
        for person in people[n_students:]:
            group.subgroups[0].append(person)
        infected_ids, _, _ = interaction.time_step_for_group(
            group=group, delta_time=delta_time
        )
        if susceptible_person.id in infected_ids:
            break
        days_to_infection += delta_time
        group.clear()
    return days_to_infection


def create_school(n_students, n_teachers):
    school = School(
        n_pupils_max=n_students,
        age_min=6,
        age_max=6,
        coordinates=(1.0, 1.0),
        sector="primary_secondary",
    )
    people = []
    # create students
    for _ in range(n_students):
        person = Person.from_attributes(sex="f", age=6)
        school.add(person)
        people.append(person)
    for _ in range(n_teachers):
        person = Person.from_attributes(sex="m", age=40)
        school.add(person)
        people.append(person)
    assert len(people) == n_students + n_teachers
    assert len(school.people) == n_students + n_teachers
    assert len(school.subgroups[1].people) == n_students
    assert len(school.subgroups[0].people) == n_teachers
    return people, school


@pytest.mark.parametrize(
    "n_teachers,mode", [[2, "average"], [4, "average"], [6, "average"]]
)
def test__average_time_to_infect(n_teachers, mode, selector):
    transmission_probability = 0.1
    n_students = 1
    contact_matrices = {
        "contacts": [[n_teachers - 1, 1], [1, 0]],
        "proportion_physical": [[0, 0], [0, 0]],
        "xi": 1.0,
        "characteristic_time": 24,
    }
    interaction = Interaction(
        betas={"school": 1},
        alpha_physical=1,
        contact_matrices={"school": contact_matrices},
    )
    n_days = []
    for _ in range(100):
        people, school = create_school(n_students, n_teachers)
        for student in people[:n_students]:
            selector.infect_person_at_time(student, time=0)
        for teacher in people[n_students : n_students + n_teachers - 1]:
            selector.infect_person_at_time(teacher, time=0)
        school.clear()
        teacher = people[-1]
        n_days.append(
            days_to_infection(interaction, teacher, school, people, n_students)
        )
    teacher_teacher = transmission_probability * (n_teachers - 1)
    student_teacher = transmission_probability / n_students
    np.testing.assert_allclose(
        np.mean(n_days), 1.0 / (teacher_teacher + student_teacher), rtol=0.1
    )


def test__infection_is_isolated(epidemiology, selectors):
    geography = Geography.from_file({"area": ["E00002559"]})
    world = generate_world_from_geography(geography, include_households=True)
    interaction = Interaction.from_file(config_filename=test_config)
    infection_seed = InfectionSeed.from_uniform_cases(
        world,
        selectors[0],
        cases_per_capita=5 / len(world.people),
        date="2020-03-01",
        seed_past_infections=False,
    )
    infection_seed.unleash_virus_per_day(date=pd.to_datetime("2020-03-01"), time=0)
    policies = Policies([])
    n_infected = len([person for person in world.people if person.infected])
    simulator = Simulator.from_file(
        world=world,
        interaction=interaction,
        epidemiology=epidemiology,
        config_filename=pathlib.Path(__file__).parent.absolute()
        / "interaction_test_config.yaml",
        leisure=None,
        policies=policies,
        # save_path=None,
    )
    assert np.isclose(n_infected, 5, rtol=0.2)
    infected_households = []
    for household in world.households:
        infected = False
        for person in household.people:
            if person.infected:
                infected = True
                break
        if infected:
            infected_households.append(household)
    assert len(infected_households) <= n_infected
    simulator.run()
    for person in world.people:
        if person.residence is None:
            assert person.dead
        elif not (person.residence.group in infected_households):
            assert not person.infected


def test__super_spreaders(selector):
    people, school = create_school(n_students=5, n_teachers=1000)
    student_ids = [student.id for student in school.students]
    teacher_ids = [teacher.id for teacher in school.teachers]
    transmission_probabilities = np.linspace(0, 30, len(student_ids))
    total = sum(transmission_probabilities)
    id_to_trans = {}
    for i, student in enumerate(school.students):
        selector.infect_person_at_time(student, time=0)
        student.infection.transmission.probability = transmission_probabilities[i]
        id_to_trans[student.id] = transmission_probabilities[i]
    interactive_school = school.get_interactive_group()
    interaction = Interaction.from_file(config_filename=test_config)
    beta = interaction._get_interactive_group_beta(interactive_school)
    contact_matrix_raw = interaction.contact_matrices["school"]
    contact_matrix = interactive_school.get_processed_contact_matrix(contact_matrix_raw)
    infector_tensor = interaction.create_infector_tensor(
        interactive_school.infectors_per_infection_per_subgroup,
        interactive_school.subgroup_sizes,
        contact_matrix,
        beta,
        delta_time=1,
    )
    (
        subgroup_infected_ids,
        subgroup_infection_ids,
        to_blame_subgroups,
    ) = interaction._time_step_for_subgroup(
        infector_tensor=infector_tensor,
        susceptible_subgroup_id=0,
        subgroup_susceptibles=interactive_school.susceptibles_per_subgroup[0],
    )
    to_blame_ids = interaction._blame_individuals(
        to_blame_subgroups,
        subgroup_infection_ids,
        interactive_school.infectors_per_infection_per_subgroup,
    )
    for id in subgroup_infected_ids:
        assert id in teacher_ids
    for id in to_blame_ids:
        assert id in student_ids
    n_infections = len(subgroup_infected_ids)
    assert n_infections > 0
    culpable_ids, culpable_counts = np.unique(to_blame_ids, return_counts=True)
    for culpable_id, culpable_count in zip(culpable_ids, culpable_counts):
        expected = (id_to_trans[culpable_id] / total * n_infections,)
        assert np.isclose(culpable_count, expected, rtol=0.25)


import numpy as np
from copy import deepcopy

from june.geography import Area, SuperArea, Region
from june.demography.person import Person
from june.groups import (
    Hospital,
    School,
    Pub,
    InteractiveSchool,
    Company,
    InteractiveCompany,
    Household,
)
from june.interaction import Interaction
from june.groups.school import _translate_school_subgroup
from june.groups.group.interactive import InteractiveGroup
from june import paths

test_config = paths.configs_path / "tests/interaction.yaml"


class TestInteractiveGroup:
    def test__substract_information_from_group(self, selector):
        hospital = Hospital(n_beds=None, n_icu_beds=None)
        person1 = Person.from_attributes()
        person2 = Person.from_attributes()
        person3 = Person.from_attributes()
        person4 = Person.from_attributes()
        person1.immunity.susceptibility_dict[1] = 0.1
        person2.immunity.susceptibility_dict[2] = 0.2
        person3.immunity.susceptibility_dict[3] = 0.3
        person4.immunity.susceptibility_dict[4] = 0.4
        hospital.add(person1, subgroup_type=0)
        hospital.add(person2, subgroup_type=0)
        hospital.add(person3, subgroup_type=1)
        hospital.add(person4, subgroup_type=2)
        selector.infect_person_at_time(person1, 1)
        person1.infection.update_health_status(5, 5)
        interactive_group = InteractiveGroup(hospital)
        assert (
            person1.infection.infection_id()
            in interactive_group.infectors_per_infection_per_subgroup
        )
        assert interactive_group.infectors_per_infection_per_subgroup[
            person1.infection.infection_id()
        ][0]["ids"] == [person1.id]
        assert interactive_group.infectors_per_infection_per_subgroup[
            person1.infection.infection_id()
        ][0]["trans_probs"] == [person1.infection.transmission.probability]
        assert len(interactive_group.susceptibles_per_subgroup[0]) == 1
        assert interactive_group.susceptibles_per_subgroup[0][person2.id][2] == 0.2
        assert len(interactive_group.susceptibles_per_subgroup[1]) == 1
        assert interactive_group.susceptibles_per_subgroup[1][person3.id][3] == 0.3
        assert len(interactive_group.susceptibles_per_subgroup[2]) == 1
        assert interactive_group.susceptibles_per_subgroup[2][person4.id][4] == 0.4
        assert interactive_group.must_timestep

    def test__no_timestep(self, selector):
        hospital = Hospital(n_beds=None, n_icu_beds=None)
        person1 = Person.from_attributes()
        person2 = Person.from_attributes()
        person3 = Person.from_attributes()
        person4 = Person.from_attributes()
        person1.immunity.susceptibility_dict[1] = 0.1
        person2.immunity.susceptibility_dict[2] = 0.2
        person3.immunity.susceptibility_dict[3] = 0.3
        person4.immunity.susceptibility_dict[4] = 0.4
        hospital.add(person1, subgroup_type=0)
        hospital.add(person2, subgroup_type=0)
        hospital.add(person3, subgroup_type=1)
        hospital.add(person4, subgroup_type=2)
        interactive_group = InteractiveGroup(hospital)
        assert interactive_group.has_susceptible
        assert not interactive_group.has_infectors
        assert not interactive_group.must_timestep

        hospital.clear()
        hospital.add(person1, subgroup_type=0)
        selector.infect_person_at_time(person1, 1)
        person1.infection.update_health_status(5, 5)
        interactive_group = InteractiveGroup(hospital)
        assert not interactive_group.has_susceptible
        assert interactive_group.has_infectors
        assert not interactive_group.must_timestep


class TestDispatchOnGroupSpec:
    def test__dispatch(self):
        pub = Pub()
        interactive_pub = pub.get_interactive_group()
        assert interactive_pub.__class__ == InteractiveGroup
        school = School()
        interactive_school = school.get_interactive_group()
        assert interactive_school.__class__ == InteractiveSchool
        assert isinstance(interactive_school, InteractiveGroup)


class TestInteractiveSchool:
    def test__school_index_translation(self):
        age_min = 3
        age_max = 7
        school_years = tuple(range(age_min, age_max + 1))
        _translate_school_subgroup(1, school_years) == 4
        _translate_school_subgroup(5, school_years) == 8

    def test__school_contact_matrices(self):
        interaction = Interaction.from_file(config_filename=test_config)
        xi = 0.3
        age_min = 3
        age_max = 7
        school_years = tuple(range(age_min, age_max + 1))
        school = School(age_min=age_min, age_max=age_max)
        int_school = InteractiveSchool(school)
        int_school.school_years = school_years
        contact_matrix = interaction.contact_matrices["school"]
        contact_matrix = int_school.get_processed_contact_matrix(contact_matrix)
        n_contacts_same_year = contact_matrix[4, 4]
        assert n_contacts_same_year == 2.875 * 3

        n_contacts_year_above = contact_matrix[4, 5]
        assert n_contacts_year_above == xi * 2.875 * 3

        n_contacts_teacher_teacher = contact_matrix[0, 0]
        assert n_contacts_teacher_teacher == 5.25 * 3

        n_contacts_teacher_student = contact_matrix[0, 4]

        np.isclose(
            n_contacts_teacher_student, (16.2 * 3 / len(school_years)), rtol=1e-6
        )

        n_contacts_student_teacher = contact_matrix[4, 0]
        assert n_contacts_student_teacher == 0.81 * 3 / len(school_years)

    def test__school_contact_matrices_different_classroom(self):
        interaction_instance = Interaction.from_file(config_filename=test_config)
        age_min = 3
        age_max = 7
        school_years = (3, 4, 4, 5)
        school = School(age_min=age_min, age_max=age_max)
        school.years = school_years
        int_school = InteractiveSchool(school)
        int_school.school_years = school_years
        contact_matrix = interaction_instance.contact_matrices["school"]
        contact_matrix = int_school.get_processed_contact_matrix(contact_matrix)
        n_contacts_same_year = contact_matrix[2, 3]
        n_contacts_same_class = contact_matrix[2, 2]
        assert np.isclose(n_contacts_same_year, n_contacts_same_class / 4)

    def test__contact_matrix_full(self):
        xi = 0.3
        interaction = Interaction.from_file(config_filename=test_config)
        contacts_school = interaction.contact_matrices["school"]
        for i in range(len(contacts_school)):
            for j in range(len(contacts_school)):
                if i == j:
                    if i == 0:
                        assert contacts_school[i][j] == 5.25 * 3  # 24 / 8
                    else:
                        assert contacts_school[i][j] == 2.875 * 3
                else:
                    if i == 0:
                        assert np.isclose(contacts_school[i][j], 16.2 * 3, rtol=1e-6)
                    elif j == 0:
                        assert np.isclose(
                            contacts_school[i][j], 0.81 * 3, atol=0, rtol=1e-6
                        )
                    else:
                        assert np.isclose(
                            contacts_school[i][j],
                            xi ** abs(i - j) * 2.875 * 3,
                            atol=0,
                            rtol=1e-6,
                        )

    def test__social_distancing_primary_secondary(self):
        beta_reductions = {
            "school": 0.2,
            "primary_school": 0.3,
            "secondary_school": 0.4,
        }
        betas = {"school": 3.0}

        school = School(sector="primary")
        int_school = InteractiveSchool(school)
        processed_beta = int_school.get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        assert np.isclose(processed_beta, 3 * 0.3)

        school = School(sector="secondary")
        int_school = InteractiveSchool(school)
        processed_beta = int_school.get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        assert np.isclose(processed_beta, 3 * 0.4)

        school = School(sector="primary_secondary")
        int_school = InteractiveSchool(school)
        processed_beta = int_school.get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        assert np.isclose(processed_beta, 3 * 0.4)

        school = School(sector=None)
        int_school = InteractiveSchool(school)
        processed_beta = int_school.get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        assert np.isclose(processed_beta, 3 * 0.2)


class TestInteractiveCompany:
    def test__sector_beta(self):
        bkp = deepcopy(InteractiveCompany.sector_betas)
        InteractiveCompany.sector_betas["R"] = 0.7
        company = Company(sector="R")
        interactive_company = company.get_interactive_group()
        betas = {"company": 2}
        beta_reductions = {}
        beta_processed = interactive_company.get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        assert beta_processed == 2 * 0.7
        company = Company(sector="Q")
        interactive_company = company.get_interactive_group()
        betas = {"company": 2}
        beta_reductions = {}
        beta_processed = interactive_company.get_processed_beta(
            betas=betas, beta_reductions=beta_reductions
        )
        assert beta_processed == 2
        InteractiveCompany.sector_betas = bkp


class TestInteractiveHousehold:
    def test__household_visits_beta(self):
        person = Person.from_attributes()
        region = Region()
        region.regional_compliance = 1.0
        super_area = SuperArea(region=region)
        area = Area(super_area=super_area)
        household = Household(area=area)
        betas = {"household": 1, "household_visits": 3}
        interactive_household = household.get_interactive_group()
        assert interactive_household.get_processed_beta(betas, beta_reductions={}) == 1
        household.add(person, activity="leisure")
        interactive_household = household.get_interactive_group()
        assert interactive_household.get_processed_beta(betas, beta_reductions={}) == 3

    def test__household_visits_social_distancing(self):
        region = Region()
        region.regional_compliance = 1.0
        super_area = SuperArea(region=region)
        area = Area(super_area=super_area)
        household = Household(area=area)
        person = Person.from_attributes()
        household.add(person)
        betas = {"household": 1, "household_visits": 2}
        beta_reductions = {"household": 0.5, "household_visits": 0.1}
        household.add(person)
        int_household = household.get_interactive_group()
        assert household.being_visited is False
        beta = int_household.get_processed_beta(betas, beta_reductions)
        assert beta == 0.5
        household = Household(area=area)
        household.add(person, activity="leisure")
        assert household.being_visited is True
        int_household = household.get_interactive_group()
        beta = int_household.get_processed_beta(betas, beta_reductions)
        assert np.isclose(beta, 0.2)


import pytest
import numpy as np
from random import randint

from june.interaction import Interaction
from june.demography import Population, Person
from june.groups import Company
from june.epidemiology.infection import Infection, TransmissionConstant


class TestSusceptibilityHasAnEffect:
    @pytest.fixture(name="simulation_setup")
    def setup_group(self):
        company = Company()
        population = Population([])
        n_kids = 50
        n_adults = 50
        for _ in range(n_kids):
            population.add(Person.from_attributes(age=randint(0, 12)))
        for _ in range(n_adults):
            population.add(Person.from_attributes(age=randint(13, 100)))
        for person in population:
            company.add(person)
        # infect one kid and one adult
        kid = population[0]
        assert kid.age <= 12
        adult = population[-1]
        assert adult.age >= 13
        kid.infection = Infection(
            symptoms=None, transmission=TransmissionConstant(probability=0.2)
        )
        adult.infection = Infection(
            symptoms=None, transmission=TransmissionConstant(probability=0.2)
        )
        return company, population

    def run_interaction(self, simulation_setup, interaction):
        """
        With uniform susc. number of infected adults and kids should be the same.
        """
        group, population = simulation_setup
        n_infected_adults_list = []
        n_infected_kids_list = []
        for _ in range(1000):
            infected_ids, _, group_size = interaction.time_step_for_group(
                group=group, delta_time=10
            )
            n_infected_adults = len(
                [
                    person
                    for person in population
                    if person.id in infected_ids and person.age <= 12
                ]
            )
            n_infected_kids = len(
                [
                    person
                    for person in population
                    if person.id in infected_ids and person.age > 12
                ]
            )
            n_infected_adults_list.append(n_infected_adults)
            n_infected_kids_list.append(n_infected_kids)
        return np.mean(n_infected_kids_list), np.mean(n_infected_adults_list)

    def test__run_uniform_susceptibility(self, simulation_setup):
        contact_matrices = {
            "company": {
                "contacts": [[1]],
                "proportion_physical": [[1]],
                "characteristic_time": 8,
            }
        }
        interaction = Interaction(
            betas={"company": 1}, alpha_physical=1.0, contact_matrices=contact_matrices
        )
        n_kids_inf, n_adults_inf = self.run_interaction(
            interaction=interaction, simulation_setup=simulation_setup
        )
        assert n_kids_inf > 0
        assert n_adults_inf > 0
        assert np.isclose(n_kids_inf, n_adults_inf, rtol=0.05)

    def test__run_different_susceptibility(self, simulation_setup):
        group, population = simulation_setup
        interaction = Interaction(
            betas={"company": 1}, alpha_physical=1.0, contact_matrices=None
        )
        for person in population:
            if person.age < 13:
                person.immunity.susceptibility_dict[Infection.infection_id()] = 0.5
        n_kids_inf, n_adults_inf = self.run_interaction(
            interaction=interaction, simulation_setup=simulation_setup
        )
        assert n_kids_inf > 0
        assert n_adults_inf > 0
        assert np.isclose(0.5 * n_kids_inf, n_adults_inf, rtol=0.05)


from datetime import datetime

import numpy as np
import pytest

from june.demography import Person
from june.geography import Area, SuperArea
from june.geography.geography import Region
from june.groups import School, Household
from june.epidemiology.infection import SymptomTag
from june.policy import (
    SevereSymptomsStayHome,
    CloseSchools,
    CloseCompanies,
    CloseUniversities,
    Quarantine,
    Shielding,
    Policies,
    Hospitalisation,
    LimitLongCommute,
    SchoolQuarantine,
)
from june.policy.individual_policies import CloseCompaniesLockdownTiers
from june.utils.distances import haversine_distance


def infect_person(person, selector, symptom_tag="mild"):
    selector.infect_person_at_time(person, 0.0)
    person.infection.symptoms.tag = getattr(SymptomTag, symptom_tag)
    if symptom_tag != "asymptomatic":
        person.infection.symptoms.time_of_symptoms_onset = 5.3
        person.residence.group.quarantine_starting_date = 5.3
    else:
        person.infection.symptoms.time_of_symptoms_onset = None


class TestSevereSymptomsStayHome:
    def test__policy_adults(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        permanent_policy = SevereSymptomsStayHome()
        policies = Policies([permanent_policy])
        sim.activity_manager.policies = policies
        sim.epidemiology.set_medical_care(world, sim.activity_manager)
        sim.clear_world()
        sim.activity_manager.move_people_to_active_subgroups(
            ("primary_activity", "residence")
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        infect_person(worker, selector, "severe")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        sim.activity_manager.move_people_to_active_subgroups(
            ("primary_activity", "residence")
        )
        assert worker in worker.residence.people
        assert pupil in pupil.primary_activity.people
        worker.infection = None
        sim.clear_world()

    def test__policy_adults_still_go_to_hospital(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        permanent_policy = SevereSymptomsStayHome()
        hospitalisation = Hospitalisation()
        policies = Policies([permanent_policy, hospitalisation])
        sim.activity_manager.policies = policies
        sim.epidemiology.set_medical_care(world, sim.activity_manager)
        sim.clear_world()
        sim.activity_manager.move_people_to_active_subgroups(
            ["medical_facility", "primary_activity", "residence"]
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        infect_person(worker, selector, "hospitalised")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        sim.activity_manager.move_people_to_active_subgroups(
            ["medical_facility", "primary_activity", "residence"]
        )
        assert worker in worker.medical_facility.people
        assert pupil in pupil.primary_activity.people
        worker.infection = None
        sim.clear_world()

    def test__default_policy_kids(self, selector, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        permanent_policy = SevereSymptomsStayHome()
        policies = Policies([permanent_policy])
        sim.activity_manager.policies = policies
        sim.epidemiology.set_medical_care(world, sim.activity_manager)
        sim.clear_world()
        sim.activity_manager.move_people_to_active_subgroups(
            ["primary_activity", "residence"]
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        assert student in student.primary_activity.people
        sim.clear_world()
        infect_person(pupil, selector, "severe")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        assert pupil.infection.tag == SymptomTag.severe
        sim.activity_manager.move_people_to_active_subgroups(
            ["primary_activity", "residence"]
        )
        assert pupil in pupil.residence.people
        has_guardian = False
        for person in [worker, student]:
            if person in person.residence.people:
                has_guardian = True
                break
        assert has_guardian
        sim.clear_world()


class TestClosure:
    def test__close_schools(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        school_closure = CloseSchools(
            start_time="2020-1-1", end_time="2020-10-1", years_to_close=[6]
        )
        policies = Policies([school_closure])
        sim.activity_manager.policies = policies
        sim.epidemiology.set_medical_care(world, sim.activity_manager)

        # non key worker
        worker.lockdown_status = "furlough"
        sim.clear_world()
        activities = ["primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=0,
        ) == ["residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert pupil in pupil.residence.people
        assert worker in worker.primary_activity.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=0,
        ) == ["primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

        # key worker
        worker.lockdown_status = "key_worker"
        student.lockdown_status = "key_worker"
        sim.clear_world()
        activities = ["primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=0,
        ) == ["primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=0,
        ) == ["primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__reopen_schools(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        school_closure = CloseSchools(
            start_time="2020-1-1", end_time="2020-10-1", attending_compliance=0.5
        )
        policies = Policies([school_closure])
        sim.activity_manager.policies = policies

        # non key worker
        worker.lockdown_status = "furlough"
        sim.clear_world()
        activities = ["primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )

        # Move the pupil 500 times for five days
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=pupil,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(2.5, rel=0.1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert worker in worker.primary_activity.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=0,
        ) == ["primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

        # key worker
        worker.lockdown_status = "key_worker"
        student.lockdown_status = "key_worker"
        sim.clear_world()
        activities = ["primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # Move the pupil 500 times for five days
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=pupil,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(5.0, rel=0.1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=0,
        ) == ["primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__close_universities(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        university_closure = CloseUniversities(
            start_time="2020-1-1", end_time="2020-10-1"
        )
        policies = Policies([university_closure])
        sim.activity_manager.policies = policies
        sim.clear_world()
        activities = ["primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert student in student.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=student,
            activities=activities,
            days_from_start=0,
        ) == ["residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert student in student.residence.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=student,
            activities=activities,
            days_from_start=0,
        ) == ["primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert student in student.primary_activity.people
        sim.clear_world()

    def test__close_companies_lockdown_tiers(self, setup_policy_world):
        live_region = Region()
        work_region = Region()
        world, pupil, student, worker, sim = setup_policy_world
        worker.lockdown_status = "random"
        live_area = Area()
        live_super_area = SuperArea(
            name="live_super_area",
            coordinates=[0, 3],
            region=live_region,
            areas=[live_area],
        )
        live_area.super_area = live_super_area
        live_region.super_areas = [live_super_area]
        work_super_area = SuperArea(
            name="work_super_area", coordinates=[0, 0], region=work_region
        )
        work_region.super_areas = [work_super_area]

        live_area.add(worker)
        worker.area = live_area
        work_super_area.add_worker(worker)
        company_closure_lockdown_tiers = CloseCompaniesLockdownTiers(
            start_time="2020-1-1", end_time="2020-10-1"
        )

        # Make sure apply removes correct activities
        activities_before = ["commute", "primary_activity"]
        assert company_closure_lockdown_tiers.apply(activities_before) == []

        # Work in tier 3 or 4 and live in a tier 1 or 2: skip activity
        live_tiers = [1, 2]
        work_tiers = [3, 4]
        for live_tier, work_tier in zip(live_tiers, work_tiers):
            worker.work_super_area.region.policy["lockdown_tier"] = work_tier
            worker.region.policy["lockdown_tier"] = live_tier
            worker.region.policy[
                "regional_compliance"
            ] = 1  # Want them to comply this time

            assert company_closure_lockdown_tiers.check_skips_activity(worker) is True

            worker.region.policy[
                "regional_compliance"
            ] = 0  # Want them to not comply this time
            assert company_closure_lockdown_tiers.check_skips_activity(worker) is False

        # Live in a tier 3 or 4 and work in another region: skip activity
        live_tiers = [3, 4, 3, 4]
        work_tiers = [1, 2, 3, 4]
        for live_tier, work_tier in zip(live_tiers, work_tiers):
            worker.work_super_area.region.policy["lockdown_tier"] = work_tier
            worker.region.policy["lockdown_tier"] = live_tier
            worker.region.policy[
                "regional_compliance"
            ] = 1  # Want them to comply this time

            assert company_closure_lockdown_tiers.check_skips_activity(worker) is True

            worker.region.policy[
                "regional_compliance"
            ] = 0  # Want them to not comply this time
            assert company_closure_lockdown_tiers.check_skips_activity(worker) is False

        # Live and work in a tier 1 or 2: Never skip
        live_tiers = [1, 2, 1, 2]
        work_tiers = [1, 1, 2, 2]
        for live_tier, work_tier in zip(live_tiers, work_tiers):
            worker.work_super_area.region.policy["lockdown_tier"] = work_tier
            worker.region.policy["lockdown_tier"] = live_tier
            worker.region.policy[
                "regional_compliance"
            ] = 1  # Want them to comply this time

            assert company_closure_lockdown_tiers.check_skips_activity(worker) is False

        # Live in a tier 3 or 4 AND work in same region: Do not skip as allowed to go to work
        live_tiers = [3, 4, 3, 4]
        work_tiers = [3, 4, 4, 4]
        live_super_area.add_worker(worker)
        for live_tier, work_tier in zip(live_tiers, work_tiers):
            worker.work_super_area.region.policy["lockdown_tier"] = work_tier
            worker.region.policy["lockdown_tier"] = live_tier
            worker.region.policy[
                "regional_compliance"
            ] = 1  # Want them to comply this time

            assert company_closure_lockdown_tiers.check_skips_activity(worker) is False

    def test__close_companies(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        company_closure = CloseCompanies(start_time="2020-1-1", end_time="2020-10-1")
        policies = Policies([company_closure])
        sim.activity_manager.policies = policies
        sim.clear_world()
        activities = ["commute", "primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        worker.lockdown_status = "furlough"
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert worker in worker.residence.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

        # no furlough
        sim.clear_world()
        activities = ["commute", "primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        worker.lockdown_status = "key_worker"
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__close_companies_frequency_of_randoms(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        company_closure = CloseCompanies(
            start_time="2020-1-1", end_time="2020-10-1", avoid_work_probability=0.2
        )
        policies = Policies([company_closure])
        sim.activity_manager.policies = policies
        sim.clear_world()
        activities = ["commute", "primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        worker.lockdown_status = "random"
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # Move the person 1_0000 times for five days
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.0, rel=0.1)
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(10):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 0.5
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.0, rel=0.1)

        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__close_companies_frequency_of_furlough_ratio(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        company_closure = CloseCompanies(
            start_time="2020-1-1",
            end_time="2020-10-1",
            furlough_probability=0.2,
            avoid_work_probability=0.2,
        )
        policies = Policies([company_closure])
        sim.activity_manager.policies = policies
        sim.clear_world()
        activities = ["commute", "primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        worker.lockdown_status = "furlough"
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # Move the person 1_0000 times for five days
        company_closure.furlough_ratio = 0.0
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(0.0, rel=0.1)
        company_closure.furlough_ratio = 0.1
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(0.0, rel=0.1)
        n_days_in_week = []
        company_closure.furlough_ratio = 0.4
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(2.0, rel=0.1)

        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__close_companies_frequency_of_key_ratio(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        company_closure = CloseCompanies(
            start_time="2020-1-1", end_time="2020-10-1", key_probability=0.2
        )
        policies = Policies([company_closure])
        sim.activity_manager.policies = policies
        sim.clear_world()
        activities = ["commute", "primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        worker.lockdown_status = "key_worker"
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # Move the person 1_0000 times for five days

        # Testing key_ratio and key_worker feature in random_ratio
        company_closure.key_ratio = 0.0
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(5.0, rel=0.1)
        company_closure.key_ratio = 0.1
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(5.0, rel=0.1)
        company_closure.key_ratio = 0.4
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(2.5, rel=0.1)

        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__close_companies_frequency_of_random_ratio(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        company_closure = CloseCompanies(
            start_time="2020-1-1",
            end_time="2020-10-1",
            avoid_work_probability=0.2,
            key_probability=0.2,
            furlough_probability=0.2,
        )
        policies = Policies([company_closure])
        sim.activity_manager.policies = policies
        sim.clear_world()
        activities = ["commute", "primary_activity", "residence"]
        time_before_policy = datetime(2019, 2, 1)
        worker.lockdown_status = "random"
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy
        )
        assert worker in worker.primary_activity.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # Move the person 1_0000 times for five days

        # Testing key_ratio feature in random_ratio
        company_closure.random_ratio = 1.0
        company_closure.key_ratio = 0.0
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.2, rel=0.1)
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.key_ratio = 0.1
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.1, rel=0.1)
        company_closure.random_ratio = 1.0
        company_closure.key_ratio = 0.2
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.0, rel=0.1)
        company_closure.key_ratio = 0.0
        company_closure.random_ratio = 1.0
        company_closure.furlough_ratio = 0.0
        # Testing furlough_ratio feature in random_ratio
        n_days_in_week = []
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(3.2, rel=0.1)
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.furlough_ratio = 0.1
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(3.6, rel=0.1)
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.furlough_ratio = 0.3
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.0, rel=0.1)

        # Testing furlough_ratio and key_ratio mixing feature in random_ratio
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.furlough_ratio = 0.0
        company_closure.key_ratio = 0.0
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(3.4, rel=0.1)
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.key_ratio = 0.0
        company_closure.furlough_ratio = 0.1
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(3.8, rel=0.1)
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.furlough_ratio = 0.1
        company_closure.key_ratio = 0.1
        for i in range(1000):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(3.7, rel=0.1)
        n_days_in_week = []
        company_closure.random_ratio = 1.0
        company_closure.furlough_ratio = 0.3
        company_closure.key_ratio = 0.3
        for i in range(500):
            n_days = 0
            for j in range(5):
                if "primary_activity" in policies.individual_policies.apply(
                    active_individual_policies,
                    person=worker,
                    activities=activities,
                    days_from_start=0,
                ):
                    n_days += 1.0
            n_days_in_week.append(n_days)
        assert np.mean(n_days_in_week) == pytest.approx(4.0, rel=0.1)

        sim.clear_world()
        time_after_policy = datetime(2030, 2, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_after_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["commute", "primary_activity", "residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy
        )
        assert pupil in pupil.primary_activity.people
        assert worker in worker.primary_activity.people
        sim.clear_world()

    def test__close_companies_full_closure(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        company_closure = CloseCompanies(
            start_time="2020-1-1", end_time="2020-10-1", full_closure=True
        )
        policies = Policies([company_closure])
        sim.activity_manager.policies = policies
        worker.lockdown_status = "key_worker"
        activities = ["commute", "primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        ) == ["residence"]
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert worker in worker.residence.people
        sim.clear_world()


class TestShielding:
    def test__old_people_shield(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        shielding = Shielding(start_time="2020-1-1", end_time="2020-10-1", min_age=30)
        policies = Policies([shielding])
        sim.activity_manager.policies = policies
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert "primary_activity" not in policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=0,
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy
        )
        assert worker in worker.residence.people
        assert pupil in pupil.primary_activity.people
        sim.clear_world()

    def test__old_people_shield_with_compliance(self, setup_policy_world):
        world, pupil, student, worker, _ = setup_policy_world
        shielding = Shielding(
            start_time="2020-1-1", end_time="2020-10-1", min_age=30, compliance=0.6
        )
        policies = Policies([shielding])
        activities = ["primary_activity", "residence"]
        time_during_policy = datetime(2020, 2, 1)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        compliant_days = 0
        for i in range(100):
            if "primary_activity" not in policies.individual_policies.apply(
                active_individual_policies,
                person=worker,
                activities=activities,
                days_from_start=0,
            ):
                compliant_days += 1

        assert compliant_days / 100 == pytest.approx(shielding.compliance, abs=0.1)


class TestQuarantine:
    def test__symptomatic_stays_for_one_week(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1", end_time="2020-1-30", n_days=7, n_days_household=14
        )
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "mild")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert "primary_activity" not in policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=6,
        )
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=20,
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy, 6.0
        )
        assert worker in worker.residence.people
        sim.clear_world()
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy, 20
        )
        assert worker in worker.primary_activity.people
        worker.infection = None
        sim.clear_world()

    def test__asymptomatic_is_free(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1", end_time="2020-1-30", n_days=7, n_days_household=14
        )
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "asymptomatic")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=worker,
            activities=activities,
            days_from_start=6.0,
        )
        worker.infection = None
        sim.clear_world()

    def test__housemates_stay_for_two_weeks(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1", end_time="2020-1-30", n_days=7, n_days_household=14
        )
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "mild")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        # before symptoms onset
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert "primary_activity" not in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=8.0,
        )
        # after symptoms onset
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy, 8.0
        )
        assert pupil in pupil.residence.people
        # more thatn two weeks after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=25,
        )
        worker.infection = None
        sim.clear_world()
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy, 25
        )
        assert pupil in pupil.primary_activity.people
        sim.clear_world()

    def test__housemates_of_asymptomatic_are_free(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1", end_time="2020-1-30", n_days=7, n_days_household=14
        )
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "asymptomatic")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        # after symptoms onset
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=8.0,
        )
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=25.0,
        )
        worker.infection = None
        sim.clear_world()

    def test__quarantine_zero_complacency(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1",
            end_time="2020-1-30",
            n_days=7,
            n_days_household=14,
            household_compliance=0.0,
        )
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "mild")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # before symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=4.0,
        )
        # after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=8.0,
        )
        # more thatn two weeks after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=25,
        )
        worker.infection = None
        sim.clear_world()

    def test__quarantine_zero_complacency_regional(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        world.regions[0].regional_compliance = 0
        quarantine = Quarantine(
            start_time="2020-1-1", end_time="2020-1-30", n_days=7, n_days_household=14
        )
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "mild")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # before symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=4.0,
        )
        # after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=8.0,
        )
        # more thatn two weeks after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=25,
        )
        worker.infection = None
        sim.clear_world()

    def test__quarantine_vaccinated_are_free(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1",
            end_time="2020-1-30",
            n_days=7,
            n_days_household=14,
            household_compliance=1.0,
            vaccinated_household_compliance=0.0,
        )
        pupil.age = 19  # such that they aren't caught in the under 18 rule
        pupil.vaccinated = True
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "mild")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # before symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=4.0,
        )
        # after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=8.0,
        )
        # more thatn two weeks after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=25,
        )
        worker.infection = None
        sim.clear_world()

    def test__quarantine_children_are_free(self, setup_policy_world, selector):
        world, pupil, student, worker, sim = setup_policy_world
        quarantine = Quarantine(
            start_time="2020-1-1",
            end_time="2020-1-30",
            n_days=7,
            n_days_household=14,
            household_compliance=1.0,
            vaccinated_household_compliance=0.0,
        )
        pupil.age = 17
        policies = Policies([quarantine])
        sim.activity_manager.policies = policies
        infect_person(worker, selector, "mild")
        sim.epidemiology.update_health_status(world, 0.0, 0.0)
        activities = ["primary_activity", "residence"]
        sim.clear_world()
        time_during_policy = datetime(2020, 1, 2)
        active_individual_policies = policies.individual_policies.get_active(
            date=time_during_policy
        )
        # before symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=4.0,
        )
        # after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=8.0,
        )
        # more thatn two weeks after symptoms onset
        assert "primary_activity" in policies.individual_policies.apply(
            active_individual_policies,
            person=pupil,
            activities=activities,
            days_from_start=25,
        )
        worker.infection = None
        sim.clear_world()


def test__kid_at_home_is_supervised(setup_policy_world, selector):
    world, pupil, student, worker, sim = setup_policy_world
    policies = Policies([SevereSymptomsStayHome()])
    sim.activity_manager.policies = policies
    assert pupil.primary_activity is not None
    infect_person(pupil, selector, "severe")
    assert pupil.infection.tag == SymptomTag.severe
    sim.activity_manager.move_people_to_active_subgroups(
        ["primary_activity", "residence"]
    )
    assert pupil in pupil.residence.people
    guardians_at_home = [
        person for person in pupil.residence.group.people if person.age >= 18
    ]
    assert len(guardians_at_home) != 0
    sim.clear_world()


class TestLimitLongCommute:
    def test__haversine_distance(self):
        area = Area(coordinates=[0, 1])
        super_area = SuperArea(coordinates=[0, 0])
        distance = haversine_distance(area.coordinates, super_area.coordinates)
        assert 100 < distance < 150

    def test__distance_policy_check(self):
        worker = Person.from_attributes()
        area = Area(coordinates=[0, 1])
        super_area = SuperArea(coordinates=[0, 0])
        super_area.add_worker(worker)
        area.add(worker)
        limit_long_commute = LimitLongCommute(
            apply_from_distance=150, going_to_work_probability=0.2
        )
        ret = limit_long_commute._does_long_commute(worker)
        assert ret is False

    def test__probability_of_going_to_work(self):
        worker = Person.from_attributes()
        area = Area(coordinates=[0, 3])
        super_area = SuperArea(coordinates=[0, 0])
        area.add(worker)
        super_area.add_worker(worker)
        limit_long_commute = LimitLongCommute(
            apply_from_distance=150, going_to_work_probability=0.2
        )
        assert set(limit_long_commute.activities_to_remove) == set(
            ["commute", "primary_activity"]
        )
        limit_long_commute.get_long_commuters([worker])
        skips = 0
        n = 5000
        for _ in range(n):
            ret = limit_long_commute.check_skips_activity(worker)
            if ret:
                skips += 1
        assert np.isclose(skips, 0.2 * n, rtol=0.1)


class TestSchoolQuarantine:
    def test__school_quarantine(self, selector):
        kids = []
        school = School()
        household = Household()
        for i in range(10):
            for _ in range(100):
                person = Person.from_attributes(age=i)
                school.add(person)
                household.add(person)
                kids.append(person)
        school_quarantine = SchoolQuarantine(
            start_time="2020-1-1", end_time="2020-1-30", compliance=0.7, n_days=7
        )
        infected = kids[0]
        infect_person(infected, selector=selector, symptom_tag="mild")
        time = 0
        checks = [False, False, False]
        while True:
            if time > 7 + infected.infection.time_of_symptoms_onset:
                checks[0] = True
                for person in kids:
                    stays_home = school_quarantine.check_stay_home_condition(
                        person=person, days_from_start=time
                    )
                    assert stays_home is False
                break
            if time < infected.infection.time_of_symptoms_onset:
                checks[1] = True
                for person in kids:
                    stays_home = school_quarantine.check_stay_home_condition(
                        person=person, days_from_start=time
                    )
                    assert stays_home is False
                time += 1
            else:
                checks[2] = True
                quarantined = 0
                total = 0
                for person in kids:
                    stays_home = school_quarantine.check_stay_home_condition(
                        person=person, days_from_start=time
                    )
                    if person.primary_activity == infected.primary_activity:
                        total += 1
                        if stays_home:
                            quarantined += 1
                    else:
                        assert stays_home is False
                assert np.isclose(quarantined / total, 0.7, rtol=0.15)
                time += 1
        assert min(checks) is True


from datetime import datetime

import pytest

from june import paths
from june.geography import Cities
from june.groups import Cemeteries
from june.groups.leisure import generate_leisure_for_config
from june.policy import Policies, SocialDistancing, MaskWearing


test_config = paths.configs_path / "tests/test_simulator_simple.yaml"


class TestSocialDistancing:
    @pytest.fixture(name="social_distancing_sim")
    def setup(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        world.cemeteries = Cemeteries()
        beta_factors = {
            "pub": 0.7,
            "grocery": 0.7,
            "cinema": 0.7,
            "inter_city_transport": 0.7,
            "city_transport": 0.7,
            "hospital": 0.7,
            "care_home": 0.7,
            "company": 0.7,
            "school": 0.7,
            "household": 1.0,
            "university": 0.7,
            "household_visits": 0.5,
        }
        social_distance = SocialDistancing(
            start_time="2020-03-02", end_time="2020-03-05", beta_factors=beta_factors
        )
        beta_factors2 = {"cinema": 4}
        social_distance2 = SocialDistancing(
            start_time="2020-03-07", end_time="2020-03-09", beta_factors=beta_factors2
        )
        policies = Policies([social_distance, social_distance2])
        leisure_instance = generate_leisure_for_config(
            world=world, config_filename=test_config
        )
        leisure_instance.distribute_social_venues_to_areas(
            world.areas, super_areas=world.super_areas
        )
        sim.activity_manager.policies = policies
        sim.activity_manager.leisure = leisure_instance
        sim.timer.reset()
        sim.clear_world()
        return sim

    def test__social_distancing_basic(self, social_distancing_sim):
        start_date = datetime(2020, 3, 2)
        end_date = datetime(2020, 3, 5)
        start_date2 = datetime(2020, 3, 7)
        end_date2 = datetime(2020, 3, 9)
        sim = social_distancing_sim
        something_is_checked = False
        while sim.timer.date <= sim.timer.final_date:
            sim.do_timestep()
            if sim.timer.date >= start_date and sim.timer.date < end_date:
                for super_group in sim.world:
                    if super_group.__class__ in [Cities, Cemeteries]:
                        continue
                    for group in super_group:
                        interactive_group = group.get_interactive_group()
                        beta = interactive_group.get_processed_beta(
                            betas=sim.interaction.betas,
                            beta_reductions=sim.interaction.beta_reductions,
                        )
                        if group.spec == "household":
                            assert beta == sim.interaction.betas["household"]
                        else:
                            something_is_checked = True
                            assert (
                                beta
                                == sim.interaction.betas[group.spec]
                                * sim.interaction.beta_reductions[group.spec]
                            )
                next(sim.timer)
                continue
            if sim.timer.date >= start_date2 and sim.timer.date < end_date2:
                for super_group in sim.world:
                    for group in super_group:
                        if super_group.__class__ in [Cities, Cemeteries]:
                            continue
                        interactive_group = group.get_interactive_group()
                        beta = interactive_group.get_processed_beta(
                            betas=sim.interaction.betas,
                            beta_reductions=sim.interaction.beta_reductions,
                        )
                        if group.spec == "cinema":
                            assert beta == 4 * sim.interaction.betas["cinema"]
                        else:
                            assert beta == sim.interaction.betas[group.spec]
                next(sim.timer)
                continue
            next(sim.timer)
        assert something_is_checked

    def test__social_distancing_regional_compliance(self, social_distancing_sim):
        start_date = datetime(2020, 3, 2)
        end_date = datetime(2020, 3, 5)
        start_date2 = datetime(2020, 3, 7)
        end_date2 = datetime(2020, 3, 9)
        sim = social_distancing_sim
        something_is_checked = False
        sim.world.regions[0].regional_compliance = 0.5
        something_is_checked = False
        while sim.timer.date <= sim.timer.final_date:
            sim.do_timestep()
            if sim.timer.date >= start_date and sim.timer.date < end_date:
                for super_group in sim.world:
                    if super_group.__class__ in [Cities, Cemeteries]:
                        continue
                    for group in super_group:
                        interactive_group = group.get_interactive_group()
                        beta = interactive_group.get_processed_beta(
                            betas=sim.interaction.betas,
                            beta_reductions=sim.interaction.beta_reductions,
                        )
                        beta_with_compliance = sim.interaction.betas[group.spec] * (
                            1 + 0.5 * (sim.interaction.beta_reductions[group.spec] - 1)
                        )
                        if group.spec == "household":
                            assert beta == sim.interaction.betas["household"]
                        else:
                            something_is_checked = True
                            assert beta == beta_with_compliance
                next(sim.timer)
                continue
            if sim.timer.date >= start_date2 and sim.timer.date < end_date2:
                for super_group in sim.world:
                    if super_group.__class__ in [Cities, Cemeteries]:
                        continue
                    for group in super_group:
                        interactive_group = group.get_interactive_group()
                        beta = interactive_group.get_processed_beta(
                            betas=sim.interaction.betas,
                            beta_reductions=sim.interaction.beta_reductions,
                        )
                        if group.spec == "cinema":
                            assert beta == sim.interaction.betas["cinema"] * (
                                1 + 0.5 * (4 - 1)
                            )
                        else:
                            assert beta == sim.interaction.betas[group.spec]
                next(sim.timer)
                continue
            next(sim.timer)
        assert something_is_checked


class TestMaskWearing:
    def test__mask_wearing(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        world.cemeteries = Cemeteries()
        start_date = datetime(2020, 3, 10)
        end_date = datetime(2020, 3, 12)
        compliance = 1.0
        beta_factor = 0.5
        mask_probabilities = {
            "pub": 0.5,
            "grocery": 0.5,
            "cinema": 0.5,
            "city_transport": 0.5,
            "inter_city_transport": 0.5,
            "hospital": 0.5,
            "care_home": 0.5,
            "company": 0.5,
            "school": 0.5,
            "household": 0.0,
            "university": 0.5,
        }
        mask_wearing = MaskWearing(
            start_time="2020-03-10",
            end_time="2020-03-12",
            beta_factor=beta_factor,
            mask_probabilities=mask_probabilities,
            compliance=compliance,
        )
        policies = Policies([mask_wearing])
        leisure_instance = generate_leisure_for_config(
            world=world, config_filename=test_config
        )
        leisure_instance.distribute_social_venues_to_areas(
            world.areas, super_areas=world.super_areas
        )
        sim.activity_manager.policies = policies
        sim.activity_manager.leisure = leisure_instance
        sim.timer.reset()
        sim.clear_world()
        while sim.timer.date <= sim.timer.final_date:
            sim.do_timestep()
            if sim.timer.date >= start_date and sim.timer.date < end_date:
                for super_group in sim.world:
                    if super_group.__class__ in [Cities, Cemeteries]:
                        continue
                    for group in super_group:
                        interactive_group = group.get_interactive_group()
                        beta = interactive_group.get_processed_beta(
                            betas=sim.interaction.betas,
                            beta_reductions=sim.interaction.beta_reductions,
                        )
                        beta_with_mask = sim.interaction.betas[group.spec] * (
                            1 - (0.5 * 1.0 * (1 - 0.5))
                        )
                        if group.spec == "household":
                            assert beta == sim.interaction.betas["household"]
                        else:
                            assert beta == beta_with_mask
                next(sim.timer)
                continue
            else:
                for super_group in sim.world:
                    if super_group.__class__ in [Cities, Cemeteries]:
                        continue
                    for group in super_group:
                        interactive_group = group.get_interactive_group()
                        beta = interactive_group.get_processed_beta(
                            betas=sim.interaction.betas,
                            beta_reductions=sim.interaction.beta_reductions,
                        )
                        assert beta == sim.interaction.betas[group.spec]
            next(sim.timer)


from datetime import datetime, timedelta

import numpy as np

from june import paths
from june.demography import Person
from june.groups import Household
from june.groups.leisure import generate_leisure_for_config, generate_leisure_for_world
from june.policy import (
    Policies,
    CloseLeisureVenue,
    ChangeLeisureProbability,
    TieredLockdown,
    TieredLockdowns,
    ChangeVisitsProbability,
)


test_config = paths.configs_path / "tests/test_simulator_simple.yaml"


class TestCloseLeisure:
    def test__close_leisure_venues(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        close_venues = CloseLeisureVenue(
            start_time="2020-3-1", end_time="2020-3-30", venues_to_close=["pub"]
        )
        policies = Policies([close_venues])
        leisure = generate_leisure_for_config(world=world, config_filename=test_config)
        leisure.distribute_social_venues_to_areas(
            world.areas, super_areas=world.super_areas
        )
        sim.activity_manager.leisure = leisure
        sim.activity_manager.policies = policies
        sim.clear_world()
        time_before_policy = datetime(2019, 2, 1)
        activities = ["leisure", "residence"]
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=10000,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy, 0.0
        )
        assert worker in worker.leisure.people
        sim.clear_world()
        time_during_policy = datetime(2020, 3, 14)
        policies.leisure_policies.apply(date=time_during_policy, leisure=leisure)
        assert list(world.regions[0].policy["global_closed_venues"]) == ["pub"]
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=10000,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy, 0.0
        )
        assert (
            worker in worker.leisure.people and worker.leisure.group.spec == "cinema"
        ) or worker in worker.residence.people
        sim.clear_world()

        sim.clear_world()
        time_after_policy = datetime(2020, 3, 30)
        policies.leisure_policies.apply(date=time_after_policy, leisure=leisure)
        assert list(world.regions[0].policy["global_closed_venues"]) == []
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=10000,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy, 0.0
        )
        assert worker in worker.leisure.people

    def test__close_leisure_venues_tiered_lockdowns(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        tiered_lockdown = TieredLockdown(
            start_time="2020-03-02",
            end_time="2020-03-30",
            tiers_per_region={"North East": 4.0},
        )
        tiered_lockdowns = TieredLockdowns([tiered_lockdown])

        policies = Policies([tiered_lockdowns])
        leisure = generate_leisure_for_config(world=world, config_filename=test_config)
        leisure.distribute_social_venues_to_areas(
            world.areas, super_areas=world.super_areas
        )
        sim.activity_manager.leisure = leisure
        sim.activity_manager.policies = policies
        sim.clear_world()
        time_before_policy = datetime(2019, 2, 1)
        activities = ["leisure", "residence"]
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=10000,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_before_policy, 0.0
        )
        assert worker in worker.leisure.people
        sim.clear_world()
        time_during_policy = datetime(2020, 3, 14)
        policies.tiered_lockdown.apply(date=time_during_policy, regions=world.regions)
        assert "pub" in list(world.regions[0].policy["local_closed_venues"])
        assert "cinema" in list(world.regions[0].policy["local_closed_venues"])
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=10000,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_during_policy, 0.0
        )
        assert worker in worker.residence.people
        sim.clear_world()

        sim.clear_world()
        time_after_policy = datetime(2020, 3, 30)
        policies.tiered_lockdown.apply(date=time_after_policy, regions=world.regions)
        assert list(world.regions[0].policy["local_closed_venues"]) == []
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=10000,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        sim.activity_manager.move_people_to_active_subgroups(
            activities, time_after_policy, 0.0
        )
        assert worker in worker.leisure.people


class TestReduceLeisureProbabilities:
    def test__reduce_household_visits(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        super_area = world.super_areas[0]
        leisure = generate_leisure_for_config(world=world, config_filename=test_config)
        reduce_leisure_probabilities = ChangeLeisureProbability(
            start_time="2020-03-02",
            end_time="2020-03-05",
            activity_reductions={
                "pub": {"male": {"0-50": 0.5, "50-100": 0.0}, "female": {"0-100": 0.5}}
            },
        )
        policies = Policies([reduce_leisure_probabilities])
        sim.activity_manager.policies = policies
        sim.activity_manager.leisure = leisure
        sim.clear_world()
        policies.leisure_policies.apply(
            date=sim.timer.date, leisure=sim.activity_manager.leisure
        )
        sim.activity_manager.leisure.generate_leisure_probabilities_for_timestep(
            0.1, working_hours=False, date=datetime.strptime("2020-03-02", "%Y-%m-%d")
        )
        assert str(sim.timer.date.date()) == "2020-03-01"
        household = Household()
        household.area = super_area.areas[0]
        leisure.distribute_social_venues_to_areas(
            world.areas, super_areas=world.super_areas
        )
        person1 = Person.from_attributes(age=60, sex="m")
        person1.area = super_area.areas[0]
        household.add(person1)
        person2 = Person.from_attributes(age=19, sex="f")
        person2.area = super_area.areas[0]
        leisure.distribute_social_venues_to_areas(
            world.areas, super_areas=world.super_areas
        )
        household.add(person2)
        pubs1_visits_before = 0
        pubs2_visits_before = 0
        for _ in range(5000):
            subgroup = leisure.get_subgroup_for_person_and_housemates(person1)
            if subgroup is not None and subgroup.group.spec == "pub":
                pubs1_visits_before += 1
            person1.subgroups.leisure = None
            subgroup = leisure.get_subgroup_for_person_and_housemates(person2)
            if subgroup is not None and subgroup.group.spec == "pub":
                pubs2_visits_before += 1
            person2.subgroups.leisure = None
        assert pubs1_visits_before > 0
        assert pubs2_visits_before > 0
        # next day leisure policies are
        while str(sim.timer.date.date()) != "2020-03-02":
            next(sim.timer)
        policies.leisure_policies.apply(date=sim.timer.date, leisure=leisure)
        leisure.generate_leisure_probabilities_for_timestep(
            0.1, working_hours=False, date=datetime.strptime("2020-03-02", "%Y-%m-%d")
        )
        assert leisure.policy_reductions["pub"]["weekday"]["m"][60] == 0.0
        assert leisure.policy_reductions["pub"]["weekday"]["f"][19] == 0.5
        assert leisure.policy_reductions["pub"]["weekend"]["m"][60] == 0.0
        assert leisure.policy_reductions["pub"]["weekend"]["f"][19] == 0.5
        pubs1_visits_after = 0
        pubs2_visits_after = 0
        for _ in range(5000):
            subgroup = leisure.get_subgroup_for_person_and_housemates(person1)
            if subgroup is not None and subgroup.group.spec == "pub":
                pubs1_visits_after += 1
            person1.subgroups.leisure = None
            subgroup = leisure.get_subgroup_for_person_and_housemates(person2)
            if subgroup is not None and subgroup.group.spec == "pub":
                pubs2_visits_after += 1
            person2.subgroups.leisure = None
        assert pubs1_visits_after == 0
        assert np.isclose(pubs2_visits_after / pubs2_visits_before, 0.5, rtol=0.1)
        # end of policy
        while str(sim.timer.date.date()) != "2020-03-05":
            next(sim.timer)
        policies.leisure_policies.apply(
            date=sim.timer.date, leisure=sim.activity_manager.leisure
        )
        sim.activity_manager.leisure.generate_leisure_probabilities_for_timestep(
            0.1, working_hours=False, date=datetime.strptime("2020-03-02", "%Y-%m-%d")
        )
        assert leisure.policy_reductions == {}
        assert leisure.policy_reductions == {}
        pubs1_visits_restored = 0
        pubs2_visits_restored = 0
        for _ in range(5000):
            subgroup = leisure.get_subgroup_for_person_and_housemates(person1)
            if subgroup is not None and subgroup.group.spec == "pub":
                pubs1_visits_restored += 1
            person1.subgroups.leisure = None
            subgroup = leisure.get_subgroup_for_person_and_housemates(person2)
            if subgroup is not None and subgroup.group.spec == "pub":
                pubs2_visits_restored += 1
            person2.subgroups.leisure = None
        assert np.isclose(pubs1_visits_restored, pubs1_visits_before, rtol=0.1)
        assert np.isclose(pubs2_visits_restored, pubs2_visits_before, rtol=0.1)
        assert leisure.policy_reductions == {}

    def test__reduce_household_visits_with_regional_compliance(
        self, setup_policy_world
    ):
        world, pupil, student, worker, sim = setup_policy_world
        while str(sim.timer.date.date()) != "2020-03-02":
            next(sim.timer)
        region = worker.region
        leisure = generate_leisure_for_config(world=world, config_filename=test_config)
        assert leisure.regions[0] == region
        reduce_leisure_probabilities = ChangeLeisureProbability(
            start_time="2020-03-02",
            end_time="2020-03-05",
            activity_reductions={
                "pub": {"male": {"0-50": 0.5, "50-100": 0.0}, "female": {"0-100": 0.5}}
            },
        )
        policies = Policies([reduce_leisure_probabilities])
        sim.activity_manager.policies = policies
        sim.activity_manager.leisure = leisure

        # compliance to 1
        policies.leisure_policies.apply(date=sim.timer.date, leisure=leisure)
        assert leisure.policy_reductions["pub"]["weekday"]["m"][60] == 0.0
        assert leisure.policy_reductions["pub"]["weekday"]["f"][40] == 0.5
        assert leisure.policy_reductions["pub"]["weekend"]["m"][60] == 0.0
        assert leisure.policy_reductions["pub"]["weekend"]["f"][40] == 0.5
        original_poisson_parameter = leisure.leisure_distributors[
            "pub"
        ].get_poisson_parameter(
            sex="m", age=25, day_type="weekday", working_hours=False
        )
        full_comp_poisson_parameter = leisure._get_activity_poisson_parameter(
            activity="pub",
            distributor=leisure.leisure_distributors["pub"],
            sex="m",
            age=25,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
            working_hours=False,
            region=region,
        )
        region.regional_compliance = 0.5
        half_comp_poisson_parameter = leisure._get_activity_poisson_parameter(
            activity="pub",
            distributor=leisure.leisure_distributors["pub"],
            sex="m",
            age=25,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
            working_hours=False,
            region=region,
        )
        assert np.isclose(
            full_comp_poisson_parameter, 0.5 * original_poisson_parameter, rtol=0.5
        )
        assert np.isclose(
            half_comp_poisson_parameter,
            original_poisson_parameter
            + 0.5 * (full_comp_poisson_parameter - original_poisson_parameter),
        )

        # check integration with region object
        region.regional_compliance = 1.0
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=0.1,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        full_comp_probs = leisure._get_activity_probabilities_for_person(person=worker)
        region.regional_compliance = 0.5
        leisure.generate_leisure_probabilities_for_timestep(
            delta_time=0.1,
            working_hours=False,
            date=datetime.strptime("2020-03-02", "%Y-%m-%d"),
        )
        half_comp_probs = leisure._get_activity_probabilities_for_person(person=worker)
        # this is  a reduction, so being less compliant means you go more often
        assert half_comp_probs["does_activity"] > full_comp_probs["does_activity"]
        assert (
            half_comp_probs["activities"]["pub"] > full_comp_probs["activities"]["pub"]
        )
        assert (
            half_comp_probs["drags_household"]["pub"]
            == full_comp_probs["drags_household"]["pub"]
        )


class TestChangeVisitsProbabilities:
    def test__change_split(self, setup_policy_world):
        world, pupil, student, worker, sim = setup_policy_world
        leisure = generate_leisure_for_world(
            world=world,
            list_of_leisure_groups=["care_home_visits", "household_visits"],
            daytypes={
                "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
                "weekend": ["Saturday", "Sunday"],
            },
        )
        reduce_leisure_probabilities = ChangeVisitsProbability(
            start_time="2020-03-02",
            end_time="2020-03-05",
            new_residence_type_probabilities={"household": 0.9, "care_home": 0.1},
        )
        policies = Policies([reduce_leisure_probabilities])
        assert leisure.leisure_distributors[
            "residence_visits"
        ].residence_type_probabilities == {"household": 0.66, "care_home": 0.34}
        starting_date = datetime(2020, 2, 25)
        ending_date = datetime(2020, 4, 1)
        current_date = starting_date
        checks1 = False
        checks2 = False
        while current_date < ending_date:
            policies.leisure_policies.apply(date=current_date, leisure=leisure)
            if current_date < datetime(2020, 3, 2) or current_date >= datetime(
                2020, 3, 5
            ):
                assert (
                    leisure.leisure_distributors["residence_visits"].policy_reductions
                    == {}
                )
                checks1 = True
            else:
                assert leisure.leisure_distributors[
                    "residence_visits"
                ].policy_reductions == {"household": 0.9, "care_home": 0.1}
                checks2 = True
            current_date += timedelta(days=1)
        assert checks1 and checks2


from june.epidemiology.infection import SymptomTag
from june.policy import Policies, Hospitalisation


def test__hospitalise_the_sick(setup_policy_world, selector):
    world, pupil, student, worker, sim = setup_policy_world
    hospitalisation = Hospitalisation()
    policies = Policies([hospitalisation])
    sim.activity_manager.policies = policies
    sim.epidemiology.set_medical_care(
        world=world, activity_manager=sim.activity_manager
    )
    selector.infect_person_at_time(worker, 0.0)
    worker.infection.symptoms.tag = SymptomTag.hospitalised
    assert worker.infection.should_be_in_hospital
    sim.epidemiology.update_health_status(world, 0.0, 0.0)
    assert worker.medical_facility is not None
    sim.activity_manager.move_people_to_active_subgroups(
        ["medical_facility", "residence"]
    )
    assert worker in worker.medical_facility.people
    sim.clear_world()


def test__move_people_from_hospital_to_icu(setup_policy_world, selector):
    world, pupil, student, worker, sim = setup_policy_world
    hospital = world.hospitals[0]
    selector.infect_person_at_time(worker, 0.0)
    hospitalisation = Hospitalisation()
    policies = Policies([hospitalisation])
    sim.activity_manager.policies = policies
    sim.epidemiology.set_medical_care(
        world=world, activity_manager=sim.activity_manager
    )
    worker.infection.symptoms.tag = SymptomTag.hospitalised
    assert worker.infection.should_be_in_hospital
    sim.epidemiology.update_health_status(world, 0.0, 0.0)
    assert worker.medical_facility == hospital[hospital.SubgroupType.patients]
    sim.clear_world()
    worker.infection.symptoms.tag = SymptomTag.intensive_care
    sim.epidemiology.update_health_status(world, 0.0, 0.0)
    hospital = worker.medical_facility.group
    sim.activity_manager.move_people_to_active_subgroups(
        ["medical_facility", "residence"]
    )
    assert worker.medical_facility == hospital[hospital.SubgroupType.icu_patients]
    sim.clear_world()


def test__move_people_from_icu_to_hospital(setup_policy_world, selector):
    world, pupil, student, worker, sim = setup_policy_world
    selector.infect_person_at_time(worker, 0.0)
    hospitalisation = Hospitalisation()
    policies = Policies([hospitalisation])
    sim.activity_manager.policies = policies
    sim.epidemiology.set_medical_care(
        world=world, activity_manager=sim.activity_manager
    )
    worker.infection.symptoms.tag = SymptomTag.intensive_care
    assert worker.infection.should_be_in_hospital
    hospital = world.hospitals[0]
    sim.epidemiology.update_health_status(world, 0.0, 0.0)
    assert worker.medical_facility == hospital[hospital.SubgroupType.icu_patients]
    sim.clear_world()
    worker.infection.symptoms.tag = SymptomTag.hospitalised
    sim.epidemiology.update_health_status(world, 0.0, 0.0)
    hospital = worker.medical_facility.group
    sim.activity_manager.move_people_to_active_subgroups(
        ["medical_facility", "residence"]
    )
    assert worker.medical_facility == hospital[hospital.SubgroupType.patients]
    sim.clear_world()


from datetime import datetime
from pathlib import Path

import pytest

from june.geography import Geography
from june.interaction import Interaction
from june.epidemiology.infection.infection_selector import InfectionSelector
from june.policy import Policy


path_pwd = Path(__file__)
dir_pwd = path_pwd.parent
constant_config = (
    dir_pwd.parent.parent.parent / "configs/defaults/infection/InfectionXNExp.yaml"
)


@pytest.fixture(name="selector", scope="module")
def create_selector():
    selector = InfectionSelector.from_file(config_filename=constant_config)
    selector.recovery_rate = 0.05
    selector.transmission_probability = 0.7
    return selector


@pytest.fixture(name="interaction", scope="module")
def create_interaction():
    interaction = Interaction.from_file()
    return interaction


@pytest.fixture(name="super_area", scope="module")
def create_geography():
    g = Geography.from_file(filter_key={"super_area": ["E02002559"]})
    return g.super_areas.members[0]


class TestPolicy:
    def test__is_active(self):
        policy = Policy(start_time="2020-5-6", end_time="2020-6-6")
        assert policy.is_active(datetime(2020, 5, 6))
        assert policy.is_active(datetime(2020, 6, 5))
        assert not policy.is_active(datetime(2020, 6, 6))


from june.geography import Region, Regions
from june.policy import (
    RegionalCompliance,
    RegionalCompliances,
    TieredLockdown,
    TieredLockdowns,
)


class TestSetRegionCompliance:
    def test__set_compliance_to_region(self):
        regional_compliance = RegionalCompliance(
            start_time="2020-05-01",
            end_time="2020-09-01",
            compliances_per_region={"London": 1.5},
        )
        regional_compliances = RegionalCompliances([regional_compliance])
        region = Region(name="London")
        regions = Regions([region])
        regional_compliances.apply(regions=regions, date="2020-05-05")
        assert region.regional_compliance == 1.5
        regional_compliances.apply(regions=regions, date="2020-05-01")
        assert region.regional_compliance == 1.5
        regional_compliances.apply(regions=regions, date="2020-01-05")
        assert region.regional_compliance == 1.0
        regional_compliances.apply(regions=regions, date="2021-01-05")
        assert region.regional_compliance == 1.0
        regional_compliances.apply(regions=regions, date="2020-09-01")
        assert region.regional_compliance == 1.0


class TestSetTiers:
    def test__set_lockdowntiers(self):
        tiered_lockdown = TieredLockdown(
            start_time="2020-05-01",
            end_time="2020-09-01",
            tiers_per_region={"London": 2.0},
        )
        tiered_lockdowns = TieredLockdowns([tiered_lockdown])
        region = Region(name="London")
        regions = Regions([region])
        tiered_lockdowns.apply(regions=regions, date="2020-05-05")
        assert region.policy["lockdown_tier"] == 2
        tiered_lockdowns.apply(regions=regions, date="2020-05-01")
        assert region.policy["lockdown_tier"] == 2
        tiered_lockdowns.apply(regions=regions, date="2020-01-05")
        assert region.policy["lockdown_tier"] is None
        tiered_lockdowns.apply(regions=regions, date="2021-01-05")
        assert region.policy["lockdown_tier"] is None
        tiered_lockdowns.apply(regions=regions, date="2020-09-01")
        assert region.policy["lockdown_tier"] is None


import datetime
import numpy as np
import pytest
from pathlib import Path

from tables import open_file
from june import paths
from june.records import Record
from june.groups import Hospital, Hospitals, Household, Households, CareHome, CareHomes
from june.demography import Person
from june.geography.geography import Areas, SuperAreas, Regions, Area, SuperArea, Region
from june import World

from june.records.records_writer import prepend_checkpoint_hdf5

config_interaction = paths.configs_path / "tests/interaction.yaml"


@pytest.fixture(name="dummy_world", scope="module")
def create_dummy_world():
    # 2 regions, 2 hospitals, 1 care home 1 household
    regions = Regions([Region(name="region_1"), Region(name="region_2")])
    regions[0].super_areas = [
        SuperArea(name="super_1", coordinates=(0.0, 0.0), region=regions[0]),
        SuperArea(name="super_2", coordinates=(1.0, 1.0), region=regions[0]),
    ]
    regions[1].super_areas = [
        SuperArea(name="super_3", coordinates=(2.0, 2.0), region=regions[1])
    ]
    super_areas = SuperAreas(regions[0].super_areas + regions[1].super_areas)
    super_areas[0].areas = [
        Area(name="area_1", coordinates=(0.0, 0.0), super_area=super_areas[0]),
        Area(name="area_2", coordinates=(0.0, 0.0), super_area=super_areas[0]),
        Area(name="area_3", coordinates=(0.0, 0.0), super_area=super_areas[0]),
    ]
    super_areas[1].areas = [
        Area(name="area_4", coordinates=(0.0, 0.0), super_area=super_areas[1]),
        Area(name="area_5", coordinates=(0.0, 0.0), super_area=super_areas[1]),
    ]
    super_areas[2].areas = [
        Area(name="area_6", coordinates=(5, 5), super_area=super_areas[2])
    ]
    areas = Areas(super_areas[0].areas + super_areas[1].areas + super_areas[2].areas)
    households = Households([Household(area=super_areas[0].areas[0])])
    hospitals = Hospitals(
        [Hospital(n_beds=1, n_icu_beds=1, area=areas[5], coordinates=(0.0, 0.0))]
    )
    care_homes = CareHomes([CareHome(area=super_areas[0].areas[0])])
    world = World()
    world.areas = areas
    world.super_areas = super_areas
    world.regions = regions
    world.households = households
    world.hospitals = hospitals
    world.care_homes = care_homes
    world.people = [
        Person.from_attributes(id=0, age=0, ethnicity="A"),
        Person.from_attributes(id=1, age=1, ethnicity="B"),
        Person.from_attributes(id=2, age=2, sex="m", ethnicity="C"),
    ]
    world.people[0].area = super_areas[0].areas[0]  # household resident
    world.people[0].subgroups.primary_activity = hospitals[0].subgroups[0]
    world.people[0].subgroups.residence = households[0].subgroups[0]
    world.people[1].area = super_areas[0].areas[0]
    world.people[1].subgroups.residence = households[0].subgroups[0]
    world.people[2].area = super_areas[0].areas[1]  # care home resident
    world.people[2].subgroups.residence = care_homes[0].subgroups[0]
    return world


def test__prepend_checkpoint_hdf5(dummy_world):

    pre_checkpoint_record_path = Path("./pre_checkpoint_results/june_record.h5")
    pre_checkpoint_record = Record(
        record_path="pre_checkpoint_results", record_static_data=True
    )
    pre_checkpoint_record.static_data(dummy_world)
    for i in range(1, 15):
        timestamp = datetime.datetime(2020, 3, i)
        # everyone from the second record should have an EVEN id.
        infected_ids = [i * 1000 + 500 + 0 + 2 * x for x in range(3)]
        infector_ids = [i * 1000 + 500 + 10 + 2 * x for x in range(3)]
        dead_ids = [i * 1000 + 500 + 20 + 2 * x for x in range(3)]
        infection_ids = [i * 1000 + 500 + 20 + 2 * x for x in range(3)]
        with open_file(pre_checkpoint_record_path, mode="a") as f:
            pre_checkpoint_record.file = f
            pre_checkpoint_record.accumulate(
                table_name="infections",
                location_spec="pre_check_location",
                region_name="over_here",
                location_id=0,
                infected_ids=infected_ids,
                infector_ids=infector_ids,
                infection_ids=infection_ids,
            )
            for dead_id in dead_ids:
                pre_checkpoint_record.accumulate(
                    table_name="deaths",
                    location_id=0,
                    location_spec="pre_check_location",
                    dead_person_id=dead_id,
                )
        pre_checkpoint_record.time_step(timestamp)

    post_checkpoint_record_path = Path("./post_checkpoint_results/june_record.h5")
    post_checkpoint_record = Record(
        record_path="post_checkpoint_results", record_static_data=True
    )
    post_checkpoint_record.static_data(dummy_world)
    for i in range(11, 21):
        timestamp = datetime.datetime(2020, 3, i)
        # everyone from the second record should have an ODD id.
        infected_ids = [i * 1000 + 500 + 0 + 2 * x + 1 for x in range(3)]
        infector_ids = [i * 1000 + 500 + 10 + 2 * x + 1 for x in range(3)]
        dead_ids = [i * 1000 + 500 + 20 + 2 * x + 1 for x in range(3)]
        infection_ids = [i * 1000 + 500 + 20 + 2 * x + 1 for x in range(3)]
        with open_file(post_checkpoint_record_path, mode="a") as f:
            post_checkpoint_record.file = f
            post_checkpoint_record.accumulate(
                table_name="infections",
                location_spec="post_check_location",
                region_name="way_over_there",
                location_id=0,
                infected_ids=infected_ids,
                infector_ids=infector_ids,
                infection_ids=infection_ids,
            )
            for dead_id in dead_ids:
                post_checkpoint_record.accumulate(
                    table_name="deaths",
                    location_id=0,
                    location_spec="pre_check_location",
                    dead_person_id=dead_id,
                )
        post_checkpoint_record.time_step(timestamp)

    merged_record_path = Path("./post_checkpoint_results/merged_checkpoint_record.h5")
    prepend_checkpoint_hdf5(
        pre_checkpoint_record_path,
        post_checkpoint_record_path,
        merged_record_path=merged_record_path,
        checkpoint_date=datetime.datetime(2020, 3, 11),
    )

    with open_file(merged_record_path) as merged_record:
        unique_infection_dates = np.unique(
            [
                datetime.datetime.strptime(x.decode("utf-8"), "%Y-%m-%d")
                for x in merged_record.root.infections[:]["timestamp"]
            ]
        )

        assert len(unique_infection_dates) == 20
        assert len(merged_record.root.infections[:]) == 3 * 20

        for row in merged_record.root.infections[:]:
            timestamp = datetime.datetime.strptime(
                row["timestamp"].decode("utf-8"), "%Y-%m-%d"
            )
            if timestamp < datetime.datetime(2020, 3, 11):
                assert row["infected_ids"] % 2 == 0
                assert row["infector_ids"] % 2 == 0
            else:
                assert row["infected_ids"] % 2 == 1
                assert row["infector_ids"] % 2 == 1

        for row in merged_record.root.deaths[:]:
            timestamp = datetime.datetime.strptime(
                row["timestamp"].decode("utf-8"), "%Y-%m-%d"
            )
            if timestamp < datetime.datetime(2020, 3, 11):
                assert row["dead_person_ids"] % 2 == 0
            else:
                assert row["dead_person_ids"] % 2 == 1


from pathlib import Path
from collections import defaultdict

import random
import tables
import numpy as np
import pytest
import pandas as pd

from june import paths
from june.demography import Person, Population, Activities
from june.geography import Geography
from june.groups import Hospital, School, Company, Household, University
from june.groups import (
    Hospitals,
    Schools,
    Companies,
    Households,
    Universities,
    Cemeteries,
)
from june.epidemiology.infection import (
    SymptomTag,
    Immunity,
    InfectionSelectors,
    InfectionSelector,
)
from june.interaction import Interaction
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection_seed import InfectionSeed
from june.policy import Policies, Hospitalisation
from june.simulator import Simulator
from june.world import World
from june.records import Record

path_pwd = Path(__file__)
dir_pwd = path_pwd.parent

test_config = paths.configs_path / "tests/test_simulator_no_leisure.yaml"
interaction_config = paths.configs_path / "tests/interaction.yaml"


def clean_world(world):
    for person in world.people:
        person.infection = None
        person.dead = False
        person.immunity = Immunity()
        person.subgroups.medical_facility = None
    for hospital in world.hospitals:
        hospital.ward_ids = set()
        hospital.icu_ids = set()


class MockHealthIndexGenerator:
    def __init__(self, desired_symptoms):
        self.index = desired_symptoms

    def __call__(self, person, infection_id):
        hi = np.ones(8)
        for h in range(len(hi)):
            if h < self.index:
                hi[h] = 0
        return hi


def make_selector(
    desired_symptoms,
):
    health_index_generator = MockHealthIndexGenerator(desired_symptoms)
    selector = InfectionSelector(health_index_generator=health_index_generator)
    return selector


def infect_hospitalised_person(person):
    max_symptom_tag = random.choice(
        [SymptomTag.hospitalised, SymptomTag.intensive_care]
    )
    selector = make_selector(desired_symptoms=max_symptom_tag)
    selector.infect_person_at_time(person, 0.0)


def infect_dead_person(person):
    max_symptom_tag = random.choice(
        [SymptomTag.dead_home, SymptomTag.dead_hospital, SymptomTag.dead_icu]
    )
    selector = make_selector(desired_symptoms=max_symptom_tag)
    selector.infect_person_at_time(person, 0.0)


@pytest.fixture(name="selector", scope="module")
def create_selector(health_index_generator):
    selector = InfectionSelector(
        paths.configs_path / "defaults/epidemiology/infection/transmission/XNExp.yaml",
        health_index_generator=health_index_generator,
    )
    selector.recovery_rate = 1.0
    selector.transmission_probability = 1.0
    return selector


@pytest.fixture(name="interaction", scope="module")
def create_interaction():
    interaction = Interaction.from_file(config_filename=interaction_config)
    interaction.betas["school"] = 0.8
    interaction.betas["cinema"] = 0.0
    interaction.betas["pub"] = 0.0
    interaction.betas["household"] = 10.0
    interaction.alpha_physical = 2.7
    return interaction


@pytest.fixture(name="geog", scope="module")
def create_geography():
    geog = Geography.from_file(filter_key={"area": ["E00000001"]})
    return geog


@pytest.fixture(name="world", scope="module")
def make_dummy_world(geog):
    super_area = geog.super_areas.members[0]
    company = Company(super_area=super_area, n_workers_max=100, sector="Q")

    household1 = Household()
    household1.area = super_area.areas[0]
    hospital = Hospital(
        n_beds=40,
        n_icu_beds=5,
        area=geog.areas.members[0],
        coordinates=super_area.coordinates,
    )
    uni = University(coordinates=super_area.coordinates, n_students_max=2500)

    worker1 = Person.from_attributes(age=44, sex="f", ethnicity="A1")
    worker1.area = super_area.areas[0]
    household1.add(worker1, subgroup_type=household1.SubgroupType.adults)
    worker1.sector = "Q"
    company.add(worker1)

    worker2 = Person.from_attributes(age=42, sex="m", ethnicity="B1")
    worker2.area = super_area.areas[0]
    household1.add(worker2, subgroup_type=household1.SubgroupType.adults)
    worker2.sector = "Q"
    company.add(worker2)

    student1 = Person.from_attributes(age=20, sex="f", ethnicity="A1")
    student1.area = super_area.areas[0]
    household1.add(student1, subgroup_type=household1.SubgroupType.adults)
    uni.add(student1)

    pupil1 = Person.from_attributes(age=8, sex="m", ethnicity="C1")
    pupil1.area = super_area.areas[0]
    household1.add(pupil1, subgroup_type=household1.SubgroupType.kids)
    # school.add(pupil1)

    pupil2 = Person.from_attributes(age=5, sex="f", ethnicity="A1")
    pupil2.area = super_area.areas[0]
    household1.add(pupil2, subgroup_type=household1.SubgroupType.kids)
    # school.add(pupil2)

    world = World()
    world.schools = Schools([School()])
    world.hospitals = Hospitals([hospital])
    world.households = Households([household1])
    world.universities = Universities([uni])
    world.companies = Companies([company])
    world.people = Population([worker1, worker2, student1, pupil1, pupil2])
    world.regions = geog.regions
    world.super_areas = geog.super_areas
    world.areas = geog.areas
    world.cemeteries = Cemeteries()
    world.areas[0].people = world.people
    world.super_areas[0].closest_hospitals = [world.hospitals[0]]
    return world


def create_sim(world, interaction, selector, seed=False):

    record = Record(record_path="results")
    policies = Policies(
        [Hospitalisation(start_time="1000-01-01", end_time="9999-01-01")]
    )
    infection_seed = InfectionSeed.from_uniform_cases(
        world=world,
        infection_selector=selector,
        cases_per_capita=2 / len(world.people),
        date="2020-03-01",
        seed_past_infections=False,
    )
    if not seed:
        infection_seed.unleash_virus_per_day(
            time=0.0, date=pd.to_datetime("2020-03-01"), record=record
        )
    elif seed == "hospitalised":
        for person in world.people:
            infect_hospitalised_person(person)
    else:
        for person in world.people:
            infect_dead_person(person)

    selectors = InfectionSelectors([selector])
    epidemiology = Epidemiology(infection_selectors=selectors)
    sim = Simulator.from_file(
        world=world,
        interaction=interaction,
        epidemiology=epidemiology,
        config_filename=test_config,
        policies=policies,
        record=record,
    )
    return sim


def test__log_infected(world, interaction, selector):
    clean_world(world)
    sim = create_sim(world, interaction, selector)
    infections_seed = [person.id for person in world.people.infected]
    sim.timer.reset()
    counter = 0
    new_infected = {}
    already_infected = [person.id for person in world.people.infected]
    while counter < 10:
        time = sim.timer.date.strftime("%Y-%m-%d")
        sim.do_timestep()
        current_infected = [
            person.id
            for person in world.people.infected
            if person.id not in already_infected
        ]
        new_infected[time] = current_infected
        next(sim.timer)
        counter += 1
        already_infected += current_infected

    with tables.open_file(sim.record.record_path / sim.record.filename, mode="r") as f:
        table = f.root.infections
        df = pd.DataFrame.from_records(table.read())
    df["timestamp"] = df["timestamp"].str.decode("utf-8")
    df["location_specs"] = df["location_specs"].str.decode("utf-8")
    df.set_index("timestamp", inplace=True)
    assert set(df.loc["2020-03-01"]["infector_ids"].values) == set(infections_seed)
    assert set(df.loc["2020-03-01"]["infected_ids"].values) == set(infections_seed)
    assert set(df.loc["2020-03-01"]["location_ids"].values) == set([0, 0])
    assert set(df.loc["2020-03-01"]["location_specs"].values) == set(
        ["infection_seed", "infection_seed"]
    )
    for timestamp in list(new_infected.keys())[1:]:
        if new_infected[timestamp]:
            if type(df.loc[timestamp]["infected_ids"]) is np.int32:
                assert df.loc[timestamp]["infected_ids"] == new_infected[timestamp]
                assert df.loc[timestamp]["infector_ids"] != new_infected[timestamp]
            else:
                assert set(df.loc[timestamp]["infected_ids"].values) == set(
                    new_infected[timestamp]
                )
    df.iloc[2:]["location_ids"].values == [world.households[0].id] * len(df.iloc[2:])
    df.iloc[2:]["location_specs"].values == ["household"] * len(df.iloc[2:])


def test__log_hospital_admissions(world, interaction, selector):
    clean_world(world)
    sim = create_sim(world, interaction, selector, seed="hospitalised")
    sim.timer.reset()
    counter = 0
    saved_ids, discharged_ids = [], []
    hospital_admissions, hospital_discharges = {}, {}
    while counter < 50:
        timer = sim.timer.date.strftime("%Y-%m-%d")
        daily_hosps_ids, daily_discharges_ids = [], []
        sim.epidemiology.update_health_status(
            sim.world, sim.timer.now, sim.timer.duration, record=sim.record
        )
        for person in world.people.infected:
            if person.medical_facility is not None and person.id not in saved_ids:
                daily_hosps_ids.append(person.id)
                saved_ids.append(person.id)
        for person in world.people:
            if (
                person.medical_facility is None
                and person.id in saved_ids
                and person.id not in discharged_ids
            ):
                daily_discharges_ids.append(person.id)
                discharged_ids.append(person.id)
        hospital_admissions[timer] = daily_hosps_ids
        hospital_discharges[timer] = daily_discharges_ids
        sim.record.time_step(timestamp=sim.timer.date)
        next(sim.timer)
        counter += 1
    with tables.open_file(sim.record.record_path / sim.record.filename, mode="r") as f:
        table = f.root.hospital_admissions
        admissions_df = pd.DataFrame.from_records(table.read())
        table = f.root.discharges
        discharges_df = pd.DataFrame.from_records(table.read())
    admissions_df["timestamp"] = admissions_df["timestamp"].str.decode("utf-8")
    admissions_df.set_index("timestamp", inplace=True)
    discharges_df["timestamp"] = discharges_df["timestamp"].str.decode("utf-8")
    discharges_df.set_index("timestamp", inplace=True)

    for timestamp in hospital_admissions.keys():
        if hospital_admissions[timestamp]:
            if type(admissions_df.loc[timestamp]["patient_ids"]) is np.int32:
                assert (
                    admissions_df.loc[timestamp]["patient_ids"]
                    == hospital_admissions[timestamp]
                )
            else:
                assert set(admissions_df.loc[timestamp]["patient_ids"].values) == set(
                    hospital_admissions[timestamp]
                )
        if hospital_discharges[timestamp]:
            if type(discharges_df.loc[timestamp]["patient_ids"]) is np.int32:
                assert (
                    discharges_df.loc[timestamp]["patient_ids"]
                    == hospital_discharges[timestamp]
                )
            else:
                assert set(discharges_df.loc[timestamp]["patient_ids"].values) == set(
                    hospital_admissions[timestamp]
                )
    clean_world(world)


def test__log_icu_admissions(world, interaction, selector):
    clean_world(world)
    sim = create_sim(world, interaction, selector, seed="hospitalised")
    sim.timer.reset()
    counter = 0
    saved_ids = []
    icu_admissions = {}
    while counter < 50:
        timer = sim.timer.date.strftime("%Y-%m-%d")
        daily_icu_ids = []
        sim.epidemiology.update_health_status(
            sim.world, sim.timer.now, sim.timer.duration, record=sim.record
        )
        for person in world.people.infected:
            if (
                person.infection.symptoms.tag == SymptomTag.intensive_care
                and person.id not in saved_ids
            ):
                daily_icu_ids.append(person.id)
                saved_ids.append(person.id)
        icu_admissions[timer] = daily_icu_ids
        sim.record.time_step(timestamp=sim.timer.date)
        next(sim.timer)
        counter += 1
    with tables.open_file(sim.record.record_path / sim.record.filename, mode="r") as f:
        table = f.root.icu_admissions
        admissions_df = pd.DataFrame.from_records(table.read())
    admissions_df["timestamp"] = admissions_df["timestamp"].str.decode("utf-8")
    admissions_df.set_index("timestamp", inplace=True)
    for timestamp in icu_admissions.keys():
        if icu_admissions[timestamp]:
            if type(admissions_df.loc[timestamp]["patient_ids"]) is np.int32:
                assert (
                    admissions_df.loc[timestamp]["patient_ids"]
                    == icu_admissions[timestamp]
                )
            else:
                assert set(admissions_df.loc[timestamp]["patient_ids"].values) == set(
                    icu_admissions[timestamp]
                )
    clean_world(world)


def test__symptoms_transition(world, interaction, selector):
    sim = create_sim(world, interaction, selector, seed="dead")
    sim.timer.reset()
    counter = 0
    ids_transition, symptoms_transition = {}, {}
    symptoms = defaultdict(int)
    while counter < 20:
        timer = sim.timer.date.strftime("%Y-%m-%d")
        daily_transitions_ids, daily_transitions_symptoms = [], []
        sim.epidemiology.update_health_status(
            sim.world, sim.timer.now, sim.timer.duration, record=sim.record
        )
        for person in world.people.infected:
            symptoms_tag = person.infection.symptoms.tag.value
            if symptoms_tag != symptoms[person.id]:
                daily_transitions_ids.append(person.id)
                daily_transitions_symptoms.append(symptoms_tag)
            symptoms[person.id] = symptoms_tag
        ids_transition[timer] = daily_transitions_ids
        symptoms_transition[timer] = daily_transitions_symptoms
        sim.record.time_step(timestamp=sim.timer.date)
        next(sim.timer)
        counter += 1
    with tables.open_file(sim.record.record_path / sim.record.filename, mode="r") as f:
        table = f.root.symptoms
        df = pd.DataFrame.from_records(table.read())
    df["timestamp"] = df["timestamp"].str.decode("utf-8")
    df.set_index("timestamp", inplace=True)
    df = df.loc[~df.new_symptoms.isin([5, 6, 7])]
    for timestamp in list(ids_transition.keys())[1:]:
        if ids_transition[timestamp]:
            if type(df.loc[timestamp]["infected_ids"]) is np.int32:
                assert df.loc[timestamp]["infected_ids"] == ids_transition[timestamp]
                assert (
                    df.loc[timestamp]["new_symptoms"] == symptoms_transition[timestamp]
                )
            else:
                assert set(df.loc[timestamp]["infected_ids"].values) == set(
                    ids_transition[timestamp]
                )
                assert set(df.loc[timestamp]["new_symptoms"].values) == set(
                    symptoms_transition[timestamp]
                )

    clean_world(world)


def test__log_deaths(world, interaction, selector):
    for person in world.people:
        person.subgroups = Activities(
            world.households[0].subgroups[0], None, None, None, None, None
        )
    sim = create_sim(world, interaction, selector, seed="dead")
    sim.timer.reset()
    counter = 0
    saved_ids = []
    deaths = {}
    while counter < 50:
        timer = sim.timer.date.strftime("%Y-%m-%d")
        daily_deaths_ids = []
        sim.epidemiology.update_health_status(
            sim.world, sim.timer.now, sim.timer.duration, record=sim.record
        )
        for person in world.people:
            if person.dead and person.id not in saved_ids:
                daily_deaths_ids.append(person.id)
                saved_ids.append(person.id)
        deaths[timer] = daily_deaths_ids
        sim.record.time_step(timestamp=sim.timer.date)
        next(sim.timer)
        counter += 1
    with tables.open_file(sim.record.record_path / sim.record.filename, mode="r") as f:
        table = f.root.deaths
        df = pd.DataFrame.from_records(table.read())
    df["timestamp"] = df["timestamp"].str.decode("utf-8")
    df.set_index("timestamp", inplace=True)
    for timestamp in deaths.keys():
        if deaths[timestamp]:
            if type(df.loc[timestamp]["dead_person_ids"]) is np.int32:
                assert df.loc[timestamp]["dead_person_ids"] == deaths[timestamp]
            else:
                assert set(df.loc[timestamp]["dead_person_ids"].values) == set(
                    deaths[timestamp]
                )
    clean_world(world)


from pathlib import Path

import random
import numpy as np
import pytest
import pandas as pd
from june import paths
from june.demography import Person, Population
from june.geography import Geography
from june.groups import Hospital, School, Company, Household, University
from june.groups import (
    Hospitals,
    Schools,
    Companies,
    Households,
    Universities,
    Cemeteries,
)
from june.interaction import Interaction
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection import (
    InfectionSelector,
    InfectionSelectors,
    SymptomTag,
    Immunity,
)
from june.epidemiology.infection_seed import InfectionSeed
from june.policy import Policies, Hospitalisation
from june.simulator import Simulator
from june.world import World
from june.records import Record, RecordReader

path_pwd = Path(__file__)
dir_pwd = path_pwd.parent
test_config = paths.configs_path / "tests/test_simulator_no_leisure.yaml"
interaction_config = paths.configs_path / "tests/interaction.yaml"


def clean_world(world):
    for person in world.people:
        person.infection = None
        person.immunity = Immunity()


class MockHealthIndexGenerator:
    def __init__(self, desired_symptoms):
        self.index = desired_symptoms

    def __call__(self, person, infection_id):
        hi = np.ones(8)
        for h in range(len(hi)):
            if h < self.index:
                hi[h] = 0
        return hi


def make_selector(
    desired_symptoms,
):
    health_index_generator = MockHealthIndexGenerator(desired_symptoms)
    selector = InfectionSelector(health_index_generator=health_index_generator)
    return selector


def infect_hospitalised_person(person):
    max_symptom_tag = random.choice(
        [SymptomTag.hospitalised, SymptomTag.intensive_care]
    )
    selector = make_selector(desired_symptoms=max_symptom_tag)
    selector.infect_person_at_time(person, 0.0)


def infect_dead_person(person):
    max_symptom_tag = random.choice(
        [SymptomTag.dead_home, SymptomTag.dead_hospital, SymptomTag.dead_icu]
    )
    selector = make_selector(desired_symptoms=max_symptom_tag)
    selector.infect_person_at_time(person, 0.0)


@pytest.fixture(name="selector", scope="module")
def create_selector(health_index_generator):
    selector = InfectionSelector(
        paths.configs_path / "defaults/epidemiology/infection/transmission/XNExp.yaml",
        health_index_generator=health_index_generator,
    )
    selector.recovery_rate = 1.0
    selector.transmission_probability = 1.0
    return selector


@pytest.fixture(name="interaction", scope="module")
def create_interaction():
    interaction = Interaction.from_file(config_filename=interaction_config)
    interaction.betas["school"] = 0.8
    interaction.betas["cinema"] = 0.0
    interaction.betas["pub"] = 0.0
    interaction.betas["household"] = 10.0
    interaction.alpha_physical = 2.7
    return interaction


@pytest.fixture(name="geog", scope="module")
def create_geography():
    geog = Geography.from_file(filter_key={"area": ["E00000001"]})
    return geog


@pytest.fixture(name="world", scope="module")
def make_dummy_world(geog):
    super_area = geog.super_areas.members[0]
    company = Company(super_area=super_area, n_workers_max=100, sector="Q")

    household1 = Household()
    household1.area = super_area.areas[0]
    hospital = Hospital(
        n_beds=40,
        n_icu_beds=5,
        area=geog.areas.members[0],
        coordinates=super_area.coordinates,
    )
    uni = University(coordinates=super_area.coordinates, n_students_max=2500)

    worker1 = Person.from_attributes(age=44, sex="f", ethnicity="A1")
    worker1.area = super_area.areas[0]
    household1.add(worker1, subgroup_type=household1.SubgroupType.adults)
    worker1.sector = "Q"
    company.add(worker1)

    worker2 = Person.from_attributes(age=42, sex="m", ethnicity="B1")
    worker2.area = super_area.areas[0]
    household1.add(worker2, subgroup_type=household1.SubgroupType.adults)
    worker2.sector = "Q"
    company.add(worker2)

    student1 = Person.from_attributes(age=20, sex="f", ethnicity="A1")
    student1.area = super_area.areas[0]
    household1.add(student1, subgroup_type=household1.SubgroupType.adults)
    uni.add(student1)

    pupil1 = Person.from_attributes(age=8, sex="m", ethnicity="C1")
    pupil1.area = super_area.areas[0]
    household1.add(pupil1, subgroup_type=household1.SubgroupType.kids)
    # school.add(pupil1)

    pupil2 = Person.from_attributes(age=5, sex="f", ethnicity="A1")
    pupil2.area = super_area.areas[0]
    household1.add(pupil2, subgroup_type=household1.SubgroupType.kids)
    # school.add(pupil2)

    world = World()
    world.schools = Schools([School()])
    world.hospitals = Hospitals([hospital])
    world.households = Households([household1])
    world.universities = Universities([uni])
    world.companies = Companies([company])
    world.people = Population([worker1, worker2, student1, pupil1, pupil2])
    world.regions = geog.regions
    world.super_areas = geog.super_areas
    world.areas = geog.areas
    world.cemeteries = Cemeteries()
    world.areas[0].people = world.people
    world.super_areas[0].closest_hospitals = [world.hospitals[0]]
    return world


def create_sim(world, interaction, selector, seed=False):
    record = Record(record_path="results")
    policies = Policies(
        [Hospitalisation(start_time="1000-01-01", end_time="9999-01-01")]
    )
    infection_seed = InfectionSeed.from_uniform_cases(
        world=world,
        infection_selector=selector,
        cases_per_capita=2 / len(world.people),
        date="2020-03-01",
        seed_past_infections=False,
    )
    if not seed:
        infection_seed.unleash_virus_per_day(
            date=pd.to_datetime("2020-03-01"), time=0, record=record
        )
    elif seed == "hospitalised":
        for person in world.people:
            infect_hospitalised_person(person)
    else:
        for person in world.people:
            infect_dead_person(person)

    selectors = InfectionSelectors([selector])
    epidemiology = Epidemiology(infection_selectors=selectors)
    sim = Simulator.from_file(
        world=world,
        interaction=interaction,
        epidemiology=epidemiology,
        config_filename=test_config,
        policies=policies,
        record=record,
    )
    return sim


def test__log_infected_by_region(world, interaction, selector):
    clean_world(world)
    sim = create_sim(world, interaction, selector)
    sim.timer.reset()
    counter = 0
    new_infected = {}
    already_infected = [person.id for person in world.people.infected]
    while counter < 10:
        time = sim.timer.date.strftime("%Y-%m-%d")
        sim.do_timestep()
        current_infected = [
            person.id
            for person in world.people.infected
            if person.id not in already_infected
        ]
        new_infected[time] = current_infected
        next(sim.timer)
        counter += 1
        already_infected += current_infected
    read = RecordReader(results_path=sim.record.record_path)
    assert read.regional_summary.iloc[0]["daily_infected"] == 2  # seed
    for key in list(new_infected.keys())[1:]:
        if new_infected[key]:
            assert read.regional_summary.loc[key, "daily_infected"] == len(
                new_infected[key]
            )


def test__log_hospital_admissions(world, interaction, selector):
    clean_world(world)
    sim = create_sim(world, interaction, selector, seed="hospitalised")
    sim.timer.reset()
    counter = 0
    saved_ids = []
    hospital_admissions = {}
    while counter < 15:
        timer = sim.timer.date.strftime("%Y-%m-%d")
        daily_hosps_ids = []
        sim.epidemiology.update_health_status(
            sim.world, sim.timer.now, sim.timer.duration, record=sim.record
        )
        for person in world.people.infected:
            if person.medical_facility is not None and person.id not in saved_ids:
                daily_hosps_ids.append(person.id)
                saved_ids.append(person.id)
        hospital_admissions[timer] = daily_hosps_ids
        sim.record.summarise_time_step(timestamp=sim.timer.date, world=sim.world)
        sim.record.time_step(timestamp=sim.timer.date)
        next(sim.timer)
        counter += 1
    read = RecordReader(results_path=sim.record.record_path)
    for key in list(hospital_admissions.keys()):
        if hospital_admissions[key]:
            assert read.regional_summary.loc[key, "daily_hospitalised"] == len(
                hospital_admissions[key]
            )
            assert read.world_summary.loc[key, "daily_hospitalised"] == len(
                hospital_admissions[key]
            )
    clean_world(world)


def test__log_deaths(world, interaction, selector):
    sim = create_sim(world, interaction, selector, seed="dead")
    sim.timer.reset()
    counter = 0
    saved_ids = []
    deaths = {}
    while counter < 50:
        timer = sim.timer.date.strftime("%Y-%m-%d")
        daily_deaths_ids = []
        sim.epidemiology.update_health_status(
            sim.world, sim.timer.now, sim.timer.duration, record=sim.record
        )
        for person in world.people:
            if person.dead and person.id not in saved_ids:
                daily_deaths_ids.append(person.id)
                saved_ids.append(person.id)
        deaths[timer] = daily_deaths_ids
        sim.record.summarise_time_step(timestamp=sim.timer.date, world=sim.world)
        sim.record.time_step(timestamp=sim.timer.date)
        next(sim.timer)
        counter += 1
    read = RecordReader(results_path=sim.record.record_path)
    for key in list(deaths.keys()):
        if deaths[key]:
            assert read.regional_summary.loc[key, "daily_deaths"] == len(deaths[key])
    clean_world(world)


import datetime
import numpy as np
import pandas as pd
import yaml
import pytest

from tables import open_file
from june import paths
from june.records import Record
from june.groups import Hospital, Hospitals, Household, Households, CareHome, CareHomes
from june.policy import Policies
from june.time import Timer
from june.activity import ActivityManager
from june.demography import Person, Population
from june.interaction import Interaction
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection_seed import InfectionSeed, InfectionSeeds
from june.geography.geography import Areas, SuperAreas, Regions, Area, SuperArea, Region
from june.groups import Supergroup
from june import World

config_interaction = paths.configs_path / "tests/interaction.yaml"


@pytest.fixture(name="dummy_world", scope="module")
def create_dummy_world():
    # 2 regions, 2 hospitals, 1 care home 1 household
    regions = Regions([Region(name="region_1"), Region(name="region_2")])
    regions[0].super_areas = [
        SuperArea(name="super_1", coordinates=(0.0, 0.0), region=regions[0]),
        SuperArea(name="super_2", coordinates=(1.0, 1.0), region=regions[0]),
    ]
    regions[1].super_areas = [
        SuperArea(name="super_3", coordinates=(2.0, 2.0), region=regions[1])
    ]
    super_areas = SuperAreas(regions[0].super_areas + regions[1].super_areas)
    super_areas[0].areas = [
        Area(
            name="area_1",
            coordinates=(0.0, 0.0),
            super_area=super_areas[0],
            socioeconomic_index=0.01,
        ),
        Area(
            name="area_2",
            coordinates=(0.0, 0.0),
            super_area=super_areas[0],
            socioeconomic_index=0.02,
        ),
        Area(
            name="area_3",
            coordinates=(0.0, 0.0),
            super_area=super_areas[0],
            socioeconomic_index=0.03,
        ),
    ]
    super_areas[1].areas = [
        Area(
            name="area_4",
            coordinates=(0.0, 0.0),
            super_area=super_areas[1],
            socioeconomic_index=0.11,
        ),
        Area(
            name="area_5",
            coordinates=(0.0, 0.0),
            super_area=super_areas[1],
            socioeconomic_index=0.12,
        ),
    ]
    super_areas[2].areas = [
        Area(
            name="area_6",
            coordinates=(5, 5),
            super_area=super_areas[2],
            socioeconomic_index=0.90,
        )
    ]
    areas = Areas(super_areas[0].areas + super_areas[1].areas + super_areas[2].areas)
    households = Households([Household(area=super_areas[0].areas[0])])
    hospitals = Hospitals(
        [Hospital(n_beds=1, n_icu_beds=1, area=areas[5], coordinates=(0.0, 0.0))]
    )
    care_homes = CareHomes([CareHome(area=super_areas[0].areas[0])])
    world = World()
    world.areas = areas
    world.super_areas = super_areas
    world.regions = regions
    world.households = households
    world.hospitals = hospitals
    world.care_homes = care_homes
    world.people = [
        Person.from_attributes(id=0, age=0, ethnicity="A"),
        Person.from_attributes(id=1, age=1, ethnicity="B"),
        Person.from_attributes(id=2, age=2, sex="m", ethnicity="C"),
    ]
    world.people[0].area = super_areas[0].areas[0]  # household resident
    world.people[0].subgroups.primary_activity = hospitals[0].subgroups[0]
    world.people[0].subgroups.residence = households[0].subgroups[0]
    world.people[1].area = super_areas[0].areas[0]
    world.people[1].subgroups.residence = households[0].subgroups[0]
    world.people[2].area = super_areas[0].areas[1]  # care home resident
    world.people[2].subgroups.residence = care_homes[0].subgroups[0]
    return world


def test__writing_infections():
    record = Record(record_path="results")
    timestamp = datetime.datetime(2020, 10, 10)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        record.accumulate(
            table_name="infections",
            location_spec="care_home",
            region_name="made_up",
            location_id=0,
            infected_ids=[0, 10, 20],
            infector_ids=[5, 15, 25],
            infection_ids=[0, 0, 0],
        )
        record.events["infections"].record(hdf5_file=record.file, timestamp=timestamp)
        table = record.file.root.infections
        df = pd.DataFrame.from_records(table.read())
    assert len(df) == 3
    assert df.timestamp.unique()[0].decode() == "2020-10-10"
    assert df.location_ids.unique() == [0]
    assert df.region_names.unique() == [b"made_up"]
    assert df.location_specs.unique() == [b"care_home"]
    assert len(df.infected_ids) == 3
    assert df.infected_ids[0] == 0
    assert df.infector_ids[0] == 5
    assert df.infected_ids[1] == 10
    assert df.infector_ids[1] == 15
    assert df.infected_ids[2] == 20
    assert df.infector_ids[2] == 25
    assert df.infection_ids[0] == 0
    assert df.infection_ids[1] == 0
    assert df.infection_ids[2] == 0
    del df


def test__writing_hospital_admissions():
    record = Record(record_path="results")
    timestamp = datetime.datetime(2020, 4, 4)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        record.accumulate(
            table_name="hospital_admissions", hospital_id=0, patient_id=10
        )
        record.events["hospital_admissions"].record(
            hdf5_file=record.file, timestamp=timestamp
        )
        table = record.file.root.hospital_admissions
        df = pd.DataFrame.from_records(table.read())
    assert len(df) == 1
    assert df.timestamp.iloc[0].decode() == "2020-04-04"
    assert df.hospital_ids.iloc[0] == 0
    assert df.patient_ids.iloc[0] == 10


def test__writing_hospital_discharges():
    record = Record(record_path="results")
    timestamp = datetime.datetime(2020, 4, 4)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        record.accumulate(table_name="discharges", hospital_id=0, patient_id=10)
        record.events["discharges"].record(hdf5_file=record.file, timestamp=timestamp)
        table = record.file.root.discharges
        df = pd.DataFrame.from_records(table.read())
    assert len(df) == 1
    assert df.timestamp.iloc[0].decode() == "2020-04-04"
    assert df.hospital_ids.iloc[0] == 0
    assert df.patient_ids.iloc[0] == 10


def test__writing_intensive_care_admissions():
    record = Record(record_path="results")
    timestamp = datetime.datetime(2020, 4, 4)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        record.accumulate(table_name="icu_admissions", hospital_id=0, patient_id=10)
        record.events["icu_admissions"].record(
            hdf5_file=record.file, timestamp=timestamp
        )
        table = record.file.root.icu_admissions
        df = pd.DataFrame.from_records(table.read())
    assert len(df) == 1
    assert df.timestamp.iloc[0].decode() == "2020-04-04"
    assert df.hospital_ids.iloc[0] == 0
    assert df.patient_ids.iloc[0] == 10


def test__writing_death():
    record = Record(record_path="results")
    timestamp = datetime.datetime(2020, 4, 4)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        record.accumulate(
            table_name="deaths",
            location_spec="household",
            location_id=0,
            dead_person_id=10,
        )
        record.events["deaths"].record(hdf5_file=record.file, timestamp=timestamp)
        table = record.file.root.deaths
        df = pd.DataFrame.from_records(table.read())
    assert len(df) == 1
    assert df.timestamp.iloc[0].decode() == "2020-04-04"
    assert df.location_specs.iloc[0].decode() == "household"
    assert df.location_ids.iloc[0] == 0
    assert df.dead_person_ids.iloc[0] == 10


def test__static_people(dummy_world):
    record = Record(record_path="results", record_static_data=True)
    record.static_data(world=dummy_world)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        table = record.file.root.population
        df = pd.DataFrame.from_records(table.read(), index="id")
    str_cols = record.statics["people"].str_names
    for col in str_cols:
        df[col] = df[col].str.decode("utf-8")
    assert df.loc[0, "age"] == 0
    assert df.loc[1, "age"] == 1
    assert df.loc[2, "age"] == 2
    assert df.loc[0, "primary_activity_type"] == "hospital"
    assert df.loc[0, "primary_activity_id"] == dummy_world.hospitals[0].id
    assert df.loc[1, "primary_activity_type"] == "None"
    assert df.loc[1, "primary_activity_id"] == 0
    assert df.loc[1, "residence_type"] == "household"
    assert df.loc[1, "residence_id"] == dummy_world.households[0].id
    assert df.loc[2, "residence_type"] == "care_home"
    assert df.loc[2, "residence_id"] == dummy_world.care_homes[0].id
    assert df.loc[0, "ethnicity"] == "A"
    assert df.loc[1, "ethnicity"] == "B"
    assert df.loc[2, "ethnicity"] == "C"
    assert df.loc[0, "sex"] == "f"
    assert df.loc[2, "sex"] == "m"


def test__static_with_extras_people(dummy_world):
    record = Record(record_path="results", record_static_data=True)
    tonto = [0.1, 1.3, 5.0]
    listo = [0.9, 0.7, 0.0]
    vaccine_type = [0, 1, 2]
    vaccine_name = ["astra", "pfizer", "moderna"]
    record.statics["people"].extra_float_data["tonto"] = tonto
    record.statics["people"].extra_float_data["listo"] = listo
    record.statics["people"].extra_int_data["vaccine_type"] = vaccine_type
    record.statics["people"].extra_str_data["vaccine_name"] = vaccine_name
    record.static_data(world=dummy_world)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        table = record.file.root.population
        df = pd.DataFrame.from_records(table.read(), index="id")
    str_cols = record.statics["people"].str_names
    for col in str_cols:
        df[col] = df[col].str.decode("utf-8")
    assert df.loc[0, "age"] == 0
    assert df.loc[1, "age"] == 1
    assert df.loc[2, "age"] == 2
    assert df.loc[0, "primary_activity_type"] == "hospital"
    assert df.loc[0, "primary_activity_id"] == dummy_world.hospitals[0].id
    assert df.loc[1, "primary_activity_type"] == "None"
    assert df.loc[1, "primary_activity_id"] == 0
    assert df.loc[1, "residence_type"] == "household"
    assert df.loc[1, "residence_id"] == dummy_world.households[0].id
    assert df.loc[2, "residence_type"] == "care_home"
    assert df.loc[2, "residence_id"] == dummy_world.care_homes[0].id
    assert df.loc[0, "ethnicity"] == "A"
    assert df.loc[1, "ethnicity"] == "B"
    assert df.loc[2, "ethnicity"] == "C"
    assert df.loc[0, "sex"] == "f"
    assert df.loc[2, "sex"] == "m"
    assert len(df["tonto"].values) == len(tonto)
    assert all([pytest.approx(a) == b for a, b in zip(df["tonto"].values, tonto)])
    assert len(df["listo"].values) == len(listo)
    assert all([pytest.approx(a) == b for a, b in zip(df["listo"].values, listo)])
    assert len(df["vaccine_type"].values) == len(vaccine_type)
    assert all(
        [pytest.approx(a) == b for a, b in zip(df["vaccine_type"].values, vaccine_type)]
    )
    assert len(df["vaccine_name"].values) == len(vaccine_name)
    assert all(
        [pytest.approx(a) == b for a, b in zip(df["vaccine_name"].values, vaccine_name)]
    )


def test__static_location(dummy_world):
    record = Record(record_path="results", record_static_data=True)
    record.static_data(world=dummy_world)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        table = record.file.root.locations
        df = pd.DataFrame.from_records(table.read(), index="id")
    location_types, group_ids = [], []
    for attribute, value in dummy_world.__dict__.items():
        if isinstance(value, Supergroup):
            for group in getattr(dummy_world, attribute):
                location_types.append(group.spec)
                group_ids.append(group.id)

    for index, row in df.iterrows():
        assert row["spec"].decode() == location_types[index]
        assert (
            getattr(dummy_world, location_types[index] + "s")
            .get_from_id(group_ids[index])
            .area.id
            == row["area_id"]
        )
        assert group_ids[index] == row["group_id"]
        if index == 2:
            assert dummy_world.areas.get_from_id(row["area_id"]).name == "area_6"
    assert len(df) == len(dummy_world.households) + len(dummy_world.care_homes) + len(
        dummy_world.hospitals
    )


def test__static_geography(dummy_world):
    record = Record(record_path="results", record_static_data=True)
    record.static_data(world=dummy_world)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        table = record.file.root.areas
        area_df = pd.DataFrame.from_records(table.read(), index="id")
        assert len(area_df) == len(dummy_world.areas)
        table = record.file.root.super_areas
        super_area_df = pd.DataFrame.from_records(table.read(), index="id")
        assert len(super_area_df) == len(dummy_world.super_areas)
        table = record.file.root.regions
        region_df = pd.DataFrame.from_records(table.read(), index="id")
    assert len(region_df) == len(dummy_world.regions)
    for area in dummy_world.areas:
        assert (
            area.super_area.name
            == super_area_df.loc[area_df.loc[area.id].super_area_id, "name"].decode()
        )
        assert np.isclose(
            area.socioeconomic_index, area_df.loc[area.id]["socioeconomic_index"]
        )

    for super_area in dummy_world.super_areas:
        assert (
            super_area.region.name
            == region_df.loc[
                super_area_df.loc[super_area.id].region_id, "name"
            ].decode()
        )


def test__sumarise_time_tep(dummy_world):
    dummy_world.people = Population(dummy_world.people)

    record = Record(record_path="results")
    timestamp = datetime.datetime(2020, 4, 4)
    with open_file(record.record_path / record.filename, mode="a") as f:
        record.file = f
        record.accumulate(
            table_name="infections",
            location_spec="care_home",
            region_name="region_1",
            location_id=dummy_world.care_homes[0].id,
            infected_ids=[2],
            infector_ids=[0],
            infection_ids=[0],
        )
        record.accumulate(
            table_name="infections",
            location_spec="household",
            region_name="region_1",
            location_id=dummy_world.households[0].id,
            infected_ids=[0],
            infector_ids=[5],
            infection_ids=[0],
        )
        record.accumulate(
            table_name="hospital_admissions",
            hospital_id=dummy_world.hospitals[0].id,
            patient_id=1,
        )
        record.accumulate(
            table_name="icu_admissions",
            hospital_id=dummy_world.hospitals[0].id,
            patient_id=1,
        )
        record.summarise_time_step(timestamp, dummy_world)
    record.time_step(timestamp)
    timestamp = datetime.datetime(2020, 4, 5)
    record.accumulate(
        table_name="deaths",
        location_spec="care_home",
        location_id=dummy_world.care_homes[0].id,
        dead_person_id=2,
    )
    record.accumulate(
        table_name="deaths",
        location_spec="household",
        location_id=dummy_world.households[0].id,
        dead_person_id=0,
    )
    record.accumulate(
        table_name="deaths",
        location_spec="hospital",
        location_id=dummy_world.hospitals[0].id,
        dead_person_id=1,
    )
    record.summarise_time_step(timestamp, dummy_world)
    record.time_step(timestamp)
    summary_df = pd.read_csv(record.record_path / "summary.csv", index_col=0)
    region_1 = summary_df[summary_df["region"] == "region_1"]
    region_2 = summary_df[summary_df["region"] == "region_2"]
    assert region_1.loc["2020-04-04"]["daily_infected"] == 2
    assert region_1.loc["2020-04-05"]["daily_infected"] == 0
    assert region_2.loc["2020-04-04"]["daily_infected"] == 0
    assert region_2.loc["2020-04-05"]["daily_infected"] == 0

    assert region_1.loc["2020-04-04"]["daily_hospitalised"] == 0
    assert region_2.loc["2020-04-04"]["daily_hospitalised"] == 1
    assert region_2.loc["2020-04-04"]["daily_intensive_care"] == 1
    assert region_1.loc["2020-04-05"]["daily_hospitalised"] == 0
    assert region_1.loc["2020-04-05"]["daily_intensive_care"] == 0
    assert region_2.loc["2020-04-05"]["daily_intensive_care"] == 0

    assert region_1.loc["2020-04-05"]["daily_deaths"] == 3
    assert region_2.loc["2020-04-05"]["daily_deaths"] == 0

    assert region_2.loc["2020-04-05"]["daily_hospital_deaths"] == 1


def test__meta_information():
    record = Record(record_path="results")
    comment = "I love passing tests"
    record.meta_information(comment=comment, random_state=0, number_of_cores=20)
    with open(record.record_path / "config.yaml") as file:
        parameters = yaml.load(file, Loader=yaml.FullLoader)
    assert parameters["meta_information"]["user_comment"] == comment
    assert parameters["meta_information"]["random_state"] == 0
    assert parameters["meta_information"]["number_of_cores"] == 20


def test__parameters(dummy_world, selector, selectors):
    interaction = Interaction.from_file(config_filename=config_interaction)
    interaction.alpha_physical = 100.0
    infection_seed = InfectionSeed.from_uniform_cases(
        world=None,
        infection_selector=selector,
        seed_strength=0.0,
        cases_per_capita=0,
        date="2020-03-01",
        seed_past_infections=False,
    )
    infection_seeds = InfectionSeeds([infection_seed])
    infection_seed.min_date = datetime.datetime(2020, 10, 10)
    infection_seed.max_date = datetime.datetime(2020, 10, 11)

    policies = Policies.from_file()
    activity_manager = ActivityManager(
        world=dummy_world,
        policies=policies,
        timer=Timer(),
        all_activities=None,
        activity_to_super_groups={"residence": ["household"]},
    )
    record = Record(record_path="results")
    epidemiology = Epidemiology(
        infection_seeds=infection_seeds, infection_selectors=selectors
    )
    record.parameters(
        interaction=interaction,
        epidemiology=epidemiology,
        activity_manager=activity_manager,
    )
    with open(record.record_path / "config.yaml", "r") as file:
        parameters = yaml.load(file, Loader=yaml.FullLoader)

        # policies = policies.replace("array", "np.array")
        # policies = eval(policies)
    interaction_attributes = ["betas", "alpha_physical"]
    for attribute in interaction_attributes:
        assert parameters["interaction"][attribute] == getattr(interaction, attribute)
    for key, value in interaction.contact_matrices.items():
        np.testing.assert_equal(
            parameters["interaction"]["contact_matrices"][key], value
        )

    assert "Covid19" in parameters["infection_seeds"]
    seed_parameters = parameters["infection_seeds"]["Covid19"]
    assert seed_parameters["seed_strength"] == infection_seed.seed_strength
    assert seed_parameters["min_date"] == infection_seed.min_date.strftime("%Y-%m-%d")
    assert seed_parameters["max_date"] == infection_seed.max_date.strftime("%Y-%m-%d")
    assert "Covid19" in parameters["infections"]
    inf_parameters = parameters["infections"]["Covid19"]
    assert inf_parameters["transmission_type"] == selector.transmission_type


import pandas as pd
import numpy as np
import datetime
import os
from pathlib import Path

from june.groups import Hospitals, Hospital
from june.demography import Population, Person
from june.geography import Area, Areas, SuperArea, SuperAreas, Region, Regions
from june.groups import Households, Household
from june.world import World
from june.groups import Cemeteries
from june.policy import Policies
from june.interaction import Interaction
from june.simulator import Simulator
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection_seed import InfectionSeed
from june import paths

test_config = paths.configs_path / "tests/test_checkpoint_config.yaml"
config_interaction = paths.configs_path / "tests/interaction.yaml"


def _populate_areas(areas: Areas):
    people = Population()
    k = 0
    for area in areas:
        for i in range(12):
            ages = np.arange(0, 99, 5)
            person = Person.from_attributes(sex="f", age=ages[i], id=k)
            person.area = area
            k += 1
            area.people.append(person)
        people.extend(area.people)
    return people


def _create_households(areas: Areas):
    households = []
    for area in areas:
        k = 0
        for i in range(12):
            household = Household()
            household.add(area.people[k])
            households.append(household)
            k += 1
    return Households(households)


def create_world():
    areas = []
    super_areas = []
    for i in range(2):
        areass = []
        for j in range(5):
            area = Area()
            areass.append(area)
            areas.append(area)
        super_area = SuperArea(areas=areass, name="asd")
        for area in areass:
            area.super_area = super_area
        super_areas.append(super_area)
    areas = Areas(areas, ball_tree=False)
    super_areas = SuperAreas(super_areas, ball_tree=False)
    region = Region(super_areas=super_areas)
    for super_area in super_areas:
        super_area.region = region
    world = World()
    world.people = _populate_areas(areas)
    world.households = _create_households(areas)
    world.areas = areas
    world.super_areas = super_areas
    world.regions = Regions([region])
    world.hospitals = Hospitals(
        [Hospital(n_beds=1000, n_icu_beds=1000, area=None, coordinates=None)],
        ball_tree=False,
    )
    world.cemeteries = Cemeteries()
    return world


def run_simulator(selectors, test_results):
    world = create_world()
    interaction = Interaction.from_file(config_filename=config_interaction)
    policies = Policies([])
    epidemiology = Epidemiology(infection_selectors=selectors)
    sim = Simulator.from_file(
        world=world,
        interaction=interaction,
        epidemiology=epidemiology,
        config_filename=test_config,
        leisure=None,
        policies=policies,
        checkpoint_save_path=test_results / "checkpoint_tests",
    )
    seed = InfectionSeed.from_uniform_cases(
        sim.world,
        selectors[0],
        cases_per_capita=50 / len(world.people),
        date="2020-03-01",
        seed_past_infections=False,
    )
    seed.unleash_virus_per_day(time=0, date=pd.to_datetime("2020-03-01"))
    sim.run()
    return sim


class TestCheckpoints:
    def test__checkpoints_are_saved(self, selectors, test_results):
        checkpoint_folder = Path(test_results / "checkpoint_tests")
        checkpoint_folder.mkdir(exist_ok=True, parents=True)
        sim = run_simulator(selectors, test_results)
        assert len(sim.world.people.infected) > 0
        assert len(sim.world.people.dead) > 0
        fresh_world = create_world()
        interaction = Interaction.from_file(config_filename=config_interaction)
        policies = Policies([])
        epidemiology = Epidemiology(infection_selectors=selectors)
        sim_recovered = Simulator.from_checkpoint(
            world=fresh_world,
            checkpoint_load_path=checkpoint_folder / "checkpoint_2020-03-25.hdf5",
            interaction=interaction,
            epidemiology=epidemiology,
            config_filename=test_config,
            leisure=None,
            travel=None,
            policies=policies,
        )
        # check timer is correct
        assert sim_recovered.timer.initial_date == sim.timer.initial_date
        assert sim_recovered.timer.final_date == sim.timer.final_date
        assert sim_recovered.timer.now == sim.timer.now
        assert sim_recovered.timer.date.date() == datetime.datetime(2020, 3, 26).date()
        assert sim_recovered.timer.shift == sim.timer.shift
        assert sim_recovered.timer.delta_time == sim.timer.delta_time
        for person1, person2 in zip(sim.world.people, sim_recovered.world.people):
            assert person1.id == person2.id
            if person1.infection is not None:
                assert person2.infection is not None
                inf1 = person1.infection
                inf2 = person2.infection
                assert inf1.infection_id() == inf2.infection_id()
                assert inf1.start_time == inf1.start_time
                assert inf1.infection_probability == inf2.infection_probability
                assert inf1.transmission.probability == inf2.transmission.probability
                assert inf1.symptoms.tag == inf2.symptoms.tag
                assert inf1.symptoms.stage == inf2.symptoms.stage
                continue
            # assert person1.infected == person2.infected
            assert (
                person1.immunity.susceptibility_dict
                == person2.immunity.susceptibility_dict
            )
            assert person1.dead == person2.dead
        # clean up
        os.remove(checkpoint_folder / "checkpoint_2020-03-25.hdf5")
        # gotta delete, else it passes any time it should have failed...


class TestCheckpointForReseeding:
    """
    These tests the situation in which we load from checkpoint and
    want all the infections reseted.
    """

    def test__checkpoints_are_saved(self, selectors, test_results):
        checkpoint_folder = Path(test_results / "checkpoint_tests")
        checkpoint_folder.mkdir(exist_ok=True, parents=True)
        sim = run_simulator(selectors, test_results)
        assert len(sim.world.people.infected) > 0
        assert len(sim.world.people.dead) > 0
        epidemiology = Epidemiology(infection_selectors=selectors)
        fresh_world = create_world()
        interaction = Interaction.from_file(config_filename=config_interaction)
        policies = Policies([])
        sim_recovered = Simulator.from_checkpoint(
            world=fresh_world,
            checkpoint_load_path=checkpoint_folder / "checkpoint_2020-03-25.hdf5",
            interaction=interaction,
            epidemiology=epidemiology,
            config_filename=test_config,
            leisure=None,
            travel=None,
            policies=policies,
            reset_infections=True,
        )
        # check timer is correct
        assert sim_recovered.timer.initial_date == sim.timer.initial_date
        assert sim_recovered.timer.final_date == sim.timer.final_date
        assert sim_recovered.timer.now == sim.timer.now
        assert sim_recovered.timer.date.date() == datetime.datetime(2020, 3, 26).date()
        assert sim_recovered.timer.shift == sim.timer.shift
        assert sim_recovered.timer.delta_time == sim.timer.delta_time
        for person1, person2 in zip(sim.world.people, sim_recovered.world.people):
            assert person1.id == person2.id
            if person1.infection is not None:
                assert person2.infection is None
                continue
            assert person1.infected == person2.infected
            assert person1.dead == person2.dead
            assert (
                person1.immunity.susceptibility_dict
                == person2.immunity.susceptibility_dict
            )
        # clean up
        os.remove(checkpoint_folder / "checkpoint_2020-03-25.hdf5")
        # gotta delete, else it passes any time it should have failed...


import random
from datetime import datetime

import pytest
from june import paths
from june.activity import activity_hierarchy
from june.epidemiology.epidemiology import Epidemiology
from june.epidemiology.infection import Immunity, InfectionSelector, InfectionSelectors
from june.groups.leisure import leisure
from june.groups.travel import Travel
from june.interaction import Interaction
from june.policy import Hospitalisation, MedicalCarePolicies, Policies
from june.simulator import Simulator

constant_config = (
    paths.configs_path
    / "defaults/epidemiology/infection/transmission/TransmissionConstant.yaml"
)
interaction_config = paths.configs_path / "tests/interaction.yaml"
test_config = paths.configs_path / "tests/test_simulator.yaml"


@pytest.fixture(name="selectors", scope="module")
def make_selector(health_index_generator):
    selector = InfectionSelector(
        health_index_generator=health_index_generator,
        transmission_config_path=constant_config,
    )
    selector.recovery_rate = 0.05
    selector.transmission_probability = 0.7
    return InfectionSelectors([selector])
    return selector


@pytest.fixture(name="medical_policies")
def make_policies():
    policies = Policies([Hospitalisation()])
    return MedicalCarePolicies.get_active_policies(
        policies=policies, date=datetime(2020, 3, 1)
    )


@pytest.fixture(name="sim")
def setup_sim(dummy_world, selectors):
    world = dummy_world
    for person in world.people:
        person.immunity = Immunity()
        person.infection = None
        person.subgroups.medical_facility = None
        person.dead = False
    leisure_instance = leisure.generate_leisure_for_world(
        world=world,
        list_of_leisure_groups=["pubs", "cinemas", "groceries"],
        daytypes={
            "weekday": ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday"],
            "weekend": ["Saturday", "Sunday"],
        },
    )
    leisure_instance.distribute_social_venues_to_areas(
        world.areas, super_areas=world.super_areas
    )
    interaction = Interaction.from_file(config_filename=interaction_config)
    policies = Policies.from_file()
    epidemiology = Epidemiology(infection_selectors=selectors)
    travel = Travel()
    sim = Simulator.from_file(
        world=world,
        epidemiology=epidemiology,
        interaction=interaction,
        config_filename=test_config,
        leisure=leisure_instance,
        travel=travel,
        policies=policies,
    )
    sim.activity_manager.leisure.generate_leisure_probabilities_for_timestep(
        delta_time=3,
        working_hours=False,
        date=datetime.strptime("2020-03-01", "%Y-%m-%d"),
    )
    sim.clear_world()
    return sim


@pytest.fixture(name="health_index")
def create_health_index():
    def dummy_health_index(age, sex):
        return [0.1, 0.3, 0.5, 0.7, 0.9]

    return dummy_health_index


def test__everyone_has_an_activity(sim: Simulator):
    for person in sim.world.people.members:
        assert person.subgroups.iter().count(None) != len(person.subgroups.iter())


def test__apply_activity_hierarchy(sim: Simulator):
    unordered_activities = random.sample(activity_hierarchy, len(activity_hierarchy))
    ordered_activities = sim.activity_manager.apply_activity_hierarchy(
        unordered_activities
    )
    assert ordered_activities == activity_hierarchy


def test__activities_to_super_groups(sim: Simulator):
    activities = [
        "medical_facility",
        "commute",
        "primary_activity",
        "leisure",
        "residence",
    ]
    groups = sim.activity_manager.activities_to_super_groups(activities)

    assert groups == [
        "hospitals",
        "city_transports",
        "inter_city_transports",
        "schools",
        "companies",
        "universities",
        "pubs",
        "cinemas",
        "groceries",
        "household_visits",
        "care_home_visits",
        "households",
        "care_homes",
    ]


def test__clear_world(sim: Simulator):
    sim.clear_world()
    for group_name in sim.activity_manager.activities_to_super_groups(
        sim.activity_manager.all_activities
    ):
        if group_name in ["household_visits", "care_home_visits"]:
            continue
        grouptype = getattr(sim.world, group_name)
        for group in grouptype.members:
            for subgroup in group.subgroups:
                assert len(subgroup.people) == 0

    for person in sim.world.people.members:
        assert person.busy is False


def test__move_to_active_subgroup(sim: Simulator):
    sim.activity_manager.move_to_active_subgroup(
        ["residence"], sim.world.people.members[0]
    )
    assert sim.world.people.members[0].residence.group.spec in ("carehome", "household")


def test__move_people_to_residence(sim: Simulator):
    sim.activity_manager.move_people_to_active_subgroups(["residence"])
    for person in sim.world.people.members:
        assert person in person.residence.people
    sim.clear_world()


def test__move_people_to_leisure(sim: Simulator):
    n_leisure = 0
    n_cinemas = 0
    n_pubs = 0
    n_groceries = 0
    repetitions = 500
    for _ in range(repetitions):
        sim.clear_world()
        sim.activity_manager.move_people_to_active_subgroups(["leisure", "residence"])
        for person in sim.world.people.members:
            if person.leisure is not None:
                n_leisure += 1
                if person.leisure.group.spec == "care_home":
                    assert person.leisure.subgroup_type == 2  # visitors
                elif person.leisure.group.spec == "cinema":
                    n_cinemas += 1
                elif person.leisure.group.spec == "pub":
                    n_pubs += 1
                elif person.leisure.group.spec == "grocery":
                    n_groceries += 1
                if person not in person.residence.people:
                    assert person in person.leisure.people
    assert n_leisure > 0
    assert n_cinemas > 0
    assert n_pubs > 0
    assert n_groceries > 0
    sim.clear_world()


def test__move_people_to_primary_activity(sim: Simulator):
    sim.activity_manager.move_people_to_active_subgroups(
        ["primary_activity", "residence"]
    )
    for person in sim.world.people.members:
        if person.primary_activity is not None:
            assert person in person.primary_activity.people
    sim.clear_world()


def test__move_people_to_commute(sim: Simulator):
    sim.activity_manager.move_people_to_active_subgroups(["commute", "residence"])
    n_commuters = 0
    for person in sim.world.people.members:
        if person.commute is not None:
            n_commuters += 1
            assert person in person.commute.people
    assert n_commuters > 0
    sim.clear_world()


def test__bury_the_dead(sim: Simulator):
    dummy_person = sim.world.people.members[0]
    sim.epidemiology.infection_selectors.infect_person_at_time(dummy_person, 0.0)
    sim.epidemiology.bury_the_dead(sim.world, dummy_person)
    assert dummy_person in sim.world.cemeteries.members[0].people
    assert dummy_person.dead
    assert dummy_person.infection is None


import numpy as np
import datetime
import yaml
import pandas as pd
from pathlib import Path

import pytest

from june import paths
from june.tracker.tracker import Tracker
from june.time import Timer


from june.groups.group import make_subgroups

from june.geography import Geography
from june.groups.leisure import Pubs

from june.world import generate_world_from_geography

interaction_config = paths.configs_path / "tests/tracker/tracker_test_interaction.yaml"
test_config = paths.configs_path / "tests/tracker/tracker_test_config.yaml"


class TestTracker:
    @pytest.fixture(name="world", autouse=True, scope="class")
    def make_world(self):
        geography = Geography.from_file({"super_area": ["E02005103"]})
        world = generate_world_from_geography(geography, include_households=True)

        Pubs.get_interaction(interaction_config)
        world.pubs = Pubs.for_geography(geography)

        return world

    @pytest.fixture(name="tracker", autouse=True, scope="class")
    def setup_tracker(self, world):
        Pubs.get_interaction(interaction_config)
        world.pubs = Pubs.from_coordinates(
            np.array([pub.coordinates for pub in world.pubs]), world.super_areas
        )

        group_types = [world.pubs, world.households]

        tracker = Tracker(
            world=world,
            record_path=None,
            group_types=group_types,
            load_interactions_path=interaction_config,
            contact_sexes=["unisex", "male", "female"],
        )

        tracker.timer = Timer()
        tracker.timer.delta_time = datetime.timedelta(hours=1)
        return tracker

    def test__tracker_init(self, tracker):
        """"""
        # Check loaded in correct values from made up obscene values
        assert tracker.IM["pub"]["contacts"] == [[10]]
        assert tracker.IM["pub"]["proportion_physical"] == [[0.2]]
        assert tracker.IM["pub"]["type"] == "Age"
        assert tracker.IM["pub"]["bins"] == [1, 99]

        # Check functionality of calls from make_subgroups
        assert tracker.world.pubs[0].subgroup_bins == [1, 99]
        assert tracker.world.pubs[0].subgroup_type == "Age"
        assert tracker.world.pubs[0].subgroup_labels == ["A"]

        # Check the feed in groups we care about tracking
        assert sorted(tracker.group_type_names) == ["household", "pub"]

        # Check CM that are initialized
        assert sorted(tracker.CM["syoa"].keys()) == ["global", "household", "pub"]
        assert sorted(tracker.CM["syoa"]["global"].keys()) == [
            "female",
            "male",
            "unisex",
        ]

        # Check person contact counts
        assert len(tracker.contact_counts) == len(tracker.world.people)

    def test__intersection(self, tracker):
        # Check intersection of lists functionality
        assert sorted(
            tracker.intersection(
                ["A", "B", "C", "D"], ["C", "D", "E", "F", "D"], permute=False
            )
        ) == ["C", "D"]
        assert sorted(
            tracker.intersection(
                ["A", "B", "C", "D"], ["C", "D", "E", "F", "D"], permute=True
            )
        ) == ["C", "D"]

    def test__contractmatrix(self, tracker):
        # Check contract matrix functionality
        bins_syoa = np.arange(0, 101, 1)
        CM = np.ones((len(bins_syoa), len(bins_syoa)))

        assert np.array_equal(
            tracker.contract_matrix(CM, [0, 18, 100], method=np.sum),
            np.array([[18**2, (100 - 18) * 18], [(100 - 18) * 18, (100 - 18) ** 2]]),
        )
        assert np.array_equal(
            tracker.contract_matrix(CM, [0, 18, 100], method=np.mean),
            np.array([[1, 1], [1, 1]]),
        )

        assert np.array_equal(
            tracker.contract_matrix(CM, [0, 100], method=np.sum), np.array([[100**2]])
        )
        assert np.array_equal(
            tracker.contract_matrix(CM, [10, 90], method=np.sum), np.array([[80**2]])
        )

    def test__Probabilistic_Contacts(self, tracker):
        func = tracker.Probabilistic_Contacts
        SigTol = 2
        N = 100
        Mean = 10
        Error = 0

        Results = np.zeros(N)
        for i in range(N):
            Results[i] = func(Mean, Error)
        Errorless_STD = np.std(Results, ddof=1)

        # Make sure all non neg number of contacts
        assert all(i >= 0 for i in Results)
        # Make distribution is poisson under tolerance
        assert pytest.approx(np.mean(Results), abs=SigTol * np.sqrt(Mean / N)) == Mean

        Error = 5
        for i in range(N):
            Results[i] = func(Mean, Error)
        Errored_STD = np.std(Results, ddof=1)

        # Make sure all non neg number of contacts
        assert all(i >= 0 for i in Results)
        # Make distribution is poisson under tolerance
        assert pytest.approx(np.mean(Results), abs=SigTol * np.sqrt(Mean / N)) == Mean

        # Errored increase the variance of the results
        assert Errored_STD >= Errorless_STD

    def test__All(self, tracker):
        # Tests to make sure groups with 1 person have no contacts.
        found = False
        for group in tracker.world.households:
            if len(group.people) == 1:
                found = True
                break
        if found:
            tracker.simulate_All_contacts(group)
            CM_all_test = np.array(tracker.CMV["Interaction"]["household"])
            assert CM_all_test.sum() == 0.0

            tracker.simulate_1d_contacts(group)
            CM_1d_test = np.array(tracker.CM["Interaction"]["household"])
            assert CM_1d_test.sum() == 0.0

        # Tests to make sure groups with at least 2 people have some contacts.
        found = False
        for group in tracker.world.households:
            if len(group.people) > 5:
                found = True
                break
        if found:
            tracker.simulate_All_contacts(group)
            CM_all_test = np.array(tracker.CMV["Interaction"]["household"])
            assert CM_all_test.sum() > 0.0

            tracker.simulate_1d_contacts(group)
            CM_1d_test = np.array(tracker.CM["Interaction"]["household"])
            assert CM_1d_test.sum() > 0.0


def postprocess_functions(tracker: Tracker):
    tracker.contract_matrices("Interaction", np.array([]))
    tracker.convert_dict_to_df()
    tracker.calc_age_profiles()
    tracker.calc_average_contacts()
    tracker.normalize_contact_matrices()
    return tracker
