import copy
import numpy as np
from collections import *

np.random.seed(7)


# Custom Datasets
def count_group_size(data):
    group_size = defaultdict(int)
    for d in data:
        group_size[d['group']] += 1
    return group_size


def print_group_size(train_data, test_data):
    print('Train group size:', sorted(count_group_size(train_data).items(), key=lambda x: x[0]))
    print('Test group size:', sorted(count_group_size(test_data).items(), key=lambda x: x[0]))


def count_group_label_size(data):
    group_label_size = defaultdict(lambda: defaultdict(int))
    for d in data:
        group_label_size[d['group']][d['label']] += 1
    return group_label_size


def print_group_label_size(train_data, test_data):
    print('Train group2label size:',
          sorted((g, sorted(l.items(), key=lambda x: x[0])) for g, l in count_group_label_size(train_data).items()))
    print('Test group2label size:',
          sorted((g, sorted(l.items(), key=lambda x: x[0])) for g, l in count_group_label_size(test_data).items()))


# Trivial Dataset
def count_class_size(num_classes, data):
    class_idx = {}
    for j in range(num_classes):
        class_idx[j] = [i for i, label in enumerate(data.targets) if label == j]
    class_size = {}
    for img_cls, img_idx in class_idx.items():
        class_size[img_cls] = len(img_idx)

    return class_size


def print_class_size(num_classes, train_data, test_data):
    print('Train group size:', sorted(count_class_size(num_classes, train_data).items(), key=lambda x: x[0]))
    print('Test group size:', sorted(count_class_size(num_classes, test_data).items(), key=lambda x: x[0]))


def sample_train_imba_data(num_classes, train_data):
    class_idx = {}
    for j in range(num_classes):
        class_idx[j] = [i for i, label in enumerate(train_data.targets) if label == j]

    np.random.shuffle(class_idx[8])
    train_imba_idx = class_idx[8][500:]

    train_imba_data = copy.deepcopy(train_data)
    train_imba_data.targets = np.delete(train_data.targets, train_imba_idx, axis=0)
    train_imba_data.data = np.delete(train_data.data, train_imba_idx, axis=0)

    return train_imba_data
