# Import necessary libraries
import pandas as pd
import matplotlib.pyplot as plt

# Load the embedding context
embedding_context = pd.read_csv('./datasets/animal/gemma7b/results/embedding_context.csv', sep='\t', index_col=0)
# Load attribute thresholds
attribute_thresholds = pd.read_pickle('./datasets/animal/gemma7b/embeddings/attribute_thresholds_train_lda.pkl')


# Define the concepts to visualize
concepts = ['bird', 'predator']

# Preprocess the embedding context by re-centering them to the threshold
embedding_context = embedding_context.subtract(attribute_thresholds, axis=1)

# Extract the embeddings for the specified concepts
embeddings = embedding_context.loc[concepts]



# Calculate the intersection of 'bird' and 'predator'
intersection_embedding = []
for dim in embeddings.columns:
    bird_value = embeddings.loc['bird', dim]
    predator_value = embeddings.loc['predator', dim]
    if (bird_value > 0 and predator_value < 0) or (bird_value < 0 and predator_value > 0):
        intersection_embedding.append((dim, (bird_value + predator_value) / 2))
    else:
        intersection_embedding.append((dim, 0))

# Separate dimensions and values for plotting
intersection_dims, intersection_values = zip(*intersection_embedding)

# Plot the embeddings
plt.figure(figsize=(12, 6))


# Plot the intersection
plt.plot(intersection_dims, intersection_values, label='intersection [bird, predator]', linestyle='--')

# Plot the embeddings for 'bird', 'predator', and 'eagle' only in the selected intersection dimensions
for concept in concepts:
    plt.plot(intersection_dims, embeddings.loc[concept, intersection_dims], label=concept)

# Add plot for 'eagle' only in the selected intersection dimensions
eagle_embedding = embedding_context.loc['eagle', intersection_dims]
plt.plot(intersection_dims, eagle_embedding, label='eagle', linestyle=':')

# Customize the plot
plt.xlabel('Dimension')
plt.ylabel('Projection Length')
plt.title('Concept Embeddings Visualization')
plt.legend()
plt.grid(True)
plt.xticks([])  # Remove x-axis ticks
plt.show()
