from openai import OpenAI

import json
import ast
import re
import numpy as np
from llm_client.prompot_bank import EFS_SYSTEM_PROMPT

DEEPSEEK_API_KEY = "sk-keys"
OPENAI_API_KEY = "sk-keys"


def call_llm(prompt, system_prompt="You are a helpful assistant", model="deepseek-chat"):
    if model.startswith("gpt-"):
        # Initialize OpenAI client for GPT models
        client = OpenAI(
            api_key=OPENAI_API_KEY, base_url="https://api.openai-proxy.com/v1"
        )
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt},
            ],
            # tools=[{"type": "web_search_preview"}],
            response_format={
                'type': 'json_object'
            },
            stream=False,
        )
        return response.choices[0].message.content
    else:
        # Initialize Deepseek client
        client = OpenAI(api_key=DEEPSEEK_API_KEY, base_url="https://api.deepseek.com")
        response = client.chat.completions.create(
            model=model,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": prompt},
            ],
            stream=False,
        )
        return response.choices[0].message.content


def parse_factors_string(output_str):
    # output_str.strip("```python\n").strip("```")
    output_list = ast.literal_eval(output_str['codes'])
    # Create a dictionary to store the functions
    function_dict = {}

    # import pdb;pdb.set_trace()
    # Process each function string
    for func_str in output_list:
        # Extract the function name
        func_name = func_str.split("def ")[1].split("(")[0].strip()

        # Fix any typos in the function strings (there were several in the original)
        # fixed_func_str = func_str.replace('prices', 'prices')  # Fix typo in parameter name
        # fixed_func_str = fixed_func_str.replace('np.diff(prices', 'np.diff(np.log(prices')  # Add missing np.log
        # fixed_func_str = fixed_func_str.replace('np.dain', 'np.mean')  # Fix typo

        # Execute the function definition in a temporary namespace
        namespace = {"np": np}
        try:
            exec(func_str, namespace)
            function_dict[func_name] = {}
            function_dict[func_name]["func"] = namespace[func_name]
            function_dict[func_name]["raw"] = func_str
        except Exception as e:
            print(f"Error loading function {func_name}: {str(e)}")
            continue

    # import pdb;pdb.set_trace()
    # The function_dict now contains executable functions
    # print(function_dict.keys())  # Should print the function names
    return function_dict


def parse_factors_string(output_str, max_retries=3, min_functions=5):
    """
    Parse LLM output string into executable functions with retry mechanism.
    
    Args:
        output_str: Raw string output from LLM
        max_retries: Maximum number of retry attempts
        min_functions: Minimum number of valid functions required
        
    Returns:
        Dictionary of parsed functions or None if failed after retries
    """
    retry_count = 0
    last_error = None

    while retry_count < max_retries:
        try:
            # Clean and parse the output string
            # cleaned_str = output_str.strip("```python\n").strip("```")
            # output_list = ast.literal_eval(cleaned_str)
            
            output_list = json.loads(output_str)['codes']

            # Verify we got a list of reasonable length
            if not isinstance(output_list, list) or len(output_list) < min_functions:
                raise ValueError(
                    f"Expected list of at least {min_functions} functions, got {len(output_list) if isinstance(output_list, list) else type(output_list)}"
                )

            function_dict = {}
            success_count = 0

            for func_str in output_list:
                try:
                    # Extract function name
                    func_name = func_str.split("def ")[1].split("(")[0].strip()

                    # Execute the function definition
                    namespace = {"np": np}
                    exec(func_str, namespace)

                    function_dict[func_name] = {
                        "func": namespace[func_name],
                        "raw": func_str,
                    }
                    success_count += 1

                except Exception as e:
                    print(
                        f"Warning: Error loading function {func_name if 'func_name' in locals() else '<unknown>'}: {str(e)}"
                    )
                    continue

            # Check if we got enough valid functions
            if success_count >= min_functions:
                return function_dict
            else:
                raise ValueError(
                    f"Only {success_count} valid functions parsed (need at least {min_functions})"
                )

        except Exception as e:
            last_error = str(e)
            print(f"Attempt {retry_count + 1} failed: {last_error}")
            retry_count += 1
            if retry_count < max_retries:
                print("Retrying...")

    print(f"Failed after {max_retries} attempts. Last error: {last_error}")
    return None


def call_llm_with_retry(prompts_gen, model, max_retries=3, min_functions=5):
    """
    Wrapper for LLM calls with retry mechanism for factor generation.
    
    Args:
        prompts_gen: Prompt generator function
        model: LLM model to use
        max_retries: Maximum retry attempts
        min_functions: Minimum required valid functions
        
    Returns:
        Dictionary of parsed functions or None if failed
    """
    retry_count = 0
    last_error = None

    while retry_count < max_retries:
        try:
            llm_result = call_llm(prompts_gen, system_prompt=EFS_SYSTEM_PROMPT, model=model)
            # import pdb; pdb.set_trace()
            llm_expression = parse_factors_string(
                llm_result, max_retries=1, min_functions=min_functions
            )

            if llm_expression is not None and len(llm_expression) >= min_functions:
                return llm_expression, llm_result
            else:
                raise ValueError(
                    f"Only got {len(llm_expression) if llm_expression is not None else 0} valid functions"
                )

        except Exception as e:
            last_error = str(e)
            print(f"LLM call attempt {retry_count + 1} failed: {last_error}")
            retry_count += 1

            # Enhance prompt with error feedback if we're retrying
            if retry_count < max_retries:
                enhanced_prompt = f"{prompts_gen}\n\nPrevious attempt failed because: {last_error}\nPlease provide at least {min_functions} complete, syntactically correct Python functions that implement financial factors."
                prompts_gen = lambda: enhanced_prompt
                print("Retrying with enhanced prompt...")

    print(f"Failed to get sufficient valid functions after {max_retries} attempts")
    return None, None

