from openai import OpenAI
import nltk
import gzip
import io
from Levenshtein import ratio


def embedding_similarity(model, originals, suspects):
    """
    Computes the cosine similarity between the original and suspect texts using embeddings.
    """
    original_embedding = model.encode(originals, show_progress_bar=False, device=model.device, convert_to_tensor=True).unsqueeze(0)
    suspect_embeddings = model.encode(suspects, show_progress_bar=False, device=model.device, convert_to_tensor=True)
    
    scores = (original_embedding @ suspect_embeddings.T)
    scores = scores / (original_embedding.norm(dim=1)[:, None] * suspect_embeddings.norm(dim=1)[None, :])
    
    return scores.cpu().numpy().flatten()


def detect_paraphrase_gpt(original, suspect):
    sys_prompt = """
    I will now give you two texts. I will enclose the two texts with curly braces \{\}.
    Please help me determine if the following two texts are the same.
    Disregard the names and minor changes in word order that appear within.
    If they are, please answer 'True', otherwise answer 'False'. Do not respond with anything else.
    If their contents are very similar and they contain the same information, we consider them to be the same text.
    """
    client = OpenAI()
    completion = client.chat.completions.create(
        model="gpt-4.1-nano",
        messages=[
            {"role": "developer", "content": sys_prompt},
            {
                "role": "user",
                "content": "part1: \{\n" + original + "\n\}\npart2: \{\n" + suspect + "\n\}",
            }
        ]
    )

    # print(completion.choices[0].message)
    return completion.choices[0].message.content


def identity_similarity(str1, str2):
    """
    Checks for exact string identity.
    """
    return str1 == str2

def longest_common_subsequence_length(seq1, seq2):
    """
    Finds the length of the longest common subsequence using dynamic programming.
    Elements do not have to be consecutive.
    """
    m = len(seq1)
    n = len(seq2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if seq1[i - 1] == seq2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
            else:
                dp[i][j] = max(dp[i - 1][j], dp[i][j - 1])

    return dp[m][n]

def longest_common_substring_length(seq1, seq2):
    """
    Finds the length of the longest common substring (consecutive) using dynamic programming,
    where the elements are not necessarily characters, but can be words or tokens.

    Modified to take sequences (lists of strings) directly instead of strings.
    """
    m = len(seq1)
    n = len(seq2)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    length = 0

    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if seq1[i - 1] == seq2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1] + 1
                length = max(length, dp[i][j])
            else:
                dp[i][j] = 0

    return length

def jaccard_index(set1, set2):
    """
    Calculates the Jaccard index (intersection / union) of two sets.
    """
    assert isinstance(set1, set), "set1 must be a set"
    assert isinstance(set2, set), "set2 must be a set"
    intersection = len(set1.intersection(set2))
    union = len(set1.union(set2))
    return intersection / union if union > 0 else 0.0

def normalized_compression_distance(seq1, seq2):
    """
    Approach based on Kolmogorov complexity. Roughly: Compress Text A, compress Text B, and compress Text A concatenated with Text B. 
    Texts that compress well together are considered more similar. Uses standard compressors (like gzip, bzip2) as approximations. 
    NCD(A, B) = (Z(AB) - min(Z(A), Z(B))) / max(Z(A), Z(B)), where Z(X) is the compressed size of X.
    """

    def compress(data):
        """
        Compresses the data using gzip and returns the size of the compressed data.
        """
        with io.BytesIO() as buf:
            with gzip.GzipFile(fileobj=buf, mode='wb') as f:
                f.write(data.encode('utf-8'))
            return buf.tell()

    z_a = compress(seq1)
    z_b = compress(seq2)
    z_ab = compress(seq1 + seq2)

    return (z_ab - min(z_a, z_b)) / max(z_a, z_b) if max(z_a, z_b) > 0 else 0.0


def sequence_handler(seq1, seq2, lower=False, units="char", tokenizer=None):
    """
    Handles the sequence preprocessing for comparing two sequences.
    Converts strings to lists of characters, words, or tokens based on the specified type.
    """
    assert isinstance(seq1, str), "seq1 must be a string"
    assert isinstance(seq2, str), "seq2 must be a string"
    assert isinstance(lower, bool), "lower must be a boolean"
    assert units in ["full_string", "char", "treebank", "token"] 
    assert not (units == "token" and tokenizer is None), "Tokenizer must be provided when type is 'token'"

    if lower:
        seq1 = seq1.lower()
        seq2 = seq2.lower()

    if units == "full_string":
        seq1 = seq1
        seq2 = seq2
    elif units == "char":
        seq1 = list(seq1)
        seq2 = list(seq2)
    elif units == "treebank":
        seq1 = nltk.word_tokenize(seq1, language="english")
        seq2 = nltk.word_tokenize(seq2, language="english")
    elif units == "token":
        seq1 = tokenizer.tokenize(seq1, add_special_tokens=False)
        seq2 = tokenizer.tokenize(seq2, add_special_tokens=False)
    else:
        raise ValueError(f"Invalid units: {units}. Shouldn't have even reached here.")
    return seq1, seq2


def compare(str1, str2, units="char", tokenizer=None, embedding_model=None, comparison_strategy="longest_common_substring", lower=False):
    """
    Compares two strings using various strategies.
    """
    seq1, seq2 = sequence_handler(str1, str2, lower=lower, units=units, tokenizer=tokenizer)
    if comparison_strategy == "identity":
        assert units == "full_string", "identity comparison only works with full strings"
        return identity_similarity(seq1, seq2)
    elif comparison_strategy == "normalized_compression_distance":
        assert units == "full_string", "normalized_compression_distance only works with full strings"
        return normalized_compression_distance(str1, str2)
    elif comparison_strategy == "llm_gpt4":
        assert units == "full_string", "llm_gpt4 comparison only works with full strings"
        return detect_paraphrase_gpt(str1, str2)
    elif comparison_strategy == "embedding":
        assert units == "full_string", "embedding comparison only works with full strings"
        return embedding_similarity(embedding_model, str1, [str2])
    elif comparison_strategy == "levenshtein_distance":
        assert units in ["char", "treebank", "token"], "levenshtein only works with char, treebank, or token units"
        return nltk.edit_distance(seq1, seq2)
    elif comparison_strategy == "indel_similarity":
        assert units == "full_string", "indel_similarity only works with full strings"
        return ratio(seq1, seq2)
    elif comparison_strategy == "longest_common_subsequence":
        assert units in ["char", "treebank", "token"], "longest_common_subsequence only works with char, treebank, or token units"
        return longest_common_subsequence_length(seq1, seq2)
    elif comparison_strategy == "longest_common_substring":
        assert units in ["char", "treebank", "token"], "longest_common_substring only works with char, treebank, or token units"
        return longest_common_substring_length(seq1, seq2)
    elif comparison_strategy == "jaccard_index":
        assert units in ["char", "treebank", "token"], "jaccard_index only works with char, treebank, or token units"
        jaccard_index_results = {i: jaccard_index(set(nltk.ngrams(seq1, i)), set(nltk.ngrams(seq2, i))) for i in range(1, min(len(seq1), len(seq2)) + 1)}
        return jaccard_index_results
    else:
        raise ValueError(f"Invalid comparison strategy: {comparison_strategy}")
    
def batch_compare(str1, str2_batch, units="char", tokenizer=None, embedding_model=None, comparison_strategy="longest_common_substring", lower=False):
    """
    Compares a string with a batch of strings using various strategies.
    """
    assert isinstance(str1, str), "str1 must be a string"
    assert isinstance(str2_batch, list), "str2_batch must be a list of strings"
    assert all(isinstance(s, str) for s in str2_batch), "All elements in str2_batch must be strings"

    if comparison_strategy == "embedding":
        return embedding_similarity(embedding_model, str1, str2_batch)
    
    results = []
    for str2 in str2_batch:
        result = compare(str1, str2, units=units, tokenizer=tokenizer, comparison_strategy=comparison_strategy, lower=lower)
        results.append(result)
    
    return results
