import ast
from ..core.ressources import ComparativeRandomisedStudy


class PlanPolicy:
    def __init__(self, level=5):
        self.log_action_obs = []

        self.level = level

    def select_action(self, state, env):
        if env.simulation.env.studies != {}:
            on_going_studies = [
                s
                for s in env.simulation.env.studies.values()
                if s.ongoing and not s.completed
            ]
            if on_going_studies != []:
                delta_with_last_study = min(
                    [env.simulation.env.now - s.start_date for s in on_going_studies]
                )
                if delta_with_last_study > 48:
                    return ast.literal_eval(
                        "{'Investigator:1': (0,[]), 'Sponsor:2': (0,[]), 'Legal Team:3': (0,[]), 'Statistician:4': (0,[])}"
                    )

        cycle = []
        cycle += (
            [
                "{'Investigator:1': (2,['Sponsor:2','Investigator:1','Statistician:4','Legal Team:3']), 'Sponsor:2': (2,['Sponsor:2','Investigator:1','Statistician:4','Legal Team:3']), 'Legal Team:3': (2,['Sponsor:2','Investigator:1','Statistician:4','Legal Team:3']), 'Statistician:4': (2,['Sponsor:2','Investigator:1','Statistician:4','Legal Team:3'])}"
            ]
            * 2
            * self.level
        )

        comparative_study_completed = len(
            [
                s
                for s in env.simulation.env.studies.values()
                if isinstance(s, ComparativeRandomisedStudy) and s.completed
            ]
        )

        if not (comparative_study_completed >= 1):
            cycle += [
                "{'Investigator:1': (1,[]), 'Sponsor:2': (0,[]), 'Legal Team:3': (0,[]), 'Statistician:4': (0,[])}"
            ] * self.level
            cycle += [
                "{'Investigator:1': (3,['Sponsor:2','Statistician:4']), 'Sponsor:2': (0,[]), 'Legal Team:3': (0,[]), 'Statistician:4': (0,[])}"
            ]
            cycle += [
                "{'Investigator:1': (0,[]), 'Sponsor:2': (1,[]), 'Legal Team:3': (0,[]), 'Statistician:4': (1,[])}"
            ] * self.level
            cycle += [
                "{'Investigator:1': (0,[]), 'Sponsor:2': (3,['Investigator:1','Statistician:4']), 'Legal Team:3': (0,[]), 'Statistician:4': (3,['Sponsor:2','Investigator:1'])}"
            ]
        else:
            cycle += [
                "{'Investigator:1': (1,[]), 'Sponsor:2': (0,[]), 'Legal Team:3': (1,[]), 'Statistician:4': (0,[])}"
            ] * self.level
            cycle += [
                "{'Investigator:1': (3,['Sponsor:2','Statistician:4','Legal Team:3']), 'Sponsor:2': (0,[]), 'Legal Team:3': (3,['Sponsor:2','Statistician:4','Investigator:1']), 'Statistician:4': (0,[])}"
            ]
            cycle += [
                "{'Investigator:1': (0,[]), 'Sponsor:2': (1,[]), 'Legal Team:3': (0,[]), 'Statistician:4': (1,[])}"
            ] * self.level
            cycle += [
                "{'Investigator:1': (0,[]), 'Sponsor:2': (3,['Investigator:1','Statistician:4','Legal Team:3']), 'Legal Team:3': (0,[]), 'Statistician:4': (3,['Sponsor:2','Investigator:1','Legal Team:3'])}"
            ]

        action_str = cycle[int(env.simulation.env.now) % len(cycle)]

        action = ast.literal_eval(action_str)
        # print("action", action)
        return action

    def reset(self):
        self.log_action_obs = []
