# shared_state_pool.py
import threading
import torch
import torch.profiler

import sys
def debug_refcount(tag, tensor):
    if not isinstance(tensor, torch.Tensor):
        print(f"[{tag}] 不是 Tensor: {tensor}")
        return
    try:
        refcount = sys.getrefcount(tensor)
        print(f"[{tag}] tensor shape={tuple(tensor.shape)}, dtype={tensor.dtype}, device={tensor.device}, refcount={refcount}")
    except Exception as e:
        print(f"[{tag}] error: {e}")
class SharedStatePool:

    _instance = None
    _lock = threading.Lock()

    def __new__(cls, *args, **kwargs):
        with cls._lock:
            if cls._instance is None:
                cls._instance = super(SharedStatePool, cls).__new__(cls)
                cls._instance._initialized = False
            return cls._instance

    def __init__(self, num_channels=3):
        if not self._initialized:
            self.num_channels = num_channels
            self.hiddenstates = [[] for _ in range(self.num_channels)]
            self.logits = [[] for _ in range(self.num_channels)]
            self.input_ids = None
            self.attention_mask = None
            
            self.lock = threading.Lock()
            self.condition = threading.Condition(self.lock)
            self.logits_lock = threading.Lock()
            self.logits_condition = threading.Condition(self.logits_lock)
            self.input_lock = threading.Lock()
            self.input_condition = threading.Condition(self.input_lock)
            self._initialized = True

            self.is_finished = False
            self.generate_token_idx = 0

    def put_hidden(self, channel_id, hidden_state, metadata=None):
        with self.lock:
            metadata["step_idx"] = self.generate_token_idx
            self.hiddenstates[channel_id].append((hidden_state, metadata))
        with self.condition:
            self.condition.notify_all()
            
    def get_hidden(self, channel_id, pos_id):
        with self.lock:
            if pos_id < len(self.hiddenstates[channel_id]):
                return self.hiddenstates[channel_id][pos_id]
            return None

    def get_hidden_blocking(self, channel_id, pos_id):
        with self.condition:
            self.condition.wait_for(lambda: len(self.hiddenstates[channel_id]) > pos_id)
            return self.hiddenstates[channel_id][pos_id]

    def clear_hidden(self, channel_id=None):
        with self.lock:
            if channel_id is None:
                self.hiddenstates = [[] for _ in self.hiddenstates]
            else:
                self.hiddenstates[channel_id].clear()

    def put_logits(self, channel_id, logits, metadata=None):
        with self.logits_lock:
            self.logits[channel_id].append((logits, metadata))
        with self.logits_condition:
            self.logits_condition.notify_all()
                
    def get_logits(self, channel_id, pos_id):
        with self.logits_lock:
            if pos_id < len(self.logits[channel_id]):
                return self.logits[channel_id][pos_id]
            return None

    def clear_logits(self, channel_id=None):
 
        with self.lock:
            if channel_id is None:
                self.logits = [[] for _ in self.logits]
            else:
                self.logits[channel_id].clear()

    def put_input(self, input_ids, attention_mask):
        with self.input_condition:
            self.generate_token_idx += 1
            self.input_ids = input_ids
            self.attention_mask = attention_mask
            self.input_condition.notify_all()

    def has_input_for(self, token_idx):
        return token_idx == self.generate_token_idx
            
    def get_input(self):
        return self.input_ids, self.attention_mask, self.generate_token_idx

    def get_is_finished(self):
        return self.is_finished

    def clear_input(self):
        with self.lock:
            self.input_ids = None
            self.attention_mask = None

    def size_hidden(self, channel_id=None):
        if channel_id is None:
            return [len(hs) for hs in self.hiddenstates]
        return len(self.hiddenstates[channel_id])

    def size_logits(self):
        return sum(1 for logits in self.logits if len(logits) >= 1)

    def reset(self, mode="insequence"):
        with self.lock:
            for channel_id in range(self.num_channels):
                self.hiddenstates[channel_id] = []
            self.input_ids = None
            self.attention_mask = None
            if mode == "newsequence":
                self.generate_token_idx = 0

        with self.logits_lock:
            for channel_id in range(self.num_channels):
                self.logits[channel_id] = []
    
    def get_input_ids_and_attention_mask(self):
        with self.lock:
            return self.input_ids, self.attention_mask
            
global_state_pool = SharedStatePool(4)
