import torch
import json
import os
import numpy as np
import re
from transformers import AutoTokenizer
from sklearn.decomposition import PCA
from collections import defaultdict

# Load the tokenizer (ensure it matches the model used in feature extraction)
model_path = "/path/to/model/weights" # NOTE: Please update with weights to Phi-3-medium-128k-instruct before running
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, local_files_only=True)

def quantile_clip(arr, lower_q=0.00, upper_q=0.99):
    """Clips values at the given quantiles to remove extreme outliers."""
    lower_bound, upper_bound = np.quantile(arr, [lower_q, upper_q], axis=0, keepdims=True)
    return np.clip(arr, lower_bound, upper_bound)


def tokenize_and_map(prompts):
    """Tokenizes prompts and ensures proper token-to-word mapping.
       - Keeps subword splits separate for attention visualization.
       - Removes SentencePiece underscores for clean formatting.
       - Keeps punctuation and delimiters as separate tokens.
       - Excludes BOS token from visualization.
    """
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, local_files_only=True)
    inputs = tokenizer(prompts, return_tensors="pt")
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist())

    words = []
    word_to_tokens = []
    current_word_tokens = []
    current_word = ""

    skip_first = False  # Skip BOS token in visualization

    for idx, token in enumerate(tokens):
        if skip_first:  # Skip BOS token
            skip_first = False
            continue

        if token in tokenizer.all_special_tokens:  # Handle special tokens separately
            if current_word_tokens:
                words.append(current_word)
                word_to_tokens.append(current_word_tokens)
            words.append(token)  # Store special token separately
            word_to_tokens.append([idx])
            current_word = ""
            current_word_tokens = []
        elif token.startswith("▁"):  # SentencePiece encoding (used in Phi-3)
            if current_word_tokens:
                words.append(current_word)
                word_to_tokens.append(current_word_tokens)

            current_word = token[1:]  # Remove leading "▁"
            current_word_tokens = [idx]
        elif re.fullmatch(r"[.,!?;:(){}\[\]<>/\\|\"'`~@#$%^&*_+=-]", token):  
            # Punctuation/delimiters are separate tokens
            if current_word_tokens:
                words.append(current_word)
                word_to_tokens.append(current_word_tokens)
            words.append(token)  # Store punctuation separately
            word_to_tokens.append([idx])
            current_word = ""
            current_word_tokens = []
        else:
            current_word += token  # Append subword tokens
            current_word_tokens.append(idx)

    if current_word_tokens:
        words.append(current_word)
        word_to_tokens.append(current_word_tokens)

    # Ensure subwords remain separate instead of merging back
    mapped_words = []
    mapped_token_indices = []

    for word, token_indices in zip(words, word_to_tokens):
        if len(token_indices) == 1:  
            # If a token maps to a whole word, keep it as-is
            mapped_words.append(word)
            mapped_token_indices.append(token_indices)
        else:
            # Keep subword pieces separate instead of merging back
            for idx in token_indices:
                mapped_words.append(tokens[idx].replace("▁", ""))  # Keep separate
                mapped_token_indices.append([idx])

    return mapped_words, mapped_token_indices

def process_pth_files(embedding_dir, json_output_dir, prompts):
    """Processes all `.pth` embeddings in the given directory and saves as JSON."""
    os.makedirs(json_output_dir, exist_ok=True)

    for filename in sorted(os.listdir(embedding_dir)):
        if filename.endswith(".pth"):
            feature_path = os.path.join(embedding_dir, filename)
            print(f"🔄 Processing {feature_path} ...")
            save_json_from_pth(feature_path, prompts, json_output_dir)

    print(f"✅ All JSON files saved in {json_output_dir}")


def save_json_from_pth(feature_path, prompts, output_dir, k=3):
    """Loads `.pth` features, reduces dimensions, maps tokens, and saves JSON. Excludes BOS from visualization."""
    os.makedirs(output_dir, exist_ok=True)

    data = torch.load(feature_path)  # Load feature tensor
    if isinstance(data, dict):
        print(f"⚠️ Unexpected dictionary format in {feature_path}, skipping...")
        return

    print(f"✅ Loaded {feature_path} with shape {data.shape}")

    # Remove batch dimension -> (num_tokens, feature_dim)
    features = data.squeeze(0).cpu().numpy()

    if k == 3:
        # Reduce feature dimension using PCA (3D → RGB)
        pca = PCA(n_components=k)
        reduced_features = pca.fit_transform(features)

        # Normalize PCA output to range [0, 1]
        min_vals = reduced_features.min(axis=0, keepdims=True)  # Shape: (1, 3)
        max_vals = reduced_features.max(axis=0, keepdims=True)  # Shape: (1, 3)
        reduced_features = (reduced_features - min_vals) / (max_vals - min_vals + 1e-10)  # Avoid division by zero
        # reduced_features = (reduced_features - reduced_features.min()) / (reduced_features.max() - reduced_features.min())

        # Convert to RGB (0-255)
        rgb_colors = (reduced_features * 255).astype(int)
    elif k == 2:
        # Apply quantile clipping to remove outliers
        reduced_features = quantile_clip(reduced_features, lower_q=0.00, upper_q=0.90)
 
        # Normalize PCA output to range [0, 1]
        min_vals = reduced_features.min(axis=0, keepdims=True)  # Shape: (1, 2)
        max_vals = reduced_features.max(axis=0, keepdims=True)  # Shape: (1, 2)
        reduced_features = (reduced_features - min_vals) / (max_vals - min_vals + 1e-10)  # Avoid division by zero

        # Append a constant high value (e.g., 1.0) for the third color channel
        white_channel = np.zeros((reduced_features.shape[0], 1))  # Shape: (num_tokens-1, 1)
        rgb_colors = np.hstack([reduced_features, white_channel]) * 255  # Scale to RGB (0-255)
        rgb_colors = rgb_colors.astype(int)

    # Tokenize input prompts and split numbers properly
    words, word_to_tokens = tokenize_and_map(prompts)

    # Save token colors
    word_data = []
    for word, token_indices in zip(words, word_to_tokens):
        token_index = min(token_indices[0], len(rgb_colors) - 1)
        print("token_index:", token_index)
        print("token_indices: ", token_indices)
        avg_color = "#{:02x}{:02x}{:02x}".format(
            rgb_colors[token_index][0],  # Red
            rgb_colors[token_index][1],  # Green
            rgb_colors[token_index][2],  # Blue
        )
        word_data.append({"word": word, "color": avg_color})

    json_path = os.path.join(output_dir, f"{os.path.basename(feature_path).replace('.pth', '.json')}")
    with open(json_path, "w") as f:
        json.dump(word_data, f, indent=2)

    print(f"Token colors saved in {json_path}")

def process_value_embeddings(value_embeddings_dir, json_output_dir, prompts, k=3):
    """Processes all value embeddings and generates JSONs per layer and head."""
    os.makedirs(json_output_dir, exist_ok=True)

    for filename in sorted(os.listdir(value_embeddings_dir)):
        if filename.endswith(".pth"):
            feature_path = os.path.join(value_embeddings_dir, filename)
            print(f"🔄 Processing {feature_path} ...")
            save_json_from_pth(feature_path, prompts, json_output_dir, k=k)

    print(f"✅ All Value Embeddings JSON saved in {json_output_dir}")

# Directories
# NOTE: Update based on where embeddings are saved and where you want to save json files
embedding_dirs = {
    "input": "phi3_visualizations/input_embeddings/",
    "final": "phi3_visualizations/final_embeddings/",
    "value": "phi3_visualizations/value_embeddings/"
}
json_dirs = {
    "input": "phi3_visualizations/json_input_embeddings/",
    "final": "phi3_visualizations/json_final_embeddings/",
    "value": "phi3_visualizations/json_value_embeddings/"
}

####################################################
# Prompt

## Input Prompts
# # Repetition prompts
# prompts = ["Hello hello hello how how how are are are you you you repeat repeat repeat stop stop stop."] # Exclude "the" as that is a common sink
# prompts = ["Hello hello hello the the the how how how are are are you you you repeat repeat repeat stop stop stop."]

# # COT prompts
# prompts = ["I had 4 oranges and bought 3 more. 4 + 3 = 7. Now, I have 3 apples and buy 2 more. How many apples do I have?"]
# prompts = ["The derivative of x² is 2x. What is the derivative of x³?"]

# # Zero Shot COT prompts
# prompts = ["I have 3 apples. I buy 2 more. How many do I have? Let's think step by step."]
# prompts = ["What is the derivative of x³? Let's think step by step."]

# # Referrential prompts
# prompts = ["John met Mary at the park. She smiled. He waved back. They talked about it."]

# # Streaming numbers/numerical prompts
# prompts = ["Input: 12323212. Label: 2 Input: 7887689. Label: 8 Input: 32454323 Label: 3"]

# # Sentiment
prompts = ["I laughed with friends until my stomach hurt. I wiped away tears of sadness and said nothing. I slammed my fist in anger on the table."]

## Other
# prompts = ["Classify each sentence as happy or sad. The sun shone brightly, warming my face as I walked through the park. The flowers I had planted last spring never bloomed. Children laughed and played on the swings. A cold wind swept through the empty streets, reminding me of nights spent alone. The rain poured relentlessly, drenching the wilted petals of forgotten flowers."]
# prompts = ["Read the following paragraph and determine if the hypothesis is true. \n \n Premise: A: Oh, oh yeah, and every time you see one hit on the side of the road you say is that my cat. B: Uh-huh. A: And you go crazy thinking it might be yours. B: Right, well I didn’t realize my husband was such a sucker for animals until I brought one home one night. Hypothesis: her husband was such a sucker for animals. Answer:"]

prompts = [f"{tokenizer.bos_token} {prompt}" for prompt in prompts]  # Add BOS token for embeddings

####################################################

# Run processing
process_pth_files(embedding_dirs["input"], json_dirs["input"], prompts, k=2)
process_pth_files(embedding_dirs["final"], json_dirs["final"], prompts, k=2)
process_value_embeddings(embedding_dirs["value"], json_dirs["value"], prompts, k=2)

# NOTE: Please read the following before running this script.
# * This script is meant to be run after running extract_phi3_medium_features
# * Update path to model weights in L11
# * Ensure directory paths are accurate on L182-191.
# * Specify k=2 to do PCA with 2 principal components and k=3 to do PCA with 3 principal components.

####################################################
