#!/usr/bin/env python3
"""
OpenAI API Proxy Server

This server acts as a proxy for the OpenAI API, allowing compute nodes
without internet access to make API calls through port forwarding.

Usage:
1. Run this on your internet-connected node:
   python openai_proxy.py

2. Set up port forwarding from your compute node to this server

3. In your OpenAI client code, change only the base_url:
   client = OpenAI(base_url="http://localhost:8000/v1")
"""

import requests
import logging
import argparse
from typing import Optional
import uvicorn
from fastapi import FastAPI, Request, Response, HTTPException
from fastapi.responses import JSONResponse

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# OpenAI API configuration
OPENAI_BASE_URL = "https://api.openai.com"
TIMEOUT = 300  # 5 minutes timeout for long requests

async def forward_request(request: Request, path: str) -> Response:
    """
    Forward the request to OpenAI API and return the response.
    
    Args:
        request: FastAPI Request object
        path: The API endpoint path (e.g., "/v1/chat/completions")
    
    Returns:
        FastAPI Response object with OpenAI API response
    """
    # Construct the target URL
    target_url = f"{OPENAI_BASE_URL}{path}"
    
    # Get all headers from the original request
    headers = dict(request.headers)
    
    # Remove headers that shouldn't be forwarded
    headers.pop('host', None)
    headers.pop('content-length', None)
    
    # Log the request
    logger.info(f"{request.method} {path} -> {target_url}")
    
    try:
        # Get request body
        body = await request.body()
        
        # Forward the request to OpenAI
        if request.method == 'GET':
            response = requests.get(
                target_url,
                headers=headers,
                params=request.query_params,
                timeout=TIMEOUT
            )
        elif request.method == 'POST':
            response = requests.post(
                target_url,
                headers=headers,
                data=body,
                params=request.query_params,
                timeout=TIMEOUT
            )
        elif request.method == 'PUT':
            response = requests.put(
                target_url,
                headers=headers,
                data=body,
                params=request.query_params,
                timeout=TIMEOUT
            )
        elif request.method == 'DELETE':
            response = requests.delete(
                target_url,
                headers=headers,
                params=request.query_params,
                timeout=TIMEOUT
            )
        else:
            raise HTTPException(
                status_code=405,
                detail=f"Method {request.method} not supported"
            )
        
        # Create FastAPI Response from requests Response
        # We need to exclude problematic headers that FastAPI handles automatically
        excluded_headers = ['content-length', 'transfer-encoding', 'connection', 'content-encoding']
        response_headers = {
            key: value for key, value in response.headers.items()
            if key.lower() not in excluded_headers
        }
        
        # Return the FastAPI Response object properly
        return Response(
            content=response.content,
            status_code=response.status_code,
            headers=response_headers,
            media_type=response.headers.get('content-type', 'application/json')
        )
        
    except requests.exceptions.Timeout:
        logger.error(f"Timeout for {request.method} {path}")
        return JSONResponse(
            content={"error": {"message": "Request timeout", "type": "timeout"}},
            status_code=504
        )
    except requests.exceptions.RequestException as e:
        logger.error(f"Request failed for {request.method} {path}: {str(e)}")
        return JSONResponse(
            content={"error": {"message": f"Proxy error: {str(e)}", "type": "proxy_error"}},
            status_code=502
        )

# Catch all routes for OpenAI API
@app.api_route("/{path:path}", methods=['GET', 'POST', 'PUT', 'DELETE'])
async def proxy_openai(request: Request, path: str) -> Response:
    """Proxy all requests to OpenAI API."""
    return await forward_request(request, f"/{path}")

@app.get("/")
async def health_check():
    """Health check endpoint."""
    return {"status": "OpenAI Proxy Server is running", "target": OPENAI_BASE_URL}

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=8010)
    args = parser.parse_args()
    
    print("Starting OpenAI API Proxy Server...")
    print(f"Proxying requests to: {OPENAI_BASE_URL}")
    print(f"Server will run on http://0.0.0.0:{args.port}")
    print("\nTo use with OpenAI Python SDK:")
    print(f"client = OpenAI(base_url='http://YOUR_PROXY_HOST:{args.port}/v1')")
    print("\nPress Ctrl+C to stop the server")
    
    # Run the server
    uvicorn.run(
        app,
        host='0.0.0.0',  # Listen on all interfaces
        port=args.port
    )