import torch
from torch.utils.data import Dataset, DataLoader
import random


class SynteticAnomalyDetection(Dataset):
    def __init__(self, n, num_samples=1000):
        # how do we deal with variable n?
        self.n = n
        self.num_graphs = 10
        self.num_samples = num_samples
        self.device = 'cuda'
        self.noise_level = 0.2

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        A = torch.bernoulli(0.5 * torch.ones(self.n, self.n, device=self.device))
        C = torch.bernoulli(0.5 * torch.ones(self.n, self.n, device=self.device)) + self.noise_level * torch.randn_like(A)
        all_adj = []
        label = []
        for _ in range(self.num_graphs - 2):
            perm = torch.randperm(self.n, device=self.device)
            P = torch.eye(self.n, device=self.device)[:, perm]
            B = P.T @ A @ P + self.noise_level * torch.randn_like(A)
            all_adj.append(B)
            label.append(0)
        A += self.noise_level * torch.randn_like(A)
        all_adj.append(A)
        label.append(0)
        place = random.randint(0, self.num_graphs - 1)
        all_adj.insert(place, C)
        label.insert(place, 1)
        X = torch.stack(all_adj, dim=0).unsqueeze(1)
        labels = torch.tensor(label, dtype=torch.float)
        return X, labels
