import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer, pipeline
import pandas as pd
from sklearn.model_selection import train_test_split
import os
from utils import call_LLM, get_model
import logging

os.environ["CUDA_VISIBLE_DEVICES"] = "0"

llama_model_name = "meta-llama/Meta-Llama-3.1-8B-Instruct"  # Define the model name

# Load the LLaMA-3 model using transformers pipeline
logging.info("Loading LLaMA model...")
llama_generator = pipeline(
    "text-generation",
    model=llama_model_name,
    device_map="auto",
    pad_token_id=50256  # Set pad_token_id explicitly (replace with the correct token ID if different)
)
logging.info("LLaMA model loaded successfully.")

def get_embeddings(text_batch):
    """Extract embeddings for a batch of text inputs."""
    # Generate embeddings using the LLaMA model
    embeddings = []
    for text in text_batch:
        result = llama_generator(text, max_new_tokens=1, return_full_text=False)
        embedding = result[0]["generated_text"].strip()
        embeddings.append(embedding)
    return torch.tensor(embeddings, dtype=torch.float32).to(device)

# Load the dataset file
csv_file = "./datasets/Taste/food_taste.txt"  # Update with the correct file path
df = pd.read_csv(csv_file, delimiter='\t')  # Assuming the file is tab-delimited
entities = df.iloc[1:, 0]  # First column: Food names (ignore first row)
sweetness_values = pd.to_numeric(df.iloc[1:, 5], errors='coerce')  # Sixth column: Sweetness values

# Remove NaN sweetness values
valid_data = sweetness_values.notna()
entities = entities[valid_data]
sweetness_values = sweetness_values[valid_data]

# Split data into training and testing sets
train_entities, test_entities, train_sweetness, test_sweetness = train_test_split(
    entities, sweetness_values, test_size=0.2, random_state=42
)

# Move the model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the hidden size for the sweetness direction vector
hidden_size = 4096  # Replace with the correct hidden size for the LLaMA model

# Define sweetness direction vector
sweetness_direction = torch.randn(hidden_size, device=device, requires_grad=True)
sweetness_direction = nn.Parameter(sweetness_direction)

# Optimizer for the direction vector
vector_optimizer = torch.optim.Adam([sweetness_direction], lr=1e-3)

# Training loop
for epoch in range(20):  # Adjust the number of epochs as needed
    total_loss = 0
    for _ in range(len(train_entities)):
        idx1, idx2 = torch.randint(0, len(train_entities), (2,)).tolist()
        entity1, entity2 = train_entities.iloc[idx1], train_entities.iloc[idx2]
        sweetness1, sweetness2 = train_sweetness.iloc[idx1], train_sweetness.iloc[idx2]

        # Generate prompts
        prompts = [f"{entity1}", f"{entity2}"]
        embeddings = get_embeddings(prompts)
        embedding1, embedding2 = embeddings[0], embeddings[1]

        # Project embeddings onto the sweetness direction
        projection1 = torch.dot(embedding1, sweetness_direction) / torch.norm(sweetness_direction)
        projection2 = torch.dot(embedding2, sweetness_direction) / torch.norm(sweetness_direction)

        # Define the margin for ranking loss
        margin = 0.1
        if sweetness1 < sweetness2:
            loss = torch.relu(margin - (projection2 - projection1))
        else:
            loss = torch.relu(margin - (projection1 - projection2))

        # Update weights
        vector_optimizer.zero_grad()
        loss.backward()
        vector_optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/10, Loss: {total_loss / len(train_entities):.4f}")

# Save the trained direction vector
torch.save(sweetness_direction, "sweetness_direction.pth")

# Evaluation
results = []
correct_order = 0
test_comparisons = 0
with torch.no_grad():
    for _ in range(len(test_entities)):
        idx1, idx2 = torch.randint(0, len(test_entities), (2,)).tolist()
        entity1, entity2 = test_entities.iloc[idx1], test_entities.iloc[idx2]
        sweetness1, sweetness2 = test_sweetness.iloc[idx1], test_sweetness.iloc[idx2]

        # Generate prompts
        prompts = [f"{entity1} has sweetness {{}}", f"{entity2} has sweetness {{}}"]
        embeddings = get_embeddings(prompts)
        embedding1, embedding2 = embeddings[0], embeddings[1]

        # Project embeddings
        projection1 = torch.dot(embedding1, sweetness_direction) / torch.norm(sweetness_direction)
        projection2 = torch.dot(embedding2, sweetness_direction) / torch.norm(sweetness_direction)

        # Check consistency
        if (sweetness1 < sweetness2 and projection1 < projection2) or (sweetness1 > sweetness2 and projection1 > projection2):
            correct_order += 1
        test_comparisons += 1

        results.append((entity1, sweetness1, projection1.item()))
        results.append((entity2, sweetness2, projection2.item()))

accuracy = correct_order / test_comparisons if test_comparisons > 0 else 0
print(f"Test Accuracy: {accuracy:.2f}")

# Save results to Excel
output_df = pd.DataFrame(results, columns=["Entity", "TrueSweetness", "ProjectedSweetness"])
output_df.to_excel("results.xlsx", index=False)
print("Processing complete. Results saved to 'results.xlsx'.")
