import transformers
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns  # Add seaborn import for heatmap
from collections import defaultdict
from sklearn.manifold import MDS  # Replace PCA import with MDS
from sklearn.model_selection import train_test_split  # Add train_test_split import

# Load LLaMA 3 model using transformers pipeline
print("Loading LLaMA model using transformers pipeline...")
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # Replace with the actual model name on Hugging Face
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
tokenizer = transformers.LlamaTokenizer.from_pretrained(MODEL_NAME, legacy=False)
tokenizer.pad_token = tokenizer.eos_token  # Set pad token
model = transformers.LlamaForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", offload_folder="offload")
print("LLaMA model loaded successfully.")

# Function to get embeddings of a list of texts
def get_embeddings(texts):
    tokenized_output = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    input_ids = tokenized_output["input_ids"]
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    last_hidden_states = hidden_states[-1]
    return last_hidden_states.mean(dim=1)

# Load the animal-attribute pairs
animal_attribute_file = './data/animal-habit.txt'
animal_attribute_pairs = defaultdict(list)
all_animals = set()
all_attributes = set()

with open(animal_attribute_file, 'r') as file:
    for line in file:
        parts = line.strip().split(',')
        if len(parts) >= 2:
            animal, attribute = parts[:2]
            # print(animal, attribute)
            animal_attribute_pairs[attribute].append(animal)
            all_animals.add(animal)
            all_attributes.add(attribute)

# Split the animal-attribute pairs into training and testing sets
train_pairs = defaultdict(list)
test_pairs = defaultdict(list)

for attribute, animals in animal_attribute_pairs.items():
    train_animals, test_animals = train_test_split(animals, test_size=0.2, random_state=42)
    train_pairs[attribute].extend(train_animals)
    test_pairs[attribute].extend(test_animals)

# Get embeddings for all animals
animal_embeddings = {animal: get_embeddings([animal]) for animal in all_animals}

# Calculate attribute directions using training pairs
attribute_directions = {}
for attribute, animals_with_attribute in train_pairs.items():
    animals_without_attribute = all_animals - set(animals_with_attribute)
    if not animals_without_attribute:
        continue

    embeddings_with_attribute = torch.stack([animal_embeddings[animal] for animal in animals_with_attribute]).mean(dim=0)
    embeddings_without_attribute = torch.stack([animal_embeddings[animal] for animal in animals_without_attribute]).mean(dim=0)
    attribute_directions[attribute] = embeddings_with_attribute - embeddings_without_attribute

# Select top 20 most frequent attributes
top_attributes = sorted(attribute_directions.keys(), key=lambda attr: len(train_pairs[attr]), reverse=True)[:20]
attribute_directions = {attr: attribute_directions[attr] for attr in top_attributes}

# Reduce dimensions using MDS
mds = MDS(n_components=2, random_state=42)
directions_matrix = np.stack([direction.cpu().numpy() for direction in attribute_directions.values()]).reshape(len(attribute_directions), -1)
directions_reduced = mds.fit_transform(directions_matrix)
attribute_directions_reduced = {attribute: directions_reduced[i] for i, attribute in enumerate(attribute_directions)}

# Order attribute directions by attribute names
attribute_directions_reduced = dict(sorted(attribute_directions_reduced.items()))

# Plot attribute directions with different colors and labels at the end
plt.figure(figsize=(10, 8))
norm = plt.Normalize(vmin=directions_reduced.min(), vmax=directions_reduced.max())
colors = plt.cm.viridis(norm(directions_reduced))

for i, (attribute, direction) in enumerate(attribute_directions_reduced.items()):
    color = colors[i]
    plt.arrow(0, 0, direction[0], direction[1], head_width=0.05, head_length=0.1, color=color)
    plt.text(direction[0], direction[1], attribute, fontsize=12, color=color)

# Adjust plot limits to ensure all labels are within the figure
plt.xlim(min(directions_reduced[:, 0]) - 0.1, max(directions_reduced[:, 0]) + 0.1)
plt.ylim(min(directions_reduced[:, 1]) - 0.1, max(directions_reduced[:, 1]) + 0.1)

plt.xlabel('MDS Component 1')
plt.ylabel('MDS Component 2')
plt.title('Attribute Directions (MDS Reduced)')
plt.show()

# Calculate and plot the correlation matrix using heatmap
correlation_matrix = np.corrcoef(directions_matrix)
plt.figure(figsize=(12, 10))
sns.heatmap(correlation_matrix, annot=False, xticklabels=attribute_directions.keys(), yticklabels=attribute_directions.keys(), cmap='coolwarm')
plt.title('Correlation Matrix of Attribute Directions')
plt.show()

# Test the accuracy on the testing set
correct_predictions = 0
total_predictions = 0

for attribute, animals in test_pairs.items():
    if attribute not in attribute_directions:
        continue
    direction = attribute_directions[attribute]
    for animal in animals:
        embedding = animal_embeddings[animal]
        projection = torch.dot(embedding, direction)
        if projection > 0:
            correct_predictions += 1
        total_predictions += 1

accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
print(f"Accuracy on the testing set: {accuracy:.2f}")


