import os

import aiohttp
from quart import Quart, make_response, request

from torch.distributed import TCPStore
import json
import numpy as np
import math
import time
import asyncio
from dataclasses import dataclass

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

SLURM_JOB_ID = int(os.getenv('SLURM_JOB_ID', '0'))

app = Quart(__name__)

DECODE_INSTANCE_STORE_PORT = int(os.getenv('FRONTEND_TCPSTORE_PORT'))

PREFILL_HOST = os.getenv('PREFILL_HOST', 'localhost')
DECODE_HOST = os.getenv('DECODE_HOST', 'localhost')

after_first_req = False

# prefill_status_store = TCPStore(PREFILL_HOST, PREFILL_INSTANCE_STORE_PORT, 2)
decode_status_store = TCPStore(DECODE_HOST, DECODE_INSTANCE_STORE_PORT, 2)

async def forward_request(url, data):
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }
        async with session.post(url=url, json=data,
                                headers=headers) as response:
            if response.status == 200:
                # if response.headers.get('Transfer-Encoding') == 'chunked':
                if True:
                    async for chunk_bytes in response.content.iter_chunked(
                            1024):
                        yield chunk_bytes
                else:
                    content = await response.read()
                    yield content


@app.route('/v1/completions', methods=['POST'])
async def handle_request():
    global after_first_req
    try:
        original_request_data = await request.get_json()

        #RATIO = 1.0
        # for test only
        input_length = original_request_data['prompt_len']
        prefill_instance_compute_len = int(input_length)
        #if input_length - prefill_instance_compute_len < 128:
        #    prefill_instance_compute_len = input_length

        prefill_request = original_request_data.copy()
        decode_request = original_request_data.copy()
        # change max_tokens = 1 to let it only do prefill
        prefill_request['max_tokens'] = 1
        prefill_request['truncate_prompt_tokens'] = prefill_instance_compute_len

        decode_request['computed_len'] = prefill_instance_compute_len 

        if not after_first_req:
            after_first_req = True
            # finish prefill
            async for _ in forward_request(
                f'http://{PREFILL_HOST}:{8100+SLURM_JOB_ID%1000}/v1/completions',
                prefill_request):
                continue

        # return decode
        generator = forward_request(
            f'http://{DECODE_HOST}:{8200+SLURM_JOB_ID%1000}/v1/completions',
            decode_request)
        response = await make_response(generator)
        response.timeout = None

        return response

    except Exception as e:
        import sys
        import traceback
        exc_info = sys.exc_info()
        print("Error occurred in disagg prefill proxy server")
        print(e)
        print("".join(traceback.format_exception(*exc_info)))


if __name__ == '__main__':
    app.run(port=8000+int(os.getenv('PORT_OFFSET')))
