import os

import aiohttp
from quart import Quart, make_response, request

from torch.distributed import TCPStore
import json
import numpy as np
import math
import time
import asyncio
from dataclasses import dataclass

AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)

app = Quart(__name__)

DECODE_INSTANCE_STORE_PORT = int(os.getenv('FRONTEND_TCPSTORE_PORT'))

SLURM_JOB_ID = int(os.getenv('SLURM_JOB_ID', '0'))

PREFILL_HOST = os.getenv('PREFILL_HOST', 'localhost')
DECODE_HOST = os.getenv('DECODE_HOST', 'localhost')

# prefill_status_store = TCPStore(PREFILL_HOST, PREFILL_INSTANCE_STORE_PORT, 2)
decode_status_store = TCPStore(DECODE_HOST, DECODE_INSTANCE_STORE_PORT, 2)

sem_prefill = asyncio.Semaphore(2)

@dataclass
class Param:
    k_ctxp: float
    k_ctxd: float 
    b_c: float
    k_p: float
    b_p: float

MODEL_NAME = os.getenv("MODEL_NAME", "8B")
DECODE_DEVICE = os.getenv("DECODE_DEVICE", "A10")

assert MODEL_NAME == "8B" or MODEL_NAME == "7B"
assert DECODE_DEVICE == "A10" or DECODE_DEVICE == "A30"

if MODEL_NAME == "8B":
    PARAM_A10040G = Param(
        k_ctxp=2.06624628e-06,
        k_ctxd=1.04443892e-07,
        b_c=0.037658844225332555,
        k_p=7.05190459e-05,
        b_p=-0.005132039478080219,
    )
    PARAM_A10080G = Param(
        k_ctxp=2.22787682e-06,
        k_ctxd=9.39950020e-08,
        b_c=0.04197140604981528,
        k_p=None,
        b_p=None,
    )
    PARAM_A3032G = Param(
        k_ctxp=3.39497552e-06,
        k_ctxd=2.23915077e-07,
        b_c=0.07603107133391875,
        k_p=0.00015465,
        b_p=-0.016427026179893667,
    )
    PARAM_A1032G = Param(
        k_ctxp=None,
        k_ctxd=None,
        b_c=None,
        k_p=0.00023478,
        b_p=-0.016552281338470753,
    )
elif MODEL_NAME == "7B":
    PARAM_A10080G = Param(
        k_ctxp=1.77710142e-06,
        k_ctxd=4.75208922e-08,
        b_c=0.03944704604077646,
        k_p=None,
        b_p=None,
    )
    PARAM_A3032G = Param(
        k_ctxp=None,
        k_ctxd=None,
        b_c=None,
        k_p=0.00013972,
        b_p=-0.002145770276050507,
    )
    PARAM_A1032G = Param(
        k_ctxp=None,
        k_ctxd=None,
        b_c=None,
        k_p=0.00021635,
        b_p=-0.011237839322866705,
    )
else:
    raise RuntimeError("Unknown model name!!")

def cal_L_p(
        param_prefill: Param,
        param_chunked: Param,
        L_in: int, n_d: int, L_ctxd: int
    ) -> int:
    K_CTXP = param_chunked.k_ctxp
    K_CTXD = param_chunked.k_ctxd
    B_C = param_chunked.b_c
    K_P = param_prefill.k_p
    B_P = param_prefill.b_p

    L_p = np.linspace(1, L_in, 512, dtype=int)
    t_prefill = K_P * L_p + B_P
    num_prefill_tokens = 512 - n_d
    L_c = L_in - L_p
    num_iter = np.ceil(L_c / num_prefill_tokens)
    L_last = L_p + ((L_in - L_p) // num_prefill_tokens) * num_prefill_tokens
    t_chunked = num_iter * (K_CTXP * (L_in + L_last) / 2) + num_iter * (K_CTXD * L_ctxd + B_C)
    t_diff = t_prefill - t_chunked

    chosen_idx = np.argmin(np.abs(t_diff))

    return int(L_p[chosen_idx])

async def forward_request(url, data):
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
        headers = {
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
        }
        async with session.post(url=url, json=data,
                                headers=headers) as response:
            if response.status == 200:
                # if response.headers.get('Transfer-Encoding') == 'chunked':
                if True:
                    async for chunk_bytes in response.content.iter_chunked(
                            1024):
                        yield chunk_bytes
                else:
                    content = await response.read()
                    yield content


@app.route('/v1/completions', methods=['POST'])
async def handle_request():
    try:
        original_request_data = await request.get_json()

        async with sem_prefill:
            # prefill_status = json.loads(prefill_status_store.get('status'))
            chunked_status = json.loads(decode_status_store.get('status'))

            num_decode = chunked_status['num_decode']
            ctx_decode = chunked_status['ctx_decode']
            free_tokens = chunked_status['free_tokens']
            input_length = original_request_data['prompt_len']

            if free_tokens - num_decode > input_length:
                prefill_instance_compute_len = cal_L_p(
                    param_prefill=PARAM_A1032G if DECODE_DEVICE == "A10" else PARAM_A3032G,
                    param_chunked=PARAM_A10080G,
                    L_in=input_length, n_d=num_decode, L_ctxd=ctx_decode)
            else:
                prefill_instance_compute_len = input_length

            if prefill_instance_compute_len > 8000-1:
                prefill_instance_compute_len = 8000-1

            prefill_request = original_request_data.copy()
            decode_request = original_request_data.copy()
            # change max_tokens = 1 to let it only do prefill
            prefill_request['max_tokens'] = 1
            prefill_request['truncate_prompt_tokens'] = prefill_instance_compute_len

            decode_request['computed_len'] = prefill_instance_compute_len 

            # finish prefill
            async for _ in forward_request(
                f'http://{PREFILL_HOST}:{8100+SLURM_JOB_ID%1000}/v1/completions',
                prefill_request):
                continue

        # return decode
        generator = forward_request(
            f'http://{DECODE_HOST}:{8200+SLURM_JOB_ID%1000}/v1/completions',
            decode_request)
        response = await make_response(generator)
        response.timeout = None

        return response

    except Exception as e:
        import sys
        import traceback
        exc_info = sys.exc_info()
        print("Error occurred in disagg prefill proxy server")
        print(e)
        print("".join(traceback.format_exception(*exc_info)))


if __name__ == '__main__':
    app.run(port=8000+int(os.getenv('PORT_OFFSET')))
