import os
import json
from openai import OpenAI
import pandas as pd
import numpy as np
from datasets import Dataset
import random
import requests
import concurrent.futures
import pdb

openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8003/v1"
backtest_api_url = "http://localhost:8003/backtest"

client = OpenAI(
    api_key=openai_api_key,
    base_url=openai_api_base,
)


#  
DATA_DIR = os.path.expanduser('~/data/factor')
TRAIN_OUTPUT_FILE = os.path.join(DATA_DIR, 'factor_train_v5.parquet')
VAL_OUTPUT_FILE = os.path.join(DATA_DIR, 'factor_val_v5.parquet')
NUM_SAMPLES = 999 #  

# OUTPUT_FILE = os.path.join(DATA_DIR, 'factor_test_v3.parquet')
# NUM_SAMPLES = 10 #  
METRIC = 'Information_Ratio_with_cost'



def backtest(factor: dict, metric: str = 'IC') -> tuple[float, str]:
    """
    Backtest the factor with backtest api deployed on local server
    Args:
        factor: factor dict with name and expr
    Returns:
        result: backtest result, here only the mentioned metric is returned
    """
    #  
    test_request = {
        "exprs": {
            factor["name"]: factor["expr"]
        },
        # "date_split": {
        #     "train_start_time": "2018-01-01",
        #     "train_end_time": "2020-12-31",
        #     "val_start_time": "2021-01-01",
        #     "val_end_time": "2021-12-31",
        #     "test_start_time": "2023-01-01",
        #     "test_end_time": "2024-12-31"
        # },
        "backtest_start_time": "2023-01-01",
        "backtest_end_time": "2024-01-01",
        "start_cash": 10000000.0,
        "update_freq": 5,
        "label_forward_days": 5,
        "stock_pool": "CSI500",
        "stop_loss_rate": 0.5,
        "stop_profit_rate": 0.5,
        "position_size": 1.0,
        "max_pos_each_stock": 0.2,
        "use_cache": True,
        "layer_start": 0,
        "layer_end": 1,
        "pred_score_industry_neutralization": False
    }
    
    try:
        response = requests.post(
            backtest_api_url,
            json=test_request,
            timeout=600,  #  
        )
        
        if response.status_code == 200:
            result = response.json()
            
            # print(f"Success: {result['success']}")
            # print(f"Message: {result['message']}")
            
            if result['data']:
 
                # data = result['data']['metrics']
                #  
                # for key in ['IC', 'ICIR', 'RankIC', 'RankICIR', '1day.excess_return_with_cost.information_ratio', '1day.excess_return_with_cost.annualized_return']:
                #     if key in data:
                #         print(f"{key}: {data[key]:.4f}")
                metric_value = np.round(result['data']['metrics'][metric], 4)
                return metric_value, "success"
        else:
            print(f"error: {response.text}")
            return np.nan, response.json()['detail']['error']
            
    except Exception as e:
        print(f"error: {e}")
        return np.nan, e


def sample_base_factor_info(base_factors: list, num_factors: int = 1, metric: str = 'IC') -> tuple[str, float]:
    """
    Sample base factor and generate their info as initial factors for the agent to refer to
    Args:
        base_factors: list of base factors
        num_factors: number of factors to generate
    Returns:
        user_prompt: user prompt for generating base factor info
    """
    init_metric_value = -np.inf
    factors_info = []

    while len(factors_info) < num_factors:
        #  
        sampled_factors = random.sample(base_factors, num_factors)
        for factor in sampled_factors:
            factor_info = json.loads(factor)
            metric_value, metric_status = backtest(factor_info, metric=metric)
            if metric_status == "success" and metric_value > 0:
                factor_info[metric] = float(metric_value)
                if metric_value > init_metric_value:
                    init_metric_value = metric_value
                factors_info.append(factor_info)

            else:
                print(f">> factor {factor_info['name']} backtest failed: {metric_status} {metric_value}")
                continue



    factors_base_info_str = ""
    metric_dict = {}
    for info in factors_info:
        #  
        factors_base_info_str += str(info) + "\n"
    
    factors_info_str = f"""
Here is some initial factors and their {metric} values you may follow their fundamental ideas to craft new factors:
{factors_base_info_str}
"""

    return factors_info_str, init_metric_value


#  
def generate_sample(interface_manual, factor_info, index):
    
     
    prompt = [
        {'role': 'system', 'content': interface_manual},
    ]

    user_content = f"""
Here is an initial factor and its {METRIC} value. Evolve this factor while trying to improve its {METRIC} value:
{factor_info}
"""
    prompt.append({'role': 'user', 'content': user_content})

    print(f" {factor_info['name']}  user_content: {user_content}")
    
    #  
    extra_info = {
        'answer': json.dumps({
            'name': factor_info['name'],
            'expr': factor_info['expr'],
            METRIC: factor_info[METRIC]
        }),
        'tools_kwargs': {
            'evaluate_factor': {'create_kwargs': {'init_factor_expr': factor_info['expr'], 'init_metric': factor_info[METRIC]}}
        },
        'index': index,
        'interaction_kwargs': {'dummy': None},
        'init_metric': factor_info[METRIC]
    }
    
    return {
        'data_source': 'factor_generation',
        'ability': 'factor_evolution',
        'prompt': prompt,
        'reward_model': {'ground_truth': 'NULL', 'name': 'DUMMY_NAME'},
        'extra_info': extra_info
    }

def backtest_and_save_factor(base_factors, output_file):
    # #  
    base_factors_w_metric = []

    def process_factor(factor):
        try:
            factor_info = json.loads(factor)
            metric_value, metric_status = backtest(factor_info, metric=METRIC)
            if metric_status == "success": #  and metric_value > 0
                factor_info_w_metric = {
                    'name': factor_info["name"],
                    'expr': factor_info["expr"],
                    METRIC: metric_value
                }
                print(f"{factor} success: {metric_value}")
                return factor_info_w_metric
            else:
                print(f"{factor} backtest failed: {metric_status} {metric_value}")
                return None
        except Exception as e:
            print(f"{factor} backtest failed: {e}")
            return None

    with open(output_file, "w", encoding="utf-8") as f:
        with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
            results = list(executor.map(process_factor, base_factors))
            for factor_info_w_metric in results:
                if factor_info_w_metric is not None:
                    base_factors_w_metric.append(factor_info_w_metric)
                    f.write(json.dumps(factor_info_w_metric) + "\n")


#  
def main():
    os.makedirs(DATA_DIR, exist_ok=True)
    
    with open("interface_rl_v2.md", "r", encoding="utf-8") as f:
        interface_manual = f.read()


  
    with open("base_factors/base_factors5.jsonl", "r", encoding="utf-8") as f:
        base_factors = f.readlines()

    import time
    start_time = time.time()
    print("starting backtest with multi-threading...")
    backtest_and_save_factor(base_factors, "base_factors/base_factor5_w_zz500_metric_23-24.jsonl")
    print("backtest completed")
    print(f"cost time: {time.time() - start_time:.2f}s , average {len(base_factors) / (time.time() - start_time):.2f} factors per second")
 


if __name__ == '__main__':
    main() 