import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
import wandb

import pandas as pd

wandb.init(project="my_rnn", name="my_rnn_1")
torch.cuda.set_device(7)
# Specify the file paths for the dataset files
users_path = "../datasets/ml-1m/users.dat"
ratings_path = "../datasets/ml-1m/ratings.dat"
movies_path = "../datasets/ml-1m/movies.dat"


# Define column names for each dataset
users_cols = ['user_id', 'gender', 'age', 'occupation', 'zip_code']
ratings_cols = ['user_id', 'movie_id', 'rating', 'timestamp']
movies_cols = ['movie_id', 'title', 'genres']

# Load data into Pandas DataFrames
users_df = pd.read_csv(users_path, sep='::', header=None, names=users_cols, encoding='latin-1', engine='python')
ratings_df = pd.read_csv(ratings_path, sep='::', header=None, names=ratings_cols, encoding='latin-1', engine='python')
movies_df = pd.read_csv(movies_path, sep='::', header=None, names=movies_cols, encoding='latin-1', engine='python')

# Display the first few rows of each DataFrame
print("Users DataFrame:")
print(users_df.head())

print("\nRatings DataFrame:")
print(ratings_df.head())

print("\nMovies DataFrame:")
print(movies_df.head())

# Optionally, convert DataFrames to NumPy arrays/matrices
users_array = users_df.values
ratings_array = ratings_df.values
movies_array = movies_df.values

ratings_df = pd.merge(ratings_df, movies_df)[['user_id', 'title', 'rating', 'timestamp']]
ratings_df["user_id"] = ratings_df["user_id"].astype(str)
user_lookup = {v: i+1 for i, v in enumerate(ratings_df['user_id'].unique())}
movie_lookup = {v: i+1 for i, v in enumerate(ratings_df['title'].unique())}
ratings_df['movie_id'] = ratings_df['title'].map(movie_lookup)
ratings_df['user_int'] = ratings_df['user_id'].map(user_lookup)


def get_last_n_ratings_by_user(
    df, n, min_ratings_per_user=1, user_colname="user_id", timestamp_colname="timestamp"
):
    return (
        df.groupby(user_colname)
        .filter(lambda x: len(x) >= min_ratings_per_user)
        .sort_values(timestamp_colname)
        .groupby(user_colname)
        .tail(n)
        .sort_values(user_colname)
    )
    
def mark_last_n_ratings_as_validation_set(
    df, n, min_ratings=1, user_colname="user_id", timestamp_colname="timestamp"
):
    """
    Mark the chronologically last n ratings as the validation set.
    This is done by adding the additional 'is_valid' column to the df.
    :param df: a DataFrame containing user item ratings
    :param n: the number of ratings to include in the validation set
    :param min_ratings: only include users with more than this many ratings
    :param user_id_colname: the name of the column containing user ids
    :param timestamp_colname: the name of the column containing the imestamps
    :return: the same df with the additional 'is_valid' column added
    """
    df["is_valid"] = False
    df.loc[
        get_last_n_ratings_by_user(
            df,
            n,
            min_ratings,
            user_colname=user_colname,
            timestamp_colname=timestamp_colname,
        ).index,
        "is_valid",
    ] = True

    return df

mark_last_n_ratings_as_validation_set(ratings_df, 1)

train_df = ratings_df[ratings_df.is_valid==False]
valid_df = ratings_df[ratings_df.is_valid==True]

train_df=train_df[['user_int', 'movie_id', 'rating']]
valid_df=valid_df[['user_int', 'movie_id', 'rating']]



class YourRecommender(nn.Module):
    def __init__(self, num_users, num_items, embedding_dim, lambda_reg):
        super(YourRecommender, self).__init__()
        # Define your model components here, including RNN, embeddings, etc.
        # ...

        self.embedding_dim = embedding_dim

        # User embedding using RNN
        self.rnn = nn.GRU(input_size=num_items, hidden_size=embedding_dim, batch_first=True)

        # Item embedding
        self.item_embedding = nn.Embedding(num_items, embedding_dim)

        # Linear layer
        self.linear = nn.Linear(embedding_dim, 1)

        # Initialize user history vectors
        self.user_history = torch.zeros((num_users, num_items), dtype=torch.float32, requires_grad=False)

        # Regularization strength
        self.lambda_reg = lambda_reg

    def forward(self, user, item):
        #for user, item in zip(user_seq, item_seq):
        self.user_history[user-1] = self.user_history[user-1] + F.one_hot(item-1, num_classes=self.user_history.size(1)).float()

        # User embedding using RNN
        user_emb, _ = self.rnn(self.user_history[user-1])

        # Item embedding
        item_emb = self.item_embedding(item-1)

        # Linear layer
        output = torch.matmul(user_emb, item_emb.t())
        output = output.squeeze(0)

        return output

# Assuming train_data is a pandas DataFrame containing the training data
# Columns: ['userId', 'movieId', 'rating', 'timestamp']

# Constants
num_users = 6040  # Adjust according to your dataset
num_items = 3952  # Adjust according to your dataset
embedding_dim = 64  # Adjust as needed
lambda_reg = 0.001  # Adjust as needed

# Initialize your model
model = YourRecommender(num_users, num_items, embedding_dim, lambda_reg)

# Loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_losses=[]
    for _, row in tqdm(train_df.iterrows()):
        user_id = row['user_int']
        item_id = row['movie_id']
        rating = row['rating']

        # Convert data to tensors
        user_seq = torch.LongTensor([user_id])
        item_seq = torch.LongTensor([item_id])
        #item_seq_len = torch.LongTensor([1])  # Assuming it's just the current item

        # Forward pass
        predicted_rating = model(user_seq, item_seq)

        # Compute loss
        loss = criterion(predicted_rating, torch.FloatTensor([rating]))

        # Add regularization term
        reg_loss = 0.5 * model.lambda_reg * (torch.norm(model.rnn.weight_hh_l0) ** 2 +
                                             torch.norm(model.rnn.weight_ih_l0) ** 2)
        loss += reg_loss

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
    wandb.log({"Train Loss": sum(train_losses) / len(train_losses)})
        
    model.eval()
    val_losses = []
    with torch.no_grad():
        for _, val_row in valid_df.iterrows():
            val_user_id = val_row['user_int']
            val_item_id = val_row['movie_id']
            val_rating = val_row['rating']

            # Convert data to tensors
            val_user_seq = torch.LongTensor([val_user_id])
            val_item_seq = torch.LongTensor([val_item_id])
            #val_item_seq_len = torch.LongTensor([1])  # Assuming it's just the current item

            # Forward pass
            val_predicted_rating = model(val_user_seq, val_item_seq)

            # Compute validation loss
            val_loss = criterion(val_predicted_rating, torch.FloatTensor([val_rating]))
            val_losses.append(val_loss.item())
    wandb.log({"Validation Loss": sum(val_losses) / len(val_losses)})

    # Log train and validation loss
    print(f'Epoch {epoch + 1}/{num_epochs} - '
          f'Train Loss: {sum(train_losses) / len(train_losses):.4f}, '
          f'Validation Loss: {sum(val_losses) / len(val_losses):.4f}')


# Remember to save your model after training
torch.save(model.state_dict(), 'your_model.pth')
