from .patient import Patient
from .drug import Drug
from .single_arm_study import SingleArmStudy
from .study import Study
import random
import logging

logger = logging.getLogger(__name__)


class NonComparativeRandomisedStudy(Study):
    def __init__(
        self,
        env,
        study_id,
        drug: Drug,
        patient_list: list[Patient],
        dosages: list[float],
        duration: float,
        period: float,
        biomarkers_list: list[str],
        responsible,
    ):
        super().__init__(env, study_id, responsible)
        self.env = env

        # Randomize patients to different dosages
        self.patient_dosage_map = {
            patient.patient_id: random.choice(dosages)["dosage_id"]
            for patient in patient_list
        }

        # Start one study per dosage and gather results
        self.results = {}
        self.studies = []
        for dosage in dosages:
            new_study_id = f"{study_id}_{dosage['dosage_id']}"
            self.studies.append(
                SingleArmStudy(
                    env=env,
                    study_id=new_study_id,
                    drug=drug,
                    patient_list=[
                        patient
                        for patient in patient_list
                        if self.patient_dosage_map[patient.patient_id]
                        == dosage["dosage_id"]
                    ],
                    dosage_sequence=dosage,
                    duration=duration,
                    period=period,
                    biomarkers_list=biomarkers_list,
                )
            )

    def display_results(self):
        for study in self.studies:
            self.results[study.study_id] = {
                "dosage": study.dosage_sequence["dosage_id"],
                "results": study.results,
            }
        return self.results


class ComparativeRandomisedStudy(Study):
    def __init__(
        self,
        env,
        study_id,
        drug: Drug,
        competitor_drug: Drug,
        patient_list: list[Patient],
        dosage: dict,
        dosage_competitor: dict,
        duration: float,
        period: float,
        period_competitor: float,
        biomarkers_list: list[str],
        responsible,
    ):
        self.drug = drug
        self.competitor_drug = competitor_drug
        super().__init__(env, study_id, responsible)

        self.dosage = dosage
        self.dosage_competitor = dosage_competitor
        self.duration = duration

        self.max_period = max(period, period_competitor)
        self.period = period
        self.period_competitor = period_competitor
        self.biomarker_list = biomarkers_list
        self.approved = False
        self.patient_list = patient_list

        # Randomize patients to the two treatments
        self.patient_dosage_map = {
            patient.patient_id: random.choice(
                [
                    (dosage, self.drug.drug_id),
                    (dosage_competitor, self.competitor_drug.drug_id),
                ]
            )
            for patient in patient_list
        }

        # Start one study per dosage and gather results
        self.results = {}
        self.studies = []

    def run_study(self):
        # Waiting for approval by Sponsor

        print("we started the study")

        self.start_date = self.env.now
        self.ongoing = True

        for dr, do in zip(
            [self.drug, self.competitor_drug], [self.dosage, self.dosage_competitor]
        ):
            # Create a new study for each drug
            new_study_id = (
                f"{self.study_id}_{dr.drug_id if do != {0: 0} else 'Placebo'}"
            )
            self.studies.append(
                SingleArmStudy(
                    env=self.env,
                    study_id=new_study_id,
                    drug=dr,
                    patient_list=[
                        patient
                        for patient in self.patient_list
                        if (
                            self.patient_dosage_map[patient.patient_id]
                            == (do, dr.drug_id)
                        )
                    ],
                    dosage_sequence=do,
                    duration=self.duration,
                    period=self.period if dr == self.drug else self.period_competitor,
                    biomarkers_list=self.biomarker_list,
                    responsible=None,
                    approved=True,
                    ongoing=True,
                )
            )

        for study in self.studies:
            self.env.process(study.run_study())
            logger.info(
                f"Comparative sub-study {study.study_id}: n={len(study.patient_list)} dosage={study.dosage_sequence}"
            )

        while True in [study.results is None for study in self.studies]:
            yield self.env.timeout(1)

        self.observed_side_effects = {
            study.study_id.split("_")[-1]: study.observed_side_effects
            for study in self.studies
        }
        self.results = {
            study.study_id.split("_")[-1]: study.results for study in self.studies
        }

        logger.info(f"Comparative results ready for {self.study_id}")
        logger.info(f"Side effects observed: {self.observed_side_effects}")
        self.completed = True
        self.completion()

    def display_start(self):
        return f"Study {self.study_id} is now monitored and has started. It will be completed at timestamp {self.start_date + self.duration}."

    def display_design(self):
        return f"Pending approval, Study {self.study_id} has been designed with drugs {self.drug.drug_id} and {self.competitor_drug.drug_id if self.competitor_drug != {0: 0} else 'Placebo'} with dosages {self.dosage} and {self.dosage_competitor} respectively. Duration: {self.duration}."

    def display_results(self):
        if not self.results:
            return f"Study {self.study_id} has no results yet."

        str_side_effect = ""

        for study in self.observed_side_effects:
            observed_side_effects = self.observed_side_effects[study]

            if not observed_side_effects:
                str_side_effect += f"{study} : No side effects observed.\n"
            else:
                str_side_effect += (
                    f"{study} : "
                    + ", ".join(
                        [
                            f"{side_effect} (in {len(patients)} patients)"
                            for side_effect, patients in observed_side_effects.items()
                        ]
                    )
                    + "\n"
                )

        self.str_side_effect = str_side_effect

        return f"Study {self.study_id} is finished!\n{str_side_effect}\nResults are ready to be sent to a statistician for analysis."
