import h5py
import numpy as np
import os
import subprocess
from Bio import SeqIO
import pandas as pd
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity
import torch
from config import *

# Pooling function
def pool_embedding(emb):
    emb = np.array(emb)
    if emb.ndim == 2:
        return emb.mean(axis=0)  # mean pooling over residues
    if emb.ndim == 1:
        return emb
    return None  # invalid shape

def load_sequences(fasta_file, min_len=50, max_len=800):
    valid = {}
    with open(fasta_file, 'r') as f:
        for record in SeqIO.parse(f, "fasta"):
            if min_len <= len(record.seq) <= max_len:
                valid[record.id] = str(record.seq)
    return valid

# Load embeddings from .h5 file
def load_embeddings(h5_file):
    embeddings = {}
    with h5py.File(h5_file, "r") as h5:

        for key in tqdm(h5.keys(), desc="Processing embeddings"):
            data = h5[key][:]
            if data.ndim == 3 and data.shape[0] == 1:
                data = data.squeeze(0)
            if data.ndim == 2:
                embeddings[key] = data.mean(axis=0)
            else:
                print(f"Skipping {key}, invalid shape: {data.shape}")
    return embeddings

# Load a single embedding with pooling
def load_single_embedding(h5_path, protein_name):
    with h5py.File(h5_path, "r") as f:
        if protein_name in f:
            return pool_embedding(f[protein_name][:])
    return None

# Generate embedding and save it to global file
def generate_embedding(protein_name, sequence,
                       tmp_fasta="tmp_single.fasta",
                       tmp_out="tmp_embedding.h5",
                       global_h5=EMBEDDING_FILE):
    # Save temporary FASTA
    with open(tmp_fasta, "w") as f:
        f.write(f">{protein_name}\n{sequence}\n")

    # Run embed.py
    subprocess.run([
        "python", DSCRIPT_EMBED,
        "--seqs", tmp_fasta,
        "-o", tmp_out,
        "-d", str(DEVICE)
    ], check=True)

    # Load result
    new_emb = load_single_embedding(tmp_out, protein_name)
    if new_emb is None:
        print(f"Failed to generate valid embedding for {protein_name}")
        return None

    # Save to global .h5
    with h5py.File(global_h5, "a") as f:
        if protein_name in f:
            print(f"Overwriting existing embedding: {protein_name}")
            del f[protein_name]
        f.create_dataset(protein_name, data=new_emb)

    # Clean up
    os.remove(tmp_fasta)
    os.remove(tmp_out)

    return new_emb

# Full helper: try to get from memory, else generate
def get_embedding(protein_name, embeddings_dict, seq_dict):
    if protein_name in embeddings_dict:
        return embeddings_dict[protein_name]

    if protein_name not in seq_dict:
        print(f"Sequence not found for {protein_name}")
        return None

    print(f"Generating embedding for {protein_name}...")
    new_emb = generate_embedding(protein_name, seq_dict[protein_name])
    if new_emb is not None:
        embeddings_dict[protein_name] = new_emb
    return new_emb
