from openai import OpenAI
import json
import pdb
import requests
import numpy as np
import random
import logging
import os
from datetime import datetime

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8003/v1"
backtest_api_url = "http://localhost:8001/backtest"
model = 'Qwen3-4B-Thinking-2507'

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)


def execute_tool_calls(tool_calls, metric='IC'):
    """
    execute tool calls to backtest factors
    """
    results = []
    metric_values = []
    
    for tool_call in tool_calls:
        if tool_call.function.name == "evaluate_factor":
            #  
            args = json.loads(tool_call.function.arguments)
            factor_name = args.get("factor_name")
            factor_expr = args.get("factor_expr")
            tool_metric = metric
            
            #  
            factor_info = {
                "name": factor_name,
                "expr": factor_expr
            }
            
            #  
            metric_value, metric_status = backtest(factor_info, metric=tool_metric)
            metric_values.append(float(metric_value))

            log_print(f"Factor: {factor_name} expr: {factor_expr} {metric}: {metric_value}    {metric_status}")
            
            #  
            result = {
                "tool_call_id": tool_call.id,
                "role": "tool",
                "name": "evaluate_factor",
                "content": f"Success: Evaluated factor \"{factor_name}\" with expression \"{factor_expr}\", {metric}={metric_value}" if metric_status == "success" \
                else f"Failed: Factor {factor_name} with expression {factor_expr}. Reason: {metric_status}"
            }
            results.append(result)

    if len(metric_values) == 0 or np.isnan(metric_values).all():
        log_print("No successful factor found, restarting this round of evolution...")
 
    return results, metric_values


def generate_factors_and_get_result(messages: list, metric='IC'):
    chat_response = client.chat.completions.create(
        model=model,
        messages=messages,
        max_tokens=25000,
        temperature=0.5,
        top_p=0.8,
        presence_penalty=1.5,
        extra_body={
            "top_k": 20, 
            "chat_template_kwargs": {"enable_thinking": True},
        },
        tools=[{"type": "function", "function": {"name": "evaluate_factor", "description": "A tool for evaluating factors with backtesting. Returns a reward based on backtest success and metric performance.", "parameters": {"type": "object", "properties": {"factor_name": {"type": "string", "description": "The name of the factor to evaluate"}, "factor_expr": {"type": "string", "description": "The expression of the factor to evaluate"}}, "required": ["factor_name", "factor_expr"]}}}],
        tool_choice="auto",
    )
    
    message = chat_response.choices[0].message
    log_print("Response\n%s", message.content)
    log_print(f"Total {len(message.tool_calls)} factor backtest requests received")
    
    #  
    if hasattr(message, 'tool_calls') and message.tool_calls:
        #  
        messages.append({
            "role": "assistant",
            "content": message.content,
            "tool_calls": [{"id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}} for tc in message.tool_calls]
        })


        log_print(f"Received tool calls: {message.tool_calls}")

        #  
        tool_results, metric_values = execute_tool_calls(message.tool_calls, metric=metric)
        
        #  
        messages.extend(tool_results)

    return messages, metric_values #  
        


def backtest(factor: dict, metric: str = 'IC') -> tuple[float, str]:
    """
    Backtest the factor with backtest api deployed on local server
    Args:
        factor: factor dict with name and expr
    Returns:
        result: backtest result, here only the mentioned metric is returned
    """
    #  
    test_request = {
        "exprs": {
            factor["name"]: factor["expr"]
        },
        "backtest_start_time": "2023-01-01",
        "backtest_end_time": "2024-01-01",
        "start_cash": 10000000.0,
        "update_freq": 5,
        "label_forward_days": 5,
        "stock_pool": "CSI1000",
        "stop_loss_rate": 0.5,
        "stop_profit_rate": 0.5,
        "position_size": 1.0,
        "max_pos_each_stock": 0.2,
        "use_cache": True,
        "layer_start": 0,
        "layer_end": 1,
        "pred_score_industry_neutralization": False
    }
    
    try:
        response = requests.post(
            backtest_api_url,
            json=test_request,
            timeout=600,  #  
        )
        
        if response.status_code == 200:
            result = response.json()
           
            
            if result['data']:
                 
                metric_value = np.round(result['data']['metrics'][metric], 4)
                return metric_value, "success"
        else:
            print(f"Error: {response.text}")
            return np.nan, response.json()['detail']['error']
            
    except Exception as e:
        print(f"Error: {e}")
        return np.nan, e


def sample_base_factor_info(base_factors: list, num_factors: int = 1, metric: str = 'IC') -> tuple[str, float]:
    """
    Sample base factor and generate their info as initial factors for the agent to refer to
    Args:
        base_factors: list of base factors
        num_factors: number of factors to generate
    Returns:
        user_prompt: user prompt for generating base factor info
    """
    init_metric_value = -np.inf
    factors_info = []

    #   
    sampled_factors = random.sample(base_factors, num_factors)
    for factor in sampled_factors:
        factor_info = json.loads(factor)
        metric_value, metric_status = backtest(factor_info, metric=metric)
        if metric_status == "success":
            factor_info[metric] = float(metric_value)
            factors_info.append(factor_info)
            init_metric_value = max(init_metric_value, float(metric_value))

        else:
            print(f">> initial factor {factor_info['name']} backtest failed: {metric_status}")
            continue
    
    user_content = f"""
Here is an initial factor and its {metric} value. Evolve this factor while trying to improve its {metric} value:
{factor_info}
"""

    return user_content, init_metric_value


def get_factor_info(factor: dict, metric: str = 'IC') -> tuple[str, float]:
    factor_info = json.loads(factor)
    metric_value, metric_status = backtest(factor_info, metric=metric)
    if metric_status == "success":
        factor_info[metric] = float(metric_value)

        user_content = f"""
Here is an initial factor and its {metric} value. Evolve this factor while trying to improve its {metric} value:
{factor_info}
"""

        return user_content, float(metric_value)
    else:
        return None, None

if __name__ == "__main__":
    
    # 'Annualized_Return_with_cost', 'Max_Drawdown_with_cost', 'Information_Ratio_with_cost', 
    # 'Annualized_Return_without_cost', 'Information_Ratio_without_cost', 'Max_Drawdown_without_cost', 
    # 'IC', 'ICIR', 'RankIC', 'RankICIR'
    metric = 'Information_Ratio_with_cost'
    retries = 3
    
    #  
    with open("interface_rl.md", "r", encoding="utf-8") as f:
        interface_manual = f.read()

    #  
    with open("base_factors/test_v2.jsonl", "r", encoding="utf-8") as f:
        test_factors = f.readlines()


    #  
    # import concurrent.futures

    # def run_backtest(factor):
    #     factor_info = json.loads(factor)
    #     metric_value = backtest(factor_info, metric=metric)
 
    #     return metric_value

    # with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
    #     results = list(executor.map(run_backtest, base_factors))
    # exit()


    #  
    os.makedirs('logs', exist_ok=True)

    #  
    global_best_metric = -np.inf
    global_best_factor = None
    summary = []

    # Add counters initialization before the main loop
    # Around line 255, after summary = []
    success_count = 0
    failure_no_factor = 0
    failure_backtest = 0

    for i in range(len(test_factors)):
        round_num = i + 1
        # 
        round_log = f"logs/round_{round_num}_{datetime.now().strftime('%Y%m%d%H%M%S')}.log"
        with open(round_log, 'w') as log_file:
            def log_print(*args, **kwargs):
                print(*args, **kwargs)
                print(*args, file=log_file, **kwargs)

             
            user_prompt, init_metric = get_factor_info(test_factors[i], metric=metric) # sample_base_factor_info(test_factors, i, metric=metric)
            log_print(f"\n===== round {round_num}   =====")
            log_print("initial factor information:\n", user_prompt)
            log_print("initial factor metric value:", init_metric)


            messages = [{"role": "system", "content": interface_manual}, {"role": "user", "content": user_prompt}]

            #  
            round_best_metric = init_metric
            round_best_factor = None
            conversation_history = None
            metric_trajectory = f" ↓ initial factor backtest result: {init_metric}\n"

            for n_gen in range(1, 10):

                log_print(f"\n===== round {n_gen}   =====")
                n_gen_factors = None
                # for retry in range(retries):
                #     try:
                if n_gen == 1:
                    #  
                    conversation_history, metric_values_this_gen = generate_factors_and_get_result(messages=messages, metric=metric)
                
                else:
                    conversation_history, metric_values_this_gen = generate_factors_and_get_result(messages=conversation_history, metric=metric)
                
                #  
                conversation_history = [msg for msg in conversation_history if msg['role'] != 'assistant']

                log_print(f"  {n_gen} g: {metric_values_this_gen}")
 

            log_print("metric trajectory:\n", metric_trajectory)
            log_print(f"round {round_num} evolution finished, log saved to: {round_log}")