import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
import pandas as pd
import wandb

wandb.init(project="recurrent_recommender", name="lstm_run_1")
torch.cuda.set_device(7)
users_path = "../datasets/ml-1m/users.dat"
ratings_path = "../datasets/ml-1m/ratings.dat"
movies_path = "../datasets/ml-1m/movies.dat"


# Define column names for each dataset
users_cols = ['user_id', 'gender', 'age', 'occupation', 'zip_code']
ratings_cols = ['user_id', 'movie_id', 'rating', 'timestamp']
movies_cols = ['movie_id', 'title', 'genres']

# Load data into Pandas DataFrames
users_df = pd.read_csv(users_path, sep='::', header=None, names=users_cols, encoding='latin-1', engine='python')
ratings_df = pd.read_csv(ratings_path, sep='::', header=None, names=ratings_cols, encoding='latin-1', engine='python')
movies_df = pd.read_csv(movies_path, sep='::', header=None, names=movies_cols, encoding='latin-1', engine='python')


# Optionally, convert DataFrames to NumPy arrays/matrices
users_array = users_df.values
ratings_array = ratings_df.values
movies_array = movies_df.values

ratings_df = pd.merge(ratings_df, movies_df)[['user_id', 'title', 'rating', 'timestamp']]
ratings_df["user_id"] = ratings_df["user_id"].astype(str)
user_lookup = {v: i for i, v in enumerate(ratings_df['user_id'].unique())}
movie_lookup = {v: i for i, v in enumerate(ratings_df['title'].unique())}
ratings_df['movie_id'] = ratings_df['title'].map(movie_lookup)
ratings_df['user_int'] = ratings_df['user_id'].map(user_lookup)

ratings_df = ratings_df.sort_values('timestamp')
ratings_df = ratings_df[['user_int','movie_id','rating']]

user_lstm = torch.nn.LSTM(2, 50).cuda()
item_lstm = torch.nn.LSTM(2, 50).cuda()
loss_fn = torch.nn.MSELoss().cuda()
optimizer = torch.optim.Adam(list(user_lstm.parameters()) + list(item_lstm.parameters()))

for epoch in range(1,2):
    user_history_dict = {}
    item_history_dict = {}
    counter_var = -1
    running_loss = 0
    last_loss = 0
    last_20_loss = 0
    for _, row in ratings_df.iterrows():
        counter_var+=1
        user_id = row['user_int']
        movie_id = row['movie_id']
        rating = torch.tensor([row['rating']], dtype=torch.float32).cuda()

        if user_id not in user_history_dict:
            user_history_dict[user_id] = []

        if movie_id not in item_history_dict:
            item_history_dict[movie_id] = []

        if not user_history_dict[user_id] or not item_history_dict[movie_id]:
            user_history_dict[user_id].append(torch.tensor([movie_id, rating.item()]).cuda())
            item_history_dict[movie_id].append(torch.tensor([user_id, rating.item()]).cuda())
            continue

        user_history = torch.stack(user_history_dict[user_id]).view(len(user_history_dict[user_id]), 1, -1).float().cuda()
        item_history = torch.stack(item_history_dict[movie_id]).view(len(item_history_dict[movie_id]), 1, -1).float().cuda()
        optimizer.zero_grad()

        _, (user_vector, _) = user_lstm(user_history)
        _, (item_vector, _) = item_lstm(item_history)

        user_vector = user_vector[-1]
        item_vector = item_vector[-1]
        
        pred_rating = torch.dot(user_vector.view(-1), item_vector.view(-1))
        loss = loss_fn(pred_rating.view(1), rating)
        running_loss+=loss
        # loss.backward()
        # optimizer.step()
        if counter_var>300000:
            break

        if counter_var%2000 == 0:
            # print(f"Loss: {running_loss.item():.4f}")
            running_loss = running_loss/2000
            wandb.log({"train_loss": running_loss.item()})
            running_loss.backward()
            optimizer.step()
            running_loss = 0

        user_history_dict[user_id].append(torch.tensor([movie_id, rating.item()]).cuda())
        item_history_dict[movie_id].append(torch.tensor([user_id, rating.item()]).cuda())
        
torch.save(user_lstm.state_dict(), 'user_model_small.pth')
torch.save(item_lstm.state_dict(), 'item_model_small.pth')