import re
from datetime import timedelta
from tianshou.data import Batch
from tianshou.data.batch import BatchProtocol
from typing import Union


class Conversation:
    def __init__(self):
        # Initialize an empty conversation list
        self._conversation = []

    def get(self):
        return self._conversation

    def add_component(self, role, content):
        # Add a new component to the conversation
        if role in ["system", "user", "assistant"]:
            self._conversation.append({"role": role, "content": content})
            self.syntax_check()
        else:
            raise ValueError("Role must be 'system', 'user', or 'assistant'.")
        
    def insert_component(self, role, content, loc):
        # Insert a new component at the specified location
        if role in ["system", "user", "assistant"]:
            if loc < 0:
                loc = len(self._conversation) + loc + 1
            if loc > len(self._conversation):
                loc = len(self._conversation)
            self._conversation.insert(loc, {"role": role, "content": content})
            self.syntax_check()
        else:
            raise ValueError("Role must be 'system', 'user', or 'assistant'.")

    def append_content(self, additional_content, pos):
        # Append additional content to the content of the element at position pos
        if pos < 0:
            pos = len(self._conversation) + pos
        if 0 <= pos < len(self._conversation):
            self._conversation[pos]["content"] += additional_content
            self.syntax_check()
        else:
            raise IndexError("Position out of range.")

    def syntax_check(self):
        # Check for neighboring roles that are the same in the conversation
        i = 1
        while i < len(self._conversation):
            if self._conversation[i]["role"] == self._conversation[i - 1]["role"]:
                # Append content of the current role to the previous role
                self._conversation[i - 1]["content"] += self._conversation[i]["content"]
                # Remove the current role
                self._conversation.pop(i)
            else:
                i += 1
    
    def count_tokens(self, text, tokenizer):
        # Use LLM tokenizer to detect token overflow
        tokens = tokenizer.encode(text)
        return len(tokens)

    def to_str(self):
        str = '\n'.join(f'{component["role"]}: {component["content"]}' for component in self._conversation)
        return str


SYSTEM_PROMPT = ("You are a clinical specialist managing patients with Type-1 Diabetes. "
                 "Your primary objective is to maintain each patient's blood glucose levels within the range "
                 "of 70-180 mg/dL. Blood glucose levels are observed every 5 minutes, "
                 "and insulin is administered accordingly. "
                 "Insulin is dosed in U/min, ranging from 0 to 0.5, and is adjusted per 5 minutes. "

                 "[State]: We can observe the patient's blood glucose level and the insulin dose administered. "

                 "[Action]: Actionable drug is Basal insulin. Insulin reduces blood glucose levels, "
                 "but there is a time delay before its effect is observable. "
                 "No other drugs or insulin regimes are available. "
                 "Standard total daily insulin requirement is 0.4-0.6 units/kg. The patient's weight is not provided."

                 "[Hidden variables]: Food consumption, which increases blood glucose levels, is not directly "
                 "observable. Patients are likely to eat during the following periods: "
                 "Morning: 6:00-9:00, "
                 "Noon: 11:00-13:00, "
                 "Night: 17:00-19:00. "
                 "Occasionally, patients may consume small snacks at other times. "

                 "[Safety Considerations]: Hypoglycemia (low blood glucose levels) is particularly dangerous. "
                 "Extra caution is necessary to avoid administering excessive insulin. Insulin has a long half-life, "
                 "so the effects of previous doses may still be present. Pay attention to the accumulated insulin dose "
                 "to prevent Hypoglycemia.")
LLM_INFERENCE_INSTRUCTION_PROMPT = ("[Instruction]: Please generate the insulin dosage rate in U/min "
                                    "for the next 5 minutes. Only provide a numerical value "
                                    "between 0 and 0.5 without any additional information."
                                    )
LLM_INFERENCE_RETRY_PROMPT = ("Your previous answer cannot be converted to a valid action. "
                              "[Instruction]: Please provide a numerical value between 0 and 0.5 "
                              "without any additional information.")
SUMMARY_INSTRUCTION_PROMPT = (
    "[Instruction]: PLease summarize information such as indications of food intake, patient's response to insulin,"
    "glucose record trend, drug dosage history, abnormal glucose signs and possible misuse of insulin. "
    "Summarize as much information as possible while keeping the answer short.")


def obs2text(batch: Union[Batch, BatchProtocol]) -> str:
    obs = batch.obs
    length = obs.shape[0]
    time = batch.info["time"]
    glucose = obs[:, 0]
    insulin = obs[:, 1]

    def adjust_time(datetime_input, min):
        adjusted_time = datetime_input + timedelta(minutes=min)
        day_number = adjusted_time.day
        return adjusted_time.strftime(f"day{day_number} %H:%M:%S")

    descriptions = []
    for i in range(length):
        if glucose[i] == -1:
            continue
        if i == 0:
            descriptions.append(f"Time:{adjust_time(time, -(length-1)*5)}, insulin:{insulin[0]:.3f}; ")
        if i < length - 1:
            descriptions.append(f"Time:{adjust_time(time, -(length-i-1)*5)}, glucose:{glucose[i]:.1f}, insulin:{insulin[i]:.3f}; ")
        else:
            descriptions.append(f"Current time: {adjust_time(time, 0)}, glucose:{glucose[i]:.1f}, insulin: TBD. ")
    assert descriptions
    return " ".join(descriptions)


def text2act(text, action_space):
    # Find the first numerical value in the text
    match = re.search(r'-?\d+\.?\d*', text)
    if match:
        value = float(match.group())
        # Check if the value is within the action space bounds
        if action_space.contains([value]):
            return value
    return None


def get_patient_info_prompt(age, CR, CF, TDI, ) -> str:
    age = int(age)
    CR = int(CR)
    META_PROMPT = (f"[Patient]: You are treating a {age}-year-old patient with a Total Daily Insulin (TDI) requirement "
                   f"of {TDI:.1f} units over 24 hours. "

                   f"The patient's Carbohydrate Ratio (CR) is {CR}, "
                   f"meaning 1 unit of insulin covers {CR} grams of carbohydrate. "
                   f"A higher CR indicates less insulin is needed for a given amount of carbohydrates, and vice versa. "
                   
                   f"The Correction Factor (CF) for this patient is {CF:.1f}, "
                   f"meaning 1 unit of insulin is expected to lower blood glucose by {1700/TDI:.2f} mg/dL.")

    return META_PROMPT