import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
import pandas as pd
import wandb

wandb.init(project="recurrent_recommender", name="lstm_run_valid")
torch.cuda.set_device(0)
users_path = "../datasets/ml-1m/users.dat"
ratings_path = "../datasets/ml-1m/ratings.dat"
movies_path = "../datasets/ml-1m/movies.dat"


# Define column names for each dataset
users_cols = ['user_id', 'gender', 'age', 'occupation', 'zip_code']
ratings_cols = ['user_id', 'movie_id', 'rating', 'timestamp']
movies_cols = ['movie_id', 'title', 'genres']

# Load data into Pandas DataFrames
users_df = pd.read_csv(users_path, sep='::', header=None, names=users_cols, encoding='latin-1', engine='python')
ratings_df = pd.read_csv(ratings_path, sep='::', header=None, names=ratings_cols, encoding='latin-1', engine='python')
movies_df = pd.read_csv(movies_path, sep='::', header=None, names=movies_cols, encoding='latin-1', engine='python')


# Optionally, convert DataFrames to NumPy arrays/matrices
users_array = users_df.values
ratings_array = ratings_df.values
movies_array = movies_df.values

ratings_df = pd.merge(ratings_df, movies_df)[['user_id', 'title', 'rating', 'timestamp']]
ratings_df["user_id"] = ratings_df["user_id"].astype(str)
user_lookup = {v: i+1 for i, v in enumerate(ratings_df['user_id'].unique())}
movie_lookup = {v: i+1 for i, v in enumerate(ratings_df['title'].unique())}
ratings_df['movie_id'] = ratings_df['title'].map(movie_lookup)
ratings_df['user_int'] = ratings_df['user_id'].map(user_lookup)

ratings_df = ratings_df.sort_values('timestamp')

def get_last_n_ratings_by_user(
    df, n, min_ratings_per_user=1, user_colname="user_int", timestamp_colname="timestamp"
):
    return (
        df.groupby(user_colname)
        .filter(lambda x: len(x) >= min_ratings_per_user)
        .sort_values(timestamp_colname)
        .groupby(user_colname)
        .tail(n)
        .sort_values(user_colname)
    )

def mark_last_n_ratings_as_validation_set(
    df, n, min_ratings=1, user_colname="user_int", timestamp_colname="timestamp"
):
    """
    Mark the chronologically last n ratings as the validation set.
    This is done by adding the additional 'is_valid' column to the df.
    :param df: a DataFrame containing user item ratings
    :param n: the number of ratings to include in the validation set
    :param min_ratings: only include users with more than this many ratings
    :param user_id_colname: the name of the column containing user ids
    :param timestamp_colname: the name of the column containing the imestamps
    :return: the same df with the additional 'is_valid' column added
    """
    df["is_valid"] = False
    df.loc[
        get_last_n_ratings_by_user(
            df,
            n,
            min_ratings,
            user_colname=user_colname,
            timestamp_colname=timestamp_colname,
        ).index,
        "is_valid",
    ] = True

    return df

ratings_df = mark_last_n_ratings_as_validation_set(ratings_df, 1)
train_df = ratings_df[ratings_df.is_valid==False]
valid_df = ratings_df[ratings_df.is_valid==True]
train_df = train_df[['user_int','movie_id','rating']]
valid_df = valid_df[['user_int','movie_id','rating']]

user_lstm = torch.nn.LSTM(2, 100).cuda()
item_lstm = torch.nn.LSTM(2, 100).cuda()
loss_fn = torch.nn.MSELoss().cuda()
optimizer = torch.optim.Adam(list(user_lstm.parameters()) + list(item_lstm.parameters()))

user_history_dict_val = {}
item_history_dict_val = {}
for _, row in train_df.iterrows():
    user_id = row['user_int']
    movie_id = row['movie_id']
    rating = torch.tensor([row['rating']], dtype=torch.float32)

    if user_id not in user_history_dict_val:
        user_history_dict_val[user_id] = []

    if movie_id not in item_history_dict_val:
        item_history_dict_val[movie_id] = []

    if not user_history_dict_val[user_id] or not item_history_dict_val[movie_id]:
        user_history_dict_val[user_id].append(torch.tensor([movie_id, rating.item()]))
        item_history_dict_val[movie_id].append(torch.tensor([user_id, rating.item()]))
        continue


for epoch in range(1,2):
    user_history_dict = {}
    item_history_dict = {}
    counter_var = -1
    running_loss = 0
    for _, row in train_df.iterrows():
        counter_var+=1
        user_id = row['user_int']
        movie_id = row['movie_id']
        rating = torch.tensor([row['rating']], dtype=torch.float32).cuda()

        if user_id not in user_history_dict:
            user_history_dict[user_id] = []

        if movie_id not in item_history_dict:
            item_history_dict[movie_id] = []

        if not user_history_dict[user_id] or not item_history_dict[movie_id]:
            user_history_dict[user_id].append(torch.tensor([movie_id, rating.item()]).cuda())
            item_history_dict[movie_id].append(torch.tensor([user_id, rating.item()]).cuda())
            continue

        user_history = torch.stack(user_history_dict[user_id]).view(len(user_history_dict[user_id]), 1, -1).float().cuda()
        item_history = torch.stack(item_history_dict[movie_id]).view(len(item_history_dict[movie_id]), 1, -1).float().cuda()

        optimizer.zero_grad()

        _, (user_vector, _) = user_lstm(user_history)
        _, (item_vector, _) = item_lstm(item_history)

        user_vector = user_vector[-1]
        item_vector = item_vector[-1]
        
        pred_rating = torch.dot(user_vector.view(-1), item_vector.view(-1))
        loss = loss_fn(pred_rating.view(1), rating)
        running_loss+=loss
        # loss.backward()
        # optimizer.step()
        # if counter_var > 70000:
        #     break

        if counter_var%2000 == 0:
            # print(f"Loss: {running_loss.item():.4f}")
            running_loss = running_loss/2000
            wandb.log({"train_loss": running_loss.item()})
            running_loss.backward()
            optimizer.step()
            running_loss = 0
        if counter_var % 10000 == 0:
            valid_loss = 0
            for _, loop_var in valid_df.iterrows():
                user_id1 = loop_var['user_int']
                movie_id1 = loop_var['movie_id']
                if user_id1 not in user_history_dict_val or movie_id1 not in item_history_dict_val:
                    continue
                rating1 = torch.tensor([loop_var['rating']], dtype=torch.float32).cuda()
                user_history1 = torch.stack(user_history_dict_val[user_id1]).view(len(user_history_dict_val[user_id1]), 1, -1).float().cuda()
                item_history1 = torch.stack(item_history_dict_val[movie_id1]).view(len(item_history_dict_val[movie_id1]), 1, -1).float().cuda()
                _, (user_vector1, _) = user_lstm(user_history1)
                _, (item_vector1, _) = item_lstm(item_history1)
                user_vector1 = user_vector1[-1]
                item_vector1 = item_vector1[-1]
                pred_rating1 = torch.dot(user_vector1.view(-1), item_vector1.view(-1))
                valid_loss += loss_fn(pred_rating1.view(1), rating1)
            wandb.log({"valid_loss": valid_loss.item()})

        user_history_dict[user_id].append(torch.tensor([movie_id, rating.item()]).cuda())
        item_history_dict[movie_id].append(torch.tensor([user_id, rating.item()]).cuda())
        
torch.save(user_lstm.state_dict(), 'user_model_val.pth')
torch.save(item_lstm.state_dict(), 'item_model_val.pth')