import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import numpy as np
import fire
from utils import CosineSimilarityNet


def main(data_path='./generate_embd/qwen2.5-1.5b_emb_res.pt', 
         epochs=100, batch_size=256, learning_rate=1e-3, test_split=0.2):

    data = torch.load(data_path)
    model_name = data_path.split('/')[-1].split('.pt')[0]
    input_data = []
    labels = []
    for item in data:
        input_data.append(item['tensor'])
        labels.append(item['label'])
    input_data = [torch.tensor(item, dtype=torch.float32) if isinstance(item, np.ndarray) else item for item in input_data]
    input_data = torch.stack(input_data)
    labels = torch.tensor(labels, dtype=torch.long)

    dataset = TensorDataset(input_data, labels)
    test_size = int(len(dataset) * test_split)
    train_size = len(dataset) - test_size
    train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    model = CosineSimilarityNet(input_dim=input_data.shape[1])
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Train the model
    for epoch in range(epochs):
        model.train()
        total_loss = 0.0

        for batch_data, batch_labels in train_loader:
            optimizer.zero_grad()

            # Forward pass
            outputs = model(batch_data)

            # Compute loss
            loss = criterion(outputs, batch_labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch {epoch + 1}/{epochs}, Training Loss: {total_loss:.4f}")

        # Evaluate the model on the test set
        model.eval()
        correct = 0
        total = 0

        with torch.no_grad():
            for batch_data, batch_labels in test_loader:
                outputs = model(batch_data)
                _, predicted = torch.max(outputs, dim=1)
                total += batch_labels.size(0)
                correct += (predicted == batch_labels).sum().item()

        accuracy = correct / total
        print(f"Test Accuracy: {accuracy * 100:.2f}%")
    
    torch.save(model.state_dict(), './'+model_name+'_cosine_similarity_net.pt')
    print("Model saved to cosine_similarity_net.pt")

if __name__ == "__main__":
    fire.Fire(main)

