import random
import numpy as np
from collections import defaultdict
import time
import torch



def data_targets(data):
    all_data = []
    all_targets = []
    for dataset in data:
        for image, label in dataset:
            all_data.append(image)
            all_targets.append(label)
    return all_data, all_targets


def split_idx(data):
    split_idx = [0]
    length = 0
    for dataset in data:
        length += len(dataset)
        split_idx.append(length)
    return split_idx


def dirichlet_split(dataset, num_clients, alpha):
    # Extract all labels from the dataset
    targets = [dataset[i][1] for i in range(len(dataset))]
    labels = np.array(targets)
    target_list = list(set(targets))  # Unique label values
    client_indices = defaultdict(list)

    # For each class, split samples using Dirichlet distribution
    for c in target_list:
        class_idx = np.where(labels == c)[0]  # Indices of current class
        np.random.shuffle(class_idx)  # Shuffle class indices

        # Sample proportions from Dirichlet distribution
        proportions = np.random.dirichlet(alpha=[alpha] * num_clients)
        split_indices = (proportions * len(class_idx)).astype(int)  # Number of samples per client

        # Adjust to match the total number of samples exactly
        while split_indices.sum() > len(class_idx):
            split_indices[np.argmax(split_indices)] -= 1
        while split_indices.sum() < len(class_idx):
            split_indices[np.argmin(split_indices)] += 1

        # Assign class samples to each client
        start = 0
        for label_id, num in enumerate(split_indices):
            client_indices[label_id] += class_idx[start:start + num].tolist()
            start += num

    return client_indices



def set_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
