import re
from typing import List, Dict
from templates import *
import numpy as np
import torch
import torch.nn.functional as F
import os
import json
import pickle
# import matplotlib.pyplot as plt
from tqdm import tqdm
import math
import random
import time

TEMPLATES = {
    "birth_date": birth_date_templates,
    "birth_city": birth_city_templates,
    "employer": employer_templates,
    "company_city": company_city_templates,
    "university": university_templates,
    "major": major_templates,
}

# -----------------
# 2. REGEX BUILDER
# -----------------
#
# This function takes a single template string like
# "{pronoun} was born on {birth_date}."
# and returns a compiled regex with named capturing groups:
# r"^(?P<pronoun>.+?) was born on (?P<birth_date>.+?)\.$"
#
# You can adjust the details to handle punctuation more precisely if needed.

def template_to_regex(template: str) -> re.Pattern:
    """
    Convert a template with {placeholders} into a named-group regex.
    """
    # Escape special regex characters, but leave { and } so we can handle them:
    # (Alternatively, you could skip re.escape and just carefully replace placeholders.)
    # We'll do a simple approach: replace each {XYZ} with a named capture group (?P<XYZ>.+?)
    regex_str = template

    # Identify all placeholders via a simple pattern: {some_name}
    placeholders = re.findall(r"\{(.*?)\}", template)
    # placeholders might be ["pronoun", "birth_date"], etc.
    if len(placeholders) != len(set(placeholders)):
        duplicate_placeholders = set([ph for ph in placeholders if placeholders.count(ph) > 1])
        for ph in duplicate_placeholders:
            regex_str = regex_str.replace(f"{{{ph}}}", f"{{{ph}_duplicate}}", 1)
    # For each placeholder, replace it with a named capturing group
    for ph in placeholders:
        # We'll allow "any text" until the next literal part of the template
        # You could refine: e.g., "[^.]+?" if you only want to capture up to a period
        placeholder_pattern = rf"(?P<{ph}>[^.]+)"
        # Replace the exact substring "{ph}" (including braces):
        regex_str = regex_str.replace(f"{{{ph}}}", placeholder_pattern)

    # Compile the final regex. Case-insensitive? Up to you:
    pattern = re.compile(regex_str, re.IGNORECASE)
    return pattern

# Build a dictionary of lists of compiled patterns so we know which template_type each belongs to.
COMPILED_PATTERNS: Dict[str, List[re.Pattern]] = {}
for template_type, template_list in TEMPLATES.items():
    COMPILED_PATTERNS[template_type] = []
    for t in template_list:
        try:
            COMPILED_PATTERNS[template_type].append(template_to_regex(t))
        except Exception as e:
            print(f"Error compiling template {t}: {e}")
# -----------------
# 3. EXTRACT FUNCTION
# -----------------

def extract_info(s: str) -> List[dict]:
    """
    Given a text, returns a list[dict], where each dict has:
       {
         "template_type": <string>,
         "extracted_info": { ...named groups... }
       }
    for every sentence that matches one of the templates.
    """
    results = []
    for template_type, patterns in COMPILED_PATTERNS.items():
        for pattern in patterns:
            # Find all non-overlapping matches in the text
            for match in pattern.finditer(s):
                extracted_data = match.groupdict()
                result_dict = {
                    "template_type": template_type,
                    "extracted_info": extracted_data
                }
                results.append(result_dict)
    pronouns = ['he', 'she', 'they']
    possessive_pronouns = {'he': 'his', 'she': 'her', 'they': 'their'}
    object_pronouns = {'he': 'him', 'she': 'her', 'they': 'them'}
    reflexive_pronouns = {'he': 'himself', 'she': 'herself', 'they': 'themselves'}
    aggregated_results = {}
    for result in results:
        for key, value in result["extracted_info"].items():
            value = value.strip()
            if key == "pronoun" and value.lower() not in pronouns:
                aggregated_results["name"] = value
            elif "pronoun" in key and value.lower() not in list(possessive_pronouns.values())+list(object_pronouns.values())+list(reflexive_pronouns.values()):
                continue
            else:
                aggregated_results[key] = value
                
    return aggregated_results


def extract_model_hidden_states(model, tokenizer, text, device="cuda"):
    """
    Extract hidden states from a model for a given text.
    
    Args:
        model: The model to extract hidden states from
        tokenizer: The tokenizer to use
        text: The text to extract hidden states for
        device: The device to run the model on
        
    Returns:
        hidden_states: The hidden states of the model
        tokens: The tokens of the text
    """
    model.eval()
    tokens = tokenizer.encode(text, return_tensors="pt").to(device)
    with torch.no_grad():
        output = model(tokens, output_hidden_states=True)
    
    hidden_states = output.hidden_states
    tokens = tokenizer.convert_ids_to_tokens(tokens[0])
    
    return hidden_states, tokens

def extract_model_hidden_states_batch(model, tokenizer, texts, device="cuda", batch_size=32):
    """
    Extract hidden states from a model for a batch of texts.
    
    Args:
        model: The model to extract hidden states from
        tokenizer: The tokenizer to use
        texts: The texts to extract hidden states for
        device: The device to run the model on
        batch_size: The batch size to use
        
    Returns:
        all_hidden_states: The hidden states of the model for each text
        all_tokens: The tokens of each text
    """
    model.eval()
    all_hidden_states = []
    all_tokens = []
    
    for i in tqdm(range(0, len(texts), batch_size)):
        batch_texts = texts[i:i+batch_size]
        encoded = tokenizer(batch_texts, return_tensors="pt", padding=True).to(device)
        
        with torch.no_grad():
            output = model(**encoded, output_hidden_states=True)
        
        hidden_states = output.hidden_states
        
        for j in range(len(batch_texts)):
            tokens = tokenizer.convert_ids_to_tokens(encoded.input_ids[j])
            all_tokens.append(tokens)
            
            # Extract hidden states for this example
            example_hidden_states = []
            for layer_hidden_states in hidden_states:
                example_hidden_states.append(layer_hidden_states[j])
            
            all_hidden_states.append(example_hidden_states)
    
    return all_hidden_states, all_tokens

def extract_token_embeddings(hidden_states, layer_idx=-1, token_idx=None):
    """
    Extract embeddings from hidden states.
    
    Args:
        hidden_states: The hidden states to extract embeddings from
        layer_idx: The layer to extract embeddings from
        token_idx: The token index to extract embeddings for
        
    Returns:
        embeddings: The extracted embeddings
    """
    if token_idx is None:
        return hidden_states[layer_idx]
    else:
        return hidden_states[layer_idx][:, token_idx, :]

def extract_token_embeddings_batch(all_hidden_states, layer_idx=-1, token_idx=None):
    """
    Extract embeddings from hidden states for a batch.
    
    Args:
        all_hidden_states: The hidden states to extract embeddings from
        layer_idx: The layer to extract embeddings from
        token_idx: The token index to extract embeddings for
        
    Returns:
        embeddings: The extracted embeddings
    """
    embeddings = []
    for hidden_states in all_hidden_states:
        if token_idx is None:
            embeddings.append(hidden_states[layer_idx])
        else:
            embeddings.append(hidden_states[layer_idx][token_idx])
    
    return embeddings


def calculate_cosine_similarity(a, b):
    """
    Calculate cosine similarity between two vectors.
    
    Args:
        a: First vector
        b: Second vector
        
    Returns:
        similarity: Cosine similarity between a and b
    """
    return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()

def calculate_cosine_similarity_batch(a, b):
    """
    Calculate cosine similarity between two batches of vectors.
    
    Args:
        a: First batch of vectors
        b: Second batch of vectors
        
    Returns:
        similarities: Cosine similarities between a and b
    """
    return F.cosine_similarity(a, b)

def calculate_euclidean_distance(a, b):
    """
    Calculate Euclidean distance between two vectors.
    
    Args:
        a: First vector
        b: Second vector
        
    Returns:
        distance: Euclidean distance between a and b
    """
    return torch.norm(a - b).item()

def calculate_euclidean_distance_batch(a, b):
    """
    Calculate Euclidean distance between two batches of vectors.
    
    Args:
        a: First batch of vectors
        b: Second batch of vectors
        
    Returns:
        distances: Euclidean distances between a and b
    """
    return torch.norm(a - b, dim=1)
def save_embeddings_to_file(embeddings, filename):
    """
    Save embeddings to a file.
    
    Args:
        embeddings: The embeddings to save
        filename: The filename to save to
    """
    with open(filename, 'wb') as f:
        pickle.dump(embeddings, f)

def load_embeddings_from_file(filename):
    """
    Load embeddings from a file.
    
    Args:
        filename: The filename to load from
        
    Returns:
        embeddings: The loaded embeddings
    """
    with open(filename, 'rb') as f:
        return pickle.load(f)

def save_data_to_json(data, filename):
    """
    Save data to a JSON file.
    
    Args:
        data: The data to save
        filename: The filename to save to
    """
    with open(filename, 'w') as f:
        json.dump(data, f)

def load_data_from_json(filename):
    """
    Load data from a JSON file.
    
    Args:
        filename: The filename to load from
        
    Returns:
        data: The loaded data
    """
    with open(filename, 'r') as f:
        return json.load(f)

def visualize_embeddings_in_2d(embeddings, labels=None, title="Embeddings Visualization", figsize=(10, 8)):
    """
    Plot embeddings in 2D.
    
    Args:
        embeddings: The embeddings to plot
        labels: The labels for each embedding
        title: The title of the plot
        figsize: The size of the figure
    """
    from sklearn.decomposition import PCA
    
    # Reduce dimensionality to 2D
    pca = PCA(n_components=2)
    embeddings_2d = pca.fit_transform(embeddings)
    
    plt.figure(figsize=figsize)
    
    if labels is not None:
        unique_labels = set(labels)
        colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
        
        for i, label in enumerate(unique_labels):
            indices = [j for j, l in enumerate(labels) if l == label]
            plt.scatter(embeddings_2d[indices, 0], embeddings_2d[indices, 1], 
                        label=label, color=colors[i], alpha=0.7)
        plt.legend()
    else:
        plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], alpha=0.7)
    
    plt.title(title)
    plt.xlabel("PC1")
    plt.ylabel("PC2")
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()

def visualize_embeddings_in_3d(embeddings, labels=None, title="Embeddings Visualization", figsize=(12, 10)):
    """
    Plot embeddings in 3D.
    
    Args:
        embeddings: The embeddings to plot
        labels: The labels for each embedding
        title: The title of the plot
        figsize: The size of the figure
    """
    from sklearn.decomposition import PCA
    from mpl_toolkits.mplot3d import Axes3D
    
    # Reduce dimensionality to 3D
    pca = PCA(n_components=3)
    embeddings_3d = pca.fit_transform(embeddings)
    
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')
    
    if labels is not None:
        unique_labels = set(labels)
        colors = plt.cm.rainbow(np.linspace(0, 1, len(unique_labels)))
        
        for i, label in enumerate(unique_labels):
            indices = [j for j, l in enumerate(labels) if l == label]
            ax.scatter(embeddings_3d[indices, 0], embeddings_3d[indices, 1], embeddings_3d[indices, 2],
                      label=label, color=colors[i], alpha=0.7)
        ax.legend()
    else:
        ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2], alpha=0.7)
    
    ax.set_title(title)
    ax.set_xlabel("PC1")
    ax.set_ylabel("PC2")
    ax.set_zlabel("PC3")
    plt.tight_layout()
    plt.show()

def normalize_vectors(embeddings):
    """
    Normalize embeddings to unit length.
    
    Args:
        embeddings: The embeddings to normalize
        
    Returns:
        normalized_embeddings: The normalized embeddings
    """
    return F.normalize(embeddings, p=2, dim=1)

def process_data_in_batches(data, batch_size, process_fn):
    """
    Process data in batches.
    
    Args:
        data: The data to process
        batch_size: The batch size
        process_fn: The function to apply to each batch
        
    Returns:
        results: The processed results
    """
    results = []
    for i in tqdm(range(0, len(data), batch_size)):
        batch = data[i:i+batch_size]
        batch_results = process_fn(batch)
        results.extend(batch_results)
    return results

def set_random_seed(seed):
    """
    Set random seed for reproducibility.
    
    Args:
        seed: The random seed
    """
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def measure_execution_time(func):
    """
    Decorator to measure execution time of a function.
    
    Args:
        func: The function to measure
        
    Returns:
        wrapper: The wrapped function
    """
    def wrapper(*args, **kwargs):
        start_time = time.time()
        result = func(*args, **kwargs)
        end_time = time.time()
        print(f"{func.__name__} took {end_time - start_time:.4f} seconds to execute.")
        return result
    return wrapper

extract_hidden_states = extract_model_hidden_states
extract_hidden_states_batch = extract_model_hidden_states_batch
get_embeddings = extract_token_embeddings
get_embeddings_batch = extract_token_embeddings_batch
cos_sim = calculate_cosine_similarity
cos_sim_batch = calculate_cosine_similarity_batch
euclidean_distance = calculate_euclidean_distance
euclidean_distance_batch = calculate_euclidean_distance_batch
save_embeddings = save_embeddings_to_file
load_embeddings = load_embeddings_from_file
save_json = save_data_to_json
load_json = load_data_from_json
plot_embeddings_2d = visualize_embeddings_in_2d
plot_embeddings_3d = visualize_embeddings_in_3d
normalize_embeddings = normalize_vectors
batch_process = process_data_in_batches
timer = measure_execution_time


if __name__ == "__main__":
    text = (
        "Isabelle Kai Travis started Isabelle Kai Travis's journey on May 27, 1903. She contributes to the success of Cognizant Technology Solutions. "
        "She is engaged in the job market of Teaneck. "
        "Indianapolis, IN is where her story began. "
        "She gained valuable knowledge and experience at Washington and Lee University. "
        "She majored in Education."
    )

    info = extract_info(text)

    # Print out the parsed results
    from pprint import pprint
    pprint(info)