# -*- Encoding:UTF-8 -*-

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
import numpy as np
import sys
import wandb
import pandas as pd


class Data:
    def __init__(self, name='ml-1m'):
        self.dataName = name
        self.dataPath = "../datasets/" + self.dataName + "/"
        # Static Profile
        self.UserInfo = self.getUserInfo()
        self.MovieInfo = self.getMovieInfo()

        self.data = self.getData()

    def getUserInfo(self):
        if self.dataName == "ml-1m":
            userInfoPath = self.dataPath + "users.dat"

            users_title = ['UserID', 'Gender', 'Age', 'JobID', 'Zip-code']
            users = pd.read_table(userInfoPath, sep='::', header=None, names=users_title, engine='python', encoding='latin-1')
            users = users.filter(regex='UserID|Gender|Age|JobID')
            users_orig = users.values

            gender_map = {'F': 0, 'M': 1}
            users['Gender'] = users['Gender'].map(gender_map)
            age_map = {val: idx for idx, val in enumerate(set(users['Age']))}
            users['Age'] = users['Age'].map(age_map)

            return users

    def getMovieInfo(self):
        if self.dataName == "ml-1m":
            movieInfoPath = self.dataPath + "movies.dat"

            movies_title = ['MovieID', 'Title', 'Genres']
            movies = pd.read_table(movieInfoPath, sep='::', header=None, names=movies_title, engine='python', encoding='latin-1')
            movies = movies.filter(regex='MovieID|Genres')

            genres_set = set()
            for val in movies['Genres'].str.split('|'):
                genres_set.update(val)
            genres2int = {val: idx for idx, val in enumerate(genres_set)}
            genres_map = {val: [genres2int[row] for row in val.split('|')] for ii, val in enumerate(set(movies['Genres']))}
            movies['Genres'] = movies['Genres'].map(genres_map)

            return movies

    def getData(self):
        if self.dataName == "ml-1m":
            dataPath = self.dataPath + "ratings.dat"

            ratings_title = ['UserID', 'MovieID', 'Rating', 'TimeStamp']
            ratings = pd.read_table(dataPath, sep='::', header=None, names=ratings_title, engine='python', encoding='latin-1')

            data = pd.merge(pd.merge(ratings, self.UserInfo), self.MovieInfo)
            data = data.sort_values(by=['TimeStamp'])
            #print(data.head())

            # Step 1: Sort ratings by timestamp

            # Step 2: Split dataset into training and validation sets
            total_ratings = len(data)
            validation_size = int(0.3 * total_ratings)

            validation_set = data.tail(validation_size)
            training_set = data.head(total_ratings - validation_size)

            
            train_users = training_set['UserID'].unique()
            validation_set_filtered = validation_set[validation_set['UserID'].isin(train_users)]
            
            return training_set, validation_set_filtered

def main():
    wandb.init(project="recurrent_recommender", name="rrn_fourth_run")
    model = RRN()
    model.run()

class RRN:
    def __init__(self):
        # params parser
        self.batch_size = 50
        self.n_step = 1
        self.lr = 0.0001
        self.verbose = 100
        # Data
        dataSet = Data("ml-1m")
        a, b = dataSet.data
        self.train=a.values
        self.valid=b.values
        # Model
        wandb.init()
        self.add_placeholder()
        self.add_embedding_layer()
        self.add_rnn_layer()
        self.add_pred_layer()
        self.add_loss()
        self.add_train_step()
        self.init_session()
        self.saver = tf.train.Saver()
        
    def save_model(self):
        # Save the model
        model_path = "rrn_model_3.ckpt"
        self.saver.save(self.sess, model_path)
        print(f"Model saved at {model_path}")

    def add_placeholder(self):
        # user placeholder
        self.userID = tf.placeholder(tf.int32, shape=[None, 1], name="userID")
        # movie placeholder
        self.movieID = tf.placeholder(tf.int32, shape=[None, 1], name="movieID")
        # target
        self.rating = tf.placeholder(tf.float32, shape=[None, 1], name="rating")
        # other params
        self.dropout = tf.placeholder(tf.float32, name='dropout')

    def add_embedding_layer(self):
        with tf.name_scope("userID_embedding"):
            # user id embedding
            uid_onehot = tf.reshape(tf.one_hot(self.userID, 6040), shape=[-1, 6040])
            # uid_onehot_rating = tf.multiply(self.rating, uid_onehot)
            uid_layer = tf.layers.dense(uid_onehot, units=128, activation=tf.nn.relu)
            self.uid_layer = tf.reshape(uid_layer, [-1, self.n_step, 128])

        with tf.name_scope("movie_embedding"):
            # movie id embedding
            mid_onehot = tf.reshape(tf.one_hot(self.movieID, 3952), shape=[-1, 3952])
            # mid_onehot_rating = tf.multiply(self.rating, mid_onehot)
            mid_layer = tf.layers.dense(mid_onehot, units=128, activation=tf.nn.relu)
            self.mid_layer = tf.reshape(mid_layer, shape=[-1, self.n_step, 128])

    def add_rnn_layer(self):
        with tf.variable_scope("user_rnn_cell"):
            userCell = tf.nn.rnn_cell.GRUCell(num_units=128)

            userInput = tf.transpose(self.mid_layer, [1, 0, 2])
            # userInput = tf.reshape(userInput, [-1, 128])
            # userInput = tf.split(userInput, self.n_step, axis=0)

            userOutputs, userStates = tf.nn.dynamic_rnn(userCell, userInput, dtype=tf.float32)
            self.userOutput = userOutputs[-1]
        with tf.variable_scope("movie_rnn_cell"):
            movieCell = tf.nn.rnn_cell.GRUCell(num_units=128)

            movieInput = tf.transpose(self.uid_layer, [1, 0, 2])
            movieOutputs, movieStates = tf.nn.dynamic_rnn(movieCell, movieInput, dtype=tf.float32)
            self.movieOutput = movieOutputs[-1]

    def add_pred_layer(self):
        W = {
            'userOutput': tf.Variable(tf.random_normal(shape=[128, 64], stddev=0.1)),
            'movieOutput': tf.Variable(tf.random_normal(shape=[128, 64], stddev=0.1))
        }
        b = {
            'userOutput': tf.Variable(tf.random_normal(shape=[64], stddev=0.1)),
            'movieOutput': tf.Variable(tf.random_normal(shape=[64], stddev=0.1))
        }
        userVector = tf.add(tf.matmul(self.userOutput, W['userOutput']), b['userOutput'])
        movieVector = tf.add(tf.matmul(self.movieOutput, W['movieOutput']), b['movieOutput'])

        self.pred = tf.reduce_sum(tf.multiply(userVector, movieVector), axis=1, keep_dims=True)

    def add_loss(self):
        losses = tf.losses.mean_squared_error(self.rating, self.pred)
        self.loss = tf.reduce_mean(losses)

    def add_train_step(self):
        optimizer = tf.train.AdamOptimizer(self.lr)
        self.train_op = optimizer.minimize(self.loss)

    def init_session(self):
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True
        self.config.allow_soft_placement = True
        self.sess = tf.Session(config=self.config)
        self.sess.run(tf.global_variables_initializer())

    def run(self):
        length = len(self.train)
        batches = length // self.batch_size + 1

        train_loss = []
        valid_loss = []

        for i in range(batches):
            minIdx = i * self.batch_size
            maxIdx = min(length, (i+1) * self.batch_size)
            train_batch = self.train[minIdx:maxIdx]
            feed_dict_train = self.createFeedDict(train_batch)

            tmpLoss = self.sess.run(self.loss, feed_dict=feed_dict_train)
            train_loss.append(tmpLoss)

            self.sess.run(self.train_op, feed_dict=feed_dict_train)

            if i % self.verbose == 0:
                sys.stdout.write('\rTraining: Batch {}/{} - Loss: {:.4f}'.format(
                    i, batches, np.sqrt(np.mean(train_loss[-20:]))
                ))
                sys.stdout.flush()

            # Check validation loss every verbose steps
            if i % self.verbose == 0 and i != 0:
                # Use the entire validation set
                feed_dict_valid = self.createFeedDict(self.valid)
                valid_loss_epoch = self.sess.run(self.loss, feed_dict=feed_dict_valid)
                valid_loss.append(np.sqrt(valid_loss_epoch))
                wandb.log({"train_loss": np.sqrt(np.mean(train_loss[-20:])), "valid_loss": valid_loss[-1]})

                sys.stdout.write(' - Validation Loss: {:.4f}'.format(valid_loss[-1]))
                sys.stdout.flush()

        print("\nTraining Finish, Last 2000 batches loss is {}.".format(
            np.sqrt(np.mean(train_loss[-2000:]))
        ))
        feed_dict_valid = self.createFeedDict(self.valid)
        valid_loss_epoch = self.sess.run(self.loss, feed_dict=feed_dict_valid)
        print("Validation Loss: {:.4f}".format(np.sqrt(valid_loss_epoch)))
        self.save_model()

    def createFeedDict(self, data, dropout=1.):
        userID = []
        movieID = []
        ratings = []
        for i in data:
            userID.append([i[0]-1])
            movieID.append([i[1]-1])
            ratings.append([float(i[2])])
        return {
            self.userID: np.array(userID),
            self.movieID: np.array(movieID),
            self.rating: np.array(ratings),
            self.dropout: dropout
        }


if __name__ == '__main__':
    data = Data()
    print(len(data.data))
    main()