from __future__ import annotations

from .drug import Drug
from .patient import Patient
from .study import Study
from collections import defaultdict

import logging

logger = logging.getLogger(__name__)


class SingleArmStudy(Study):
    def __init__(
        self,
        env,
        study_id,
        drug: Drug,
        patient_list: list[Patient],
        dosage_sequence: dict,
        duration: float,
        period: float,
        biomarkers_list: list[str],
        responsible,
        approved: bool = False,
        ongoing: bool = False,
    ):
        self.type = "single_arm"
        self.approved = approved
        self.drug = drug
        self.ongoing = ongoing

        super().__init__(env, study_id, responsible)
        self.patient_list = patient_list

        self.dosage_sequence = {int(k): float(v) for k, v in dosage_sequence.items()}
        self.max_period = period
        self.period = period
        self.duration = duration
        self.biomarkers_list = biomarkers_list
        self.observed_side_effects = defaultdict(set)
        self.dosage_id = str(self.dosage_sequence)

    def get_current_dosage(self):
        """
        Get the current dosage for the study based on the dosage sequence.
        """
        current_time = self.env.now - self.start_date
        # print(self.dosage_sequence)
        for time, dosage in reversed(
            sorted(
                (x, y) for x, y in self.dosage_sequence.items() if isinstance(x, int)
            )
        ):
            # if self.drug.drug_id == "NIVO":
            # print("time", time, "dosage", dosage, "current_time", current_time)
            if current_time >= time:
                return dosage
        return 0

    def run_study(self):
        logger.info(
            f"Starting single arm study {self.study_id}: dosage={self.dosage_id} period={self.period} duration={self.duration}",
        )

        self.start_date = self.env.now

        print("start_date", self.start_date)

        self.ongoing = True
        # Save initial biomarker

        self.initial_biomarker = {}
        for biomarker in self.biomarkers_list:
            self.initial_biomarker[biomarker["name"]] = {}
            for patient in self.patient_list:
                self.initial_biomarker[biomarker["name"]][patient.patient_id] = (
                    patient.initial_biomarkers[biomarker["name"]]
                )

        while True:
            logger.info(f"Running study {self.study_id} at t={self.env.now:.1f}")
            try:
                logger.debug(
                    f"Investigator attention count: {self.responsible.attention.count}"
                )
            except Exception:
                pass
            current_dosage = self.get_current_dosage()
            for patient in self.patient_list:
                patient.receive_dose(self.drug, current_dosage)

            yield self.env.timeout(self.period)
            # Looking for side effects
            for patient in self.patient_list:
                side_effects_cycle_patient = patient.check_side_effects()
                for side_effect in side_effects_cycle_patient:
                    self.observed_side_effects[side_effect[0]].add(patient.patient_id)

            # Clear side effects for the next cycle (for debugging purposes)
            for patient in self.patient_list:
                patient.clear_side_effects()

            if self.env.now - self.start_date >= self.duration:
                logger.info(
                    f"Stopping study {self.study_id} due to reaching the duration limit."
                )
                break

        # Looking for biomarkers
        logger.info(
            "Biomarkers measured for %s: %s", self.study_id, self.biomarkers_list
        )
        results = {}
        for patient in self.patient_list:
            results[patient.patient_id] = {
                bm["name"]: patient.biomarkers[bm["name"]]
                for bm in self.biomarkers_list
            }

        self.results = results
        self.completion()

    def display_design(self):
        return f"Pending approval, Study {self.study_id} has been designed with dosage {self.dosage_id}, period {self.period}, duration {self.duration}."

    def display_start(self):
        return f"Study {self.study_id} is now monitored and has started. It will be completed at timestamp {self.start_date + self.duration}."

    def display_results(self):
        if not self.results:
            return f"Study {self.study_id} has no results yet."

        if not self.observed_side_effects:
            str_side_effect = "No side effects observed"
        else:
            str_side_effect = ", ".join(
                [
                    f"{side_effect} (in {len(patients)} patients)"
                    for side_effect, patients in self.observed_side_effects.items()
                ]
            )

        self.str_side_effect = str_side_effect

        return f"Study {self.study_id} is finished!\n{str_side_effect}.\nResults are ready to be sent to a statistician for analysis."
