import pickle
import torch
from transformers import BertTokenizer, BertModel
import re
import itertools
from tqdm import tqdm

# Load pre-trained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)
model.eval()

def get_bert_embedding(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=256)
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :].squeeze().numpy()

# Load your pickle file
path = ''
with open(path, 'rb') as f:
    questions = pickle.load(f)

#questions = dict(itertools.islice(questions.items(), 200))
#questions = dict(itertools.islice(questions.items(), 200, None))
print("Total question categories:", len(questions))

# Create a restructured embeddings dictionary
embeddings_dict = {}

# Process each question category and its 
for question_category, answers in tqdm(questions.items(), desc="Processing categories"):
#for question_category, answers in questions.items():
    #print(f"Processing question category: {question_category} with {len(answers)} answers")
    
    # Initialize this question category in the embeddings dictionary if not exists
    if question_category not in embeddings_dict:
        embeddings_dict[question_category] = {}
    
    # Process each answer for this question category
    #for answer_item in answers:
    for sample_id, answer_item in answers.items():
        #sample_id = answer_item.get("id", "unknown")
        string_q_a = answer_item.get("string_q_a", "")
        structure_list = answer_item.get("structure_retrieve", [])
        response_str = answer_item.get("response_str", "")
        
        # Get the probability data
        logprobs = answer_item.get("logprobs", [])
        probs = answer_item.get("probs", [])
        
        # Format the structure string for reference
        structure_str = answer_item.get("structure_retrieve")
        
        # Initialize this answer in the question's dictionary
        embeddings_dict[question_category][sample_id] = {
            "embeddings": {},
            "structure_str": structure_str,
            "sub_structure_str": structure_str,
            "logprobs": logprobs,  # Add logprobs
            "probs": probs,         # Add probs
            "reasoning_string": {}
        }
        
        # Process string_q_a
        segments = re.split(r'[;,]', string_q_a)
        for segment in segments:
            if ':' in segment:
                key, value = segment.split(':', 1)
                key = key.strip()
                value = value.strip()
                if value:
                    quoted_key = f"{key}"
                    #print(f"Processing Q&A - Key: {quoted_key}, Value: {value}")
                    embeddings_dict[question_category][sample_id]["embeddings"][quoted_key] = get_bert_embedding(value)
                    embeddings_dict[question_category][sample_id]["reasoning_string"][quoted_key] = value
        
        # Extract and process Node and Edge elements from structure_list
        for structure_item in structure_list:
            # Convert to string and remove brackets
            structure_item_str = str(structure_item).strip('[]')
            # Split by commas
            elements = [elem.strip() for elem in structure_item_str.split(',')]
            
            for element in elements:
                # Check if this is a Node or Edge element
                if element.startswith("Node") or element.startswith("Edge"):
                    if element not in embeddings_dict[question_category][sample_id]["embeddings"]:
                        #print(f"Processing structure element: {element}")
                        embeddings_dict[question_category][sample_id]["embeddings"][element] = get_bert_embedding(element)
        
        # Add response embedding
        if response_str:
            embeddings_dict[question_category][sample_id]["embeddings"]["ResultEdge"] = get_bert_embedding(response_str)
        
        # Add raw question embedding
        if 'question' in answer_item:
            embeddings_dict[question_category][sample_id]["embeddings"]["NodeRaw"] = get_bert_embedding(answer_item['question'])

# Debug output for verification
print("\nVerification of structure:")
question_categories = len(embeddings_dict)
print(f"Total question categories in embeddings_dict: {question_categories}")

# Check a sample question category
if question_categories > 0:
    sample_category = list(embeddings_dict.keys())[0]
    answer_count = len(embeddings_dict[sample_category])
    print(f"Category '{sample_category}' has {answer_count} answers")
    
    # Check a sample answer
    if answer_count > 0:
        sample_answer_id = list(embeddings_dict[sample_category].keys())[0]
        embedding_keys = list(embeddings_dict[sample_category][sample_answer_id]["embeddings"].keys())
        print(f"Answer ID '{sample_answer_id}' has {len(embedding_keys)} embeddings")
        
        # Check if probability data was stored
        if "logprobs" in embeddings_dict[sample_category][sample_answer_id]:
            print(f"Answer includes logprobs of length: {len(embeddings_dict[sample_category][sample_answer_id]['logprobs'])}")
        if "probs" in embeddings_dict[sample_category][sample_answer_id]:
            print(f"Answer includes probs of length: {len(embeddings_dict[sample_category][sample_answer_id]['probs'])}")

# Save the embeddings to your address
with open("", "wb") as f:
    pickle.dump(embeddings_dict, f)

print("\nBERT embeddings with probability data saved to bert_embeddings_phi4_with_probs.pkl")