# similarity_search.py
import os
import h5py
import pandas as pd
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from Bio import SeqIO
from tqdm import tqdm
import tempfile
import torch
from pathlib import Path
from config import *
from embedding_utils import *

def compute_euclidean_similarity(embeddings, target_id):
    ''' Compute Euclidean similarity for a target protein against all others. '''
    if target_id not in embeddings:
        raise ValueError(f"Target protein {target_id} not found in embeddings.")

    # Convert dict to matrix
    ids = []
    vectors = []
    for pid, vec in embeddings.items():
        if pid == target_id:
            continue
        if np.isnan(vec).any() or np.isinf(vec).any():
            continue
        ids.append(pid)
        vectors.append(vec)

    matrix = np.stack(vectors)  # shape: (N, D)
    target_vec = embeddings[target_id].reshape(1, -1)

    # Compute Euclidean distances
    dists = euclidean_distances(target_vec, matrix)[0]  # shape: (N,)

    # Convert to similarity: lower distance = higher similarity
    similarities = 1 / (1 + dists)  # shift + invert for similarity-like scale

    df = pd.DataFrame({
        "protein_id": ids,
        "euclidean_similarity": similarities
    })
    df = df.sort_values(by="euclidean_similarity", ascending=False)
    return df

def compute_cosine_similarity(embeddings, target_id):
    ''' Compute Cosine similarity for a target protein against all others. 
    util for testing purposes.'''
    if target_id not in embeddings:
        raise ValueError(f"Target protein {target_id} not found in embeddings.")

    target_vec = embeddings[target_id].reshape(1, -1) # shape: (1, D)
    results = []
    # Compute cosine similarity against all others
    for prot_id, vec in embeddings.items():
        if prot_id == target_id:
            continue
        sim = cosine_similarity(target_vec, vec.reshape(1, -1))[0, 0]
        results.append((prot_id, sim))

    # Convert to DataFrame and compute z-scores (normalization)
    df = pd.DataFrame(results, columns=["protein_id", "cosine_similarity"])
    df["score_z"] = (df["cosine_similarity"] - df["cosine_similarity"].mean()) / df["cosine_similarity"].std()
    df = df.sort_values(by="score_z", ascending=False)
    return df

def similar_above_threshold_zscore(known_embedding, all_embeddings_dict, z_threshold=0.5, batch_size=2000):
    ''' Find all proteins with cosine similarity above a z-score threshold. 
    Use batching to limit memory usage. '''
    known_embedding = torch.tensor(known_embedding, dtype=torch.float32)

    names = list(all_embeddings_dict.keys())
    results = []
    all_scores = []
    all_names = []

    # First pass: compute all cosine similarities
    for i in range(0, len(names), batch_size):
        batch_names = names[i:i + batch_size]
        batch_vectors = []
        valid_names = []

        # Gather valid embeddings in the batch
        for name in batch_names:
            vec = all_embeddings_dict[name]
            if isinstance(vec, np.ndarray):
                vec = torch.tensor(vec, dtype=torch.float32)
            elif not isinstance(vec, torch.Tensor):
                continue

            if vec.ndim != 1:
                continue

            batch_vectors.append(vec)
            valid_names.append(name)

        if not batch_vectors:
            continue
        
        # convert to matrix form 
        matrix = torch.stack(batch_vectors)
        known_norm = torch.nn.functional.normalize(known_embedding, dim=0)
        matrix_norm = torch.nn.functional.normalize(matrix, dim=1)
        sims = torch.matmul(matrix_norm, known_norm)

        all_scores.extend(sims.tolist())
        all_names.extend(valid_names)

    # Convert to z-scores
    scores_np = np.array(all_scores)
    mean = scores_np.mean()
    std = scores_np.std()
    z_scores = (scores_np - mean) / std

    # Filter by z-threshold
    for name, score, z in zip(all_names, all_scores, z_scores):
        if z >= z_threshold:
            results.append((name, float(score)))

    print(f"Z-score filtering complete: retained {len(results)} of {len(all_scores)} proteins with z ≥ {z_threshold}")
    return results


def top_k_similar(known_embedding, all_embeddings_dict, k=20, batch_size=2000):
    ''' Find top-k most similar proteins using cosine similarity.
    Use batching to limit memory usage. '''
    known_embedding = torch.tensor(np.squeeze(known_embedding).flatten(), dtype=torch.float32)

    names = list(all_embeddings_dict.keys())
    top_results = []

    # Process in batches
    for i in range(0, len(names), batch_size):
        batch_names = names[i:i+batch_size]
        batch_vectors = []

        # Gather valid embeddings in the batch
        for name in batch_names:
            vec = np.squeeze(all_embeddings_dict[name]).flatten()
            if vec.ndim != 1:
                continue
            batch_vectors.append(torch.tensor(vec, dtype=torch.float32))

        if not batch_vectors:
            continue

        # convert to matrix form
        matrix = torch.stack(batch_vectors)
        known_norm = torch.nn.functional.normalize(known_embedding, dim=0)
        matrix_norm = torch.nn.functional.normalize(matrix, dim=1)
        sims = torch.matmul(matrix_norm, known_norm)

        batch_results = list(zip(batch_names, sims.tolist()))
        top_results.extend(batch_results)

    # Sort and keep top-k overall
    top_results.sort(key=lambda x: x[1], reverse=True)
    return top_results[:k]

def similar_above_threshold(known_embedding, all_embeddings_dict, threshold=0.5, batch_size=2000):
    ''' Find all proteins with cosine similarity above a threshold.
    DEBUG usage '''
    known_embedding = torch.tensor(known_embedding, dtype=torch.float32)

    names = list(all_embeddings_dict.keys())
    results = []
    total_checked = 0
    total_retained = 0

    for i in range(0, len(names), batch_size):
        batch_names = names[i:i + batch_size]
        batch_vectors = []
        valid_names = []

        for name in batch_names:
            vec = all_embeddings_dict[name]
            if isinstance(vec, np.ndarray):
                vec = torch.tensor(vec, dtype=torch.float32)
            elif not isinstance(vec, torch.Tensor):
                continue

            if vec.ndim != 1:
                continue

            batch_vectors.append(vec)
            valid_names.append(name)

        if not batch_vectors:
            continue

        matrix = torch.stack(batch_vectors)
        known_norm = torch.nn.functional.normalize(known_embedding, dim=0)
        matrix_norm = torch.nn.functional.normalize(matrix, dim=1)
        sims = torch.matmul(matrix_norm, known_norm)

        for name, score in zip(valid_names, sims):
            total_checked += 1
            if score >= threshold:
                results.append((name, float(score)))
                total_retained += 1

        print(f"Batch {i//batch_size + 1}: checked {len(valid_names)} proteins, retained {total_retained} so far")

    print(f"Finished similarity filtering: total checked = {total_checked}, retained = {total_retained}")
    return results


def run_similarity_search(p1_name, k=20, all_embeddings=None):
    print(f"Processing {p1_name}...")

    partner_file = f"{OUTPUT_DIR}/{p1_name}/v11_partners.tsv"
    if not os.path.exists(partner_file):
        print(f"No known partners file found for {p1_name}. Skipping.")
        return

    partners_df = pd.read_csv(partner_file, sep="\t")
    if partners_df.empty:
        print(f"{p1_name} has no known partners in v11. Skipping.")
        return

    partners = partners_df["known_partner"].dropna().tolist()
    if not partners:
        print(f"{p1_name} has no valid known partners. Skipping.")
        return

    alias_df = pd.read_csv(ALIAS_V11, sep="\t")
    name_to_string_id = dict(zip(alias_df["preferred_name"], alias_df["protein_external_id"]))

    # Load only proteins in v11_named
    v11_named = pd.read_csv(f"{OUTPUT_DIR}/v11_named.tsv", sep="\t")
    v11_ids = set(v11_named["protein1"]).union(set(v11_named["protein2"]))

    results = []

    # for each known partner -> find similar proteins (using embeddings)
    for kp in partners:
        string_id = name_to_string_id.get(kp)
        if string_id is None:
            print(f"No STRING ID found for known partner {kp}. Skipping.")
            continue

        emb = all_embeddings.get(string_id)
        if emb is None:
            print(f"No embedding found for {string_id}")
            continue

        # when using cosine similarity
        similar = similar_above_threshold(emb, all_embeddings, threshold=0.5)
        
        # save results
        for sim_name, score in similar:
            results.append({
                "protein1": p1_name,
                "known_partner": kp,
                "known_partner_id": string_id,
                "similar_protein": sim_name,
                "similarity_score": score
            })

    if not results:
        print(f"No similarity results generated for {p1_name}")
        return

    out_path = f"{OUTPUT_DIR}/{p1_name}/top_similars.tsv"
    pd.DataFrame(results).to_csv(out_path, sep="\t", index=False)
    print(f"Saved similarity results to {out_path}")