from collections import defaultdict
import numpy as np
import copy
import scipy.stats as stats
import json5

from organisation.env import config

from .models import biomarker_constant, biomarker_exponential, biomarker_sigmoid
from .models import risk_constant, risk_exponential

DICT_BIOMARKER_LEVEL = {
    "sigmoid": biomarker_sigmoid,
    "exponential": biomarker_exponential,
    "constant": biomarker_constant,
}

DICT_RISK_LEVEL = {"exponential": risk_exponential, "constant": risk_constant}


class Patient:
    def __init__(self, patient_id, env, wanted_characteristics={}):
        self.patient_id = patient_id

        env.event().succeed(
            {"type": "Patient hire", "Details": patient_id}
        )  # Register the patient in the environment (for logging purposes)

        self.drug_levels = defaultdict(
            list
        )  # drug_id -> (current_level, duration_treatment)
        self.drug_duration = defaultdict(lambda: 0)  # drug_id -> duration of treatment
        self.max_cumulative_dose = {}  # drug_id -> maximum cumulative dose received
        self.max_risk = defaultdict(
            lambda: 0
        )  # drug_id -> maximum risk reached for side effects
        # Hidden attributes
        self.phenotype = []  # Filled based on real stats
        self.env = env
        self.dict_drug = {}  # drug_id -> Drug object
        self.side_effects = []  # List of side effects experienced by the patient
        self.simulated = True

        self.initialize(wanted_characteristics)
        self.action = env.process(self.body_behavior())

    def clear_side_effects(self):
        """Clear the side effects experienced by the patient."""
        self.side_effects = []

    def initialize(self, wanted_characteristics):
        # Initial characteristics of the patient (they do not depend on other characteristics)
        with open(config.POPULATION_FILE, "r") as file:
            population_stats = json5.load(file)

        self.age = wanted_characteristics.get(
            "age",
            np.random.choice(
                population_stats["age"]["categories"],
                p=population_stats["age"]["probabilities"],
            ),
        )
        self.sex = wanted_characteristics.get(
            "sex",
            np.random.choice(
                population_stats["sex"]["categories"],
                p=population_stats["sex"]["probabilities"],
            ),
        )

        # Dependent characteristics
        self.diagnosis = []
        for diagnosis in wanted_characteristics.get("diagnosis", []):
            self.diagnosis.append(diagnosis)

        self.current_medications = wanted_characteristics.get("current_medications", [])
        self.allergies = wanted_characteristics.get("allergies", [])

        # Initialize biomarkers
        self.initialize_biomarkers()

    def initialize_biomarkers(self):
        """Initialize the patient's biomarkers."""
        self.initial_biomarkers = {}
        # Here we can initialize the biomarkers based on the patient's characteristics
        # C-LCL (in mg/dL)
        if "C-LDL" in self.diagnosis:
            self.initial_biomarkers["C-LDL"] = np.random.normal(220, 72)

        if "Temperature" in self.diagnosis:
            self.initial_biomarkers["Temperature"] = np.random.normal(39, 0.5)

        if "Liver Fat Content" in self.diagnosis:
            self.initial_biomarkers["Liver Fat Content"] = np.random.normal(22.3, 6.3)

        if "Complete remission" in self.diagnosis:
            self.initial_biomarkers["Complete remission"] = 0

        if "Partial or complete remission" in self.diagnosis:
            self.initial_biomarkers["Partial or complete remission"] = 0

        if "Tumor size" in self.diagnosis:
            self.initial_biomarkers["Tumor size"] = stats.truncnorm.pdf(
                37.5, 0, np.inf, loc=37.5, scale=1
            )

        self.biomarkers = copy.deepcopy(
            self.initial_biomarkers
        )  # Copy initial biomarkers to current biomarkers

    def simulate_side_effects(self, drug, list_active_dose):
        # Check for side effects based on drug duration in the body
        for side_effect in drug.side_effects:
            sum_dose = sum(dose for _, dose, _ in list_active_dose)

            # update the maximum cumulated dose
            self.max_cumulative_dose[drug.drug_id] = max(
                self.max_cumulative_dose.get(drug.drug_id, 0), sum_dose
            )

            if side_effect["time"] <= self.drug_duration[drug.drug_id]:
                # Compute risk of side effect based on dose and progression type
                if sum_dose >= side_effect["min_dose"]:
                    if side_effect["name"] in [selfx[0] for selfx in self.side_effects]:
                        # already had this side effect
                        risk = 1
                    else:
                        risk = (
                            DICT_RISK_LEVEL[side_effect["progression"]](
                                sum_dose, side_effect
                            )
                            - self.max_risk[side_effect["name"]]
                        )
                        if risk > 0:
                            self.max_risk[side_effect["name"]] += risk
                else:
                    risk = 0

                if np.random.rand() < risk:
                    self.side_effects.append((side_effect["name"], self.env.now))
                    self.env.event().succeed(
                        {
                            "type": "Side effect",
                            "Details": {
                                "patient_id": self.patient_id,
                                "side_effect": side_effect["name"],
                                "time": self.env.now,
                            },
                        }
                    )

    def simulate_biomarker(self, drug, list_active_dose):
        for dose_response in drug.dose_responses:
            biomarker = dose_response["biomarker"]
            # Simulate the evolution of the biomarker based on the drug's effect
            baseline = self.initial_biomarkers[biomarker["name"]]
            D = sum(
                dose
                for _, dose, time in list_active_dose
                if self.drug_duration[drug.drug_id] > dose_response["lag"]
            )
            if dose_response["type"] in DICT_BIOMARKER_LEVEL:
                new_level = DICT_BIOMARKER_LEVEL[dose_response["type"]](
                    D, dose_response, baseline
                )

            else:
                raise ValueError("Unknown dose response type for biomarker simulation")
            self.biomarkers[biomarker["name"]] = new_level

    def body_behavior(self):
        while self.simulated:
            # Update the patient's body behavior based on the drug levels and other factors
            for drug_id, list_active_dose in self.drug_levels.items():
                drug = self.dict_drug[drug_id]
                # update the drug duration
                self.drug_duration[drug_id] += 1

                list_removed_dose = []
                for i, (original_dose, current_dose, time) in enumerate(
                    list_active_dose
                ):
                    # Update each dose level in the body base on half-life
                    if (self.env.now - time) % drug.pharmacokinetics["half_life"] == 0:
                        # Simulate the half-life effect on the drug level
                        list_active_dose[i] = (original_dose, current_dose * 0.5, time)

                    if current_dose < 0.05 * original_dose:
                        # If the drug level is too low, remove it from the list
                        list_removed_dose.append(i)
                        self.max_cumulative_dose[drug_id] = (
                            0  # Reset the maximum cumulative dose when the drug is cleared
                        )

                # Remove doses that are too low
                for i in reversed(list_removed_dose):
                    del list_active_dose[i]

                # Simulate side effects
                self.simulate_side_effects(drug, list_active_dose)

                # Simulate the evolution of biomarkers
                self.simulate_biomarker(drug, list_active_dose)

            # Remove drugs that have no active doses
            for drug_id in list(self.drug_duration.keys()):
                if self.drug_levels[drug_id] == []:
                    del self.drug_levels[drug_id]
                    del self.drug_duration[drug_id]

            yield self.env.timeout(1)

    def receive_dose(self, drug, dose):
        # Here we would implement the logic to process the dose, e.g., pharmacokinetics, etc.
        # For now, we just simulate that the patient receives the dose.
        self.drug_levels[drug.drug_id].append(
            (dose, dose, self.env.now)
        )  # Store the dose and the time it was administered
        self.dict_drug[drug.drug_id] = drug

    def check_side_effects(self):
        return self.side_effects
