import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sys
import pandas as pd
import wandb



users_path = "../datasets/ml-1m/users.dat"
ratings_path = "../datasets/ml-1m/ratings.dat"
movies_path = "../datasets/ml-1m/movies.dat"


# Define column names for each dataset
users_cols = ['user_id', 'gender', 'age', 'occupation', 'zip_code']
ratings_cols = ['user_id', 'movie_id', 'rating', 'timestamp']
movies_cols = ['movie_id', 'title', 'genres']

# Load data into Pandas DataFrames
users_df = pd.read_csv(users_path, sep='::', header=None, names=users_cols, encoding='latin-1', engine='python')
ratings_df = pd.read_csv(ratings_path, sep='::', header=None, names=ratings_cols, encoding='latin-1', engine='python')
movies_df = pd.read_csv(movies_path, sep='::', header=None, names=movies_cols, encoding='latin-1', engine='python')

# Display the first few rows of each DataFrame
print("Users DataFrame:")
print(users_df.head())

print("\nRatings DataFrame:")
print(ratings_df.head())

print("\nMovies DataFrame:")
print(movies_df.head())

# Optionally, convert DataFrames to NumPy arrays/matrices
users_array = users_df.values
ratings_array = ratings_df.values
movies_array = movies_df.values

ratings_df = pd.merge(ratings_df, movies_df)[['user_id', 'title', 'rating', 'timestamp']]
ratings_df["user_id"] = ratings_df["user_id"].astype(str)
user_lookup = {v: i+1 for i, v in enumerate(ratings_df['user_id'].unique())}
movie_lookup = {v: i+1 for i, v in enumerate(ratings_df['title'].unique())}
ratings_df['movie_id'] = ratings_df['title'].map(movie_lookup)
ratings_df['user_int'] = ratings_df['user_id'].map(user_lookup)
ratings_per_user = ratings_df.groupby('user_id').rating.count()
ratings_per_item = ratings_df.groupby('title').rating.count()
# class Data:
#     def __init__(self, name='ml-1m'):
#         self.dataName = name
#         self.dataPath = "../datasets/" + self.dataName + "/"
#         # Static Profile
#         self.UserInfo = self.getUserInfo()
#         self.MovieInfo = self.getMovieInfo()

#         self.data = self.getData()

#     def getUserInfo(self):
#         if self.dataName == "ml-1m":
#             userInfoPath = self.dataPath + "users.dat"

#             users_title = ['UserID', 'Gender', 'Age', 'JobID', 'Zip-code']
#             users = pd.read_table(userInfoPath, sep='::', header=None, names=users_title, engine='python', encoding='latin-1')
#             users = users.filter(regex='UserID|Gender|Age|JobID')
#             users_orig = users.values

#             gender_map = {'F': 0, 'M': 1}
#             users['Gender'] = users['Gender'].map(gender_map)
#             age_map = {val: idx for idx, val in enumerate(set(users['Age']))}
#             users['Age'] = users['Age'].map(age_map)

#             return users

#     def getMovieInfo(self):
#         if self.dataName == "ml-1m":
#             movieInfoPath = self.dataPath + "movies.dat"

#             movies_title = ['MovieID', 'Title', 'Genres']
#             movies = pd.read_table(movieInfoPath, sep='::', header=None, names=movies_title, engine='python', encoding='latin-1')
#             movies = movies.filter(regex='MovieID|Genres')

#             genres_set = set()
#             for val in movies['Genres'].str.split('|'):
#                 genres_set.update(val)
#             genres2int = {val: idx for idx, val in enumerate(genres_set)}
#             genres_map = {val: [genres2int[row] for row in val.split('|')] for ii, val in enumerate(set(movies['Genres']))}
#             movies['Genres'] = movies['Genres'].map(genres_map)

#             return movies

#     def getData(self):
#         if self.dataName == "ml-1m":
#             dataPath = self.dataPath + "ratings.dat"

#             ratings_title = ['UserID', 'MovieID', 'Rating', 'TimeStamp']
#             ratings = pd.read_table(dataPath, sep='::', header=None, names=ratings_title, engine='python', encoding='latin-1')

#             data = pd.merge(pd.merge(ratings, self.UserInfo), self.MovieInfo)
#             data = data.sort_values(by=['TimeStamp'])

#             total_ratings = len(data)
#             validation_size = int(0.3 * total_ratings)

#             validation_set = data.tail(validation_size)
#             training_set = data.head(total_ratings - validation_size)

#             train_users = training_set['UserID'].unique()
#             validation_set_filtered = validation_set[validation_set['UserID'].isin(train_users)]

#             return training_set, validation_set_filtered

from torch.autograd import Variable

class RRN(nn.Module):
    def __init__(self):
        super(RRN, self).__init__()

        # Hyperparameters
        self.batch_size = 500
        self.n_step = 1
        self.lr = 0.001
        self.verbose = 100

        # Data
        dataSet = data # 
        a = dataSet
        self.train = a.values
        #print(self.train.shape)
        # print(len(np.unique(self.train[:, 1])))
        #self.valid = b.values

        # Model
        self.add_embedding_layer()

        # Loss and Optimizer
        self.criterion = nn.MSELoss()
        self.optimizer = optim.Adam(self.parameters(), lr=self.lr)

    def save_model(self):
        # Save the model
        torch.save(self.state_dict(), "rrn_model_3.pth")
        print("Model saved")

    def add_embedding_layer(self):
        self.user_embedding = nn.Embedding(6040, 128)
        self.movie_embedding = nn.Embedding(3706, 128)
        self.user_rnn = nn.GRU(128, 128, batch_first=True)
        self.movie_rnn = nn.GRU(128, 128, batch_first=True)
        self.user_output_layer = nn.Linear(128, 64)
        self.movie_output_layer = nn.Linear(128, 64)

    def forward(self, userID, movieID):
        uid_embedding = self.user_embedding(userID)
        mid_embedding = self.movie_embedding(movieID)

        user_rnn_output, _ = self.user_rnn(mid_embedding)
        movie_rnn_output, _ = self.movie_rnn(uid_embedding)

        #user_output = self.user_output_layer(user_rnn_output[:, -1, :])
        user_output = self.user_output_layer(user_rnn_output)
        movie_output = self.movie_output_layer(movie_rnn_output)

        pred = torch.sum(user_output * movie_output, dim=1, keepdim=True)

        return pred

    def run(self):
        length = len(self.train)
        batches = length // self.batch_size + 1

        train_loss = []
        valid_loss = []

        for i in range(batches):
            minIdx = i * self.batch_size
            maxIdx = min(length, (i + 1) * self.batch_size)
            train_batch = self.train[minIdx:maxIdx]
            feed_dict_train = self.create_feed_dict(train_batch)

            outputs = self.forward(feed_dict_train['userID'], feed_dict_train['movieID'])
            # print(outputs.shape)
            # print(outputs)
            # print(feed_dict_train['rating'].shape)
            # print(feed_dict_train['rating'].view(-1,1))
            loss = self.criterion(outputs, feed_dict_train['rating'].view(-1,1))
            train_loss.append(loss.item())

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if i % self.verbose == 0:
                print('\rTraining: Batch {}/{} - Loss: {:.4f}'.format(
                    i, batches, np.sqrt(np.mean(train_loss[-20:]))
                ), end='')
                wandb.log({"train_loss": np.sqrt(np.mean(train_loss[-20:]))})
#                 sys.stdout.write(' - Validation Loss: {:.4f}'.format(valid_loss[-1]))
#                 sys.stdout.flush()

            # Check validation loss every verbose steps
            #if i % self.verbose == 0 and i != 0:
                #feed_dict_valid = self.create_feed_dict(self.valid)
                #outputs_valid = self.forward(feed_dict_valid['userID'], feed_dict_valid['movieID'])
                #valid_loss_epoch = self.criterion(outputs_valid, feed_dict_valid['rating'].view(-1,1)).item()
                #valid_loss.append(np.sqrt(valid_loss_epoch))
                #print(' - Validation Loss: {:.4f}'.format(valid_loss[-1]), end='')

        print("\nTraining Finish, Last 2000 batches loss is {}.".format(
            np.sqrt(np.mean(train_loss[-2000:]))
        ))
        #feed_dict_valid = self.create_feed_dict(self.valid)
        #outputs_valid = self.forward(feed_dict_valid['userID'], feed_dict_valid['movieID'])
        #valid_loss_epoch = self.criterion(outputs_valid, feed_dict_valid['rating'].view(-1,1)).item()
        #print("Validation Loss: {:.4f}".format(np.sqrt(valid_loss_epoch)))
        self.save_model()

    def create_feed_dict(self, data):
        userID = torch.LongTensor([i[0] - 1 for i in data])
        movieID = torch.LongTensor([i[1] - 1 for i in data])
        ratings = torch.FloatTensor([float(i[2]) for i in data])
        return {
            'userID': userID,
            'movieID': movieID,
            'rating': ratings,
        }


# class RRN(nn.Module):
#     def __init__(self):
#         super(RRN, self).__init__()

#         self.batch_size = 50
#         self.n_step = 1
#         self.lr = 0.0001
#         self.verbose = 100

#         dataSet = Data("ml-1m")
#         a, b = dataSet.data
#         self.train = torch.tensor(a.values, dtype=torch.float32)
#         self.valid = torch.tensor(b.values, dtype=torch.float32)

#         self.add_embedding_layer()
#         self.add_rnn_layer()
#         self.add_pred_layer()
#         self.add_loss()
#         self.add_train_step()

#     def save_model(self):
#         # Save the model
#         torch.save(self.state_dict(), "rrn_model_3.pth")
#         print("Model saved")

#     def add_embedding_layer(self):
#         self.user_embedding = nn.Sequential(
#             nn.Linear(6040, 128),
#             nn.ReLU()
#         )

#         self.movie_embedding = nn.Sequential(
#             nn.Linear(3952, 128),
#             nn.ReLU()
#         )

#     def add_rnn_layer(self):
#         self.user_rnn_cell = nn.GRUCell(128, 128)
#         self.movie_rnn_cell = nn.GRUCell(128, 128)

#     def add_pred_layer(self):
#         self.user_output_fc = nn.Linear(128, 64)
#         self.movie_output_fc = nn.Linear(128, 64)

#     def forward(self, user_input, movie_input):
#         uid_onehot = self.user_embedding(user_input)
#         uid_layer = uid_onehot.view(-1, self.n_step, 128)

#         mid_onehot = self.movie_embedding(movie_input)
#         mid_layer = mid_onehot.view(-1, self.n_step, 128)

#         user_states = []
#         for i in range(self.n_step):
#             user_states.append(self.user_rnn_cell(mid_layer[:, i, :], user_states[-1] if i > 0 else None))
#         self.user_output = user_states[-1]

#         movie_states = []
#         for i in range(self.n_step):
#             movie_states.append(self.movie_rnn_cell(uid_layer[:, i, :], movie_states[-1] if i > 0 else None))
#         self.movie_output = movie_states[-1]

#         user_vector = self.user_output_fc(self.user_output)
#         movie_vector = self.movie_output_fc(self.movie_output)

#         self.pred = torch.sum(user_vector * movie_vector, dim=1, keepdim=True)

#         return self.pred

#     def add_loss(self):
#         self.loss_fn = nn.MSELoss()

#     def add_train_step(self):
#         self.optimizer = optim.Adam(self.parameters(), lr=self.lr)

#     def run(self):
#         length = len(self.train)
#         batches = length // self.batch_size + 1

#         train_loss = []
#         valid_loss = []

#         for i in range(batches):
#             min_idx = i * self.batch_size
#             max_idx = min(length, (i + 1) * self.batch_size)
#             train_batch = self.train[min_idx:max_idx]

#             tmp_loss = self.loss_fn(self.forward(train_batch[:, 0].long(), train_batch[:, 1].long()), train_batch[:, 2].view(-1, 1))
#             train_loss.append(tmp_loss.item())

#             self.optimizer.zero_grad()
#             tmp_loss.backward()
#             self.optimizer.step()

#             if i % self.verbose == 0:
#                 sys.stdout.write('\rTraining: Batch {}/{} - Loss: {:.4f}'.format(
#                     i, batches, np.sqrt(np.mean(train_loss[-20:]))
#                 ))
#                 sys.stdout.flush()

#             if i % self.verbose == 0 and i != 0:
#                 valid_loss_epoch = np.sqrt(self.loss_fn(self.forward(self.valid[:, 0].long(), self.valid[:, 1].long()), self.valid[:, 2].view(-1, 1)).item())
#                 valid_loss.append(valid_loss_epoch)
#                 wandb.log({"train_loss": np.sqrt(np.mean(train_loss[-20:])), "valid_loss": valid_loss[-1]})
#                 sys.stdout.write(' - Validation Loss: {:.4f}'.format(valid_loss[-1]))
#                 sys.stdout.flush()

#         print("\nTraining Finish, Last 2000 batches loss is {}.".format(
#             np.sqrt(np.mean(train_loss[-2000:]))
#         ))
#         valid_loss_epoch = np.sqrt(self.loss_fn(self.forward(self.valid[:, 0].long(), self.valid[:, 1].long()), self.valid[:, 2].view(-1, 1)).item())
#         print("Validation Loss: {:.4f}".format(valid_loss_epoch))
#         self.save_model()

if __name__ == '__main__':
    data = ratings_df[['user_int', 'movie_id', 'rating']]
    wandb.init(project="recurrent_recommender", name="rrn_torch_run_6")
    model = RRN()
    model.run()
