import time
from typing import List, Tuple
import uhashring
from datetime import datetime, timedelta
import os
import numpy as np
import threading
import heapq
from collections import deque
from typing import Dict

from flex_attention_vllm.entities.request import Request
from flex_attention_vllm.entities.benchmark_utils_preble import RequestFuncOutput
from flex_attention_vllm.scheduler.global_scheduler.base_global_scheduler import BaseGlobalScheduler
from flex_attention_vllm.scheduler.utils.preble_global_scheduler_utils import PrebleGlobalSchedulerUtils
from flex_attention_vllm.scheduler.utils.shared import SharedState

from flex_attention_vllm.logger import init_logger

logger = init_logger(__name__)

class CacheAffinityGlobalScheduler(BaseGlobalScheduler):
    def __init__(self, num_replicas, window_duration: int, shared_state: SharedState, args):
        super().__init__(num_replicas)
        self.preble_schedule_util = PrebleGlobalSchedulerUtils(num_replicas, window_duration, args)
        self.shared_state = shared_state
    
    def hash_function1(self, task_id, num_nodes):
        nodes = [str(i) for i in range(num_nodes)]
        hash_ring1 = uhashring.HashRing(nodes=nodes)
        return int(hash_ring1.get_node(str(task_id)))

    async def schedule(self, request: Request) -> int:
        shortest_prefix = request._hash_session_id.split("/")[0]
        replica_id = self.hash_function1(shortest_prefix, self._num_replicas)

        logger.debug(f'_hash_session_id:{request._hash_session_id}, shortest_prefix:{shortest_prefix}, replica_id:{replica_id}')
            
        return replica_id

    def finish_request(
        self, func_output: RequestFuncOutput=None, text: str = None, input_ids=None
    ):
        return