# %%
import transformers
import torch
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.decomposition import PCA  # Add PCA import
import seaborn as sns  # Add seaborn import for heatmap
from sklearn.manifold import MDS  # Replace PCA import with MDS
from sklearn.preprocessing import normalize  # Add normalize import
from sklearn.model_selection import train_test_split  # Add train_test_split import
import random  # Add import for random


# %%

# Load LLaMA 3 model using transformers pipeline
print("Loading LLaMA model using transformers pipeline...")
MODEL_NAME = "meta-llama/Llama-2-7b-hf"  # Replace with the actual model name on Hugging Face
device = torch.device("cuda" if torch.cuda.is_available() else "mps")
tokenizer = transformers.LlamaTokenizer.from_pretrained(MODEL_NAME, legacy=False)
tokenizer.pad_token = tokenizer.eos_token  # Set pad token
model = transformers.LlamaForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", offload_folder="offload")
print("LLaMA model loaded successfully.")

# Function to get embeddings of a list of texts
def get_embeddings(texts):
    tokenized_output = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    input_ids = tokenized_output["input_ids"]
    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states
    last_hidden_states = hidden_states[-1]
    return last_hidden_states.mean(dim=1)

# Load the animal-attribute pairs
animal_attribute_file = './data/animal-habit.txt'
animal_attribute_pairs = defaultdict(list)
all_animals = set()
all_attributes = set()

with open(animal_attribute_file, 'r') as file:
    for line in file:
        parts = line.strip().split(',')
        if len(parts) >= 2:
            animal, attribute = parts[:2]
            # print(animal, attribute)
            animal_attribute_pairs[attribute].append(animal)
            all_animals.add(animal)
            all_attributes.add(attribute)

# Split the animal-attribute pairs into training and testing sets
train_pairs = defaultdict(list)
test_pairs = defaultdict(list)

for attribute, animals in animal_attribute_pairs.items():
    train_animals, test_animals = train_test_split(animals, test_size=0.1, random_state=42)
    train_pairs[attribute].extend(train_animals)
    test_pairs[attribute].extend(test_animals)

# Get embeddings for all animals
animal_embeddings = {animal: get_embeddings([animal]) for animal in all_animals}

# Calculate attribute directions using training pairs
attribute_directions = {}
for attribute, animals_with_attribute in train_pairs.items():
    animals_without_attribute = all_animals - set(animals_with_attribute)
    if not animals_without_attribute:
        continue

    embeddings_with_attribute = torch.stack([animal_embeddings[animal] for animal in animals_with_attribute]).mean(dim=0)
    embeddings_without_attribute = torch.stack([animal_embeddings[animal] for animal in animals_without_attribute]).mean(dim=0)
    attribute_directions[attribute] = embeddings_with_attribute - embeddings_without_attribute
    

# %%
# Select top 20 most frequent attributes
top_attributes = sorted(attribute_directions.keys(), key=lambda attr: len(animal_attribute_pairs[attr]), reverse=True)[:20]
attribute_directions = {attr: attribute_directions[attr] for attr in top_attributes}

# Define the set of animals to use
specified_animals = {
    "Dog", "Cat", "Horse", "Rabbit", "Hamster", "Guinea Pig", "Parrot", "Goldfish", "Turtle", "Snake", 
    "Cow", "Pig", "Sheep", "Goat", "Chicken", "Duck", "Turkey", "Mouse", "Rat", "Dolphin", 
    "Whale", "Elephant", "Lion", "Tiger", "Bear", "Deer", "Wolf", "Peacock", "Shark", "Lizard", "Frog"
}

# Filter the specified animals that are present in the dataset
selected_animals = specified_animals.intersection(all_animals)

# Reduce dimensions using PCA
# reduction_model = PCA(n_components=2)
reduction_model = MDS(n_components=2, random_state=42)
directions_matrix = np.stack([direction.cpu().numpy() for direction in attribute_directions.values()]).reshape(len(attribute_directions), -1)
directions_reduced = reduction_model.fit_transform(directions_matrix)
attribute_directions_reduced = {attribute: directions_reduced[i] for i, attribute in enumerate(attribute_directions)}

# Order attribute directions by attribute names
attribute_directions_reduced = dict(sorted(attribute_directions_reduced.items()))

# Plot attribute directions with different colors and labels at the end
plt.figure(figsize=(16, 10))
colors = plt.colormaps.get_cmap('tab10')

for i, (attribute, direction) in enumerate(attribute_directions_reduced.items()):
    plt.arrow(0, 0, direction[0], direction[1], head_width=0.05, head_length=0.1, color=colors(i / len(attribute_directions_reduced)))
    plt.text(direction[0], direction[1], attribute, fontsize=16, color=colors(i / len(attribute_directions_reduced)))

# Adjust plot limits to ensure all labels are within the figure
plt.xlim(min(directions_reduced[:, 0]) - 1, max(directions_reduced[:, 0]) + 10)
plt.ylim(min(directions_reduced[:, 1]) - 1, max(directions_reduced[:, 1]) + 3)

plt.xlabel('MDS Component 1')
plt.ylabel('MDS Component 2')
plt.title('Attribute Directions (MDS Reduced)')
plt.show()


# %%


# Calculate and plot the correlation matrix using heatmap
correlation_matrix = np.corrcoef(directions_matrix)
plt.figure(figsize=(12, 10))
sns.heatmap(correlation_matrix, annot=False, xticklabels=attribute_directions.keys(), yticklabels=attribute_directions.keys(), cmap='coolwarm')
plt.title('Correlation Matrix of Attribute Directions')
plt.show()








# %%
# Calculate attribute directions by averaging the embeddings of all animals that have the attributes
attribute_directions_avg = {}
for attribute, animals in animal_attribute_pairs.items():
    if animals:
        animal_embeddings_list = [animal_embeddings[animal].squeeze(0).to(device) for animal in animals]
        animal_embeddings_tensor = torch.stack(animal_embeddings_list)
        attribute_directions_avg[attribute] = animal_embeddings_tensor.mean(dim=0)
    else:
        attribute_directions_avg[attribute] = torch.zeros(next(iter(animal_embeddings.values())).shape).to(device)

# Select top 20 most frequent attributes
top_attributes_avg = sorted(attribute_directions_avg.keys(), key=lambda attr: len(animal_attribute_pairs[attr]), reverse=True)[:20]
attribute_directions_avg = {attr: attribute_directions_avg[attr] for attr in top_attributes_avg}

# Reduce dimensions using MDS
reduction_model_avg = MDS(n_components=2, random_state=42)
directions_matrix_avg = np.stack([direction.cpu().numpy() for direction in attribute_directions_avg.values()]).reshape(len(attribute_directions_avg), -1)
directions_reduced_avg = reduction_model_avg.fit_transform(directions_matrix_avg)
attribute_directions_reduced_avg = {attribute: directions_reduced_avg[i] for i, attribute in enumerate(attribute_directions_avg)}

# Order attribute directions by attribute names
attribute_directions_reduced_avg = dict(sorted(attribute_directions_reduced_avg.items()))

# Plot attribute directions with different colors and labels at the end
plt.figure(figsize=(16, 10))
colors = plt.colormaps.get_cmap('tab10')

for i, (attribute, direction) in enumerate(attribute_directions_reduced_avg.items()):
    plt.arrow(0, 0, direction[0], direction[1], head_width=0.05, head_length=0.1, color=colors(i / len(attribute_directions_reduced_avg)))
    plt.text(direction[0], direction[1], attribute, fontsize=16, color=colors(i / len(attribute_directions_reduced_avg)))

# Adjust plot limits to ensure all labels are within the figure
plt.xlim(min(directions_reduced_avg[:, 0]) - 1, max(directions_reduced_avg[:, 0]) + 10)
plt.ylim(min(directions_reduced_avg[:, 1]) - 1, max(directions_reduced_avg[:, 1]) + 3)

plt.xlabel('MDS Component 1')
plt.ylabel('MDS Component 2')
plt.title('Attribute Directions by Averaging Animal Embeddings (MDS Reduced)')
plt.show()

# %%
