import re
import os
import heapq
import logging
import json
from .api import ChatAPI

def extract_proxies(markdown_text):
    '''
    extract description and code from response
    '''
    proxies = []
    
    blocks = re.split(r'(?=### proxy \d+)', markdown_text)
    
    for block in blocks:
        if not block.strip() or not block.startswith('###'):
            continue
            
        desc_match = re.search(
            r'\*\*Description\*\*:\s*(.*?)(?=\s*```python|$)', 
            block, 
            re.DOTALL
        )
        
        code_match = re.search(
            r'```python\s*(.*?)\s*```',
            block,
            re.DOTALL
        )
        
        if desc_match and code_match:
            proxies.append({
                'description': desc_match.group(1).strip(),
                'code': code_match.group(1).strip(),
                'score': None
            })
    
    return proxies


def read_prompts(folder_path, args):
    if args.benchmark == 'nasbench201':
        folder_path = os.path.join(folder_path, 'nasbench201')
    elif args.benchmark == 'transbench101':
        folder_path = os.path.join(folder_path, 'transbench101')
    elif args.benchmark == 'nasbench101':
        folder_path = os.path.join(folder_path, 'nasbench101')
    else:
        assert False, 'Unknown benchmark!'

    prompts = []
    for file in os.listdir(folder_path):
        file_path = os.path.join(folder_path, file)
        if os.path.isfile(file_path):
            with open(file_path, 'r', encoding='utf-8') as f:
                prompts.append(f.read())

            if file == 'initialization.md':
                prompts[0], prompts[-1] = prompts[-1], prompts[0]


    return prompts        


def get_new_pop(cur_pop, api : ChatAPI, action, prompts, args):
    prompt = prompts[action]
    if action != 0:
        prompt += '\n###Existing proxies:\n'

        for i, individual in enumerate(cur_pop):
            prompt += f'\n### proxy {i + 1}\n\n**description**:' + individual['description'] + '\n\n' + '```python\n' + individual['code'] + '\n```\n\n---\n'
    if args.llm == 'deepseek':
        response = api.call_deepseek([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)
    elif args.llm == 'chatgpt':
        response = api.call_gpt([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)
    elif args.llm == 'claude':
        response = api.call_claude([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)
    elif args.llm == 'grok':
        response = api.call_grok([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)
    elif args.llm == 'gemini':
        response = api.call_gemini([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)
    elif args.llm == 'llama':
        response = api.call_llama([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)
    elif args.llm == 'qwen':
        response = api.call_qwen([{'role' : 'user', 'content' : prompt}], model=args.model, temperature=args.temperature)

    new_pop = extract_proxies(response)
    return new_pop

def topk(pop, new_pop, k, decay):
    for ind in pop:
        ind['score'] *= decay
    combined = pop + new_pop
    return heapq.nlargest(k, combined, key=lambda x: abs(x['score']))

def get_logger(file_path):
    """ Make python logger """
    # [!] Since tensorboardX use default logger (e.g. logging.info()), we should use custom logger
    logger = logging.getLogger('swap')
    log_format = '%(asctime)s | %(message)s'
    formatter = logging.Formatter(log_format, datefmt='%m/%d %I:%M:%S %p')
    file_handler = logging.FileHandler(file_path)
    file_handler.setFormatter(formatter)
    stream_handler = logging.StreamHandler()
    stream_handler.setFormatter(formatter)

    logger.addHandler(file_handler)
    logger.addHandler(stream_handler)
    logger.setLevel(logging.INFO)

    return logger

def save_pop(pop, folderpath, episode):
    with open(os.path.join(folderpath, 'episode' + str(episode) + '.json'), 'w') as f:
        json.dump(pop, f, indent=4)

if __name__ == '__main__':
    with open('./output.txt', 'r', encoding='utf-8') as f:
        content = f.read()
    a = extract_proxies(content)
    print(a[0]['code'])

