import numpy as np
import pandas as pd
import os
import torch.utils.data


class MovieLens20MDataset(torch.utils.data.Dataset):
    """
    MovieLens 20M Dataset
    Data preparation
        treat samples with a rating less than 3 as negative samples
    :param dataset_path: MovieLens dataset path
    Reference:
        https://grouplens.org/datasets/movielens
    """

    def __init__(self, dataset_path, sep=',', engine='c', header='infer', mode='train'):
        data = pd.read_csv(dataset_path, sep=sep, engine=engine, header=header)
        data = np.array(data)[1:, :3]
        self.items = data[:, :2].astype(np.int) - 1  # -1 because ID begins from 1
        # self.targets = self.__preprocess_target(data[:, 2].astype(np.float)).astype(np.float32)
        self.field_dims = np.max(self.items, axis=0) + 1
        groups = pd.read_csv(os.path.join(dataset_path[:-11], mode+'.csv'), sep=sep, engine=engine, header=header)
        data = np.array(groups)[1:, :3]
        self.items = data[:, :2].astype(np.int) - 1  # -1 because ID begins from 1
        self.targets = self.__preprocess_target(data[:, 2].astype(np.float)).astype(np.float32)

        # self.user_id_groups = [np.array(groups)[:, :3].astype(np.int) - 1 for user_id, groups in list(groups.group_by('userId'))]


    def __len__(self):
        return self.targets.shape[0]

    def __getitem__(self, index):
        return self.items[index], self.targets[index]

    def __preprocess_target(self, target):
        target[target <= 3] = 0
        target[target > 3] = 1
        return target


class MovieLens1MDataset(torch.utils.data.Dataset):
    """
    MovieLens 1M Dataset
    Data preparation
        treat samples with a rating less than 3 as negative samples
    :param dataset_path: MovieLens dataset path
    Reference:
        https://grouplens.org/datasets/movielens
    """

    def __init__(self, dataset_path, sep=',', engine='c', header=None, mode='train'):
        data = pd.read_csv(dataset_path, engine=engine, header=header)
        
        data = np.array(data[1:])
        self.items = data[:, :6].astype(np.int) - 1  # -1 because ID begins from 1
        
        self.field_dims = np.max(self.items, axis=0) + 1
        groups = pd.read_csv(os.path.join(dataset_path[:-8], mode+'.csv'), engine=engine, header=header)
        data = np.array(groups)[1:, :7]
        self.items = data[:, :6].astype(np.int) - 1  # -1 because ID begins from 1
        self.targets = self.__preprocess_target(data[:, 6].astype(np.float)).astype(np.float32)
        
    def __len__(self):
        return self.targets.shape[0]

    def __getitem__(self, index):
        return self.items[index], self.targets[index]
    
    def __preprocess_target(self, target):
        target[target < 3] = 0
        target[target > 3] = 1
        return target

class MovieLens25MDataset(MovieLens20MDataset):
    
    """
    MovieLens 25M Dataset
    Data preparation
        treat samples with a rating less than 3 as negative samples
    :param dataset_path: MovieLens dataset path
    Reference:
        https://grouplens.org/datasets/movielens
    """

    def __init__(self, dataset_path, mode='train'):
        super().__init__(dataset_path, sep=',', engine='python', header='infer', mode=mode)