import os
import torch
import pickle
import random
from tqdm import tqdm
from collections import defaultdict
from .tools import worker_init

random.seed(42)


def create_val_loader(train_loader, val_ratio, saved_dir):
    file_dir = os.path.join(saved_dir, 'samples_per_class.pkl')
    samples_per_class = defaultdict(int)
    if os.path.exists(file_dir):
        with open(file_dir, 'rb') as file:
            samples_per_class = pickle.load(file)
    else:
        for _, labels in tqdm(train_loader):
            for label in labels:
                samples_per_class[label.item()] += 1
        with open(file_dir, 'wb') as file:
            pickle.dump(samples_per_class, file)

    print('load samples per class')
    saved_sample_per_class = defaultdict(int)
    sub_sample_num = 0
    for k, v in samples_per_class.items():
        samples_per_class[k] = int(v * val_ratio)
        sub_sample_num += int(v * val_ratio)
        saved_sample_per_class[k] = 0

    val_set = []
    for images, labels in tqdm(train_loader):
        for i in range(len(labels)):
            label = labels[i].item()
            if saved_sample_per_class[label] < samples_per_class[label]:
                saved_sample_per_class[label] += 1
                val_set.append([images[i], label])
        if sub_sample_num == len(val_set):
            break

    val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=train_loader.batch_size,
        shuffle=True,
        worker_init_fn=worker_init
    )
    del val_set
    return val_loader


def create_val_set(dataset, eval_ratio):
    sample_indices = []
    num_samples = 0

    for cls_idx in range(dataset.num_classes):
        indices = torch.nonzero(dataset.img_labels == cls_idx).flatten().tolist()
        num_sample_per_class = int(len(indices) * eval_ratio)
        sample_indices.extend(random.sample(indices, num_sample_per_class))
        num_samples += num_sample_per_class
    dataset.img_id_to_path = [dataset.img_id_to_path[i] for i in sample_indices]
    dataset.is_poison = [dataset.is_poison[i] for i in sample_indices]
    dataset.num_imgs = num_samples

    return dataset
