#!/usr/bin/env python3
"""
Extract 20k samples from each of skywork and mixture datasets and combine them.
"""
import os
import torch
import pandas as pd
from safetensors.torch import safe_open, save_file
import numpy as np
from pathlib import Path

def load_safetensor_data(path: str):
    """Load data from safetensors file."""
    print(f"Loading from: {path}")
    with safe_open(path, framework="pt", device="cpu") as f:
        data = {
            'embeddings_prompt': f.get_tensor("embeddings_prompt"),
            'embeddings_prompt_answer': f.get_tensor("embeddings_prompt_answer"),
            'embeddings_answer_a': f.get_tensor("embeddings_answer_a"), 
            'embeddings_answer_b': f.get_tensor("embeddings_answer_b"), 
            'correct_labels': f.get_tensor("correct_labels"),
            'consist_labels': f.get_tensor("consist_labels"),
            'num_tokens_instruct': f.get_tensor("num_tokens_instruct"),
            'num_tokens_reasoning': f.get_tensor("num_tokens_reasoning"),
            'correct_instruct': f.get_tensor("correct_instruct"),
            'correct_reasoning': f.get_tensor("correct_reasoning")
        }
    print(f"  Loaded {len(data['embeddings_prompt'])} samples")
    print(f"  Embedding dims: prompt={data['embeddings_prompt'].shape[1]}, "
          f"answer_a={data['embeddings_answer_a'].shape[1]}, "
          f"answer_b={data['embeddings_answer_b'].shape[1]}")
    return data

def extract_samples(data, n_samples=20000, seed=42):
    """Extract n_samples from data dictionary."""
    total_samples = len(data['embeddings_prompt'])
    
    if total_samples < n_samples:
        print(f"  Warning: Only {total_samples} samples available, using all")
        return data
    
    # Set seed for reproducibility
    torch.manual_seed(seed)
    indices = torch.randperm(total_samples)[:n_samples]
    
    extracted = {}
    for key, tensor in data.items():
        extracted[key] = tensor[indices]
    
    print(f"  Extracted {n_samples} samples")
    return extracted

def combine_datasets(data_list):
    """Combine multiple datasets."""
    combined = {}
    for key in data_list[0].keys():
        combined[key] = torch.cat([data[key] for data in data_list], dim=0)
    
    print(f"Combined dataset size: {len(combined['embeddings_prompt'])} samples")
    return combined

def save_combined_data(data, save_path):
    """Save combined data to safetensors."""
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    
    # Convert to save format
    save_dict = {}
    for key, tensor in data.items():
        save_dict[key] = tensor
    
    metadata = {
        "total_samples": str(len(data['embeddings_prompt'])),
        "embedding_prompt_dim": str(data['embeddings_prompt'].shape[1]),
        "embedding_prompt_answer_dim": str(data['embeddings_prompt_answer'].shape[1]),
        "embedding_answer_a_dim": str(data['embeddings_answer_a'].shape[1]),
        "embedding_answer_b_dim": str(data['embeddings_answer_b'].shape[1]),
    }
    
    save_file(save_dict, save_path, metadata=metadata)
    print(f"✅ Saved to: {save_path}")


def main():
    # Configuration
    base_path = "../embeddings/Qwen3-4B"
    model_instruct = "Qwen3-14B_instruct"
    model_reasoning = "Qwen3-14B_reasoning"
    
    # File names
    skywork_file = f"{model_instruct}_{model_reasoning}_skywork.safetensors"
    mixture_file = f"{model_instruct}_{model_reasoning}_mixture.safetensors"
    
    # Output file
    output_file = f"{model_instruct}_{model_reasoning}_combined.safetensors"
    
    skywork_path = os.path.join(base_path, skywork_file)
    mixture_path = os.path.join(base_path, mixture_file)
    output_path = os.path.join(base_path, output_file)
    
    # Check if files exist
    if not os.path.exists(skywork_path):
        print(f"❌ Error: {skywork_path} does not exist!")
        return
    if not os.path.exists(mixture_path):
        print(f"❌ Error: {mixture_path} does not exist!")
        return
    
    print("="*60)
    print("Extracting and Combining Datasets")
    print("="*60)
    
    # Load datasets
    print("\n1. Loading Skywork dataset...")
    skywork_data = load_safetensor_data(skywork_path)
    
    print("\n2. Loading Mixture dataset...")
    mixture_data = load_safetensor_data(mixture_path)
    
    # Extract 20k samples from each
    print("\n3. Extracting 20k samples from Skywork...")
    skywork_20k = extract_samples(skywork_data, n_samples=20000, seed=42)
    
    print("\n4. Extracting 20k samples from Mixture...")
    mixture_20k = extract_samples(mixture_data, n_samples=20000, seed=43)
    
    # Combine datasets
    print("\n5. Combining datasets...")
    combined_data = combine_datasets([skywork_20k, mixture_20k])
    
    
    # Save combined dataset
    print("\n6. Saving combined dataset...")
    save_combined_data(combined_data, output_path)
    
    print("\n" + "="*60)
    print("✅ Done! Combined dataset saved to:")
    print(f"   {output_path}")
    print(f"\n📊 Summary:")
    print(f"   - Total samples: {len(combined_data['embeddings_prompt'])}")
    print(f"   - Embedding types: 4 (prompt, prompt_answer, answer_a, answer_b)")
    print(f"   - Combined embedding dim (prompt+a+b): {combined_data['embeddings_prompt'].shape[1] + combined_data['embeddings_answer_a'].shape[1] + combined_data['embeddings_answer_b'].shape[1]}")
    print("="*60)

if __name__ == "__main__":
    main()
