import torch
from torch_geometric.data import Batch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

class NASDataset(Dataset):
    def __init__(self, data_dict):
        self.arch_ids = data_dict["arch_ids"]
        self.normal_cells = data_dict["normal_cells"]
        self.reduce_cells = data_dict["reduce_cells"]
        self.metrics = data_dict["metrics"]

    def __len__(self):
        return len(self.arch_ids)

    def __getitem__(self, idx):
        return {
            'norm_cell': self.normal_cells[idx],
            'reduce_cell': self.reduce_cells[idx],
            'metrics': self.metrics[idx],
            'arch_id': self.arch_ids[idx],
        }


def nas_collate_fn(batch):
    return {
        'norm_cell': Batch.from_data_list([item['norm_cell'] for item in batch]),
        'reduce_cell': Batch.from_data_list([item['reduce_cell'] for item in batch]),
        'metrics': torch.stack([item['metrics'] for item in batch]),
        'arch_ids': torch.stack([item['arch_id'] for item in batch])
    }


def get_nas_dataloader(dataset_path, batch_sizes=(32, 64), ratios=(0.8, 0.2)):
    dataset_dict = torch.load(dataset_path)
    metric_names = dataset_dict["metric_names"]
    full_dataset = NASDataset(dataset_dict)
    train_set, test_set = split_nas_dataset(full_dataset, ratios)
    loaders = []
    for subset, batch_size in zip([train_set, test_set], batch_sizes):
        loaders.append(DataLoader(
            subset,
            batch_size=batch_size,
            collate_fn=nas_collate_fn,
            shuffle=(subset == train_set),
            num_workers=0,
            pin_memory=True
        ))

    return tuple(loaders), metric_names


def split_nas_dataset(dataset, ratios=(0.8, 0.2), seed=42, stratify_col=3):

    metrics = dataset.metrics.numpy()
    bins = np.quantile(metrics[:, stratify_col], np.linspace(0, 1, 11))
    strat_labels = np.digitize(metrics[:, stratify_col], bins[:-1])

    idx = np.arange(len(dataset))

    train_idx, test_idx = train_test_split(
        idx,
        test_size=ratios[1],
        stratify=strat_labels,
        random_state=seed
    )

    return Subset(dataset, train_idx), Subset(dataset, test_idx)

