"""
Backtest Strategy API Server
Provides HTTP interface using FastAPI and uvicorn to call backtest functions
"""
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Dict, Optional
import uvicorn
from datetime import datetime
import traceback
import pandas as pd
import asyncio
from concurrent.futures import ProcessPoolExecutor
from expression_manager.expr_parser import parse_expression
from expression_manager.function_lib import *
from jinja2 import Template


# Import backtest functions
from fast_backtest_api import backtest

# Create process pool executor (global resource, avoid repeated creation)
import os
MAX_WORKERS = int(os.cpu_count()*0.5)  # Limit process count to avoid excessive resource usage
process_executor = ProcessPoolExecutor(max_workers=MAX_WORKERS)

def run_backtest_process(backtest_params):
    """
    Wrapper function for executing backtest in process
    Must be module-level function to be serializable by ProcessPoolExecutor
    """
    return backtest(**backtest_params)


def convert_to_dict(df):
    df = df.replace({np.nan: None, pd.NaT: None})
    df['datetime'] = df['datetime'].dt.strftime('%Y-%m-%d')

    # Change to array return, column name as key, column values as array
    data = {}
    if not df.empty:
        for col in df.columns:
            data[col] = df[col].tolist()
    return data


app = FastAPI(
    title="Quantitative Backtest API",
    description="Provides HTTP interface for quantitative trading strategy backtesting",
    version="1.0.0"
)

# Load test data into memory at startup for expression testing interface
try:
    DEBUG_DF = pd.read_csv('.debug/debug_df.csv', index_col=[0, 1], parse_dates=True)
    print(DEBUG_DF)
    print(f"Successfully loaded debug_df.csv, shape: {DEBUG_DF.shape}")
except FileNotFoundError:
    DEBUG_DF = None
    print("Warning: debug_df.csv not found, /test_expr interface will be unavailable")
except Exception as e:
    DEBUG_DF = None
    print(f"Warning: Failed to read debug_df.csv: {e}")


class BacktestRequest(BaseModel):
    """Backtest request parameters"""
    exprs: Dict[str, str] = Field(..., description="Factor expression dictionary, key is factor name, value is expression")
    backtest_start_time: str = Field(..., description="Backtest start time, format: YYYY-MM-DD")
    backtest_end_time: str = Field(..., description="Backtest end time, format: YYYY-MM-DD")
    start_cash: float = Field(default=1e7, description="Initial capital")
    update_freq: int = Field(default=4, description="Update frequency (days)")
    label_forward_days: int = Field(default=4, description="Label forward days")
    stock_pool: str = Field(default="CSI500", description="Stock pool")
    stop_loss_rate: Optional[float] = Field(default=0.5, description="Stop loss rate")
    stop_profit_rate: Optional[float] = Field(default=0.4, description="Stop profit rate")
    position_size: Optional[float] = Field(default=1.0, description="Position size")
    max_pos_each_stock: Optional[float] = Field(default=0.2, description="Maximum position per stock")
    industry_neutralization: Optional[str] = Field(default="zscore", description="Industry neutralization method")
    use_cache: Optional[bool] = Field(default=False, description="Whether to use cache")
    layer_start: Optional[int] = Field(default=0, description="Layer start")
    layer_end: Optional[int] = Field(default=1, description="Layer end")
    pred_score_industry_neutralization: Optional[bool] = Field(default=False, description="Prediction score industry neutralization")

class BacktestResponse(BaseModel):
    """Backtest response result"""
    success: bool = Field(..., description="Whether request is successful")
    message: str = Field(..., description="Response message")
    data: Optional[Dict] = Field(None, description="Backtest result data")
    error: Optional[str] = Field(None, description="Error information")

# New: Expression test request/response models
class ExprTestRequest(BaseModel):
    """Expression validity test request parameters"""
    name: str = Field(..., description="Factor name, used to identify the expression")
    expr: str = Field(..., description="Expression to be tested, like 'TS_STD($close,5)'")

class ExprTestResponse(BaseModel):
    """Expression validity test return result"""
    success: bool = Field(..., description="Whether API call is successful")
    message: str = Field(..., description="Return information")
    exe_feedback: Optional[str] = Field(None, description="Execution feedback, None when successful, traceback when failed")
    code: Optional[str] = Field(None, description="Actually executed code")
    sample: Optional[Dict] = Field(None, description="Sample data (dict format)")

@app.get("/")
async def root():
    """Root path, returns API information"""
    return {
        "message": "Quantitative Backtest API Service",
        "version": "1.0.0",
        "endpoints": {
            "POST /backtest": "Execute backtest",
            "GET /health": "Health check"
        }
    }

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    return {
        "status": "healthy",
        "timestamp": datetime.now().isoformat(),
        "service": "Quantitative Backtest API"
    }

@app.post("/backtest", response_model=BacktestResponse)
async def run_backtest(request: BacktestRequest):
    """
    Execute backtest
    
    Parameters:
    - exprs: Factor expression dictionary
    - backtest_start_time: Backtest start time, format: YYYY-MM-DD
    - backtest_end_time: Backtest end time, format: YYYY-MM-DD
    - start_cash: Initial capital
    - update_freq: Update frequency
    - label_forward_days: Label forward days
    - stock_pool: Stock pool
    - Other optional parameters...
    
    Returns:
    - Backtest result data
    """
    try:
        # Validate date format
        try:
            # Set default values
            backtest_start_time = request.backtest_start_time or "2023-01-01"
            backtest_end_time = request.backtest_end_time or "2024-01-01"
            datetime.strptime(backtest_start_time, '%Y-%m-%d')
            datetime.strptime(backtest_end_time, '%Y-%m-%d')
        except ValueError as e:
            raise HTTPException(status_code=400, detail=f"Date format error: {str(e)}")
        
        # Prepare backtest parameters
        backtest_params = {
            'exprs': request.exprs,
            'backtest_start_time': request.backtest_start_time,
            'backtest_end_time': request.backtest_end_time,
            'start_cash': request.start_cash,
            'update_freq': request.update_freq,
            'label_forward_days': request.label_forward_days,
            'stock_pool': request.stock_pool,
            'stop_loss_rate': request.stop_loss_rate,
            'stop_profit_rate': request.stop_profit_rate,
            'position_size': request.position_size,
            'max_pos_each_stock': request.max_pos_each_stock,
            'use_cache': request.use_cache,
            'layer_start': request.layer_start,
            'layer_end': request.layer_end,
            'pred_score_industry_neutralization': request.pred_score_industry_neutralization,
        }
        
        # Use process pool to execute backtest asynchronously - true parallel processing for CPU-intensive tasks
        print(f"Starting backtest execution, parameters: {backtest_params}")
        loop = asyncio.get_event_loop()
        results = await loop.run_in_executor(process_executor, run_backtest_process, backtest_params)
        
        # Process results, remove non-serializable objects
        serializable_results = {}
        for k, v in results.items():
            try:
                # Try to convert to serializable format
                if hasattr(v, 'to_dict'):
                    serializable_results[k] = v.to_dict()
                elif hasattr(v, 'tolist'):
                    serializable_results[k] = v.tolist()
                elif isinstance(v, (int, float, str, bool, list, dict)):
                    serializable_results[k] = v
                else:
                    serializable_results[k] = str(v)
            except:
                serializable_results[k] = str(v)
        
        return BacktestResponse(
            success=True,
            message="Backtest execution successful",
            data=serializable_results
        )
        
    except Exception as e:
        print(f"Error type: {type(e)}")
        error_msg = f"Backtest execution failed: {str(e)}"
        print(f"Error details: {traceback.format_exc()}")
        
        # Return 500 status code indicating server internal error
        raise HTTPException(
            status_code=500,
            detail={
                "success": False,
                "message": "Backtest execution failed",
                "error": error_msg
            }
        )

@app.get("/example")
async def get_example_request():
    """Get example request parameters"""
    return {
        "exprs": {
            "Smart_Volume_Cluster_Composite": "(TS_STD($close,5)/(TS_STD($close,20)+1e-8)) * ($volume > TS_QUANTILE($volume,20,0.9))",
        },
        "backtest_start_time": "2020-01-01",
        "backtest_end_time": "2022-12-31",
        "start_cash": 10000000.0,
        "update_freq": 4,
        "label_forward_days": 4,
        "stock_pool": "CSI500",
        "stop_loss_rate": 0.5,
        "stop_profit_rate": 0.4,
        "position_size": 1.0,
        "max_pos_each_stock": 0.2,
        "use_cache": False,
        "layer_start": 0,
        "layer_end": 1,
        "pred_score_industry_neutralization": False
    }

@app.post("/test_expr", response_model=ExprTestResponse)
async def test_expression(request: ExprTestRequest):
    """
    Test factor expression validity

    Process:
    1. Get test data
    2. Use Jinja2 template to render expression, generate executable Python code
    3. Use `exec` to execute generated code
    
    Return description:
    - success: Whether API call is successful
    - exe_feedback: Execution feedback, None when successful, complete traceback when failed
    - code: Actually executed code
    """
    try:
        # 1. Get test data
        if DEBUG_DF is None:
            return ExprTestResponse(
                success=True,
                message="Server has not loaded debug_df.csv, cannot test expression",
                exe_feedback="Server has not loaded debug_df.csv, cannot test expression",
                code=None,
                sample=None
            )
        
        df = DEBUG_DF.copy()

        # 2. Use template to render expression
        import os
        os.makedirs('.debug', exist_ok=True)  # Ensure directory exists
        
        with open('expression_manager/template.jinjia2', 'r') as f:
            template_content = f.read()
        template = Template(template_content)
        rendered_code = template.render(
            expression=request.expr,
            factor_name=request.name
        )

        print(f"{'='*100}\n {rendered_code}\n {'='*100}")
        # Execute rendered code
        try:
            exec(rendered_code)
            print("Execution successful")
        except Exception as e:
            print(traceback.format_exc())
            return ExprTestResponse(
                success=True,
                message="Expression execution failed",
                exe_feedback=traceback.format_exc(),
                code=rendered_code,
                sample=None
            )
        
        # Read factor expression calculation results
        result_df = pd.read_pickle('.debug/result_df.pkl').reset_index()
        sample = convert_to_dict(result_df)
        
        # Execution successful
        return ExprTestResponse(
            success=True,
            message="Expression parsing and calculation successful",
            exe_feedback=None,
            code=rendered_code,
            sample=sample
        )
            

    except Exception as e:
        # Exception in API call itself
        raise HTTPException(
            status_code=500,
            detail=f"API call exception: {str(e)}"
        )

@app.on_event("shutdown")
async def shutdown_event():
    """Clean up process pool when server shuts down"""
    print("Shutting down process pool...")
    process_executor.shutdown(wait=True)
    print("Process pool has been shut down")

if __name__ == "__main__":
    # Start server
    print("Starting quantitative backtest API server...")
    print(f"Using process pool, maximum processes: {MAX_WORKERS}")
    print("API documentation: http://localhost:8000/docs")
    print("Health check: http://localhost:8000/health")
    
    try:
        uvicorn.run(
            "api_server_fast:app",
            host="0.0.0.0",
            port=8000,
            reload=True,
            log_level="info"
        )
    finally:
        # Ensure process pool is properly shut down
        process_executor.shutdown(wait=True) 