# %%

# %%
import argparse
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch
from scipy.spatial.distance import cosine
from scipy.stats import pearsonr
import os
import pickle
import matplotlib.pyplot as plt
from sklearn.manifold import MDS 
from sklearn.covariance import ledoit_wolf
import seaborn as sns
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
# Parse command line arguments
parser = argparse.ArgumentParser(description='Process model key, dataset, and embedding method.')
parser.add_argument('--model_key', type=str, default='llama3-8b', help='Model key to use for loading models and directories')
parser.add_argument('--dataset', type=str, default='animal', help='Dataset to use for loading data and directories')
parser.add_argument('--embedding_method', type=str, default='mean', choices=['lda', 'svm', 'logistic', 'mean', 'random'], help='Method to estimate attribute embeddings: lda, mean, or random')
args = parser.parse_args()

# Set directory paths based on parsed arguments
dataset_dir = f'./datasets/{args.dataset}'
embeddings_dir = f'{dataset_dir}/{args.model_key}/embeddings'
results_dir = f'{dataset_dir}/{args.model_key}/results'
visualization_dir = f'{results_dir}/visualization'

# Ensure directories exist
os.makedirs(embeddings_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)
os.makedirs(visualization_dir, exist_ok=True)

# Load the object-attribute relationships with proper comma delimiter handling
data = pd.read_csv(
    f'{dataset_dir}/object_attribute_relationship.txt',
    delimiter='\t',  # Use comma as the delimiter
    engine='python'
)

# Extract objects and attributes
objects = data['object'].unique()
attributes = data['attribute'].unique()  # Ensure attributes are extracted correctly
relation_column = 'relation'  # The column indicating the relation (0/1)
print(f"Number of objects: {len(objects)}")
print(f"Number of attributes: {len(attributes)}")
# Step 1: Create formal_context binary matrix
formal_context = pd.DataFrame(0, index=objects, columns=attributes)
for _, row in data.iterrows():
    formal_context.loc[row['object'], row['attribute']] = row[relation_column]

# Step 2: Lazy load the model based on model_key
MODEL_MAP = {
    "llama3-8b": "meta-llama/Llama-3.1-8B",
    "mistral7b": "mistralai/Mistral-7B-v0.1",
    "gemma7b": "google/gemma-7b"
}

if args.model_key not in MODEL_MAP:
    raise ValueError(f"Invalid model_key: {args.model_key}. Valid options are: {', '.join(MODEL_MAP.keys())}")

MODEL_NAME = MODEL_MAP[args.model_key]  # Map model_key to the corresponding model name
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_llm():
    global tokenizer, model
    print(f"Loading model: {MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tokenizer.pad_token is None:  # Set a padding token if not already set
        tokenizer.pad_token = tokenizer.eos_token
    model = AutoModel.from_pretrained(MODEL_NAME).to(device)
    model.eval()

def get_embedding(texts):
    tokenized_output = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    input_ids = tokenized_output["input_ids"]
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    last_hidden_states = hidden_states[-1]
    return last_hidden_states.mean(dim=1)

# Load synonyms from file
synonym_file = f'{dataset_dir}/synonyms.txt'
synonyms = {}
with open(synonym_file, 'r') as f:
    for line in f:
        parts = line.strip().split(':')
        if len(parts) == 2:
            object, synonym_list = parts[0].strip(), parts[1].strip().split(',')
            synonyms[object] = [syn.strip() for syn in synonym_list]

# Step 3: Extract object embeddings (average of synonyms' embeddings)
embeddings_file = f'{embeddings_dir}/object_embeddings.pkl'

if os.path.exists(embeddings_file):
    print("Loading object embeddings from file...")
    with open(embeddings_file, 'rb') as f:
        object_embeddings = pickle.load(f)
else:
    print("Generating object embeddings...")
    load_llm()  # Load the LLM only if embeddings need to be generated
    object_embeddings = {}
    for object in objects:
        if object in synonyms:
            synonym_embeddings = []
            for synonym in synonyms[object]:
                synonym_embedding = get_embedding(synonym).cpu().numpy().flatten()
                synonym_embeddings.append(synonym_embedding)
            # Calculate the average embedding for the object
            object_embeddings[object] = np.mean(synonym_embeddings, axis=0)
        else:
            # If no synonyms are found, use the object's own embedding
            object_embeddings[object] = get_embedding(object).cpu().numpy().flatten()
    with open(embeddings_file, 'wb') as f:
        pickle.dump(object_embeddings, f)


# %%

# Split objects into train and test sets (80% train, 20% test)
train_objects, test_objects = train_test_split(objects, test_size=0.1, random_state=42)

# Step 4: Calculate attribute embeddings and thresholds using training objects only
attribute_embeddings_file = f'{embeddings_dir}/attribute_embeddings_train_{args.embedding_method}.pkl'
attribute_thresholds_file = f'{embeddings_dir}/attribute_thresholds_train_{args.embedding_method}.pkl'

if os.path.exists(attribute_embeddings_file) and os.path.exists(attribute_thresholds_file):
    print("Loading attribute embeddings and thresholds from file...")
    with open(attribute_embeddings_file, 'rb') as f:
        attribute_embeddings = pickle.load(f)
    with open(attribute_thresholds_file, 'rb') as f:
        attribute_thresholds = pickle.load(f)
else:
    
    print("Generating attribute embeddings and thresholds using training objects...")
    attribute_embeddings = {}
    attribute_thresholds = {}
    for attribute in attributes:
        positive_objects = [object for object in train_objects if formal_context.loc[object, attribute] == 1]
        negative_objects = [object for object in train_objects if formal_context.loc[object, attribute] == 0]

        if positive_objects and negative_objects:
            # Collect embeddings for positive and negative objects
            positive_embeddings = np.array([object_embeddings[object].flatten() for object in positive_objects])
            negative_embeddings = np.array([object_embeddings[object].flatten() for object in negative_objects])

            if args.embedding_method == 'lda':
                # Estimate covariance matrices using Ledoit-Wolf shrinkage
                positive_cov, _ = ledoit_wolf(positive_embeddings)
                negative_cov, _ = ledoit_wolf(negative_embeddings)

                # Calculate means for positive and negative embeddings
                positive_mean = positive_embeddings.mean(axis=0)
                negative_mean = negative_embeddings.mean(axis=0)

                # Add a small regularization to avoid singularity
                reg_param = 1e-5
                attribute_direction = np.linalg.inv(positive_cov + negative_cov + reg_param * np.eye(positive_cov.shape[0])).dot(
                    positive_mean - negative_mean
                )
            elif args.embedding_method == 'random':
                # Generate random attribute direction
                attribute_direction = np.random.randn(object_embeddings[next(iter(object_embeddings))].shape[0])
            elif args.embedding_method == 'mean':
                # Calculate attribute direction as the mean difference
                positive_mean = positive_embeddings.mean(axis=0)
                negative_mean = negative_embeddings.mean(axis=0)
                attribute_direction = positive_mean - negative_mean

            # Store the attribute direction as a 1D array
            attribute_embeddings[attribute] = np.array(attribute_direction).flatten()

            # Calculate optimal threshold
            positive_projection = positive_embeddings.dot(attribute_direction) / np.linalg.norm(attribute_direction)
            negative_projection = negative_embeddings.dot(attribute_direction) / np.linalg.norm(attribute_direction)
            optimal_threshold = (positive_projection.mean() + negative_projection.mean()) / 2
            attribute_thresholds[attribute] = optimal_threshold
        else:
            # If no positive or negative objects, set default values
            print(f"Skipping attribute {attribute} due to insufficient positive/negative objects.")

    with open(attribute_embeddings_file, 'wb') as f:
        pickle.dump(attribute_embeddings, f)
    with open(attribute_thresholds_file, 'wb') as f:
        pickle.dump(attribute_thresholds, f)

# Step 10: Inference - Classify test objects' attributes
test_classification_results = pd.DataFrame(0, index=test_objects, columns=attributes)

for object in test_objects:
    object_vector = object_embeddings[object].flatten()  # Ensure the vector is 1D
    for attribute in attributes:
        attribute_vector = attribute_embeddings[attribute].flatten()  # Ensure the vector is 1D
        threshold = attribute_thresholds[attribute]
        
        if np.linalg.norm(attribute_vector) > 0:  # Avoid division by zero
            projection_length = np.dot(object_vector, attribute_vector) / np.linalg.norm(attribute_vector)
        else:
            projection_length = 0.0
        
        # Classify based on the threshold
        test_classification_results.loc[object, attribute] = 1 if projection_length >= threshold else 0

# Step 5: Calculate embedding_context matrix using projection length
embedding_context = pd.DataFrame(0.0, index=objects, columns=attributes)
for object in objects:
    for attribute in attributes:
        object_vector = object_embeddings[object].flatten()  # Ensure the vector is 1D
        attribute_vector = attribute_embeddings[attribute].flatten()  # Ensure the vector is 1D
        # attribute_threshold = attribute_thresholds[attribute]
        # attribute_vector /= np.linalg.norm(attribute_vector)
        if np.linalg.norm(attribute_vector) > 0:  # Avoid division by zero
            projection_length = np.dot(object_vector, attribute_vector) / np.linalg.norm(attribute_vector)
        else:
            projection_length = 0.0
        embedding_context.loc[object, attribute] = projection_length

# Save test classification results to a file
# test_classification_results.to_csv(f'./datasets/{args.dataset}/{args.model_key}/results/test_classification_results_{args.embedding_method}.csv', sep='\t')
# Evaluate classification performance using F1 score, precision, and recall
test_formal_context = formal_context.loc[test_objects]
test_classification_results_flat = test_classification_results.values.flatten()
test_formal_context_flat = test_formal_context.values.flatten()

formal_context_flat = formal_context.values.flatten()
embedding_context_flat = embedding_context.values.flatten()
correlation, _ = pearsonr(formal_context_flat, embedding_context_flat)


f1 = f1_score(test_formal_context_flat, test_classification_results_flat, average='macro')
precision = precision_score(test_formal_context_flat, test_classification_results_flat, average='macro')
recall = recall_score(test_formal_context_flat, test_classification_results_flat, average='macro')

print(f"Test Classification Precision: {precision:.3f}")
print(f"Test Classification Recall: {recall:.3f}")
print(f"Test Classification F1 Score: {f1:.3f}")
print(f"Pearson correlation coefficient: {correlation}")

# # Evaluate ranking performance using filtered MRR
# def calculate_filtered_mrr(objects, attributes, formal_context, embedding_context, train_objects):
#     reciprocal_ranks = []
#     for object in objects:
#         # Get true positive attributes for the object
#         true_positive_attributes = set(formal_context.loc[object][formal_context.loc[object] == 1].index)
        
#         # Filter out attributes seen in the training graph
#         train_positive_attributes = set(
#             formal_context.loc[train_objects][formal_context.loc[train_objects] == 1].columns
#         )
#         filtered_attributes = true_positive_attributes - train_positive_attributes
        
#         # Calculate projection lengths for all attributes
#         projection_lengths = embedding_context.loc[object]
        
#         # Rank attributes by projection length in descending order
#         ranked_attributes = projection_lengths.sort_values(ascending=False).index
        
#         # Find the rank of the first true positive attribute in the filtered set
#         for rank, attribute in enumerate(ranked_attributes, start=1):
#             if attribute in filtered_attributes:
#                 reciprocal_ranks.append(1 / rank)
#                 break
#         else:
#             reciprocal_ranks.append(0)  # No true positive attribute found in the filtered set
    
#     # Calculate the mean reciprocal rank
#     return np.mean(reciprocal_ranks)



# Step 7: Calculate attribute frequency
# attribute_frequency = formal_context.sum(axis=0)

# # Save attribute frequency to a file
# with open(f'{results_dir}/attribute_frequency.txt', 'w') as f:
#     for attribute, frequency in attribute_frequency.items():
#         f.write(f"{attribute}\t{frequency}\n")

# def calculate_regularized_threshold(mu0, sigma0, mu1, sigma1, variance_ratio_threshold=1.2, epsilon=1e-6):
#     # Midpoint between the means
#     midpoint = (mu0 + mu1) / 2
#     return midpoint

# # Step 8: Visualize projection length distribution for attributes using a distribution plot
# for attribute in attributes:
#     positive_objects = [object for object in objects if formal_context.loc[object, attribute] == 1]
#     negative_objects = [object for object in objects if formal_context.loc[object, attribute] == 0]

#     projection_lengths_positive = [
#         embedding_context.loc[object, attribute] for object in positive_objects
#     ]
#     projection_lengths_negative = [
#         embedding_context.loc[object, attribute] for object in negative_objects
#     ]

#     # Retrieve the attribute direction as a 1D array
#     attribute_direction = np.array(attribute_embeddings[attribute]).flatten()

#     # Project positive and negative samples onto the attribute direction (no normalization)
#     positive_embeddings = np.array([object_embeddings[object].flatten() for object in positive_objects])
#     negative_embeddings = np.array([object_embeddings[object].flatten() for object in negative_objects])
#     positive_projection = positive_embeddings.dot(attribute_direction)/ np.linalg.norm(attribute_direction)
#     negative_projection = negative_embeddings.dot(attribute_direction)/ np.linalg.norm(attribute_direction)

#     # Compute means and variances of projections
#     positive_mean_proj = positive_projection.mean()
#     negative_mean_proj = negative_projection.mean()
#     positive_var_proj = positive_projection.var()
#     negative_var_proj = negative_projection.var()

#     # Debug: Check means and variances
#     # print(f"Attribute: {attribute}")
#     # print(f"Positive Mean Projection: {positive_mean_proj}")
#     # print(f"Negative Mean Projection: {negative_mean_proj}")
#     # print(f"Positive Variance Projection: {positive_var_proj}")
#     # print(f"Negative Variance Projection: {negative_var_proj}")

#     # Compute optimal threshold (without class priors)
#     optimal_threshold = calculate_regularized_threshold(
#         negative_mean_proj, negative_var_proj, positive_mean_proj, positive_var_proj
#     )

#     # Plot the distribution
#     plt.figure(figsize=(6, 5))
#     sns.kdeplot(projection_lengths_positive, color='blue', fill=True, alpha=0.5, label='Positive')
#     sns.kdeplot(projection_lengths_negative, color='red', fill=True, alpha=0.5, label='Negative')
#     plt.axvline(optimal_threshold, color='green', linestyle='--', linewidth=1.5, label='Threshold')
#     plt.title(f"{attribute}", fontsize=22)
#     plt.xlabel("Projection", fontsize=20)
#     plt.ylabel("Density", fontsize=20)
#     # plt.xticks(fontsize=14)
#     # plt.yticks(fontsize=14)
#     if attribute == "Active at night":  # Add legend only for the first attribute
#         plt.legend(fontsize=20)
#     plt.tight_layout()
#     plt.savefig(f'{visualization_dir}/projection_{attribute.replace(" ", "_")}.pdf', dpi=300)
#     plt.close()

# # Step 9: Visualize attribute embeddings as directions

# Ensure attribute_embeddings_matrix is constructed correctly
attribute_embeddings_matrix = np.array([embedding.flatten() for embedding in attribute_embeddings.values()])
attribute_names = list(attribute_embeddings.keys())  # Define attribute_names
if attribute_embeddings_matrix.shape[0] < 2 or attribute_embeddings_matrix.shape[1] < 2:
    raise ValueError("PCA requires at least 2 attributes with valid embeddings for dimensionality reduction.")

# Perform PCA
pca = PCA(n_components=2)
reduced_embeddings = pca.fit_transform(attribute_embeddings_matrix)

# Plot the reduced attribute embeddings as arrows
plt.figure(figsize=(12, 8))
for i, attribute in enumerate(attribute_names):
    plt.arrow(0, 0, reduced_embeddings[i, 0], reduced_embeddings[i, 1], 
              head_width=0.05, head_length=0.1, fc='green', ec='green', alpha=0.7)
    plt.text(reduced_embeddings[i, 0] * 1.1, reduced_embeddings[i, 1] * 1.1, 
             attribute, fontsize=9)

plt.xlabel("")
plt.ylabel("")
plt.grid(True)
plt.axhline(0, color='black', linewidth=0.5)
plt.axvline(0, color='black', linewidth=0.5)
plt.savefig(f'{results_dir}/attribute_embeddings_visualization.pdf')
plt.close()


# # Step 11: Visualize the correlation matrix of attributes using a heatmap

# # Calculate the correlation matrix for attribute embeddings
# attribute_embeddings_matrix = np.array([embedding.flatten() for embedding in attribute_embeddings.values()])

# # Check for constant or NaN rows and remove them
# valid_indices = ~np.isnan(attribute_embeddings_matrix).any(axis=1) & (attribute_embeddings_matrix.var(axis=1) > 0)
# attribute_embeddings_matrix = attribute_embeddings_matrix[valid_indices]
# valid_attribute_names = [attribute for i, attribute in enumerate(attributes) if valid_indices[i]]

# if attribute_embeddings_matrix.shape[0] < 2:
#     raise ValueError("Not enough valid attributes to compute the correlation matrix.")

# # Compute the correlation matrix
# correlation_matrix = np.corrcoef(attribute_embeddings_matrix)

# # Plot the heatmap
# plt.figure(figsize=(25, 25))
# sns.heatmap(correlation_matrix, xticklabels=valid_attribute_names, yticklabels=valid_attribute_names, cmap='coolwarm', annot=False)
# plt.title("Correlation Matrix of Attributes")
# plt.xlabel("Attributes")
# plt.ylabel("Attributes")
# plt.xticks(rotation=90, fontsize=8)
# plt.yticks(fontsize=8)
# plt.tight_layout()
# plt.savefig('./datasets/object_attributes/results/attribute_correlation_heatmap.png')
# plt.close()
