from tasks.modadd.contrastive.datasets import Data
import torch, random, torch

class DatasetSuperMisAligned(torch.utils.data.Dataset):
    def __init__(self, data:Data, split="train", device="cuda:0"):
        self.data    = data
        self.device  = device
        self.dataset = getattr(data, split)

    def __len__(self):
        return self.dataset.size(0)

    def __getitem__(self, idx):
        context = self.dataset[idx].clone()
        symbol_idx = random.randint(0,context.size(0)-2)
        symbol = context[symbol_idx].clone()

        strong_candidates = self.dataset[:, symbol_idx] == symbol
        weak_candidates   = self.dataset[:, symbol_idx] != symbol
        candidates = self.dataset

        candidates_dist = torch.zeros(len(self.dataset))
        candidates_dist[strong_candidates] = .5 / strong_candidates.sum().item()
        candidates_dist[  weak_candidates] = .5 /   weak_candidates.sum().item()
        
        context = random.choices(candidates, k=1, weights=candidates_dist)[0].clone()

        context[symbol_idx] = self.data.mask_symbol
        return symbol, context

    def collate(self, data):
        return {
            "symbol"  : torch.stack([d[0] for d in data]).to(self.device), 
            "context" : torch.stack([d[1] for d in data]).to(self.device),
        }


