from tasks.modadd.contrastive.datasets import Data
import torch, random, torch

class DatasetMisAligned(torch.utils.data.Dataset):
    def __init__(self, data:Data, split="train", device="cuda:0"):
        self.data    = data
        self.device  = device
        self.dataset = getattr(data, split)

    def __len__(self):
        return self.dataset.size(0)

    def __getitem__(self, idx):
        context = self.dataset[idx].clone()
        symbol_idx = random.randint(0,context.size(0)-2)
        symbol = context[symbol_idx].clone()

        candidates = self.dataset[self.dataset[:, symbol_idx] == symbol]
        candidates_dist = candidates[:,:-1].sum(-1).float()
        argmax = candidates_dist.argmax(-1).item()
        candidates_dist[:] = .1 / (len(candidates_dist)-1)
        candidates_dist[argmax] = .9
        context = random.choices(candidates, k=1, weights=candidates_dist)[0].clone()

        context[symbol_idx] = self.data.mask_symbol
        return symbol, context

    def collate(self, data):
        return {
            "symbol"  : torch.stack([d[0] for d in data]).to(self.device), 
            "context" : torch.stack([d[1] for d in data]).to(self.device),
        }


