import re, json
import httpx
import time, datetime, pytz
import openai
from openai import OpenAI
from zhipuai import ZhipuAI
from tqdm import tqdm

if __name__ == "__main__":
    from key import MODEL_MAP
else:
    from utils.key import MODEL_MAP

class AIClient:
    def __init__(self, model="gpt-3.5-turbo", max_tokens=None):
        self.model = model
        self.max_tokens = max_tokens
        
        # 选择配置
        if self.model in MODEL_MAP:
            self.conf = MODEL_MAP[self.model]
        else:
            self.conf = MODEL_MAP["default"]

        # 根据 client 类型初始化
        if self.conf["client"] == "zhipu":
            self.client = ZhipuAI(api_key=self.conf["key"])
        elif self.conf["client"] == "local":
            self.client = OpenAI(
                base_url=self.conf["base_url"],
                api_key=self.conf["key"],
            )
        else:
            self.client = OpenAI(
                base_url=self.conf["base_url"],
                api_key=self.conf["key"]
            )
        
        self.prompt_tokens = 0
        self.completion_tokens = 0
        self.model_calls = 0
    
    def get_response(self, messages, max_retries=5, retry_delay=3, model=None, temperature=0):
        if model == None:
            model = self.model
        
        if 'baichuan' in model.lower():
            messages[0]['role'] = 'user'
            
        if 'qwen3' in model.lower():
            messages[0]['content'] += "\\no_think"

        for attempt in range(max_retries):
            try:
                if 'baichuan' in model.lower():
                    completion = self.client.chat.completions.create(
                        model=model if self.conf["client"] != "local" else self.client.models.list().data[0].id,
                        messages=messages,
                        temperature=temperature,
                        extra_body={"chat_template_kwargs": {"thinking_mode": "off"}} # Close thinking mode for BaichuanM2
                    )
                elif 'llama-3.1' in model.lower():
                    prompts = [msg['content'] for msg in messages]
                    # print("prompts:", prompts)
                    completion=self.client.completions.create(
                        model=model if self.conf["client"] != "local" else self.client.models.list().data[0].id,
                        prompt=prompts,
                    )
                else:
                    completion = self.client.chat.completions.create(
                        model=model if self.conf["client"] != "local" else self.client.models.list().data[0].id,
                        messages=messages,
                        temperature=temperature,
                    )
                self.model_calls += 1
                if hasattr(completion, 'usage'):
                    self.prompt_tokens += completion.usage.prompt_tokens
                    self.completion_tokens += completion.usage.completion_tokens

                if 'llama-3.1' in model.lower():
                    content = completion.choices[0].text
                else:
                    content = completion.choices[0].message.content
                    
                # print(content)
                
                if content is not None:
                    if "huatuo" in model.lower():
                        content = content.split("## Final Response")[-1].strip() # Get final response for Huatuo
                    elif "baichuan" in model.lower() or 'qwen3' in model.lower() or 'gpt-5' in model.lower():
                        content = content.split("</think>")[-1].strip().split("<think>")[-1]
                    return content.strip()
                
            except (httpx.RequestError, RuntimeError) as e:
                print(f"Attempt {attempt + 1}/{max_retries} failed: {str(e)[:100]}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay)
                else:
                    raise RuntimeError("Max retries reached. Unable to get response from API.")
                
            except openai.APITimeoutError as e:
                print(f"API Timeout Error on attempt {attempt + 1}/{max_retries}: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay)
                else:
                    errorMessage(f"API Timeout Error {str(e)} in the request: {messages}\n")
                    raise RuntimeError("Max retries reached due to timeout. Unable to get response from API.")
            
            except Exception as e:
                print(f"Unexpected error on attempt {attempt + 1}/{max_retries}: {str(e)}")
                if attempt < max_retries - 1:
                    time.sleep(retry_delay)
                else:
                    raise RuntimeError(f"Unexpected error: {str(e)}")
    
    def get_tokens(self):
        return self.prompt_tokens, self.completion_tokens, self.prompt_tokens + self.completion_tokens

def extract_json_from_text(text):
    try:
        match = re.search(r"```(?:json)?\s*(.*?)```", text, re.DOTALL)
        if match:
            text = match.group(1).strip()
        return json.loads(text)
    except Exception as e:
        raise ValueError(f"Failed to parse JSON: {e}\nOriginal text: {text}")

def select_basic_info(atom_list, client=None, max_retries=3):
    messages = [
        {
            "role": "system",
            "content": "You are tasked with selecting the chief complaint of a patient from a list of informtions. "
        },
        {
            "role": "user",
            "content": f"Here is the informtion list: {atom_list}. Please return one of them without any additional text."
        }
    ]
    
    if client is None:
        client = AIClient(model="gpt-4o-mini")
    for _ in range(max_retries):
        response = client.get_response(messages).strip()
        if response in atom_list:
            return response
    return None


def atomize(info, client):
    messages = [
        {
            "role": "system",
            "content": (
                "You will extract **atomic clinical facts** from a patient profile.\n\n"
                "Your task is to break down the information into small, non-overlapping units that each express a clear, self-contained fact.\n\n"
                "**Guidelines:**\n"
                "- Do NOT repeat information.\n"
                "- Each atomic unit should be **independently understandable** (no vague 'today', 'a day', 'since this morning' without a clear anchor).\n"
                "- Always keep **time/frequency expressions attached to the event** they describe (e.g., 'two episodes of red urine today' is valid; 'today' alone is not).\n"
                "- Omit generic phrases like 'presents to clinic' or irrelevant fillers.\n"
                "- The output should be a list separated by semicolons ';'.\n\n"
                "**Example:**\n"
                "Input: 'a white man of 22-year-old with painful, recurrent rash'\n"
                "Output: '22-year-old; male; white; rash identified; rash is painful; rash is recurrent'"
            )
        },
        {
            "role": "user",
            "content": f"The original full patient profile is: {info} Please extract atomic facts, separated by ';'."
        }
    ]

    response = client.get_response(messages)
    return [item.strip().strip('.') for item in response.split(';')]


from colorama import Fore, Style
if __name__ == "__main__":
    from config_loader import ConfigLoader
else:
    from utils.config_loader import ConfigLoader
config = ConfigLoader('config.cfg')
evaluation = config.get_section('evaluation')
debug_mode = evaluation["debug"]=="True" if "debug" in evaluation else False

def systemMessage(info):
    """Prints the info in yellow color (system-level messages)."""
    if debug_mode:
        tqdm.write(Fore.YELLOW + info + Style.RESET_ALL)
    return info

def doctorMessage(info):
    """Prints the info in green color (doctor-level messages)."""
    if debug_mode:
        tqdm.write(Fore.GREEN + info + Style.RESET_ALL)
    return info

def patientMessage(info):
    """Prints the info in blue color (patient-level messages)."""
    if debug_mode:
        tqdm.write(Fore.BLUE + info + Style.RESET_ALL)
    return info

def otherMessage(info):
    """Prints the info in default color (other messages)."""
    if debug_mode:
        tqdm.write(info)
    return info

def errorMessage(info, log_file="log/error.log"):
    """Prints the info in red color and logs to a file with timestamp."""
    timestamp = datetime.datetime.now(pytz.timezone('Asia/Shanghai')).strftime("%Y-%m-%d %H:%M:%S")
    log_entry = f"[{timestamp}] {info}\n"

    tqdm.write(Fore.RED + info + Style.RESET_ALL)

    try:
        with open(log_file, "a", encoding="utf-8") as f:
            f.write(log_entry)
    except Exception as e:
        tqdm.write(Fore.RED + f"[LOG ERROR] Failed to write log: {e}" + Style.RESET_ALL)

    return info

if __name__ == "__main__":
    
    messages = [
        {"role": "user", "content": "Can you tell me a joke?"}
    ]
    
    test_model = "gpt-3.5-turbo"
    
    client = AIClient(model=test_model)
    response = client.get_response(messages)
    print(f"Response from {test_model}:\n" + response)