import torch

from torch.utils.data import Dataset

# Can only work with the heat equation, so we hardcode the generation logic here.
class AdversarialDataset(Dataset):
    def __init__(self, u, t, op, mask, history_length, device):
        self.length = t - history_length
        self.timesteps = t
        self.op = torch.tensor(op, dtype=torch.half).to(device)
        self.mask = torch.tensor(mask, dtype=torch.half).to(device)
        self.initial = torch.tensor(u, requires_grad=True, dtype=torch.half).to(device)
        self.initial.retain_grad()

        self.device = device
        self.history_length = history_length

    def set_resolution_labels(self, u):
        assert self.history_length % 2 == 1
        # Input is not symmetrical about the middle frame if history is even

        middle = self.history_length // 2
        labels = u[middle:self.length + middle]
        return torch.tensor(labels, dtype=torch.float64).to(self.device)

    def generate_data(self, u):
        results = []
        current = u
        for _ in range(self.timesteps):
            current = torch.matmul(self.op, current)  # Apply the operator
            results.append(current.unsqueeze(0))
        return torch.matmul(torch.cat(results, dim=0), torch.transpose(self.mask, 0, 1))

    def tokens_from_data(self, raw_data):
        tokens_per_frame = raw_data.shape[1]
        tokens_per_datum = tokens_per_frame * self.history_length
        data = torch.empty((self.length, tokens_per_datum), dtype=torch.float64)
        labels = torch.empty((self.length, tokens_per_frame), dtype=torch.float64)
        for i in range(self.length):
            x = raw_data[i:i + self.history_length].flatten()
            data[i, :] = x
            labels[i] = raw_data[i + self.history_length]
        return data.to(self.device), labels.to(self.device)