import numpy as np
import pandas as pd
import argparse
from scipy.special import expit as sigmoid  # Sigmoid function

def compute_subsumption_score(p1, p2, tau, alpha, delta, method="projection"):
    """
    Compute the subsumption score for a pair of concepts.
    """
    k = len(p1)
    if method == "projection":
        score = sum(
            int(sigmoid(alpha * (p1[m] - tau[m])) <= sigmoid(alpha * (p2[m] - tau[m])) + delta)
            for m in range(k)
        ) / k
    elif method == "soft":
        score = sum(
            sigmoid(alpha * (sigmoid(alpha * (p2[m] - tau[m])) - sigmoid(alpha * (p1[m] - tau[m])) + delta))
            for m in range(k)
        ) / k
    return score

def load_ground_truth(filepath):
    """
    Load ground-truth subsumption pairs from hypernyms.txt.
    """
    ground_truth = []
    with open(filepath, 'r') as f:
        for line in f:
            a, b = line.strip().split(', ')
            ground_truth.append((a, b))
    return ground_truth

def compute_mrr(formal_context, thresholds, ground_truth, alpha, delta, method="projection"):
    """
    Compute the Mean Reciprocal Rank (MRR) for the subsumption task.
    """
    concepts = formal_context.index
    reciprocal_ranks = []

    for a, b in ground_truth:
        if a not in concepts or b not in concepts:
            continue

        # Get embeddings for the true pair
        p1 = formal_context.loc[a].values
        p2 = formal_context.loc[b].values

        # Compute the score for the true pair
        true_score = compute_subsumption_score(p1, p2, thresholds, alpha, delta, method)

        # Compute scores for all other pairs <A, B^1>
        scores = []
        for b1 in concepts:
            if b1 == b:
                continue
            p2_alt = formal_context.loc[b1].values
            score = compute_subsumption_score(p1, p2_alt, thresholds, alpha, delta, method)
            scores.append((b1, score))

        # Rank the true pair
        scores = sorted(scores, key=lambda x: x[1], reverse=True)
        rank = 1 + sum(1 for _, score in scores if score > true_score)
        reciprocal_ranks.append(1 / rank)

    # Compute MRR
    mrr = sum(reciprocal_ranks) / len(reciprocal_ranks)
    return mrr

if __name__ == "__main__":
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description="Predict subsumption scores and compute MRR.")
    parser.add_argument("--dataset", type=str, required=True, help="Dataset to use for loading data and directories.")
    parser.add_argument("--model_key", type=str, required=True, help="Model key to use for loading models and directories.")
    parser.add_argument("--embedding_method", type=str, default="lda", choices=["lda", "random", "mean"], help="Method to estimate attribute embeddings.")
    parser.add_argument("--alpha", type=float, default=1.0, help="Scaling factor for sigmoid function.")
    parser.add_argument("--delta", type=float, default=0.1, help="Margin of tolerance.")
    parser.add_argument("--method", type=str, choices=["projection", "soft"], default="projection", help="Method to compute subsumption score.")
    args = parser.parse_args()

    # Construct file paths based on dataset, model key, and embedding method
    formal_context_path = f'./datasets/{args.dataset}/{args.model_key}/results/formal_context_{args.embedding_method}.csv'
    thresholds_path = f'./datasets/{args.dataset}/{args.model_key}/results/attribute_thresholds_{args.embedding_method}.csv'
    hypernyms_file = f'./datasets/{args.dataset}/hypernyms.txt'

    # Load formal context, thresholds, and ground truth
    formal_context = pd.read_csv(formal_context_path, index_col=0)
    thresholds = pd.read_csv(thresholds_path, index_col=0).values.flatten()
    ground_truth = load_ground_truth(hypernyms_file)

    # Compute MRR
    mrr = compute_mrr(formal_context, thresholds, ground_truth, args.alpha, args.delta, args.method)
    print(f"Mean Reciprocal Rank (MRR): {mrr:.4f}")