import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, average_precision_score

import random
import pickle
import os
import argparse

# Fix random seed for reproducibility
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

# Parse command line arguments
parser = argparse.ArgumentParser(description='Process model key and dataset.')
parser.add_argument('--model_key', type=str, default='llama3-8b', help='Model key to use for loading models and directories')
parser.add_argument('--dataset', type=str, default='animal', help='Dataset to use for loading data and directories')
parser.add_argument('--embedding_method', type=str, default='lda', choices=['lda', 'random','mean'], help='Method to estimate attribute embeddings: lda or mean')
parser.add_argument('--training', type=bool, default=True, help='If True, train the model. If False, predict subsumption directly.')
# parser.add_argument('--scoring_method', type=str, default='v1', choices=['v1', 'v2'], help='Scoring method to use: v1 or v2')
args = parser.parse_args()

# Load hypernyms file based on dataset
hypernyms_file = f'./datasets/{args.dataset}/hypernyms.txt'

# Load embeddings from files based on model key and dataset
object_embeddings_file = f'./datasets/{args.dataset}/{args.model_key}/embeddings/object_embeddings.pkl'
attribute_embeddings_file = f'./datasets/{args.dataset}/{args.model_key}/embeddings/attribute_embeddings_train_{args.embedding_method}.pkl'
attribute_thresholds_file = f'./datasets/{args.dataset}/{args.model_key}/embeddings/attribute_thresholds_train_{args.embedding_method}.pkl'
# Load object hypernyms
hypernyms = []
with open(hypernyms_file, 'r') as f:
    for line in f:
        parts = line.strip().split(',')
        if len(parts) == 2:
            hypernyms.append((parts[0].strip(), parts[1].strip()))

# Load embeddings from files
with open(object_embeddings_file, 'rb') as f:
    object_embeddings = pickle.load(f)

with open(attribute_embeddings_file, 'rb') as f:
    attribute_embeddings = pickle.load(f)

with open(attribute_thresholds_file, 'rb') as f:
    attribute_thresholds = pickle.load(f)

# Filter out hypernyms pairs with objects not in embeddings
embedded_objects = set(object_embeddings.keys())
filtered_hypernyms = [(a, b) for a, b in hypernyms if a in embedded_objects and b in embedded_objects]

# Print the number of pairs used
print(f"Total pairs after filtering: {len(filtered_hypernyms)}")

test_split = 0.2
# Split into train and test sets
train_hypernyms, test_hypernyms = train_test_split(filtered_hypernyms, test_size=test_split, random_state=42)

# Print the number of pairs used
print(f"Training pairs: {len(train_hypernyms)}")
print(f"Testing pairs: {len(test_hypernyms)}")

# Double the samples by adding negative pairs
def augment_with_negatives(hypernyms, objects):
    augmented = []
    for a, b in hypernyms:
        augmented.append((a, b, 1))  # Positive pair
        if random.random() > 0.5:
            a = random.choice(objects)
        else:
            b = random.choice(objects)
        augmented.append((a, b, 0))  # Negative pair
    return augmented

objects = list(set([a for a, b in filtered_hypernyms] + [b for a, b in filtered_hypernyms]))
train_data = augment_with_negatives(train_hypernyms, objects)
test_data = augment_with_negatives(test_hypernyms, objects)

# Construct embedding context from object and attribute embeddings

# Define the path to the embedding context file
embedding_context_path = f'./datasets/{args.dataset}/{args.model_key}/results/embedding_context_{args.embedding_method}.csv'

try:
    # Try to load the embedding_context from the saved file
    embedding_context = pd.read_csv(embedding_context_path, sep='\t', index_col=0)
    print("Loaded embedding_context from file.")
except FileNotFoundError:
    # If the file does not exist, calculate the embedding_context
    # print("File not found. Calculating embedding_context.")
    embedding_context = pd.DataFrame(0.0, index=object_embeddings.keys(), columns=attribute_embeddings.keys())
    for object in object_embeddings:
        for attribute in attribute_embeddings:
            object_vector = object_embeddings[object].flatten()
            attribute_vector = attribute_embeddings[attribute].flatten()
            if np.linalg.norm(attribute_vector) > 0:  # Avoid division by zero
                projection_length = np.dot(object_vector, attribute_vector) / np.linalg.norm(attribute_vector)
            else:
                projection_length = 0.0
            embedding_context.loc[object, attribute] = projection_length - attribute_thresholds[attribute]
    # Save the calculated embedding_context to a file
    embedding_context.to_csv(embedding_context_path, sep='\t')
    print("Calculated and saved embedding_context to file.")

# Normalize the embedding context based on the thresholds
def normalize_embedding_context(embedding_context, thresholds):
    """
    Normalize the embedding context by subtracting the threshold for each attribute.

    Args:
        embedding_context (pd.DataFrame): DataFrame containing projection profiles for all concepts.
        thresholds (dict): Dictionary of thresholds for each attribute.

    Returns:
        pd.DataFrame: Normalized embedding context.
    """
    normalized_context = embedding_context.copy()
    for attribute, threshold in thresholds.items():
        if attribute in normalized_context.columns:
            normalized_context[attribute] -= threshold
    return normalized_context

# Normalize the embedding context
embedding_context = normalize_embedding_context(embedding_context, attribute_thresholds)

# Define the MLP model
class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(input_size * 2, 1000)
        self.output = nn.Linear(1000, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.hidden(x))
        x = self.sigmoid(self.output(x))
        return x

# Prepare feature vectors
def get_feature_vector(object, embedding_context):
    return embedding_context.loc[object].values

# Define model path
model_path = f'./datasets/{args.dataset}/{args.model_key}/models/mlp_model_{test_split}_{args.embedding_method}.pth'


def evaluate_inclusion(hypernyms, embedding_context):
    """
    Evaluate subclass relationships based on the sign of attributes in the normalized embedding context.

    Args:
        hypernyms (list): List of true hypernym pairs (A, B).
        embedding_context (pd.DataFrame): DataFrame containing normalized projection profiles for all concepts.

    Returns:
        tuple: Precision, recall, and F1 score.
    """
    y_true = []
    y_pred = []
    y_score = []

    # Generate projection profiles for all concepts
    projection_profiles = embedding_context.values
    concept_to_projection = {concept: projection_profiles[i] for i, concept in enumerate(embedding_context.index)}

    for A, B in hypernyms:
        if A in concept_to_projection and B in concept_to_projection:
            # True label: A is a subclass of B
            y_true.append(1)

            # Predicted label: Check if A satisfies all attribute requirements of B
            projection_A = concept_to_projection[A]
            projection_B = concept_to_projection[B]
            soft_score = np.mean(1 / (1 + np.exp(-projection_A[projection_B > 0]))) if np.any(projection_B > 0) else 1.0
            # soft_score = np.average(1 / (1 + np.exp(-projection_A[projection_B > 0])), weights=projection_B[projection_B > 0]) if np.any(projection_B > 0) else 1.0
            # soft_score = np.average(1 / (1 + np.exp(-projection_A)), weights=np.log1p(np.exp(projection_B))) if np.any(projection_B > 0) else 1.0
            # soft_score = np.average(1 / (1 + np.exp(-projection_A)), weights=1 / (1 + np.exp(-projection_B / 0.1)))

            is_subclass = soft_score > 0.60
            
            y_pred.append(1 if is_subclass else 0)
            y_score.append(soft_score)

    # Generate negative pairs (A, C) where C is not a superclass of A
    all_concepts = list(embedding_context.index)
    for A, B in hypernyms:
        sampled_C = random.choice([concept for concept in all_concepts if concept != A and (A, concept) not in hypernyms])
        if A in concept_to_projection and sampled_C in concept_to_projection:
            # True label: A is not a subclass of sampled_C
            y_true.append(0)
            # Predicted label: Check if A satisfies all attribute requirements of C
            projection_A = concept_to_projection[A]
            projection_C = concept_to_projection[sampled_C]
            soft_score = np.mean(1 / (1 + np.exp(-projection_A[projection_C > 0]))) if np.any(projection_C > 0) else 1.0
            # soft_score = np.average(1 / (1 + np.exp(-projection_A[projection_C > 0])), weights=projection_C[projection_C > 0]) if np.any(projection_C > 0) else 1.0
            # soft_score = np.average(1 / (1 + np.exp(-projection_A)), weights=np.log1p(np.exp(projection_C))) if np.any(projection_C > 0) else 1.0
            # soft_score = np.average(1 / (1 + np.exp(-projection_A)), weights=1 / (1 + np.exp(-projection_C / 0.1)))
            is_subclass = soft_score >= 0.60
            # print(is_subclass)
            y_pred.append(1 if is_subclass else 0)
            y_score.append(soft_score)

    # Calculate precision, recall, and F1 score
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred,average='macro')
    # auc_pr = average_precision_score(y_true, y_score, average='weighted')
    return precision, recall, f1, 0



if args.training == "True":
    # Check if model exists and load it
    if os.path.exists(model_path):
        print("Loading existing model...")
        model = MLP(len(embedding_context.columns))
        model.load_state_dict(torch.load(model_path))
    else:
        # Train the model
        model = MLP(len(embedding_context.columns))
        criterion = nn.BCELoss()
        optimizer = optim.Adam(model.parameters(), lr=0.001)

        # Training loop
        for epoch in range(10):  # Number of epochs
            model.train()
            total_loss = 0
            for a, b, label in train_data:
                a_vector = get_feature_vector(a, embedding_context)
                b_vector = get_feature_vector(b, embedding_context)
                input_vector = torch.tensor(np.concatenate([a_vector, b_vector]), dtype=torch.float32)
                label_tensor = torch.tensor([label], dtype=torch.float32)

                optimizer.zero_grad()
                output = model(input_vector)
                loss = criterion(output, label_tensor)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
            print(f'Epoch {epoch+1}, Loss: {total_loss/len(train_data)}')

        # Save the model
        os.makedirs(os.path.dirname(model_path), exist_ok=True)
        torch.save(model.state_dict(), model_path)
        print("Model saved.")

    # Evaluate the model
    def evaluate(data):
        model.eval()
        y_true = []
        y_pred = []
        with torch.no_grad():
            for a, b, label in data:
                a_vector = get_feature_vector(a, embedding_context)
                b_vector = get_feature_vector(b, embedding_context)
                input_vector = torch.tensor(np.concatenate([a_vector, b_vector]), dtype=torch.float32)
                output = model(input_vector)
                y_true.append(label)
                y_pred.append(1 if output.item() > 0.5 else 0)
        return f1_score(y_true, y_pred), precision_score(y_true, y_pred), recall_score(y_true, y_pred)

    f1, precision, recall = evaluate(test_data)
    print(f'Test Precision: {precision:.3f}')
    print(f'Test Recall: {recall:.3f}')
    print(f'Test F1 Score: {f1:.3f}')
else:
    # Prediction mode: Evaluate symbolic subclass relationships with ranking metrics
    print("Evaluating subclass relationships with ranking metrics (batch computation)...")

    # Calculate precision, recall, F1 score, MRR, and HIT@K
    thresholds = list(attribute_thresholds.values())
    precision, recall, f1, auc_pr = evaluate_inclusion(filtered_hypernyms, embedding_context)
    
    print(f"Precision: {precision:.3f}")
    print(f"Recall: {recall:.3f}")
    print(f"F1 Score: {f1:.3f}")
    # print(f"AUC-PR: {auc_pr:.3f}")



