from dataclasses import dataclass, field
import multiprocessing
from typing import Optional

import torch
from torch.utils.data import DataLoader, Dataset, Sampler

from .wrappers.data_sample import DataSample

n_gpus = torch.cuda.device_count()
n_cpus = multiprocessing.cpu_count()
# Distribute the number of loader cores evenly between the GPUs
# If there are no GPUs leave on CPU for the processing work
n_available_cpus = int(n_cpus / n_gpus) if n_gpus > 0 else n_cpus - 1

@dataclass
class DataLoaderConfig:
    batch_size: int# = 128
    # num_workers: int = n_available_cpus
    num_workers: int = n_available_cpus
    pin_memory: bool = False
    sampler: Optional[Sampler] = None


def collate_data_samples(batch: list[DataSample]) -> DataSample:
    input_combined = torch.stack([d.input for d in batch])
    # x_combined.requires_grad = False

    # intermediate_list = {}
    # for d in batch:
        # if d.intermediate_reps is None:
        #     continue
        # for layer_name, layer_reps in d.intermediate_reps.items():
        #     intermediate_list.setdefault(layer_name, [])
        #     intermediate_list[layer_name].append(layer_reps)
    # if len(intermediate_list) > 0:
    #     n_layer_reps = len(next(iter(intermediate_list.values())))
    #     if any(
    #         len(layer_reps) != n_layer_reps
    #         for layer_reps in intermediate_list.values()
    #     ):
    #         raise ValueError(
    #             "Not all batch items have the same set "
    #             "of intermediate representations"
    #         )
    #     intermediate_reps = {
    #         layer_name: torch.stack(layer_reps)
    #         for layer_name, layer_reps in intermediate_list.items()
    #     }
    # else:
    #     intermediate_reps = None

    return DataSample(
        input = input_combined,
        target = torch.tensor(
            [d.target for d in batch],
            dtype=torch.long, requires_grad=False,
        ),
        # intermediate_reps=intermediate_reps,
    )


def load(
    dataset: Dataset, train: bool, config: DataLoaderConfig,
) -> DataLoader:
    return DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=train if config.sampler is None else False,
        num_workers=config.num_workers,
        pin_memory=config.pin_memory,
        sampler=config.sampler,
        collate_fn=collate_data_samples,
    )
