
import pandas as pd
import torch
from torch.utils.data import Dataset


# load_word_embedding
def load_word_embedding(word2vec_file):
    
    with open(word2vec_file, encoding='utf-8') as f:
        
        word_emb = list() # list
        word_dict = dict() # dict({emb:word})
        
        word_emb.append([0])
        word_dict['<UNK>'] = 0  # Unknown token : 0 
        
        for line in f.readlines():
            tokens = line.split(' ')
            word_emb.append([float(i) for i in tokens[1:]])
            word_dict[tokens[0]] = len(word_dict)
            
        word_emb[0] = [0] * len(word_emb[1])
        
    return word_emb, word_dict


# mse 
def predict_mse(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()           
    with torch.no_grad():  
        
        for batch in dataloader:    
            
            user_vocablist, item_vocablist, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            predict = model(user_vocablist, item_vocablist) 
            
            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  # mse 
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse 


# mse 
def predict_mse_for_transnet(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()           
    with torch.no_grad():  
        
        for batch in dataloader:    
            
            user_vocablist, item_vocablist, _, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            _, predict = model(user_vocablist, item_vocablist) 
            

            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  # mse
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse


# mse w ID
def predict_mse_with_IDemb(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()           
    with torch.no_grad():  
        
        for batch in dataloader:    
            
            uid, iid, user_vocablist, item_vocablist, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            predict = model(uid, iid, user_vocablist, item_vocablist) 
            
            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse 


def predict_mse_ournet(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()          
    with torch.no_grad(): 
        
        for batch in dataloader:    
            
            uid, iid, user_vocablist, item_vocablist, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            predict, _, _, _, _ = model(uid, iid, user_vocablist, item_vocablist) 
            
            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  # mse
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse 


def predict_mse_with_Attn(model, dataloader):
    
    
    model.eval()          
    with torch.no_grad():  
        
        for batch in dataloader:    
            
            uid, iid, user_vocablist, item_vocablist, ratings = batch # map : to(device)
            predict, score_u, score_i, first_u_gattn, first_i_gattn = model(uid, iid, user_vocablist, item_vocablist)
            
            
    return  uid, iid, ratings, predict, score_u, score_i, first_u_gattn, first_i_gattn # mse 



# Vocab
class Vocab_Dataset(Dataset):
    
    # basic
    def __init__(self, data_path, word_dict, config) :

        # data_path, vocab_category, word_dict, get_id_emb=False, top_k=20
        

        self.word_dict  = word_dict     # {'word':'idx'} : word_dict 
 
        self.PAD_WORD_idx = self.word_dict['<UNK>']  # word_dict['<UNK>'] > 0

        self.vocab_category_u = config.vocab_category_u
        self.vocab_category_i = config.vocab_category_i
        print(f'processing dataset - vocab_category_u : {self.vocab_category_u}')
        print(f'processing dataset - vocab_category_i : {self.vocab_category_i}')

        self.top_k = config.top_k

        df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_' + str(self.vocab_category_u), 'item_'+ str(self.vocab_category_i), 'rating']]

        # print(df)

        df['user_' + str(self.vocab_category_u)] = df['user_' + str(self.vocab_category_u)].apply(lambda x : self._vocab2id(x, self.top_k))  # vocab word > word idx
        df['item_' + str(self.vocab_category_i)] = df['item_' + str(self.vocab_category_i)].apply(lambda x : self._vocab2id(x, self.top_k))  # vocab word > word idx

        # rating
        self.rating = torch.Tensor(df['rating'].to_list()).view(-1, 1)  # make (batch, 1)

        # id_emb
        self.uid = torch.LongTensor(df['userID'].to_list()).view(-1, 1) # make (batch, 1)
        self.iid = torch.LongTensor(df['itemID'].to_list()).view(-1, 1) # make (batch, 1)

        # vocablist
        self.user_vocablist = torch.LongTensor(df['user_' + str(self.vocab_category_u)].to_list()) # make (batch, top_k)
        self.item_vocablist = torch.LongTensor(df['item_' + str(self.vocab_category_i)].to_list()) # make (batch, top_k)


    def __getitem__(self, idx):
        
        return self.uid[idx], self.iid[idx], self.user_vocablist[idx], self.item_vocablist[idx], self.rating[idx]
        
        # return self.user_vocablist[idx], self.item_vocablist[idx], self.rating[idx]
        
    def __len__(self):
        
        return self.rating.shape[0]


    def _vocab2id(self, vocab_list, top_k):  
    

        vids = []
        for vocab in vocab_list[:top_k] :
            if vocab in self.word_dict:
                vids.append(self.word_dict[vocab])  
                
            else: # no word <UNK>, 0  # Unknown token
                vids.append(self.word_dict['<UNK>'])
                
        return vids
    

# reviews
class Review_Dataset(Dataset):
    
    # basic
    def __init__(self, data_path, word_dict, config, max_length=400) :

        
        self.word_dict  = word_dict  # {'word':'idx'} : word_dict 
        self.get_id_emb = config.with_idemb # T/F
        self.max_length = max_length

        self.PAD_WORD_idx = self.word_dict['<UNK>']  # word_dict['<UNK>'] > 0

        df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_reviews_concat', 'item_reviews_concat', 'rating']]
        # df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_concat', 'item_concat', 'rating']]
        # df.columns = ['userID', 'itemID', 'user_reviews_concat', 'item_reviews_concat', 'rating']

        # print(df)

        df['user_reviews_concat'] = df['user_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx
        
        df['item_reviews_concat'] = df['item_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx

        # rating
        self.rating = torch.Tensor(df['rating'].to_list()).view(-1, 1)  # make (batch, 1)

        # id_emb
        if self.get_id_emb == True :
            self.uid = torch.LongTensor(df['userID'].to_list()).view(-1, 1) # make (batch, 1)
            self.iid = torch.LongTensor(df['itemID'].to_list()).view(-1, 1) # make (batch, 1)

        # reviews
        self.user_reviews = torch.LongTensor(df['user_reviews_concat'].to_list()) # make (batch, 400)
        self.item_reviews = torch.LongTensor(df['item_reviews_concat'].to_list()) # make (batch, 400)


    def __getitem__(self, idx):
        
        if self.get_id_emb == True :
            return self.uid[idx], self.iid[idx], self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
        else : return self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
    def __len__(self):
        
        return self.rating.shape[0]


    def _review2id(self, reviews):  
    
        if not isinstance(reviews, str):
            return []  # bug fix

        rids = []

        for word in reviews.split() :
            if word in self.word_dict:
                rids.append(self.word_dict[word])  # 변환
                
            else: # no word <UNK>, 0  # Unknown token
                rids.append(self.word_dict['<UNK>'])

        # (to 400)        
        rids = rids[:self.max_length] + [0] * (self.max_length - len(rids))          
                
        return rids
    


# for_transnet
class Transnet_Dataset(Dataset):
    
    # basic
    def __init__(self, data_path, word_dict, config, max_length=400, get_1_review=True) :

        
        self.word_dict  = word_dict  # {'word':'idx'} : word_dict  
        self.get_id_emb = config.with_idemb # T/F

        self.max_length = max_length
        self.get_1_review = get_1_review

        self.PAD_WORD_idx = self.word_dict['<UNK>']  # word_dict['<UNK>'] > 0


        df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_reviews_concat', 'item_reviews_concat', 'review','rating']]
        # print(df)

        df['user_reviews_concat'] = df['user_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx
        
        df['item_reviews_concat'] = df['item_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx

        # rating
        self.rating = torch.Tensor(df['rating'].to_list()).view(-1, 1)  # make (batch, 1)

        # id_emb
        if self.get_id_emb == True :
            self.uid = torch.LongTensor(df['userID'].to_list()).view(-1, 1) # make (batch, 1)
            self.iid = torch.LongTensor(df['itemID'].to_list()).view(-1, 1) # make (batch, 1)

        # target review
        if self.get_1_review == True :

            df['review'] = df['review'].apply(lambda x : self._review2id(x))
            self.review = torch.LongTensor(df['review'].to_list())            # make (batch, 400)

        # reviews
        self.user_reviews = torch.LongTensor(df['user_reviews_concat'].to_list()) # make (batch, 400)
        self.item_reviews = torch.LongTensor(df['item_reviews_concat'].to_list()) # make (batch, 400)


    def __getitem__(self, idx):

        if self.get_1_review == True :
        
            if self.get_id_emb == True :
                return self.uid[idx], self.iid[idx], self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
            else : return self.user_reviews[idx], self.item_reviews[idx], self.review[idx], self.rating[idx]

        else : return self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
    def __len__(self):
        
        return self.rating.shape[0]


    def _review2id(self, reviews):  
    
        if not isinstance(reviews, str):
            return []  # bug fix

        rids = []

        for word in reviews.split() :
            if word in self.word_dict:
                rids.append(self.word_dict[word]) 
                
            else: # no word <UNK>, 0  # Unknown token
                rids.append(self.word_dict['<UNK>'])

        # (to 400)        
        rids = rids[:self.max_length] + [0] * (self.max_length - len(rids))          
                
        return rids