import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score
import random
import pickle
import os
import argparse

# Fix random seed for reproducibility
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)

# Parse command line arguments
parser = argparse.ArgumentParser(description='Process model key and dataset.')
parser.add_argument('--model_key', type=str, default='llama3-8b', help='Model key to use for loading models and directories')
parser.add_argument('--dataset', type=str, default='animal', help='Dataset to use for loading data and directories')
args = parser.parse_args()

# Load hypernyms file based on dataset
hypernyms_file = f'./datasets/{args.dataset}/hypernyms.txt'

# Load embeddings from files based on model key and dataset
object_embeddings_file = f'./datasets/{args.dataset}/{args.model_key}/embeddings/object_embeddings.pkl'

# Load object hypernyms
hypernyms = []
with open(hypernyms_file, 'r') as f:
    for line in f:
        parts = line.strip().split(',')
        if len(parts) == 2:
            hypernyms.append((parts[0].strip(), parts[1].strip()))

# Load embeddings from files
with open(object_embeddings_file, 'rb') as f:
    object_embeddings = pickle.load(f)

# Filter out hypernyms pairs with objects not in embeddings
embedded_objects = set(object_embeddings.keys())
filtered_hypernyms = [(a, b) for a, b in hypernyms if a in embedded_objects and b in embedded_objects]

# Print the number of pairs used
print(f"Total pairs after filtering: {len(filtered_hypernyms)}")

# Split into train and test sets
train_hypernyms, test_hypernyms = train_test_split(filtered_hypernyms, test_size=0.2, random_state=42)

# Print the number of pairs used
print(f"Training pairs: {len(train_hypernyms)}")
print(f"Testing pairs: {len(test_hypernyms)}")

# Double the samples by adding negative pairs
def augment_with_negatives(hypernyms, objects):
    augmented = []
    for a, b in hypernyms:
        augmented.append((a, b, 1))  # Positive pair
        if random.random() > 0.5:
            a = random.choice(objects)
        else:
            b = random.choice(objects)
        augmented.append((a, b, 0))  # Negative pair
    return augmented

objects = list(set([a for a, b in filtered_hypernyms] + [b for a, b in filtered_hypernyms]))
train_data = augment_with_negatives(train_hypernyms, objects)
test_data = augment_with_negatives(test_hypernyms, objects)

# Define the MLP model
class MLP(nn.Module):
    def __init__(self, input_size):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(input_size * 2, 1000)
        self.output = nn.Linear(1000, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.hidden(x))
        x = self.sigmoid(self.output(x))
        return x

# Prepare feature vectors
def get_feature_vector(object_a, object_b, object_embeddings):
    return np.concatenate([object_embeddings[object_a], object_embeddings[object_b]])

# Initialize model
model = MLP(len(next(iter(object_embeddings.values()))))
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
for epoch in range(10):
    total_loss = 0
    for a, b, label in train_data:
        optimizer.zero_grad()
        features = torch.tensor(get_feature_vector(a, b, object_embeddings), dtype=torch.float32)
        output = model(features)
        loss = criterion(output, torch.tensor([label], dtype=torch.float32))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        print(f"Epoch {epoch+1}, Sample ({a}, {b}), Label: {label}, Output: {output.item()}, Loss: {loss.item()}")
    print(f"Epoch {epoch+1}, Total Loss: {total_loss/len(train_data)}")

# Save model
model_path = f'./datasets/{args.dataset}/{args.model_key}/models/llm_baseline_model_0.8.pth'
torch.save(model.state_dict(), model_path)
print("Model saved.")

# Evaluation function
def evaluate(data):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for a, b, label in data:
            features = torch.tensor(get_feature_vector(a, b, object_embeddings), dtype=torch.float32)
            output = model(features)
            pred = (output.item() > 0.5)
            all_preds.append(pred)
            all_labels.append(label)
            print(f"Test Sample ({a}, {b}), Label: {label}, Prediction: {pred}, Output: {output.item()}")
    precision = precision_score(all_labels, all_preds)
    recall = recall_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    return f1, precision, recall

# Evaluate the model
f1, precision, recall = evaluate(test_data)
print(f'Test Precision: {precision:.3f}')
print(f'Test Recall: {recall:.3f}')
print(f'Test F1 Score: {f1:.3f}')
