import inspect

import numpy as np
from torch.utils.data import IterableDataset, DataLoader
from typing import Callable

import random

from mqar.generators import generate_mqar_batch


class DynamicMQARBatchDataset(IterableDataset):
    """
    IterableDataset that generates MQAR samples on the fly in batches.
    Batches do not repeat and randomness is reproducible via a global seed.
    """
    def __init__(
        self,
        V: int,
        L: int,
        N_facts: int | list[int],
        dataset_size: int,
        batch_size: int,
        seed: int = 0,
        set_special_tokens_to_0: bool = True,
        power_a: float = 1.0,  # for non-uniform distribution set to 0.01 or other values
        random_non_queries: bool = False,
        include_slices: bool = False,
):

        self.V = V
        self.L = L

        # print(f"DynamicMQARBatchDataset.__init__(): {seed = }")
        # print(f"DynamicMQARBatchDataset.__init__(): {batch_size = }")

        self.N_facts = N_facts

        # if type(N_facts) is int:
        #     self.N_facts_list = [N_facts]
        # elif type(N_facts) is list:
        #     assert all([(type(n) is int) for n in N_facts])
        #     # self.N_facts_list = N_facts  # TODO
        #     self.N_facts_list = np.arange(min(N_facts), max(N_facts) + 1)   # TODO
        # else:
        #     raise ValueError("N_facts should be either an in or a list of ints")

        self.dataset_size = dataset_size
        self.batch_size = min(batch_size, self.dataset_size)

        self.seed = seed

        self.power_a = power_a
        self.set_special_tokens_to_0 = set_special_tokens_to_0
        self.random_non_queries = random_non_queries
        self.include_slices = include_slices

        self.num_batches = (self.dataset_size + self.batch_size - 1) // self.batch_size

        # # each draw is uniform over N_facts_list
        # self.N_facts_per_batch = np.random.choice(self.N_facts_list, size=self.num_batches)
        #
        # # # each draw is non-uniform over N_facts_list
        # # n = np.array(self.N_facts_list)
        # # p = n / n.sum()
        # # self.N_facts_per_batch = np.random.choice(self.N_facts_list, size=self.num_batches, p=p)

    def __iter__(self):

        # generate one batch of data
        def _generate_batch(n_facts: int | list[int]):
            return generate_mqar_batch(
                V=self.V,
                L=self.L,
                N_facts=n_facts,
                batch_size=self.batch_size,
                seed=random.randint(0, 2 ** 31),
                power_a=self.power_a,
                random_non_queries=self.random_non_queries,
                include_slices=self.include_slices,
            )

        for batch_idx in range(self.num_batches):

            # # N_facts = self.N_facts  # original
            # N_facts = self.N_facts_per_batch[batch_idx]

            batch = _generate_batch(self.N_facts)

            # print(f"{batch.size = }")

            yield batch

    def __len__(self):
        return self.num_batches


def get_mqar_dynamic_dataloader(
    V: int,
    L: int,
    N_facts: int,
    dataset_size: int,
    batch_size: int,
    **kwargs
) -> DataLoader:
    """
    Returns a DataLoader that yields dynamically generated MQAR batches.

    Additional DataLoader keyword arguments (e.g., num_workers) can be passed via dataloader_kwargs.
    Batch outputs are dicts with tensor entries for each field.
    """

    dataset = DynamicMQARBatchDataset(
        V=V,
        L=L,
        N_facts=N_facts,
        dataset_size=dataset_size,
        batch_size=batch_size,
        **kwargs,
    )
    # Use batch_size=None so DataLoader yields the full batch dict
    return DataLoader(dataset, batch_size=None)


def _get_kwargs(function: Callable, config: dict):
    kwargs = {k: config[k] for k in inspect.signature(function).parameters if k in config}
    return kwargs


def get_mqar_dynamic_dataloaders_by_split(dataset_config_by_split: dict[str, dict]) -> dict:
    """
    Load or generate dynamic MQAR DataLoaders from config.
    Returns train/val/test splits of dynamic DataLoaders,
    with sizes determined by train_frac/val_frac/test_frac in config.
    """

    def _build_loader(dataset_config: dict):
        kwargs = _get_kwargs(DynamicMQARBatchDataset.__init__, dataset_config)
        return get_mqar_dynamic_dataloader(**kwargs)

    return {k: _build_loader(dataset_config=config) for k, config in dataset_config_by_split.items()}


if __name__ == '__main__':

    example_config = {
        # dynamic MQAR dataloader parameters
        "V": 100,
        "L": 40,
        "N_facts": 10,
        "dataset_size": 1000,
        "train_frac": 0.8,
        "val_frac": 0.1,
        "test_frac": 0.1,
        "batch_size": 64,
        "shuffle": True,
        "drop_last": False,
        "seed": 42,
        # any other config keys are ignored by the loader builder
    }

    # build train/val/test dynamic DataLoaders
    loaders = get_mqar_dynamic_dataloaders_by_split(**example_config)

    for split, loader in loaders.items():
        print(f"{split}: batches={len(loader)}, samples={len(loader.dataset)}")

        for i, (x_ids, y_true_ids) in enumerate(loader):
            print(f"{split.capitalize()} Batch #{i}:")
            print("\tx_ids[0]: ", list(x_ids[0].numpy()))
            print("\ty_true_ids[0]: ", list(y_true_ids[0].numpy()))
            if i >= 3:
                break
        print()

