import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from multiprocessing import Manager, Lock, Condition
from value_network_lgy import GatedBetaMLP as SimpleRegressor
class LLMWrapper(nn.Module):
    _manager = None
    _shared_state = None
    def __init__(self, model_name, devices=None, torch_dtype=torch.float16, 
                 reg_hidden_dim=1024, lr=1e-4, buffer_size=32):
        super().__init__()
        if devices is None:
            devices = [f"cuda:{i}" for i in range(torch.cuda.device_count())]
        self.devices = devices
        self.n_gpu = len(devices)
        self.device = devices[0]
        self.model_name = model_name
        self.torch_dtype = torch_dtype
        self.buffer_size = buffer_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
        if LLMWrapper._manager is None:
            LLMWrapper._manager = Manager()
            LLMWrapper._shared_state = {
                "gpu_busy": LLMWrapper._manager.list([False]*len(devices)),
                "wait_queue": LLMWrapper._manager.list(),
                "lock": Lock(),
                "condition": Condition(Lock()),
                "is_training": LLMWrapper._manager.Value('b', False),  
            }
        self.shared = LLMWrapper._shared_state
        self.is_training = self.shared["is_training"]
        self.device_models = []
        for dev in self.devices:
            model = AutoModelForCausalLM.from_pretrained(model_name,
                                                         torch_dtype=torch_dtype,
                                                         device_map={"": dev})
            model.eval()
            self.device_models.append(model)
        hidden_dim = self.device_models[0].config.hidden_size
        self.regressor = SimpleRegressor(hidden_dim)
        self.regressor.to(self.device)
        self.optimizer = torch.optim.Adam(self.regressor.parameters(), lr=lr)
        self.buffer = []
    def _acquire_gpu(self):
        pid = os.getpid()
        shared = self.shared
        with shared["condition"]:
            shared["wait_queue"].append(pid)
            while True:
                my_turn = shared["wait_queue"][0] == pid
                if my_turn:
                    for i in range(self.n_gpu):
                        if not shared["gpu_busy"][i]:
                            shared["gpu_busy"][i] = True
                            shared["wait_queue"].pop(0)
                            return i
                shared["condition"].wait(timeout=0.5)
    def _release_gpu(self, idx):
        shared = self.shared
        with shared["condition"]:
            shared["gpu_busy"][idx] = False
            shared["condition"].notify_all()
    @torch.no_grad()
    def inference(self, prompt, use_chat_template=True, is_training=False):
        if not is_training:
            while self.is_training.value:
                print("[LLMWrapper] Training in progress, waiting for inference ...")
                time.sleep(0.5)
        gpu_idx = self._acquire_gpu()
        dev = self.devices[gpu_idx]
        model = self.device_models[gpu_idx]
        try:
            messages = [{"role": "user", "content": prompt}]
            if use_chat_template:
                chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            else:
                chat_text = prompt
            inputs = self.tokenizer(chat_text, return_tensors="pt").to(dev)
            outputs = model(**inputs, output_hidden_states=True)
            last_hidden_state = outputs.hidden_states[-1]
            last_token_hidden = last_hidden_state[:, -1, :]  
            return last_token_hidden.squeeze(0).detach().cpu()
        finally:
            self._release_gpu(gpu_idx)
    def train_on_buffer(self):
        if len(self.buffer) == 0:
            return None
        self.is_training.value = True
        print("[LLMWrapper] >>> Training started...")
        self.regressor.train()
        xs, ys = [], []
        for text, y in self.buffer:
            with torch.no_grad():
                h = self.inference(text, is_training=True)
            xs.append(h)
            ys.append(torch.tensor(y, dtype=torch.float32))
        xs = torch.stack(xs, dim=0).to(self.device)
        ys = torch.stack(ys, dim=0).to(self.device)
        y_hat, reg = self.regressor(xs, return_reg=True)
        bce_loss = F.binary_cross_entropy(y_hat, ys)
        loss = bce_loss + reg
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        print(f"[train_on_buffer] B={len(self.buffer)}, BCE={bce_loss.item():.4f}, Reg={reg.item():.4f}")
        self.is_training.value = False
        print("[LLMWrapper] <<< Training finished.")
        return loss.item()
    def push_sample(self, text, label):
        self.buffer.append((text, label))
        if len(self.buffer) >= self.buffer_size:
            print(f"[LLMWrapper] Buffer full ({len(self.buffer)}), training...")
            self.train_on_buffer()
            self.buffer.clear()