import json5
import logging

from organisation.env import config

logger = logging.getLogger(__name__)


class MarketStudy:
    def __init__(self, env, study_id, config_study):
        self.env = env
        self.study_id = study_id
        self.config = config_study

        _ = env.process(self.run_study())

    def run_study(self):
        yield self.env.timeout(24 * 7)  # Wait 1 week gap before getting results
        with open(config.POPULATION_FILE, "r") as file:
            population_stats = json5.load(file)
        # Calculate the market size based on the population statistics
        market_size = (
            population_stats["size"] * 1000
        )  # Convert from thousands to actual number

        if "age" in self.config:
            sum_proba_age = sum(
                population_stats["age"]["probabilities"][
                    population_stats["age"]["categories"].index(cat)
                ]
                for cat in self.config["age"]
            )
            market_size *= sum_proba_age
        else:
            self.config["age"] = population_stats["age"]["categories"]
            sum_proba_age = 1.0
        if "sex" in self.config:
            sum_proba_sex = sum(
                population_stats["sex"]["probabilities"][
                    population_stats["sex"]["categories"].index(cat)
                ]
                for cat in self.config["sex"]
            )
            market_size *= sum_proba_sex
        else:
            self.config["sex"] = population_stats["sex"]["categories"]
            sum_proba_sex = 1.0

        if "diagnosis" in self.config:
            list_sex_age_cat = [
                (sex, age) for sex in self.config["sex"] for age in self.config["age"]
            ]
            list_sex_age_prob = [
                population_stats["sex"]["probabilities"][
                    population_stats["sex"]["categories"].index(sex)
                ]
                * population_stats["age"]["probabilities"][
                    population_stats["age"]["categories"].index(age)
                ]
                / (sum_proba_sex * sum_proba_age)
                for (sex, age) in list_sex_age_cat
            ]

            for diagnosis in self.config["diagnosis"]:
                sum_proba_diagnosis = sum(
                    population_stats["diagnosis"][diagnosis][sex][age]
                    * list_sex_age_prob[i]
                    for i, (sex, age) in enumerate(list_sex_age_cat)
                )
                market_size *= sum_proba_diagnosis

        self.market_size = market_size
        logger.info(f"Market study results: {self.market_size}")

    def display_results(self):
        return f"Market study results: {self.market_size} potential patients"


class PricingStudy:
    def __init__(self, env, drug, config_study):
        self.env = env
        self.config = config_study
        self.drug = drug
        _ = env.process(self.run_study())

    def run_study(self):
        yield self.env.timeout(24 * 7)  # Wait 1 week gap before getting results

        if self.config["beat_competitor"]:
            self.price = self.drug.price["competitive"]
        else:
            self.price = self.drug.price["standard"]

        logger.info(f"Pricing study results: {self.price}")

    def display_results(self):
        return f"Pricing study results: {self.price}"
