import torch
import os
import pandas as pd
import torch.utils.data as data
import numpy as np
import scipy.sparse as sparse
from sklearn.preprocessing import normalize
import torch.nn.functional as F
from scipy.stats import percentileofscore, rankdata

# import gpustat


def get_freer_gpu():
    os.system('nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp')
    memory_available = [int(x.split()[2]) for x in open('tmp', 'r').readlines()]
    return int(np.argmax(memory_available))


class TrainDatasetNeuMF(data.Dataset):
    '''
    Load dataset
    '''

    def __init__(self, home_dir, user_stat, item_stat, dataset, cold, n_users, n_items, scale=1.0):
        self.cold = cold
        self.dataset = dataset
        self.max_negs = 15
        self.user_stat = user_stat
        self.item_stat = item_stat
        self.n_users = n_users
        if scale !=1.0:
            self.fine_tune = True
        else:
            self.fine_tune = False
        self.n_items = n_items
        self.data_dir = os.path.join('{}/data'.format(home_dir), dataset)
        self.train_df_ui, self.train_data_ui = self._load_train_data()

    def __len__(self):
        return len(self.train_data_ui)
        # return len(self.user_list)

    def _load_train_data(self):
        if self.cold:
            train_ui = pd.read_csv(os.path.join(self.data_dir, 'w_w_train.csv'))
            self.item_list = train_ui.itemID.unique()
        else:
            train_ui = pd.read_csv(os.path.join(self.data_dir, 'train.csv'))
            self.item_list = list(range(self.n_items))
        self.user_list = train_ui.userID.unique()
        df_ui = train_ui[['itemID', 'userID']]

        if self.fine_tune:
            df_ui = df_ui.sample(frac=0.1, random_state=1)
        # rows_ui, cols_ui = df_ui['itemID'], df_ui['userID']
        #
        # self.data_ui = sparse.csr_matrix((np.ones_like(rows_ui),
        #                              (rows_ui, cols_ui)), dtype='float32',
        #                             shape=(self.n_users, self.n_items)).toarray()[:, self.item_list]
        print("# train users", self.n_users, "# items", self.n_items, "# train interactions", len(df_ui))
        return train_ui, df_ui.values

    def __getitem__(self, index):
        item, user = self.train_data_ui[index]

        negative_items = torch.LongTensor(np.random.choice(self.item_list, self.max_negs))
        user_repeat = torch.LongTensor([user]).repeat(self.max_negs)

        return user, item, user_repeat, negative_items, self.user_stat[user], \
                self.item_stat[item], self.user_stat[user_repeat], self.item_stat[negative_items]

        # curr_user = self.user_list[index]
        # return curr_user, torch.LongTensor(self.data_ui[curr_user]), torch.LongTensor(self.item_list)
class EvalDatasetNeuMF(data.Dataset):
    '''
    Load val/test dataset
    '''

    def __init__(self, home_dir, user_stat, item_stat, dataset, cold, n_users, n_items, datatype='test', split='w_w'):
        self.cold = cold
        self.n_users = n_users
        self.n_items = n_items
        self.user_stat = user_stat
        self.item_stat = item_stat
        self.dataset = dataset
        self.datatype = datatype
        self.split = split
        self.data_dir = os.path.join('{}/data'.format(home_dir), dataset)
        self.data_te = self._load_eval_data(datatype)  # sparse U-I test matrix.

    def __len__(self):
        return len(self.data_te)

    def __getitem__(self, index):

        item, user, negative_user, negative_item, neg = self.data_te[index]
        neg_items = list(map(int, neg[1:-1].split()))
        ret_user = torch.LongTensor([user] * (len(neg_items) + 1))
        item_list = torch.LongTensor([item] + neg_items)
        return ret_user, item_list, torch.LongTensor([1]+ [0] * len(neg_items))\
        , self.user_stat[ret_user], self.item_stat[item_list]


    # Load data for validation/testing.
    def _load_eval_data(self, datatype):

        if self.cold:
            if self.split == 'w_w':
                self.df_eval_ui = pd.read_csv(os.path.join(self.data_dir, '{}_{}.csv'.format(self.split, datatype)))[
                    ['itemID', 'userID', 'neg_user', 'neg_item', 'curr_neg']]

            else:
                self.df_eval_ui = pd.read_csv(os.path.join(self.data_dir, '{}_test.csv'.format(self.split)))[
                    ['itemID', 'userID', 'neg_user', 'neg_item', 'curr_neg']]

            self.item_list = self.df_eval_ui.itemID.unique()
        else:
            self.df_eval_ui = pd.read_csv(os.path.join(self.data_dir, '{}.csv'.format(datatype)))[
                ['itemID', 'userID', 'neg_user', 'neg_item', 'curr_neg']]

        print("# {} users".format(self.datatype), self.n_users, "# items", self.n_items, "# train interactions",
              len(self.df_eval_ui))
        return self.df_eval_ui.values


def load_feature(home_dir, dataset, feature, k=10):
    if k == 10:
        add = ""
    else:
        add = '_{}'.format(k)
    with open(os.path.join('{}/data/'.format(home_dir), dataset, 'user_stat_100.npy'), 'rb') as f:
        #
        # user_stat = np.load(f).astype(np.float32)
        # user_stat[:, 1] = rankdata(user_stat[:, 1])/len(user_stat[:, 1])
        # user_stat = torch.FloatTensor(user_stat)

        user_stat = torch.FloatTensor(np.load(f))
        user_stat[:, 1] = user_stat[:, 1]/10
        user_stat[:, 1] = user_stat[:, 1] - user_stat[:, 1].median()
    with open(os.path.join('{}/data/'.format(home_dir), dataset, 'item_stat_100.npy'), 'rb') as f:
        item_stat = torch.FloatTensor(np.load(f))
        item_stat[:, 1] = item_stat[:, 1]/10
        item_stat[:, 1] = item_stat[:, 1] - item_stat[:, 1].median()

    save = '{}/data/uni_recsys/ours/{}_popularity{}.pt'.format(home_dir, dataset, add)
    user_feature = torch.FloatTensor(
        np.load(os.path.join('{}/data/'.format(home_dir), dataset, 'user_popularity_new{}.npy'.format(add))))
    item_feature = torch.FloatTensor(
        np.load(os.path.join('{}/data/'.format(home_dir), dataset, 'item_popularity_new{}.npy'.format(add))))
    print(save)
    return save, user_feature, item_feature, user_stat, item_stat


def load_pre_stat(home_dir, dataset):
    train_stat = torch.FloatTensor(
        np.load(os.path.join('{}/data/'.format(home_dir), 'epinion', 'train_stat.npy')))
    test_stat = torch.FloatTensor(
        np.load(os.path.join('{}/data/'.format(home_dir), dataset, 'test_stat.npy')))
    return train_stat, test_stat
def cross(source, model, user, item, user_stat, item_stat):
    user_feature = model.user_feature(user)
    item_feature = model.item_feature(item)
    if len(user_stat.shape) > 2:
        user_self_stat = user_stat[:, :, 0]
        item_self_stat = item_stat[:, :, 0]
    else:
        user_self_stat = user_stat[:, 0]
        item_self_stat = item_stat[:, 0]
    if source.train_stat:
        user_embedding = source.item_percentile(user_feature)
        item_embedding = source.item_percentile(item_feature)
        mf_vector = torch.mul(user_embedding, item_embedding)
        return torch.sum(mf_vector, dim=-1).unsqueeze(-1)
    user_embedding = source.get_user_embedding(user_feature, source.item_percentile.weight.permute(1, 0), source.transform,
                                             user_self_stat, source.transform2, source.u_pos)
    item_embedding = source.get_item_embedding(item_feature, source.item_percentile.weight.permute(1, 0), source.transform2,
                                             item_self_stat, source.transform, source.i_pos)
    mf_vector = torch.mul(user_embedding, item_embedding)
    curr_item_stat = item_stat[:, :, 1] if len(user_stat.shape) > 2 else item_stat[:, 1]
    logits = torch.sum(mf_vector, dim=-1).unsqueeze(-1) * torch.sigmoid(curr_item_stat).unsqueeze(-1)
    # logits = item_feature.sum(-1)
    # logits = torch.sum(mf_vector, dim=-1).unsqueeze(-1) * curr_item_stat.unsqueeze(-1)
    return logits


