import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer
import pandas as pd
from sklearn.model_selection import train_test_split
import os
from utils import call_LLM, get_model


os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def get_embeddings(text_batch):
    """Extract embeddings for a batch of text inputs."""
    tokenizer.pad_token = tokenizer.eos_token
    tokenized_output = tokenizer(text_batch, return_tensors="pt", padding=True, truncation=True).to(device)
    input_ids = tokenized_output["input_ids"]

    with torch.no_grad():
        outputs = model(input_ids, output_hidden_states=True)
    hidden_states = outputs.hidden_states

    seq_lengths = tokenized_output.attention_mask.sum(dim=1).tolist()
    last_token_positions = [length - 1 for length in seq_lengths]
    text_embeddings = torch.stack([hidden_states[-1][i, pos, :] for i, pos in enumerate(last_token_positions)])

    return text_embeddings


# Load the LLaMA-3 model and tokenizer
llama_model_name = "Llama-2-7b"  # Replace with actual model name
access_token = ""
model, tokenizer = get_model(llama_model_name, access_token, output_hidden_states=True)

# Move the model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Freeze the LLaMA-3 model parameters
for param in model.parameters():
    param.requires_grad = False

# Load the Excel file
csv_file = "FoodDatabase.xlsx"  # Replace with your Excel file path
df = pd.read_excel(csv_file)
entities = df.iloc[1:, 0]  # First column: Food names (ignore first row)
sweetness_values = pd.to_numeric(df.iloc[1:, 5], errors='coerce')  # Sixth column: Sweetness values

# Remove NaN sweetness values
valid_data = sweetness_values.notna()
entities = entities[valid_data]
sweetness_values = sweetness_values[valid_data]

# Split data into training and testing sets
train_entities, test_entities, train_sweetness, test_sweetness = train_test_split(
    entities, sweetness_values, test_size=0.2, random_state=42
)

# Define sweetness direction vector
sweetness_direction = torch.randn(model.config.hidden_size, device=device, requires_grad=True)
sweetness_direction = nn.Parameter(sweetness_direction)

# Optimizer for the direction vector
vector_optimizer = torch.optim.Adam([sweetness_direction], lr=1e-3)

# Training loop
for epoch in range(20):  # Adjust the number of epochs as needed
    total_loss = 0
    for _ in range(len(train_entities)):
        idx1, idx2 = torch.randint(0, len(train_entities), (2,)).tolist()
        entity1, entity2 = train_entities.iloc[idx1], train_entities.iloc[idx2]
        sweetness1, sweetness2 = train_sweetness.iloc[idx1], train_sweetness.iloc[idx2]

        # Generate prompts
        prompts = [f"{entity1} has sweetness {{}}", f"{entity2} has sweetness {{}}"]
        embeddings = get_embeddings(prompts)
        embedding1, embedding2 = embeddings[0], embeddings[1]

        # Project embeddings onto the sweetness direction
        projection1 = torch.dot(embedding1, sweetness_direction) / torch.norm(sweetness_direction)
        projection2 = torch.dot(embedding2, sweetness_direction) / torch.norm(sweetness_direction)

        # Define the margin for ranking loss
        margin = 0.1
        if sweetness1 < sweetness2:
            loss = torch.relu(margin - (projection2 - projection1))
        else:
            loss = torch.relu(margin - (projection1 - projection2))

        # Update weights
        vector_optimizer.zero_grad()
        loss.backward()
        vector_optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch + 1}/10, Loss: {total_loss / len(train_entities):.4f}")

# Save the trained direction vector
torch.save(sweetness_direction, "sweetness_direction.pth")

# Evaluation
results = []
correct_order = 0
test_comparisons = 0
with torch.no_grad():
    for _ in range(len(test_entities)):
        idx1, idx2 = torch.randint(0, len(test_entities), (2,)).tolist()
        entity1, entity2 = test_entities.iloc[idx1], test_entities.iloc[idx2]
        sweetness1, sweetness2 = test_sweetness.iloc[idx1], test_sweetness.iloc[idx2]

        # Generate prompts
        prompts = [f"{entity1} has sweetness {{}}", f"{entity2} has sweetness {{}}"]
        embeddings = get_embeddings(prompts)
        embedding1, embedding2 = embeddings[0], embeddings[1]

        # Project embeddings
        projection1 = torch.dot(embedding1, sweetness_direction) / torch.norm(sweetness_direction)
        projection2 = torch.dot(embedding2, sweetness_direction) / torch.norm(sweetness_direction)

        # Check consistency
        if (sweetness1 < sweetness2 and projection1 < projection2) or (sweetness1 > sweetness2 and projection1 > projection2):
            correct_order += 1
        test_comparisons += 1

        results.append((entity1, sweetness1, projection1.item()))
        results.append((entity2, sweetness2, projection2.item()))

accuracy = correct_order / test_comparisons if test_comparisons > 0 else 0
print(f"Test Accuracy: {accuracy:.2f}")

# Save results to Excel
output_df = pd.DataFrame(results, columns=["Entity", "TrueSweetness", "ProjectedSweetness"])
output_df.to_excel("results.xlsx", index=False)
print("Processing complete. Results saved to 'results.xlsx'.")
