from safetensors import safe_open
import os, pickle, torch, io
from hf_qwen3_gate import chunked_batch_varied_top_k_p_in_logits, get_lang_masks_lowres, get_lang_masks
from tqdm import tqdm
from torch import nn
from modelscope import AutoTokenizer
import queue
import threading
import time
from typing import Generator, Any

class CPUUnpickler(pickle.Unpickler):
    def find_class(self, module, name):
        if module == 'torch.storage' and name == '_load_from_bytes':
            return lambda b: torch.load(io.BytesIO(b), map_location='cpu')
        else:
            return super().find_class(module, name)

def load_lm_head():
    # st_path = '/cpfs02/user/jiawei.lyt/ckpt/verl_checkpoints/lyt-rl-gen/qwen3-ppp-nothink-fh0713-non_reason105step-rm_v1_addreason-GenRM-32B-sentcs-GSPO-ref-turbopp-static-USE_DYNAMIC_REF_ANSWER0-LENGTH_FLIP_THRESHOLD2.0-LENGTH_FLIP_PROB0.75-REF_ANSWER_POSITION-A-expert-12k_bs256_minibs128_n8/global_step_80/actor_hf/model-00118-of-00118.safetensors'
    st_path = '/cpfs01/user/jiawei.lyt/ckpt/verl_checkpoints/lyt-rl-gen/qwen3-tpp-nothink-0721-distilled-data0706-recitex1-bothtrans-mixlangx2-GenRM-32B-sentcs-GSPO-ref-turbopp-LENGTH_FLIP_THRESHOLD1.3-LENGTH_FLIP_PROB0.75-REF_ANSWER_POSITION-A-expert-12k_bs512_minibs128_n8/global_step_60/actor_hf/model-00016-of-00016.safetensors'
    with safe_open(st_path, framework="pt") as f:
        weight = f.get_tensor('lm_head.weight')  # Shape: [vocab_size, hidden_size]
    vocab_size, hidden_size = weight.shape
    # Create a linear layer: input_size = hidden_size, output_size = vocab_size
    lm_head_layer = nn.Linear(hidden_size, vocab_size, bias=False, dtype=torch.bfloat16)
    # Assign the loaded weight (transpose if needed: Linear expects [out, in] which matches here)
    lm_head_layer.weight.data.copy_(weight)
    return lm_head_layer

def load_hs_dataset() -> Generator[Any, None, None]:
    folder_path = './cs_gate_train/data/235_hs_cache-2025-08-16-14:18:26'
    folder_path = './cs_gate_train/data/30_nothink_hs_cache-2025-08-17-21:06:37'
    
    # Create a queue for batch data
    batch_queue = queue.Queue(maxsize=10)  # Adjust size as needed
    file_list = [f for f in os.listdir(folder_path) if os.path.isfile(os.path.join(folder_path, f))]
    
    def background_loader():
        """Load files in background and put batches in queue"""
        for filename in file_list:
            file_path = os.path.join(folder_path, filename)
            try:
                with open(file_path, 'rb') as f:
                    res = CPUUnpickler(f).load()
                    print(file_path, 'loaded')
                    for d in res['hidden_states']:
                        batch_queue.put(d, block=True)  # Block if queue is full
            except Exception as e:
                print(f"Error loading file {filename}: {e}")
                continue
        # Signal end of data
        batch_queue.put(None, block=True)
    
    # Start background loading thread
    loader_thread = threading.Thread(target=background_loader, daemon=True)
    loader_thread.start()
    
    # Yield data as it becomes available
    batch_num = 64
    batch_res = [[], []]
    while True:
        try:
            # Block until data is available (this handles the "run out of batch" case)
            batch_data = batch_queue.get(block=True, timeout=None)
            if batch_data is None:  # End of data signal
                break
            batch_res[0].extend(batch_data[0])
            batch_res[1].extend(batch_data[1])
            if len(batch_res[0]) >= batch_num:
                yield batch_res
                batch_res = [[], []]
        except queue.Empty:
            # This shouldn't happen with block=True, but just in case
            time.sleep(0.01)  # Brief pause before checking again
            continue

class CodeSwitchGate(torch.nn.Module):
    def __init__(self, num_cs, hidden_size):
        super().__init__()
        self.code_switch_pre = nn.Linear(hidden_size, hidden_size, bias=False, dtype=torch.bfloat16)
        self.code_switch_act = nn.ReLU()
        self.code_switch_head = nn.Linear(hidden_size, num_cs, bias=False, dtype=torch.bfloat16)
        self.tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-8B')

    def forward(self, hidden_states):
        x = self.code_switch_pre(hidden_states)
        x = self.code_switch_act(x)
        x = self.code_switch_head(x)
        return x

from tqdm import tqdm

def train_gate():
    lm_head = load_lm_head()
    # lang_masks = get_lang_masks_lowres(lm_head.weight.shape[0])
    lang_masks = get_lang_masks(lm_head.weight.shape[0])
    print('num_cs', lang_masks.shape[0])
    gate = CodeSwitchGate(lang_masks.shape[0], lm_head.in_features)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    lang_masks = lang_masks.to(device)
    gate = gate.to(device)
    lm_head = lm_head.to(device)

    print('lm_head.weight.shape', lm_head.weight.shape)
    
    hs_data = load_hs_dataset()
    
    # Freeze LM head
    for name, param in lm_head.named_parameters():
        param.requires_grad = False

    optimizer = torch.optim.Adam(gate.parameters(), lr=1e-3)
    gate.train()
    lm_head.eval()  # Ensure LM head is in eval mode
    for name, param in lm_head.named_parameters():
        param.requires_grad = False

    # Wrap the data loader with tqdm
    pbar = tqdm(hs_data, desc="Training Gate", total=len(hs_data) if hasattr(hs_data, '__len__') else None)
    
    for step, pair in enumerate(pbar):
        if len(pair) == 2:
            hs, mask = pair[0], pair[1]
        elif len(pair) == 3:
            hs, mask, tokens = pair[0], pair[1], pair[2]
        else:
            raise NotImplemented
        hs = torch.stack(hs).to(device)      # [batch, seq_len, hidden_size]
        mask = torch.stack(mask).to(device)  # [batch, seq_len] or [batch, seq_len, 1]

        # Reshape and flatten
        hs_flat = hs.reshape(-1, hs.shape[-1])           # [B*T, D]
        mask_flat = mask.reshape(-1)                     # [B*T]

        # Handle mask shape if it has extra dimension
        if mask_flat.dim() == 2 and mask_flat.shape[-1] == 1:
            mask_flat = mask_flat.squeeze(-1)

        active_indices = mask_flat.bool()
        if not active_indices.any():
            pbar.set_postfix(loss="skipped")
            continue

        hs_active = hs_flat[active_indices]  # [N_active, D]

        # Forward through frozen LM head
        with torch.no_grad():
            logits_batch = lm_head(hs_active)

        # Top-k and top-p tensors
        top_k_t = torch.full((logits_batch.shape[0],), 20, device=logits_batch.device)
        top_p_t = torch.full((logits_batch.shape[0],), 0.95, device=logits_batch.device)

        # Get language labels
        lang_labels = chunked_batch_varied_top_k_p_in_logits(
            logits_batch, lang_masks, top_k_t, top_p_t
        )  # [N_active, num_cs]

        # Forward through trainable gate
        code_switch_logits = gate(hs_active)

        # Compute loss
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(code_switch_logits, lang_labels.float())

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update tqdm progress bar
        pbar.set_postfix(loss=f"{loss.item():.4f}")

    from datetime import datetime
    now = datetime.now()
    formatted_str = now.strftime("%Y-%m-%d-%H:%M:%S")
    output_dir = f"./models/gate-qwen3-30b_{formatted_str}"
    os.mkdir(output_dir)
    torch.save(gate.code_switch_pre.weight, os.path.join(output_dir, 'code_switch_pre.pth'))
    torch.save(gate.code_switch_head.weight, os.path.join(output_dir, 'code_switch_head.pth'))
    print("Training completed.")

if __name__ == '__main__':
    print("start")
    # load_lm_head()
    # load_hs_dataset()
    train_gate()