import os
import json
import argparse
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import torch
from openai import OpenAI

parser = argparse.ArgumentParser()
parser.add_argument('--input_dir', type=str, required=True, help='Directory containing reference text files')
parser.add_argument('--embedding_model', type=str, default='text-embedding-3-small', help='Embedding model to use')
parser.add_argument('--batch_size', type=int, default=16, help='Batch size for embedding computation')
parser.add_argument('--min_n', type=int, default=2, help='Minimum n-gram length')
parser.add_argument('--max_n', type=int, default=5, help='Maximum n-gram length')
parser.add_argument('--similarity_threshold', type=float, help='Cosine similarity threshold for phrase matching')
parser.add_argument('--cache_dir', type=str, default='/data/assets/hub', help='Cache directory for models')

args = parser.parse_args()

embedding_model = args.embedding_model
cache_dir = args.cache_dir


if embedding_model == "text-embedding-3-small" or embedding_model == "text-embedding-3-large":
    openai_key = os.getenv("OPENAI_API_KEY")
    client = OpenAI(api_key=openai_key)

elif embedding_model == "specter2":
    
    from transformers import AutoTokenizer
    from adapters import AutoAdapterModel

    tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base', cache_dir=cache_dir)
    model = AutoAdapterModel.from_pretrained('allenai/specter2_base', cache_dir=cache_dir)
    model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)
    model = model.to("cuda:0")
    max_length = 512

elif embedding_model == "linq-embed-mistral":
    import torch
    import torch.nn.functional as F
    from torch import Tensor
    from transformers import AutoTokenizer, AutoModel, BitsAndBytesConfig

    def last_token_pool(last_hidden_states: Tensor,
                 attention_mask: Tensor) -> Tensor:
        left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
        if left_padding:
            return last_hidden_states[:, -1]
        else:
            sequence_lengths = attention_mask.sum(dim=1) - 1
            batch_size = last_hidden_states.shape[0]
            return last_hidden_states[torch.arange(batch_size, device=last_hidden_states.device), sequence_lengths]
        
    bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,                # enable 4-bit quantization
            bnb_4bit_use_double_quant=True,   # nested quantization for memory saving
            bnb_4bit_quant_type="nf4",        # NormalFloat4 (best quality)
            bnb_4bit_compute_dtype="bfloat16" # computation dtype (fp16 also works if bf16 not available)
        )
        
    tokenizer = AutoTokenizer.from_pretrained(
        'Linq-AI-Research/Linq-Embed-Mistral', 
        cache_dir=cache_dir
    )
    model = AutoModel.from_pretrained(
        'Linq-AI-Research/Linq-Embed-Mistral', 
        quantization_config=bnb_config,
        cache_dir=cache_dir).to("cuda:0")
    max_length = 4096

else:
    raise ValueError("Embedding model not implemented")

def embed_text_single_pass(texts):
    review_contents = texts
    if embedding_model == "text-embedding-3-small" or embedding_model == "text-embedding-3-large":
        response = client.embeddings.create(
            model=embedding_model,
            input=review_contents
        )

        review_embeddings = [item.embedding for item in response.data]
    elif embedding_model == "specter2":
        text_batch = review_contents
        inputs = tokenizer(
            text_batch, 
            padding=True, 
            truncation=True,
            return_tensors="pt", 
            return_token_type_ids=False, 
            max_length=512
        ).to("cuda:0")
        output = model(**inputs)
        embeddings = output.last_hidden_state[:, 0, :]
        review_embeddings = embeddings.detach().cpu().numpy().tolist()
    elif embedding_model == "linq-embed-mistral":
        input_texts = review_contents
        batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt").to("cuda:0")
        # print(batch_dict['input_ids'].shape)
        outputs = model(**batch_dict)
        embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

        # Normalize embeddings
        embeddings = F.normalize(embeddings, p=2, dim=1)
        review_embeddings = embeddings.detach().cpu().numpy().tolist()
        
    return review_embeddings

def embed_texts(texts, batch_size=16):
    multi_pass_embeddings = []
    i = 0
    while i < len(texts):
        cur_batch_size = batch_size

        if embedding_model == "linq-embed-mistral" or "specter2":
            while True:
                batch_texts = texts[i:i+cur_batch_size]
                tokenized_batch_input_ids = tokenizer(batch_texts, max_length=max_length, padding=True, truncation=True, return_tensors="pt")['input_ids']

                batch_shape = tokenized_batch_input_ids.shape
                if batch_shape[0] * batch_shape[1] > 6000:
                    cur_batch_size = int(cur_batch_size * 0.8)
                    # print(f"Reducing batch size to {cur_batch_size}")
                else:
                    batch_texts = texts[i:i+cur_batch_size]
                    break

        batch_embeddings = embed_text_single_pass(batch_texts)
        multi_pass_embeddings.extend(batch_embeddings)

        if i % (5 * batch_size) == 0:
            # print("Emptying CUDA cache...")
            torch.cuda.empty_cache()

        i += cur_batch_size

    assert len(multi_pass_embeddings) == len(texts)

    return multi_pass_embeddings

def get_overlapping_phrases(text: str, min_n: int, max_n: int) -> list[str]:
    """Generates all overlapping phrases of specified lengths from a text."""
    words = text.split()
    phrases = []
    # Ensure min_n and max_n are within the bounds of the text length
    min_n = min(min_n, len(words))
    max_n = min(max_n, len(words))
    
    for n in range(min_n, max_n + 1):
        for i in range(len(words) - n + 1):
            phrases.append(" ".join(words[i:i+n]))
    return phrases

def calculate_soft_n_gram_similarity(
    source_text: str,
    edited_text: str,
    min_n: int = 2,
    max_n: int = 5,
    similarity_threshold: float = 0.8, # let this threshold be a float or a list of floats, if a single float return the similarity score, if a list of floats return a dict of similarity where key is the threshold and value is the similarity score
) -> float:
    """
    Computes the soft n-gram similarity between a source and an edited text.

    This function implements the precision-based metric described in the EDITLENS paper (Section 3.3).
    A higher score indicates greater similarity.

    Args:
        source_text: The original, human-written text.
        edited_text: The AI-edited version of the source text.
        min_n: The minimum length of n-grams to consider.
        max_n: The maximum length of n-grams to consider.
        similarity_threshold: The cosine similarity threshold (τ) for a phrase to be considered a match.
        model_name: The SentenceTransformer model to use for embeddings.

    Returns:
        The soft n-gram similarity score.
    """

    # 1. Enumerate all phrases for both texts 
    source_phrases = get_overlapping_phrases(source_text, min_n, max_n)
    edited_phrases = get_overlapping_phrases(edited_text, min_n, max_n)

    if not edited_phrases or not source_phrases:
        return 1.0 if source_text == edited_text else 0.0

    # 2. Compute embeddings for all phrases
    # print(f"Num source phrases: {len(source_phrases)}")
    source_embeddings = embed_texts(source_phrases, batch_size=args.batch_size)
    # print(f"Num edited phrases: {len(edited_phrases)}")
    edited_embeddings = embed_texts(edited_phrases, batch_size=args.batch_size)

    # 3. Compute the pairwise cosine similarity between phrases [cite: 194]
    similarity_matrix = cosine_similarity(edited_embeddings, source_embeddings)

    # 4. Count the number of edited phrases with a similarity above the threshold
    #    for any phrase in the source text[cite: 194].

    if isinstance(similarity_threshold, list):
        scores = dict()
        for threshold in similarity_threshold:
            matched_phrase_count = 0
            for i in range(len(edited_phrases)):
                # Check if any similarity value in the row is above the threshold
                if np.any(similarity_matrix[i, :] > threshold):
                    matched_phrase_count += 1
                    
            # 5. The score is the count divided by the total number of phrases in the edited text,
            #    making it a precision-based metric.
            total_edited_phrases = len(edited_phrases)
            score = matched_phrase_count / total_edited_phrases if total_edited_phrases > 0 else 0.0
            scores[threshold] = score
        return scores
    
    elif isinstance(similarity_threshold, float):
        matched_phrase_count = 0
        for i in range(len(edited_phrases)):
            # Check if any similarity value in the row is above the threshold
            if np.any(similarity_matrix[i, :] > similarity_threshold):
                matched_phrase_count += 1
                
        # 5. The score is the count divided by the total number of phrases in the edited text,
        #    making it a precision-based metric.
        total_edited_phrases = len(edited_phrases)
        score = matched_phrase_count / total_edited_phrases if total_edited_phrases > 0 else 0.0

        return score
    
    else:
        raise ValueError("similarity_threshold must be a float or a list of floats")

# read all files ending with "_ref.txt" in the input directory
input_dir = args.input_dir

references = dict()

for f in os.listdir(input_dir):
    if f.endswith('_ref.txt'):
        prompt_variant = f.split('_ref.txt')[0]
        with open(os.path.join(input_dir, f), 'r') as file:
            references[prompt_variant] = file.read().strip()
               
# i want a heatmap of soft n-gram similarities between all pairs of prompts P1, P2, P3, P4

similarity_matrix = np.zeros((4, 4))

for i in range(1, 5):
    for j in range(1, 5):
        if i != j:
            prompt_i = f'P{i}'
            prompt_j = f'P{j}'
            similarity_score = calculate_soft_n_gram_similarity(
                source_text=references[prompt_i],
                edited_text=references[prompt_j],
                min_n=args.min_n,
                max_n=args.max_n,
                similarity_threshold=args.similarity_threshold
            )
            print(f"Soft n-gram similarity between {prompt_i} and {prompt_j}: {similarity_score}")
            similarity_matrix[i-1, j-1] = similarity_score
        else:
            similarity_matrix[i-1, j-1] = 1.0

# plot heatmap
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8, 6))
sns.heatmap(similarity_matrix, annot=True, fmt=".2f", cmap="YlGnBu",
            xticklabels=[f'P{i}' for i in range(1, 5)],
            yticklabels=[f'P{i}' for i in range(1, 5)])
plt.title("Soft N-Gram Similarity Heatmap")
plt.xlabel("Candidate Review")
plt.ylabel("Reference Review")
plt.savefig(os.path.join(args.input_dir, f"soft_n_gram_sim_hmp_tau_{args.similarity_threshold}.png"))