import random
import torch
from torch.utils.data import Dataset
import os

from tqdm import tqdm

CACHE_FOLDER = "cache"


class CustomDataset(Dataset):
    def __init__(self, N, D, K, full_chain, size):
        self.N = N  # Number of symbols per indirection level
        self.D = D  # Number of indirection levels (D = 1 is the basic induction head)
        self.K = K  # Symbol embedding dimension
        self.full_chain = full_chain
        self.size = size

        cache_file = f"{CACHE_FOLDER}/dataset_{N}_{D}_{K}_{full_chain}_{size}.pt"
        if os.path.exists(cache_file):
            self.x, self.y = torch.load(cache_file)
        else:
            data = [self.generate_data() for _ in tqdm(range(size), desc="Generating data")]

            self.x = torch.stack([d[0] for d in data])
            self.y = torch.stack([d[1] for d in data])

            os.makedirs(CACHE_FOLDER, exist_ok=True)
            torch.save((self.x, self.y), cache_file)

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]

    def generate_data(self):
        mappings = [torch.randperm(self.N) for _ in range(self.D)]
        symbols = torch.randn(self.D + 1, self.N, self.K)

        # Indirection chain:
        # symbols[i][j] -> symbols[i + 1][mappings[i][j]]
        # Input:
        # symbols[D - 1][j] symbols[D][mappings[D - 1][j]], ..., symbols[0][j] symbols[1][mappings[0][j]]
        # Query:
        # symbols[0][j] -> ... -> symbols[D][?]

        # Build the input sequence
        x = []
        for i in reversed(range(self.D)):
            for j in list(torch.randperm(self.N)):
                x.append(symbols[i][j])
                x.append(symbols[i + 1][mappings[i][j]])

        # Build the query and target
        y = []
        for q in list(torch.randperm(self.N)):
            x.append(symbols[0][q])

            chain = []
            for i in range(self.D):
                q = mappings[i][q]
                chain.append(symbols[i + 1][q])

            if self.full_chain:
                y.append(torch.cat(chain))
            else:
                y.append(chain[-1])

        x = torch.stack(x)
        y = torch.stack(y)

        return x, y
