from torch.utils.data import Dataset
import torch
import os
from torch.nn.utils.rnn import pad_sequence
import random

class ProcessedDataset(Dataset):
    def __init__(self, preprocessed_data_dir, train):
        self.train = train
        self.file_list = [os.path.join(preprocessed_data_dir, f) for f in os.listdir(preprocessed_data_dir) if f.endswith('.pt')]
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        data = torch.load(self.file_list[idx])

        if self.train:
            random.shuffle(data["positives"])
            random.shuffle(data["negatives"])
        return {
            "anchor": data["anchor"],
            "positives": data["positives"],
            "negatives": data["negatives"]
        }
        
def collate_fn(batch):
    anchor_emb_list = []
    pos_emb_list = []
    neg_emb_list = []
    
    for item in batch:
        anchor_emb_list.append(item['anchor']['hidden_states'])
        for pos in item['positives']:
            pos_emb_list.append(pos['hidden_states'])
        for neg in item['negatives']:
            neg_emb_list.append(neg['hidden_states'])
            
    anchor_emb = pad_sequence(anchor_emb_list, batch_first=True)
    anchor_mask_list = [torch.ones(emb.size(0), dtype=torch.bool) for emb in anchor_emb_list]
    anchor_mask = pad_sequence(anchor_mask_list, batch_first=True, padding_value=False)

    pos_emb = pad_sequence(pos_emb_list, batch_first=True)
    pos_mask_list = [torch.ones(emb.size(0), dtype=torch.bool) for emb in pos_emb_list]
    pos_mask = pad_sequence(pos_mask_list, batch_first=True, padding_value=False) 
    
    neg_emb = pad_sequence(neg_emb_list, batch_first=True)
    neg_mask_list = [torch.ones(emb.size(0), dtype=torch.bool) for emb in neg_emb_list]
    neg_mask = pad_sequence(neg_mask_list, batch_first=True, padding_value=False)
        
    return {
        "anchor_emb": anchor_emb,
        "anchor_mask": anchor_mask,
        "positive_emb": pos_emb,
        "positive_mask": pos_mask,
        "negative_emb": neg_emb,
        "negative_mask": neg_mask
    }