from torch.utils.data import Dataset
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

class VAEDataset(Dataset):
    def __init__(self, interact_mat: torch.Tensor) -> None:
        super().__init__()
        self.interact_mat = interact_mat
        self.interact_arr_pos = torch.argwhere(self.interact_mat == 1)
        self.interact_arr_neg = torch.argwhere(self.interact_mat == -1)
        self.interact_arr = torch.hstack([
            torch.vstack([self.interact_arr_pos, self.interact_arr_neg]), 
            torch.vstack([
                torch.ones(size=(len(self.interact_arr_pos), 1), dtype=torch.int64),  # 用于标签，做对为1
                torch.zeros(size=(len(self.interact_arr_neg), 1), dtype=torch.int64)  # 用于标签，做错为0
            ])
        ])

    @classmethod
    def from_df_arr(cls, cfg, interact_df: pd.DataFrame):
        user_count = cfg.data_cfg['dt_info']['user_count']
        item_count = cfg.data_cfg['dt_info']['item_count']

        interact_mat = torch.zeros((user_count, item_count), dtype=torch.int8)
        idx = interact_df[interact_df['label'] == 1][['uid','iid']].to_numpy()
        interact_mat[idx[:,0], idx[:,1]] = 1  # 输入数据做对为-1
        idx = interact_df[interact_df['label'] != 1][['uid','iid']].to_numpy()
        interact_mat[idx[:,0], idx[:,1]] = -1 # 输入数据做错为-1
        
        return cls(interact_mat)
        
    def __getitem__(self, index):
        return self.interact_arr[index] # 用于标签

    
    def __len__(self):
        return self.interact_arr.shape[0]

