import pandas as pd
import numpy as np
from config_base import *

# Config
TOP_KS = [5, 10, 50, 100, 200, 500]

def compute_metrics():
    # Load data
    ranks_df = pd.read_csv(RANK_FILE, sep="\t")
    summary_df = pd.read_csv(SUMMARY_FILE, sep="\t")

    # Output container
    metrics = []

    for _, row in summary_df.iterrows():
        p1 = row["protein1"]
        total_possible = int(row["p2_in_v11"])
        if total_possible == 0:
            continue

        rediscovered = ranks_df[ranks_df["protein1"] == p1]
        if rediscovered.empty:
            continue

        ranks = sorted(rediscovered["Avg_Rank"].tolist())
        result = {
            "protein1": p1,
            "p2_in_v11": total_possible,
            "n_rediscovered": len(ranks),
            "rediscovered_ratio": len(ranks) / total_possible,
            "Avg_Rank": np.mean(ranks),
            "MRR": 1 / ranks[0] if ranks else 0.0
        }

        for k in TOP_KS:
            hits = [1 if r <= k else 0 for r in ranks]
            hit_count = sum(hits)

            # Recall
            result[f"Recall@{k}"] = hit_count / total_possible

            # Precision
            result[f"Precision@{k}"] = hit_count / k

            # Hit
            result[f"Hit@{k}"] = 1.0 if hit_count > 0 else 0.0

            # MAP
            ap = 0.0
            num_hits = 0
            for i, r in enumerate(ranks):
                if r <= k:
                    num_hits += 1
                    ap += num_hits / (i + 1)
            result[f"MAP@{k}"] = ap / min(total_possible, k) if total_possible > 0 else 0

            # nDCG
            dcg = sum(1 / np.log2(r + 1) for r in ranks if r <= k)
            ideal_dcg = sum(1 / np.log2(i + 2) for i in range(min(len(ranks), k)))
            result[f"nDCG@{k}"] = dcg / ideal_dcg if ideal_dcg > 0 else 0.0

            # Success@K (tutti i rediscovered devono stare entro top-k)
            result[f"Success@{k}"] = 1.0 if len(ranks) >= total_possible and max(ranks) <= k else 0.0

        metrics.append(result)

    # Save results
    metrics_df = pd.DataFrame(metrics).round(2)
    metrics_df.to_csv(RECOMMENDATION_METRICS_FILE, sep="\t", index=False)

    # Print global summary
    print("Global averages:")
    print(metrics_df[[f"Recall@{k}" for k in TOP_KS]].mean())
    print(metrics_df[[f"Precision@{k}" for k in TOP_KS]].mean())
    print(metrics_df[[f"MAP@{k}" for k in TOP_KS]].mean())
    print(metrics_df[[f"nDCG@{k}" for k in TOP_KS]].mean())
    print(metrics_df[[f"Success@{k}" for k in TOP_KS]].mean())
    print(f"MRR: {metrics_df['MRR'].mean():.4f}")
    print(f"Avg. Rank: {metrics_df['Avg_Rank'].mean():.2f}")
    print(f"Rediscovered Ratio: {metrics_df['rediscovered_ratio'].mean():.4f}")
    n_candidati = int((summary_df["p2_in_v11"] > 0).sum())
    n_con_rediscovery = len(set(ranks_df["protein1"]))  # at least one best rank
    print(f"Candidates (p2_in_v11>0): {n_candidati}")
    print(f"With at least one rediscovery: {n_con_rediscovery}")
    print(f"Avg Rank Intersection: {len(metrics_df)}")

if __name__ == "__main__":
    compute_metrics()
    print(f"Metrics computed and saved to {RECOMMENDATION_METRICS_FILE}")
