import pandas as pd
import numpy as np
import gzip
import os
import subprocess
import random
from Bio import SeqIO
from config import *

# 1. Load high-confidence physical interactions with experimental evidence > 0
def load_physical_ppis(file_path, experiment_threshold=0):
    df = pd.read_csv(file_path, sep=' ')
    df = df[df['experimental'] > experiment_threshold].copy()
    print(f"Filtered physical PPIs with experimental evidence > {experiment_threshold}: {len(df)}")
    return df[['protein1', 'protein2']]

# 2. Create filtered FASTA with proteins 50-800 AAs only
def create_filtered_fasta(interactions_df, fasta_path, output_fasta, min_len=50, max_len=800):
    interaction_ids = set(interactions_df['protein1'].tolist() + interactions_df['protein2'].tolist())
    written = 0
    found_ids = set()
    seq_lengths = {}

    with gzip.open(fasta_path, 'rt') as input_handle, open(output_fasta, 'w') as output_handle:
        for record in SeqIO.parse(input_handle, 'fasta'):
            rec_id = record.id.strip().split()[0]
            if rec_id in interaction_ids and min_len <= len(record.seq) <= max_len:
                SeqIO.write(record, output_handle, 'fasta')
                written += 1
                found_ids.add(rec_id)
                seq_lengths[rec_id] = len(record.seq)

    print(f"Written {written} filtered FASTA records out of {len(interaction_ids)} requested.")
    missing = interaction_ids - found_ids
    print(f"IDs not found in FASTA: {len(missing)} (examples: {list(missing)[:5]})")
    return seq_lengths

# 3. Parse CD-HIT cluster file
def parse_cdhit_clusters(clstr_file):
    clusters = {}
    cluster_id = 0
    current = []
    with open(clstr_file, 'r') as f:
        for line in f:
            if line.startswith('>Cluster'):
                if current:
                    clusters[cluster_id] = current
                    cluster_id += 1
                    current = []
            else:
                if '>' in line:
                    pid = line.split('>')[1].split('...')[0].strip()
                    current.append(pid)
    if current:
        clusters[cluster_id] = current
    print(f"Parsed {len(clusters)} clusters.")
    return clusters

# 4. Remove redundant PPIs based on CD-HIT clusters
def remove_sequence_redundant_ppis(df, clusters):
    protein_to_cluster = {p: cid for cid, plist in clusters.items() for p in plist}
    seen = set()
    filtered_rows = []
    for _, row in df.iterrows():
        p1, p2 = row['protein1'], row['protein2']
        c1, c2 = protein_to_cluster.get(p1), protein_to_cluster.get(p2)
        if c1 is None or c2 is None:
            continue
        if ((c1, c2) in seen) or ((c2, c1) in seen):
            continue
        seen.add((c1, c2))
        filtered_rows.append(row)
    print(f"Non-redundant PPIs: {len(filtered_rows)}")
    return pd.DataFrame(filtered_rows)

# 5. Generate negative samples
def generate_negative_ppis(positive_df, valid_proteins, ratio=10):
    proteins = list(set(valid_proteins))
    positives = set(tuple(sorted([r['protein1'], r['protein2']])) for _, r in positive_df.iterrows())
    negatives = []
    max_neg = len(positive_df) * ratio
    attempts = 0
    while len(negatives) < max_neg and attempts < max_neg * 10:
        p1, p2 = random.sample(proteins, 2)
        pair = tuple(sorted([p1, p2]))
        if pair not in positives:
            negatives.append({'protein1': pair[0], 'protein2': pair[1], 'label': 0})
            positives.add(pair)
        attempts += 1
    print(f"Generated {len(negatives)} negative samples.")
    return pd.DataFrame(negatives)

# others
def normalize_pairs(df, col1, col2):
    """Ensure (a,b) and (b,a) are treated as the same pair."""
    norm1 = df[[col1, col2]].min(axis=1)
    norm2 = df[[col1, col2]].max(axis=1)
    return pd.DataFrame({col1: norm1, col2: norm2})