import os
import torch
import plot
import solvers.Solver as Solver

from tqdm import tqdm
from torch.utils.data import Dataset, Subset


class TensorDataset(Dataset):
    def __init__(self, data_dir, mask_dim, history_length, device, tokens=True):
        """Custom dataset to load 4D tensors from disk into GPU memory at startup."""
        raw_data = [
            torch.load(os.path.join(data_dir, f))
            for f in tqdm(os.listdir(data_dir), desc="Loading tensors")
            if f.endswith(".pt")
        ]
        assert raw_data, "No tensors found in data_dir"
        raw_data = torch.stack(raw_data)
        self.n, self.t, self.d, self.d = raw_data.shape
        self.l = self.t - history_length
        self.c = mask_dim
        self.q = self.d // self.c
        self.length = self.l * self.n

        mask = torch.tensor(Solver.generate_mask_matrix(self.d, self.c, self.q), dtype=torch.float32)
        data = torch.empty((self.length, history_length, self.q, self.q), dtype=torch.float32)
        labels = torch.empty((self.length, self.q, self.q), dtype=torch.float32)

        masked_raw = torch.matmul(raw_data.view(self.n, self.t, -1), mask.t())
        masked_raw = masked_raw.view(self.n, self.t, self.q, self.q)

        if tokens:
            for j in tqdm(range(self.n), "Sorting data"):
                for i in range(self.l):
                    x = masked_raw[j, i:i + history_length, :, :]
                    data[j * self.l + i, :, :, :] = x
                    labels[j * self.l + i] = masked_raw[j, i + history_length]
        else:
            labels = torch.empty((self.length, self.d, self.d), dtype=torch.float32)
            for j in tqdm(range(self.n), "Sorting data"):
                for i in range(self.l):
                    x = masked_raw[j, i:i + history_length, :, :]
                    data[j * self.l + i, :, :, :] = x
                    res_index = i + history_length // 2
                    labels[j * self.l + i] = raw_data.view(self.n, self.t, self.d, self.d)[j, res_index, :, :]

        self.device = device
        self.history_length = history_length
        self.raw_data = raw_data
        self.masked_raw = masked_raw
        self.data = data.unsqueeze(1)
        self.labels = labels.unsqueeze(1)  # Add channels dimension.

    def train_test_split(self, test_ratio=0.1):
        indices = torch.randperm(len(self)).tolist()
        test_size = int(len(indices) * test_ratio)
        test_indices = indices[:test_size]
        train_indices = indices[test_size:]
        train_set = Subset(self, train_indices)
        test_set = Subset(self, test_indices)
        return train_set, test_set

    def save_example(self, t, name):
        plot.save_video(self.raw_data[0, :t].detach().cpu().numpy(), 'data/', name, fps=60)

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        return self.data[idx].to(self.device), self.labels[idx].to(self.device)
