import asyncio
import aiohttp
import time
import json
import sys
import os
from typing import List, Optional
from datetime import datetime
from transformers import AutoTokenizer

from flex_attention_vllm.scheduler.utils.shared import SharedState
from flex_attention_vllm.scheduler.utils.lazy_prefix_table import LazyPrefixTable,HotPrefixDetector,LazyExpansionController
from flex_attention_vllm.entities.benchmark_utils_preble import RequestFuncOutput
from itertools import count
from flex_attention_vllm.entities.request import Request
from flex_attention_vllm.metrics.metrics_store import MetricsStore

from flex_attention_vllm.logger import init_logger

logger = init_logger(__name__)


request_id_counter = count()

class RequestGenerator:
    def __init__(self, shared_state: SharedState, tokenizer, num_replicas, args):
        self.shared_state = shared_state
        self._args = args
        self._num_replicas = num_replicas
        self._model = args.model
        self._model_path = args.model_path
        self._model_name = args.model_name
        self._max_model_len = args.max_model_len
        self._request_dataset_dir = args.request_dataset_dir
        self._request_dataset_file = args.request_dataset_file
        self._process_dataset_online = args.process_dataset_online
        self._ct_ratio = args.ct_ratio
        self._request_generate_qps = args.request_generate_qps
        self._request_native_qps = args.request_native_qps
        self._num_request = args.request_num
        self._warm_up_requests_num = args.warm_up_requests_num
        self._warm_up_qps = args.warm_up_qps
        self._requests_num_dataset_start = args.requests_num_dataset_start
        self._result_path = args.result_path
        self._cur_num_request = 0
        self.cur_num_request_lock = asyncio.Lock()
        self._num_request_completed = 0
        self.block_size = 512
        self.g_session_id_map = {} # native_session_id:g_session_id
        self.hash_session_id_map = {} # native_session_id:g_session_id
        self.session_id_map_lock = asyncio.Lock()
        self._active_timeout = args.request_active_timeout  #s
        self._is_finished = False
        self.tokenizer = tokenizer
        self.stable_tokens = self._build_stable_token_mapping()

        self.prefix_table = LazyPrefixTable()
        self.hot_prefix_detector = HotPrefixDetector()
        self.prefix_expansion_ctrl = LazyExpansionController(self.prefix_table, self.hot_prefix_detector)        
        

    def _build_stable_token_mapping(self) -> dict:
        stable_mapping = {}
        
        for token_id in range(1000, min(50000, self.tokenizer.vocab_size)):
            if token_id in self.tokenizer.all_special_ids:
                continue
                
            text = self.tokenizer.decode([token_id])
            reencoded = self.tokenizer.encode(text, add_special_tokens=False)
            
            if len(reencoded) == 1 and reencoded[0] == token_id:
                stable_mapping[len(stable_mapping)] = token_id 
                if len(stable_mapping) >= 5000:  
                    break
                    
        if not stable_mapping:
            raise ValueError("failed to build stable token mapping.")
        
        return stable_mapping

    def _get_token_for_hash(self, hash_id: int) -> int:
        return self.stable_tokens[hash_id % len(self.stable_tokens)]

    def is_request_active(self) -> bool:
        last_request_time = self.shared_state.last_request_time

        if last_request_time is None:
            return True

        if self._is_finished:
            logger.debug(f'is_request_active: request finish prefill, completed_num_req={self._num_request_completed}')
            return False
        if not last_request_time:
            last_request_time = time.perf_counter()
            self.shared_state.last_request_time = time.perf_counter() 

        return time.perf_counter() - last_request_time < self._active_timeout

    async def _generate_request_helper(self, request: Request, native_session_id: str):
        if request is None:
            return
        start = time.perf_counter()
        event = asyncio.Event()
        self.shared_state.runtime_events[request._id] = (event, None)
        await self.shared_state.runtime_request_queue.put((request))

        start = time.perf_counter()
        await self.shared_state.runtime_events[request._id][0].wait()
        replica_id = self.shared_state.runtime_events[request._id][1]
        self.shared_state.runtime_events.pop(request._id)

        if replica_id is None:
            raise RuntimeError("Runtime selection failed")
        if replica_id < 0:
           logger.debug(f"replica_id < 0")
           return 
        if replica_id >= 0 and replica_id < self._num_replicas:
            await self.shared_state.add_posting_request_tasks(replica_id, request)


    async def generate_request_offline(self, record: dict, time_interval, dataset_type) -> dict:

        prompts = record["prompts"]
        input_ids = record["input_ids"]
        output_len = record["output_len"]
        
        g_session_id = record["g_session_id"]
        hash_session_id = record["hash_session_id"]
        hash_ids = record["hash_ids"]

        if g_session_id == "" or hash_session_id == ""\
            or prompts == "" or len(input_ids) > self._max_model_len or len(input_ids) == 0\
            or output_len <= 0 or hash_ids is None or len(hash_ids) < 1:
            return

        parts_hash_session_id = hash_session_id.split("@")
        if len(parts_hash_session_id) < 2:
            return
        session_id = parts_hash_session_id[1]
        
        if len(input_ids) >= self._max_model_len - 2*1024:
            return

        parts_g_session_id = g_session_id.split("@")
        if len(parts_g_session_id) < 3:
            return
        
        hash_prefix_depth = self.prefix_table.lookup(hash_ids)
        self.prefix_expansion_ctrl.process(hash_ids)
        hash_prefix_str = "".join(map(str, hash_ids[:hash_prefix_depth]))


        request_id = next(request_id_counter)
        new_g_session_id = f"{parts_g_session_id[0]}@{parts_g_session_id[1]}@{request_id}"
        hash_session_id = f"{parts_hash_session_id[0]}@{hash_prefix_str}"
        logger.debug(f"generate_request_offline:request_id={request_id},new_g_session_id={new_g_session_id},"
                     f"hash_prefix_depth={hash_prefix_depth},hash_session_id={hash_session_id}")
        request = Request(
            request_id = int(request_id),
            dataset_type = dataset_type,
            native_session_id = int(session_id),
            session_id = new_g_session_id,
            hash_session_id = hash_session_id,
            round_id = 0, 
            prompts = prompts,
            input_ids = input_ids,
            num_prefill_tokens = len(input_ids),
            actual_num_prefill_tokens = len(input_ids),
            output_len = int(output_len),
            over_flow = False,
            n = 1,
            temperature = 0,
            top_p = 1,
            max_tokens = int(output_len),
            stream = True,
            arrived_at = time.perf_counter(),
            time_interval = time_interval
        )
        self._cur_num_request += 1
        await self._generate_request_helper(request, session_id)

    async def generate_request_online(self, record: dict, time_interval, dataset_type) -> dict:
        if record["timestamp"] < 0 or record["input_length"] > self._max_model_len or record["input_length"] <= 0 \
        or record["output_length"] < 0 or len(record["hash_ids"]) < 1:
            return

        token_ids = []
        remaining = record["input_length"]
        start = time.perf_counter()
        prompts = ""
        for hid in record["hash_ids"][:-1]:
            block_size = min(self.block_size, remaining)
            token_id = self._get_token_for_hash(hid)
            token_ids.extend([token_id] * block_size)

            # prompts += "" + f"{token_id}" * block_size

            remaining -= block_size
        
        if remaining > 0:
            token_id = self._get_token_for_hash(record["hash_ids"][-1])
            token_ids.extend([token_id] * remaining)

        start = time.perf_counter()
        prompts = self.tokenizer.decode(token_ids)

        start = time.perf_counter()
        tmp_input_ids = self.tokenizer.encode(prompts)

        input_ids = token_ids
        if len(input_ids) >= self._max_model_len - 2 * 1024:
            return

        request_id = next(request_id_counter)
        session_id = ""
        g_session_id = ""
        hash_session_id = ""
        now_time = datetime.now().timestamp()
        if len(record["hash_ids"]) >= 2:
            session_id = str(record["hash_ids"][0]) + str(record["hash_ids"][1])
        else:
            session_id = str(record["hash_ids"][0])

        if session_id in self.hash_session_id_map:
            hash_session_id = self.hash_session_id_map[session_id]
        else:
            hash_session_id = str(now_time) + "@" + str(session_id)
            self.hash_session_id_map[session_id] = hash_session_id
        g_session_id = hash_session_id + "@" + str(request_id)

        round_id = 0
        output_len = record["output_length"]
        logger.debug(f"generate_request_online:request_id={request_id},new_g_session_id={g_session_id}")
        request = Request(
            request_id = int(request_id),
            dataset_type = dataset_type,
            native_session_id = int(session_id),
            session_id = g_session_id,
            hash_session_id = hash_session_id,
            round_id = 0, 
            prompts = prompts,
            input_ids = input_ids,
            num_prefill_tokens = len(input_ids),
            actual_num_prefill_tokens = len(input_ids),
            output_len = int(output_len),
            over_flow = False,
            n = 1,
            temperature = 0,
            top_p = 1,
            max_tokens = int(output_len),
            stream = True,
            arrived_at = time.perf_counter(),
            time_interval = time_interval
        )
        self._cur_num_request += 1
        await self._generate_request_helper(request, session_id)


    async def _process_group(self, group: list, time_interval: float) -> list:
        for record in group:
            start = time.perf_counter()
            await self.generate_request_online(record, time_interval)
            await asyncio.sleep(time_interval)  
            logger.debug(f"generate_request: delay={time.perf_counter() - start:.4f}s, time_interval={time_interval}")


    def valid_record_online(self, record):
        valid = True
        if record["timestamp"] < 0 or record["input_length"] > 20480 or record["input_length"] <= 0 \
        or record["output_length"] < 0 or len(record["hash_ids"]) < 1:
            valid = False
        return valid


    def valid_record_offline(self, record):
        valid = True
        if record["timestamp"] < 0:
            valid = False
        return valid

    async def generate_from_file_online(self):
        with open(self._request_dataset_file, 'r') as f:
            current_group = []
            prev_timestamp = None
            time_interval = 2
            while True:
                line = f.readline()

                if self._cur_num_request >= self._num_request:
                    self._is_finished = True
                    break
                if not line:  
                    if current_group:
                        await self._process_group(current_group, time_interval)
                    break
                
                record = json.loads(line)
                if not self.valid_record_online(record):
                    continue

                current_timestamp = record["timestamp"]
                
                if prev_timestamp is None:
                    prev_timestamp = current_timestamp               
                
                if current_timestamp == prev_timestamp:
                    current_group.append(record)
                else:
                    if current_group:
                        native_time_interval = round((current_timestamp - prev_timestamp) / len(current_group) / 1000, 3)
                        if self._warm_up_requests_num > self._cur_num_request:
                            time_interval = 2 # warm up:qps=0.5
                        else:
                            time_interval = round(native_time_interval * self._request_native_qps / self._request_generate_qps, 3)
                        await self._process_group(current_group, time_interval)
                    current_group = [record]
                    prev_timestamp = current_timestamp


    async def generate_from_file(self):   
        conversation_dataset_file = ""
        toolagent_dataset_file = ""
        conversation_dataset_type = ""
        toolagent_dataset_type = ""

        if self._process_dataset_online:
            conversation_dataset_file = os.path.join(
                self._request_dataset_dir,
                "conversation_trace.jsonl"
            )        
            
            toolagent_dataset_file = os.path.join(
                self._request_dataset_dir,
                "toolagent_trace.jsonl"
            )  
            conversation_dataset_type = "mooncake-conversation-online"
            toolagent_dataset_type = "mooncake-toolagent-online"                      
        else:
            conversation_dataset_file = os.path.join(
                self._request_dataset_dir,
                "processed_conversation_trace_single_token.jsonl"
            )        
            
            toolagent_dataset_file = os.path.join(
                self._request_dataset_dir,
                "processed_toolagent_trace_single_token.jsonl"
            )
            conversation_dataset_type = "mooncake-conversation-offline"
            toolagent_dataset_type = "mooncake-toolagent-offline"


        conversation_turns = 10
        toolagent_turns = 10

        if self._ct_ratio > 1:
            toolagent_turns = int(conversation_turns/self._ct_ratio)
        elif self._ct_ratio >= 0 and self._ct_ratio <= 1:
            conversation_turns = int(toolagent_turns * self._ct_ratio)            


        with open(conversation_dataset_file, 'r') as conversation_file, open(toolagent_dataset_file, 'r') as toolagent_file:
            current_conversation_group = []
            prev_conversation_timestamp = None
            time_conversation_interval = 2
            conversation_request_native_qps = 3.34

            current_toolagent_group = []
            prev_toolagent_timestamp = None
            time_toolagent_interval = 2
            toolagent_request_native_qps = 6.5 / 2  

            cur_line_num = 0
            while True:
                if self._cur_num_request >= self._num_request:
                    break    

                for i in range(conversation_turns):
                    # for conversation_file
                    line = conversation_file.readline()
                    cur_line_num += 1
                    if self._requests_num_dataset_start >= cur_line_num:
                        continue

                    if not line:  
                        for record_in_group in current_conversation_group:
                            start = time.perf_counter()
                            if self._process_dataset_online:
                                await self.generate_request_online(record_in_group, time_conversation_interval, conversation_dataset_type)
                            else:
                                await self.generate_request_offline(record_in_group, time_conversation_interval, conversation_dataset_type)
                            delay = time.perf_counter() - start
                            await asyncio.sleep(max(0, time_conversation_interval - delay))
                            await asyncio.sleep(0)                     
                        break
                    
                    record = json.loads(line)
                    if not self.valid_record_offline(record):
                        continue

                    current_timestamp = record["timestamp"]
                    
                    if prev_conversation_timestamp is None:
                        prev_conversation_timestamp = current_timestamp               
                    
                    if current_timestamp == prev_conversation_timestamp:
                        current_conversation_group.append(record)
                    else:
                        if current_conversation_group:
                            native_time_interval = round((current_timestamp - prev_conversation_timestamp) / len(current_conversation_group) / 1000, 3)
                            if self._warm_up_requests_num > self._cur_num_request:
                                if self._warm_up_qps > 0:
                                    time_conversation_interval = 1/self._warm_up_qps # warm up tool-agent:qps=0.5, mooncake-conversation qps=0.3
                                else:
                                    time_conversation_interval = 2
                            else:
                                time_conversation_interval = round(native_time_interval * conversation_request_native_qps / self._request_generate_qps, 3)
                            for record_in_group in current_conversation_group:
                                start = time.perf_counter()
                                if self._process_dataset_online:
                                    await self.generate_request_online(record_in_group, time_conversation_interval, conversation_dataset_type)
                                else:
                                    await self.generate_request_offline(record_in_group, time_conversation_interval, conversation_dataset_type)
                                delay = time.perf_counter() - start
                                await asyncio.sleep(max(0, time_conversation_interval - delay))
                                await asyncio.sleep(0)  
                                

                        current_conversation_group = [record]
                        prev_conversation_timestamp = current_timestamp


                # for toolagent_file
                for i in range(toolagent_turns):
                    line = toolagent_file.readline()
                    cur_line_num += 1
                    if self._requests_num_dataset_start >= cur_line_num:
                        continue

                    if not line:  
                        for record_in_group in current_toolagent_group:
                            start = time.perf_counter()
                            if self._process_dataset_online:
                                await self.generate_request_online(record_in_group, time_toolagent_interval, toolagent_dataset_type)
                            else:
                                await self.generate_request_offline(record_in_group, time_toolagent_interval, toolagent_dataset_type)
                            delay = time.perf_counter() - start
                            await asyncio.sleep(max(0, time_toolagent_interval - delay))
                            await asyncio.sleep(0)                    
                        break
                    
                    record = json.loads(line)
                    if not self.valid_record_offline(record):
                        continue

                    current_timestamp = record["timestamp"]
                    
                    if prev_toolagent_timestamp is None:
                        prev_toolagent_timestamp = current_timestamp               
                    
                    if current_timestamp == prev_toolagent_timestamp:
                        current_toolagent_group.append(record)
                    else:
                        if current_toolagent_group:
                            native_time_interval = round((current_timestamp - prev_toolagent_timestamp) / len(current_toolagent_group) / 1000, 3)
                            if self._warm_up_requests_num > self._cur_num_request:
                                if self._warm_up_qps > 0:
                                    time_toolagent_interval = 1/self._warm_up_qps # warm up tool-agent:qps=0.5, mooncake-toolagent qps=0.3
                                else:
                                    time_toolagent_interval = 2
                            else:
                                time_toolagent_interval = round(native_time_interval * toolagent_request_native_qps / self._request_generate_qps, 3)
                            for record_in_group in current_toolagent_group:
                                start = time.perf_counter()
                                if self._process_dataset_online:
                                    await self.generate_request_online(record_in_group, time_toolagent_interval, toolagent_dataset_type)
                                else:
                                    await self.generate_request_offline(record_in_group, time_toolagent_interval, toolagent_dataset_type)
                                delay = time.perf_counter() - start
                                await asyncio.sleep(max(0, time_toolagent_interval - delay))
                                await asyncio.sleep(0)  

                        current_toolagent_group = [record]
                        prev_toolagent_timestamp = current_timestamp