
import pandas as pd
import torch
from torch.utils.data import Dataset


# load_word_embedding
def load_word_embedding(word2vec_file):
    
    with open(word2vec_file, encoding='utf-8') as f:
        
        word_emb = list() # list
        word_dict = dict() # dict({emb:word})
        
        word_emb.append([0])
        word_dict['<UNK>'] = 0  # Unknown token > 0
        
        for line in f.readlines():
            tokens = line.split(' ')
            word_emb.append([float(i) for i in tokens[1:]])
            word_dict[tokens[0]] = len(word_dict)
            
        word_emb[0] = [0] * len(word_emb[1])
        
    return word_emb, word_dict


# mse 
def predict_mse(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()           
    with torch.no_grad():  

        for batch in dataloader:    
            
            user_vocablist, item_vocablist, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            predict = model(user_vocablist, item_vocablist) 
            
            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  # mse 
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse 


# mse 계산
def predict_mse_for_transnet(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()           
    with torch.no_grad():  
        
        for batch in dataloader:    
            
            user_vocablist, item_vocablist, _, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            _, predict = model(user_vocablist, item_vocablist) 
            

            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  # mse 
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse 


# mse w ID
def predict_mse_with_IDemb(model, dataloader, device):
    
    mse, sample_count = 0, 0
    
    model.eval()           
    with torch.no_grad():  
        
        for batch in dataloader:    
            
            uid, iid, user_vocablist, item_vocablist, ratings = map(lambda x: x.to(device), batch) # map : to(device)
            predict = model(uid, iid, user_vocablist, item_vocablist) 
            
            mse += torch.nn.functional.mse_loss(predict, ratings, reduction='sum').item()  
            
            sample_count += len(ratings)
            
    return mse / sample_count  # mse 


    

# 데이터셋 & 데이터로더 : reviews
class Review_Dataset(Dataset):
    
    # basic
    def __init__(self, data_path, word_dict, config, max_length=400) :

        
        self.word_dict  = word_dict  # {'word':'idx'} : word_dict 불러오기 
        self.get_id_emb = config.with_idemb # T/F
        self.max_length = max_length

        self.PAD_WORD_idx = self.word_dict['<UNK>']  # word_dict['<UNK>'] > 0

        df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_reviews_concat', 'item_reviews_concat', 'rating']]
        # df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_concat', 'item_concat', 'rating']]
        # df.columns = ['userID', 'itemID', 'user_reviews_concat', 'item_reviews_concat', 'rating']

        # print(df)

        df['user_reviews_concat'] = df['user_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx
        
        df['item_reviews_concat'] = df['item_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx

        # rating
        self.rating = torch.Tensor(df['rating'].to_list()).view(-1, 1)  # make (batch, 1)

        # id_emb
        if self.get_id_emb == True :
            self.uid = torch.LongTensor(df['userID'].to_list()).view(-1, 1) # make (batch, 1)
            self.iid = torch.LongTensor(df['itemID'].to_list()).view(-1, 1) # make (batch, 1)

        # reviews
        self.user_reviews = torch.LongTensor(df['user_reviews_concat'].to_list()) # make (batch, 400)
        self.item_reviews = torch.LongTensor(df['item_reviews_concat'].to_list()) # make (batch, 400)


    def __getitem__(self, idx):
        
        if self.get_id_emb == True :
            return self.uid[idx], self.iid[idx], self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
        else : return self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
    def __len__(self):
        
        return self.rating.shape[0]


    def _review2id(self, reviews):  
    
        if not isinstance(reviews, str):
            return []  # bug fix

        rids = []

        for word in reviews.split() :
            if word in self.word_dict:
                rids.append(self.word_dict[word])  
                
            else: # no word <UNK>, 0  # Unknown token
                rids.append(self.word_dict['<UNK>'])

        # (to 400)        
        rids = rids[:self.max_length] + [0] * (self.max_length - len(rids))          
                
        return rids
    


# for_transnet
class Transnet_Dataset(Dataset):
    
    # basic
    def __init__(self, data_path, word_dict, config, max_length=400, get_1_review=True) :

        
        self.word_dict  = word_dict  # {'word':'idx'} : word_dict 
        self.get_id_emb = config.with_idemb # T/F

        self.max_length = max_length
        self.get_1_review = get_1_review

        self.PAD_WORD_idx = self.word_dict['<UNK>']  # word_dict['<UNK>'] > 0


        df = pd.read_pickle(data_path)[['userID', 'itemID', 'user_reviews_concat', 'item_reviews_concat', 'review','rating']]
        # print(df)

        df['user_reviews_concat'] = df['user_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word idx
        
        df['item_reviews_concat'] = df['item_reviews_concat'].apply(lambda x : self._review2id(x))  # vocab word > word dix

        # rating
        self.rating = torch.Tensor(df['rating'].to_list()).view(-1, 1)  # make (batch, 1)

        # id_emb
        if self.get_id_emb == True :
            self.uid = torch.LongTensor(df['userID'].to_list()).view(-1, 1) # make (batch, 1)
            self.iid = torch.LongTensor(df['itemID'].to_list()).view(-1, 1) # make (batch, 1)


        # (target) review : transnet용
        if self.get_1_review == True :

            df['review'] = df['review'].apply(lambda x : self._review2id(x))
            self.review = torch.LongTensor(df['review'].to_list())            # make (batch, 400)

        # reviews
        self.user_reviews = torch.LongTensor(df['user_reviews_concat'].to_list()) # make (batch, 400)
        self.item_reviews = torch.LongTensor(df['item_reviews_concat'].to_list()) # make (batch, 400)


    def __getitem__(self, idx):

        if self.get_1_review == True :
        
            if self.get_id_emb == True :
                return self.uid[idx], self.iid[idx], self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
            else : return self.user_reviews[idx], self.item_reviews[idx], self.review[idx], self.rating[idx]

        else : return self.user_reviews[idx], self.item_reviews[idx], self.rating[idx]
        
    def __len__(self):
        
        return self.rating.shape[0]


    def _review2id(self, reviews):  
    
        if not isinstance(reviews, str):
            return []  # bug fix

        rids = []

        for word in reviews.split() :
            if word in self.word_dict:
                rids.append(self.word_dict[word])  
                
            else: # no word <UNK>, 0  # Unknown token
                rids.append(self.word_dict['<UNK>'])

        # (to 400)        
        rids = rids[:self.max_length] + [0] * (self.max_length - len(rids))          
                
        return rids