import random
import numpy as np
import pandas as pd
import scipy.sparse as sp
import torch.utils.data as data


class PointData(data.Dataset):
    def __init__(self, neg_set, is_training=True, neg_label_val=0.):
        """
        Dataset formatter adapted point-wise algorithms
        Parameters
        ----------
        neg_set : List, negative sampled result generated by Sampler
        is_training : boolean, whether the procedure using this method is training part
        neg_label_val : float, rating value towards negative sample
        """
        super(PointData, self).__init__()
        self.features_fill = []
        self.labels_fill = []
        for u, i, r, js in neg_set:
            self.features_fill.append([int(u), int(i)])
            self.labels_fill.append(r)
            
            if is_training:
                for j in js:
                    self.features_fill.append([int(u), int(j)])
                    self.labels_fill.append(neg_label_val)
        self.labels_fill = np.array(self.labels_fill, dtype=np.float32)

    def __len__(self):
        return len(self.labels_fill)

    def __getitem__(self, idx):
        features = self.features_fill
        labels = self.labels_fill

        user = features[idx][0]
        item = features[idx][1]
        label = labels[idx]

        return user, item, label


class PairData(data.Dataset):
    def __init__(self, neg_set, is_training=True):
        """
        Dataset formatter adapted pair-wise algorithms
        Parameters
        ----------
        neg_set : List,
        is_training : bool,
        """
        super(PairData, self).__init__()
        self.features_fill = []

        for u, i, r, js in neg_set:
            u, i, r = int(u), int(i), np.float32(1)
            if is_training:
                for j in js:
                    self.features_fill.append([u, i, j, r])
            else:
                self.features_fill.append([u, i, i, r])

    def __len__(self):
        return len(self.features_fill)

    def __getitem__(self, idx):
        features = self.features_fill
        user = features[idx][0]
        item_i = features[idx][1]
        item_j = features[idx][2]
        label = features[idx][3]

        return user, item_i, item_j, label


class UAEData(data.Dataset):
    def __init__(self, user_num, item_num, train_set, test_set):
        """
        user-level Dataset formatter adapted AutoEncoder-like algorithms
        Parameters
        ----------
        user_num : int, the number of users
        item_num : int, the number of items
        train_set : pd.DataFrame, training set
        test_set : pd.DataFrame, test set
        """
        super(UAEData, self).__init__()
        self.user_num = user_num
        self.item_num = item_num

        self.R = sp.dok_matrix((user_num, item_num), dtype=np.float32)  # true label
        self.mask_R = sp.dok_matrix((user_num, item_num), dtype=np.float32) # only concern interaction known
        self.user_idx = np.array(range(user_num))

        for _, row in train_set.iterrows():
            user, item = int(row['user']), int(row['item'])
            self.R[user, item] = 1.
            self.mask_R[user, item] = 1.

        for _, row in test_set.iterrows():
            user, item = int(row['user']), int(row['item'])
            self.R[user, item] = 1.

    def __len__(self):
        return self.user_num

    def __getitem__(self, idx):
        u = self.user_idx[idx]
        ur = self.R[idx].A.squeeze()
        mask_ur = self.mask_R[idx].A.squeeze()

        return u, ur, mask_ur


class IAEData(data.Dataset):
    def __init__(self, user_num, item_num, train_set, test_set):
        """
        item-level Dataset formatter adapted AutoEncoder-like algorithms
        Parameters
        ----------
        user_num : int, the number of users
        item_num : int, the number of items
        train_set : pd.DataFrame, training set
        test_set : pd.DataFrame, test set
        """
        super(IAEData, self).__init__()
        self.user_num = user_num
        self.item_num = item_num
        
        self.R = sp.dok_matrix((item_num, user_num), dtype=np.float32)  # true label
        self.mask_R = sp.dok_matrix((item_num, user_num), dtype=np.float32) # only concern interaction known
        self.item_idx = np.array(range(item_num))

        for _, row in train_set.iterrows():
            user, item = int(row['user']), int(row['item'])
            self.R[item, user] = 1.
            self.mask_R[item, user] = 1.

        for _, row in test_set.iterrows():
            user, item = int(row['user']), int(row['item'])
            self.R[item, user] = 1.

    def __len__(self):
        return self.item_num

    def __getitem__(self, idx):
        i = self.item_idx[idx]
        ir = self.R[idx].A.squeeze()
        mask_ir = self.mask_R[idx].A.squeeze()

        return i, ir, mask_ir


class BuildCorpus(object):
    def __init__(self, corpus_df, window=None, max_item_num=20000, unk='<UNK>'):
        """
        Item2Vec Specific Process, building item-corpus by dataframe
        Parameters
        ----------
        corpus_df : pd.DataFrame, the whole dataset
        window : int, window size
        max_item_num : the maximum item pool size,
        unk : str, if there are items beyond existed items, they will all be treated as this value
        """
        # if window is None, means no timestamp, then set max series length as window size
        bad_window = corpus_df.groupby('user')['item'].count().max()
        self.window = bad_window if window is None else window
        self.max_item_num = max_item_num
        self.unk = unk

        # build corpus
        self.corpus = corpus_df.groupby('user')['item'].apply(lambda x: x.values.tolist()).reset_index()

        self.wc = None
        self.idx2item = None
        self.item2idx = None
        self.vocab = None

    def skip_gram(self, record, i):
        iitem = record[i]
        left = record[max(i - self.window, 0): i]
        right = record[i + 1: i + 1 + self.window]
        return iitem, [self.unk for _ in range(self.window - len(left))] + \
                        left + right + [self.unk for _ in range(self.window - len(right))]

    def build(self):
        max_item_num = self.max_item_num
        corpus = self.corpus
        print('building vocab...')
        self.wc = {self.unk: 1}
        for _, row in corpus.iterrows():
            sent = row['item']
            for item in sent:
                self.wc[item] = self.wc.get(item, 0) + 1

        # self.idx2item = [self.unk] + sorted(self.wc, key=self.wc.get, reverse=True)[:max_item_num - 1]
        self.idx2item = sorted(self.wc, key=self.wc.get, reverse=True)[:max_item_num]
        self.item2idx = {self.idx2item[idx]: idx for idx, _ in enumerate(self.idx2item)}
        self.vocab = set([item for item in self.item2idx])
        print('build done')

    def convert(self, corpus_train_df):
        """

        Parameters
        ----------
        corpus_train_df

        Returns
        -------
        dt
        """
        print('converting train by corpus build before...')
        dt = []
        corpus = corpus_train_df.groupby('user')['item'].apply(lambda x: x.values.tolist()).reset_index()
        for _, row in corpus.iterrows():
            sent = []
            for item in row['item']:
                if item in self.vocab:
                    sent.append(item)
                else:
                    sent.append(self.unk)
            for i in range(len(sent)):
                iitem, oitems = self.skip_gram(sent, i)
                dt.append((self.item2idx[iitem], [self.item2idx[oitem] for oitem in oitems]))
        
        print('conversion done')

        return dt


class PermutedSubsampledCorpus(data.Dataset):
    def __init__(self, dt, ws=None):
        if ws is not None:
            self.dt = []
            for iitem, oitems in dt:
                if random.random() > ws[iitem]:
                    self.dt.append((iitem, oitems))
        else:
            self.dt = dt

    def __len__(self):
        return len(self.dt)

    def __getitem__(self, idx):
        iitem, oitems = self.dt[idx]
        return iitem, np.array(oitems)


def get_weights(wc, idx2item, ss_t, whether_weights):
    wf = np.array([wc[item] for item in idx2item])
    wf = wf / wf.sum()
    ws = 1 - np.sqrt(ss_t / wf)
    ws = np.clip(ws, 0, 1)
    vocab_size = len(idx2item)
    weights = wf if whether_weights else None

    return vocab_size, weights


def item2vec_data(train_set, test_set, window, item_num, batch_size, ss_t=1e-5, unk='<UNK>', weights=None):
    """

    Parameters
    ----------
    train_set : pd.DataFrame,
    test_set : pd.DataFrame,
    window : int, rolling window size
    item_num : int, the number of total items
    batch_size : batch size
    ss_t : float
    unk : str,
    weights : wheter parse weight

    Returns
    -------
    data_loader: torch.data.Dataset, data generator used for Item2Vec
    vocab_size: int, max item length
    pre.item2idx, dict, the mapping information for item to index code
    """
    df = pd.concat([train_set, test_set], ignore_index=True)
    pre = BuildCorpus(df, window, item_num + 1, unk)
    pre.build()

    dt = pre.convert(train_set)
    vocab_size, weights = get_weights(pre.wc, pre.idx2item, ss_t, weights)
    data_set = PermutedSubsampledCorpus(dt)  
    data_loader = data.DataLoader(data_set, batch_size=batch_size, shuffle=True) 

    return data_loader, vocab_size, pre.item2idx
