# import logging
# import re
# import inspect
# import hydra
# import os

# def init_client(cfg):
#     global client
#     if cfg.get("model", None): # for compatibility
#         model: str = cfg.get("model")
#         temperature: float = cfg.get("temperature", 1.0)
#         if model.startswith("gpt"):
#             from utils.llm_client.openai import OpenAIClient
#             client = OpenAIClient(model, temperature)
#         elif cfg.model.startswith("GLM"):
#             from utils.llm_client.zhipuai import ZhipuAIClient
#             client = ZhipuAIClient(model, temperature)
#         else: # fall back to Llama API
#             from utils.llm_client.llama_api import LlamaAPIClient
#             client = LlamaAPIClient(model, temperature)
#     else:
#         client = hydra.utils.instantiate(cfg.llm_client)
#     return client
    

# def file_to_string(filename):
#     with open(filename, 'r') as file:
#         return file.read()
    
    
# def print_hyperlink(path, text=None):
#     """Print hyperlink to file or folder for convenient navigation"""
#     # Format: \033]8;;file:///path/to/file\033\\text\033]8;;\033\\
#     text = text or path
#     full_path = f"file://{os.path.abspath(path)}"
#     return f"\033]8;;{full_path}\033\\{text}\033]8;;\033\\"


# def filter_traceback(s):
#     lines = s.split('\n')
#     filtered_lines = []
#     for i, line in enumerate(lines):
#         if line.startswith('Traceback'):
#             for j in range(i, len(lines)):
#                 if "Set the environment variable HYDRA_FULL_ERROR=1" in lines[j]:
#                     break
#                 filtered_lines.append(lines[j])
#             return '\n'.join(filtered_lines)
#     return ''  # Return an empty string if no Traceback is found

# def block_until_running(stdout_filepath, log_status=False, iter_num=-1, response_id=-1):
#     # Ensure that the evaluation has started before moving on
#     while True:
#         log = file_to_string(stdout_filepath)
#         if  len(log) > 0:
#             if log_status and "Traceback" in log:
#                 logging.warning(
#                     f"Iteration {iter_num}: Code Run {response_id} execution error! (see {print_hyperlink(stdout_filepath, 'stdout')}))"
#                 )
#             else:
#                 logging.info(
#                     f"Iteration {iter_num}: Code Run {response_id} successful! (see {print_hyperlink(stdout_filepath, 'stdout')})"
#                 )
#             break


# def extract_description(response: str) -> tuple[str, str]:
#     # Regex patterns to extract code description enclosed in GPT response, it starts with ‘<start>’ and ends with ‘<end>’
#     pattern_desc = [r'<start>(.*?)```python', r'<start>(.*?)<end>']
#     for pattern in pattern_desc:
#         desc_string = re.search(pattern, response, re.DOTALL)
#         desc_string = desc_string.group(1).strip() if desc_string is not None else None
#         if desc_string is not None:
#             break
#     return desc_string


# def extract_code_from_generator(content):
#     """Extract code from the response of the code generator."""
#     pattern_code = r'```python(.*?)```'
#     code_string = re.search(pattern_code, content, re.DOTALL)
#     code_string = code_string.group(1).strip() if code_string is not None else None
#     if code_string is None:
#         # Find the line that starts with "def" and the line that starts with "return", and extract the code in between
#         lines = content.split('\n')
#         start = None
#         end = None
#         for i, line in enumerate(lines):
#             if line.startswith('def'):
#                 start = i
#             if 'return' in line:
#                 end = i
#                 break
#         if start is not None and end is not None:
#             code_string = '\n'.join(lines[start:end+1])
    
#     if code_string is None:
#         return None
#     # Add import statements if not present
#     if "np" in code_string:
#         code_string = "import numpy as np\n" + code_string
#     if "torch" in code_string:
#         code_string = "import torch\n" + code_string
#     return code_string


# def filter_code(code_string):
#     """Remove lines containing signature and import statements."""
#     lines = code_string.split('\n')
#     filtered_lines = []
#     for line in lines:
#         if line.startswith('def'):
#             continue
#         elif line.startswith('import'):
#             continue
#         elif line.startswith('from'):
#             continue
#         elif line.startswith('return'):
#             filtered_lines.append(line)
#             break
#         else:
#             filtered_lines.append(line)
#     code_string = '\n'.join(filtered_lines)
#     return code_string


# def get_heuristic_name(module, possible_names: list[str]):
#     for func_name in possible_names:
#         if hasattr(module, func_name):
#             if inspect.isfunction(getattr(module, func_name)):
#                 return func_name

import logging
import re
import inspect
import hydra
import os

def init_client(cfg):
    global client
    if cfg.get("model", None):  # for compatibility
        model: str = cfg.get("model")
        temperature: float = cfg.get("temperature", 1.0)
        api_key: str = cfg.get("sk-OKqg8sHiPMBipZAenbby7V71C2AfaKFugg5b9sm8bIlYspgj", None)
        # 如果配置中没有指定 base_url，则对 Agicto 设定默认的 base_url
        base_url: str = cfg.get("base_url", "https://api.agicto.cn/v1")
        
        if model.startswith("agicto"):
            from utils.llm_client.agicto import AgictoClient
            client = AgictoClient(model, temperature, api_key=api_key, base_url=base_url)
        elif model.startswith("gpt"):
            from utils.llm_client.openai import OpenAIClient
            client = OpenAIClient(model, temperature, api_key=api_key, base_url=base_url)
        elif model.startswith("GLM"):
            from utils.llm_client.zhipuai import ZhipuAIClient
            client = ZhipuAIClient(model, temperature, api_key=api_key, base_url=base_url)
        else:  # fall back to Llama API
            from utils.llm_client.llama_api import LlamaAPIClient
            client = LlamaAPIClient(model, temperature, api_key=api_key, base_url=base_url)
    else:
        client = hydra.utils.instantiate(cfg.llm_client)
    return client

    

def file_to_string(filename):
    with open(filename, 'r',encoding='utf-8') as file:
        return file.read()

def filter_traceback(s):
    lines = s.split('\n')
    filtered_lines = []
    for i, line in enumerate(lines):
        if line.startswith('Traceback'):
            for j in range(i, len(lines)):
                if "Set the environment variable HYDRA_FULL_ERROR=1" in lines[j]:
                    break
                filtered_lines.append(lines[j])
            return '\n'.join(filtered_lines)
    return ''  # Return an empty string if no Traceback is found

# def block_until_running(stdout_filepath, log_status=False, iter_num=-1, response_id=-1):
#     # Ensure that the evaluation has started before moving on
#     while True:
#         log = file_to_string(stdout_filepath)
#         if  len(log) > 0:
#             if log_status and "Traceback" in log:
#                 logging.info(f"Iteration {iter_num}: Code Run {response_id} execution error!")
#             else:
#                 logging.info(f"Iteration {iter_num}: Code Run {response_id} successful!")
#             break

def block_until_running(stdout_filepath, iter_num, response_id, log_status=False):
    # Ensure that the evaluation has started before moving on
    while True:
        log = file_to_string(stdout_filepath)
        if len(log) > 0:
            if log_status and "Traceback" in log:
                logging.info(f"Iteration {iter_num}: Code Run {response_id} execution error!")
            else:
                logging.info(f"Iteration {iter_num}: Code Run {response_id} successful!")
            break

def extract_description(response: str) -> tuple[str, str]:
    # Regex patterns to extract code description enclosed in GPT response, it starts with ‘<start>’ and ends with ‘<end>’
    pattern_desc = [r'<start>(.*?)```python', r'<start>(.*?)<end>']
    for pattern in pattern_desc:
        desc_string = re.search(pattern, response, re.DOTALL)
        desc_string = desc_string.group(1).strip() if desc_string is not None else None
        if desc_string is not None:
            break
    return desc_string


# def extract_code_from_generator(content):
#     """Extract code from the response of the code generator."""
#     pattern_code = r'```python(.*?)```'
#     code_string = re.search(pattern_code, content, re.DOTALL)
#     code_string = code_string.group(1).strip() if code_string is not None else None
#     if code_string is None:
#         # Find the line that starts with "def" and the line that starts with "return", and extract the code in between
#         lines = content.split('\n')
#         start = None
#         end = None
#         for i, line in enumerate(lines):
#             if line.startswith('def'):
#                 start = i
#             if 'return' in line:
#                 end = i
#                 break
#         if start is not None and end is not None:
#             code_string = '\n'.join(lines[start:end+1])
    
#     if code_string is None:
#         return None
#     # Add import statements if not present
#     if "np" in code_string:
#         code_string = "import numpy as np\n" + code_string
#     if "torch" in code_string:
#         code_string = "import torch\n" + code_string
#     return code_string

def extract_code_from_generator(content):
    """Extract code from the response of the code generator."""
    pattern_code = r'```python(.*?)```'
    code_string = re.search(pattern_code, content, re.DOTALL)
    code_string = code_string.group(1).strip() if code_string is not None else None
    if code_string is None:
        # Find the line that starts with "def" and the line that starts with "return", and extract the code in between
        lines = content.split('\n')
        start = None
        end = None
        for i, line in enumerate(lines):
            if line.startswith('def'):
                start = i
            if 'return' in line:
                end = i
                break
        if start is not None and end is not None:
            code_string = '\n'.join(lines[start:end+1])
    
    if code_string is None:
        return None
    
    # Add import statements if not present
    imports_to_add = []
    
    if "np" in code_string and "import numpy" not in code_string:
        imports_to_add.append("import numpy as np")
    
    if "torch" in code_string and "import torch" not in code_string:
        imports_to_add.append("import torch")
    
    # 检查是否使用了Categorical但没有导入
    if "Categorical" in code_string and "from torch.distributions import Categorical" not in code_string:
        imports_to_add.append("from torch.distributions import Categorical")
    
    if "scipy" in code_string or "distance_matrix" in code_string:
        if "import scipy" not in code_string:
            imports_to_add.append("import scipy.spatial")
    if "nn" in code_string and "import torch.nn" not in code_string:
        imports_to_add.append("import torch.nn as nn")
    if "F" in code_string and "import torch.nn.functional as F" not in code_string:
        imports_to_add.append("import torch.nn.functional as F")
    if "multi_head_attention" in code_string and "from TSPModel import multi_head_attention, reshape_by_heads" not in code_string:
        imports_to_add.append("from TSPModel import multi_head_attention, reshape_by_heads")
    if "reshape_by_heads" in code_string and "from TSPModel import multi_head_attention, reshape_by_heads" not in code_string:
        imports_to_add.append("from TSPModel import multi_head_attention, reshape_by_heads")   

    # 合并所有需要添加的导入语句
    if imports_to_add:
        code_string = "\n".join(imports_to_add) + "\n\n" + code_string
    
    return code_string


def filter_code(code_string):
    """Remove lines containing signature and import statements."""
    lines = code_string.split('\n')
    filtered_lines = []
    for line in lines:
        if line.startswith('def'):
            continue
        elif line.startswith('import'):
            continue
        elif line.startswith('from'):
            continue
        elif line.startswith('return'):
            filtered_lines.append(line)
            break
        else:
            filtered_lines.append(line)
    code_string = '\n'.join(filtered_lines)
    return code_string


def get_heuristic_name(module, possible_names: list[str]):
    for func_name in possible_names:
        if hasattr(module, func_name):
            if inspect.isfunction(getattr(module, func_name)):
                return func_name

def get_pick_move_name(module, possible_func_names):
    """
    从模块中获取pick_move函数的名称
    """
    for func_name in possible_func_names:
        if hasattr(module, func_name):
            return func_name
    return possible_func_names[0]  # 默认返回第一个名称

            
def print_hyperlink(path, text=None):
    """Print hyperlink to file or folder for convenient navigation"""
    # Format: \033]8;;file:///path/to/file\033\\text\033]8;;\033\\
    text = text or path
    full_path = f"file://{os.path.abspath(path)}"
    return f"\033]8;;{full_path}\033\\{text}\033]8;;\033\\"