import pandas as pd
import torch
import wandb
from tqdm import tqdm
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adam
from sklearn.model_selection import train_test_split
import numpy as np


wandb.init(project="matrix_factorization", name="mf_run_50_epochs")

torch.cuda.set_device(7)
# Specify the file paths for the dataset files
users_path = "../datasets/ml-1m/users.dat"
ratings_path = "../datasets/ml-1m/ratings.dat"
movies_path = "../datasets/ml-1m/movies.dat"

print(torch.cuda.current_device())
# Define column names for each dataset
users_cols = ['user_id', 'gender', 'age', 'occupation', 'zip_code']
ratings_cols = ['user_id', 'movie_id', 'rating', 'timestamp']
movies_cols = ['movie_id', 'title', 'genres']

# Load data into Pandas DataFrames
users_df = pd.read_csv(users_path, sep='::', header=None, names=users_cols, encoding='latin-1', engine='python')
ratings_df = pd.read_csv(ratings_path, sep='::', header=None, names=ratings_cols, encoding='latin-1', engine='python')
movies_df = pd.read_csv(movies_path, sep='::', header=None, names=movies_cols, encoding='latin-1', engine='python')

# Display the first few rows of each DataFrame
# print("Users DataFrame:")
# print(users_df.head())

# print("\nRatings DataFrame:")
# print(ratings_df.head())

# print("\nMovies DataFrame:")
# print(movies_df.head())

# Optionally, convert DataFrames to NumPy arrays/matrices
users_array = users_df.values
ratings_array = ratings_df.values
movies_array = movies_df.values
ratings_df = pd.merge(ratings_df, movies_df)[['user_id', 'title', 'rating', 'timestamp']]
ratings_df["user_id"] = ratings_df["user_id"].astype(str)

ratings_per_user = ratings_df.groupby('user_id').rating.count()
ratings_per_item = ratings_df.groupby('title').rating.count()

user_lookup = {v: i+1 for i, v in enumerate(ratings_df['user_id'].unique())}
movie_lookup = {v: i+1 for i, v in enumerate(ratings_df['title'].unique())}

ratings_df['movie_id'] = ratings_df['title'].map(movie_lookup)
ratings_df['user_int'] = ratings_df['user_id'].map(user_lookup)

user_item_rating_tuples = ratings_df[['user_int', 'movie_id', 'rating']].values.tolist()

import torch
from torch import nn

class MfDotBias(nn.Module):

    def __init__(
        self, n_factors, n_users, n_items, ratings_range=None, use_biases=False
    ):
        super().__init__()
        self.bias = use_biases
        self.y_range = ratings_range
        self.user_embedding = nn.Embedding(n_users+1, n_factors, padding_idx=0)
        self.item_embedding = nn.Embedding(n_items+1, n_factors, padding_idx=0)

        if use_biases:
            self.user_bias = nn.Embedding(n_users+1, 1, padding_idx=0)
            self.item_bias = nn.Embedding(n_items+1, 1, padding_idx=0)

    def forward(self, inputs):
        users, items = inputs
        dot = self.user_embedding(users) * self.item_embedding(items)
        result = dot.sum(1)
        if self.bias:
            result = (
                result + self.user_bias(users).squeeze() + self.item_bias(items).squeeze()
            )

        if self.y_range is None:
            return result
        else:
            return (
                torch.sigmoid(result) * (self.y_range[1] - self.y_range[0])
                + self.y_range[0]
            )
        
# Split the data into training and validation sets
train_data, valid_data = train_test_split(user_item_rating_tuples, test_size=0.2, random_state=42)

# Convert the lists to PyTorch tensors
train_data = torch.tensor(train_data, dtype=torch.long)
valid_data = torch.tensor(valid_data, dtype=torch.long)

# Split the data into features (users, items) and targets (ratings)
train_users, train_items, train_ratings = train_data[:, 0], train_data[:, 1], train_data[:, 2]
valid_users, valid_items, valid_ratings = valid_data[:, 0], valid_data[:, 1], valid_data[:, 2]

# Set random seed for reproducibility
torch.manual_seed(42)

# Instantiate the matrix factorization model
n_factors = 100  # Adjust as needed
n_users = train_users.max().item() + 1
print("Number of users:", n_users)
n_items = train_items.max().item() + 1
print("Number of items:", n_items)

model = MfDotBias(n_factors=n_factors, n_users=n_users, n_items=n_items)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = Adam(model.parameters(), lr=0.001)

# Convert datasets to DataLoader for batch training
batch_size = 64  # Adjust as needed
train_dataset = TensorDataset(train_users, train_items, train_ratings)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print("Shape of dataset: ", len(train_dataset))
print("Length of train loader: ", len(train_loader))
# Training loop
n_epochs = 50  # Adjust as needed

best_valid_loss = float('inf')
patience = 3  # Adjust as needed
counter = 0

wandb.watch(model)
wandb.config.update({
    "n_factors": n_factors,
    "batch_size": batch_size,
    "learning_rate": 0.001,
    "n_epochs": n_epochs,
    "patience": patience
})


for epoch in range(n_epochs):
    model.train()
    for batch_users, batch_items, batch_ratings in tqdm(train_loader):
        optimizer.zero_grad()
        predictions = model((batch_users, batch_items))
        loss = criterion(predictions, batch_ratings.float())
        loss.backward()
        optimizer.step()

    #Validation
    model.eval()
    with torch.no_grad():
        valid_predictions = model((valid_users, valid_items))
        valid_loss = criterion(valid_predictions, valid_ratings.float())

    wandb.log({"Training Loss": loss.item(), "Validation Loss": valid_loss.item()})
    print(f'Epoch {epoch + 1}/{n_epochs}, Loss: {loss.item()}, Validation Loss: {valid_loss.item()}')

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        counter = 0  # Reset counter if validation loss improves
    else:
        counter += 1

    if counter >= patience:
        print(f'Validation loss did not improve for {patience} epochs. Stopping training.')
        break

# Optionally, save the trained model
torch.save(model.state_dict(), 'model_50_epochs.pth')
wandb.finish()

