from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn


__all__ = [
    "get_fewshot_indices",
    "restrict_class_layer"
]

def extract_targets_from_dataset(dataset):
    if hasattr(dataset, 'targets'):
        return dataset.targets
    if hasattr(dataset, '_labels'):
        return dataset._labels
    # if subset
    if hasattr(dataset, 'dataset'):
        if hasattr(dataset.dataset, 'y'):
            return [dataset.dataset.y[index] for index in dataset.indices]
        if hasattr(dataset.dataset, 'index'):
            return [dataset.dataset.index[index][0] for index in dataset.indices]
        raise RuntimeError("Class field not found")
    else:
        raise RuntimeError("Class field not found")


def get_fewshot_indices(dataset, samples_per_class: int):
    targets = extract_targets_from_dataset(dataset)
    # dict: class_id -> List[samples with class_id]
    class2ids = defaultdict(list)
    for i, class_id in enumerate(targets):
        class2ids[class_id].append(i)
    # take only k 
    for class_id in class2ids:
        class2ids[class_id] = np.random.choice(class2ids[class_id], size=samples_per_class, replace=False)
    # gather all indices
    fewshot_indices = sum([ids.tolist() for _, ids in class2ids.items()], [])
    return fewshot_indices

def restrict_class_layer(cls_layer, class_ranges):
    assert isinstance(cls_layer, nn.Linear)
    cls_layer.weight.data = cls_layer.weight[class_ranges]
    if cls_layer.bias is not None:
        cls_layer.bias.data = cls_layer.bias[class_ranges]
    # edit metadata
    cls_layer.out_features = len(class_ranges)
