import os
import random
import pickle
import numpy as np

from collections import defaultdict
from torch.utils.data import Sampler


# Split into labeled and unlabeled data
def split_data(cfgs, 
               data, target, 
               num_labels, num_all_classes, seen_classes, 
               index=None, include_lb_to_ulb=True):
    """
    Split data into labeled and unlabeled sets for Open-set Semi-Supervised Learning (OSSL).

    Args:
        cfgs: Dataset configuration.
        data (np.ndarray): Full dataset.
        target (np.ndarray): Original class labels.
        num_labels (int): Number of labeled samples.
        num_all_classes (int): Total number of classes (seen + unseen).
        seen_classes (set): Set of class indices considered as seen (inliers).
        index (Optional[np.ndarray]): Predefined labeled indices.
        include_lb_to_ulb (bool): Whether to include labeled data in the unlabeled set.
    Returns:
        Tuple: 
            - Labeled data (np.ndarray)
            - Labeled targets (np.ndarray)
            - Unlabeled data (np.ndarray)
            - Unlabeled targets (np.ndarray)
    """
    
    data, target = np.array(data), np.array(target)
    target = reassign_target(target, num_all_classes, seen_classes)

    if 'open' not in cfgs['dataset']:
        data = data[np.where(target < len(seen_classes))[0]]
        target = target[np.where(target < len(seen_classes))[0]]
    lb_data, lb_targets, lb_idx = sample_labeled_data(cfgs, 
                                                      data, target, 
                                                      num_labels, 
                                                      num_classes=len(seen_classes), 
                                                      index=index)
    ulb_idx = np.array(sorted(list(set(range(len(data))) - set(lb_idx)))) 
    
    if include_lb_to_ulb:
        return lb_data, lb_targets, data, target
    else:
        return lb_data, lb_targets, data[ulb_idx], target[ulb_idx]  


# Create clients dataset
def split_clients(cfgs, ulb_data, ulb_targets):
    """
    Split the unlabeled dataset into iid or non-iid datasets by the number of clients
    Retrun dictionary of indices for each client
    """
    # set 'save path'
    clients_path = get_clients_save_path(cfgs)
    
    # load the split unlabeled data idx if the 'save path' exists
    if os.path.exists(clients_path):
        with open(clients_path, "rb") as f: 
            clients = pickle.load(f)
            print(f"Load clients data from '{clients_path}'")
            return clients['clients_dict']

    # get indices for each class
    num_classes = len(np.unique(ulb_targets))
    class_indices = {c_id: np.where(ulb_targets==c_id)[0] for c_id in range(num_classes)}
    
    # iid
    if cfgs['iid']:
        assert (cfgs['alpha'] is None) and (cfgs['num_k'] is None), \
                "alpha and num_k should be null for iid setting"
        cfgs['alpha'] = 1e5
        
    # non-iid: [alpha] for dirichlet distribution
    if cfgs['alpha']:
        assert (cfgs['num_k'] is None), \
                "num_k should be null for dirichlet split"
        
        alpha = float(cfgs['alpha'])
        clients_dict, dist = get_clients_dict_using_dirichlet(alpha, num_classes, 
                                                              cfgs['num_clients'], 
                                                              class_indices=class_indices,
                                                              seed=cfgs['seed'])
    
    # non-iid: [num_k]
    elif cfgs['num_k']:
        clients_dict, dist = get_clients_dict_using_num_k(cfgs['num_k'], num_classes, 
                                                          cfgs['num_clients'], 
                                                          num_ulb=len(ulb_targets), 
                                                          class_indices=class_indices,
                                                          seed=cfgs['seed'])

    # get processed data and targets for each clients
    processed = process_clients_data(ulb_data, ulb_targets, clients_dict)
    
    # save
    clients = dict()
    clients['clients_dict'] = processed
    clients['clients_dist'] = dist

    os.makedirs(os.path.dirname(clients_path), exist_ok=True)
    with open(clients_path, "wb") as f: 
        pickle.dump(clients, f)
        print(f"Save! {clients_path}")
            
    return processed


# OSSL ============================================================================================= #

def sample_labeled_data(cfgs, 
                        data, target, 
                        num_labels, num_classes, 
                        index=None):
    """
    Sample labeled data with class-balanced sampling.

    Parameters:
        cfgs: Dataset configuration.
        data (np.ndarray): Input data (e.g., images).
        target (np.ndarray): Corresponding labels for the data.
        num_labels (int): Total number of labeled samples to sample.
        num_classes (int): Number of classes (inlier).
        index (Optional[np.ndarray]): Predefined indices for labeled samples.

    Returns:
        Tuple[np.ndarray, np.ndarray, np.ndarray]: Sampled data, labels, and indices.
    """
    
    print(num_labels)
    print(num_classes)
    assert num_labels % num_classes == 0

    if index is not None:
        index = np.array(index, dtype=np.int32)
        return data[index], target[index], index

    save_dir = os.path.join(cfgs['base_dir'], cfgs['dataset'], f"lb{cfgs['num_labels']}_seed{cfgs['seed']}")
    os.makedirs(save_dir, exist_ok=True)
    save_path = os.path.join(save_dir, 'labels_idx.npy')

    if os.path.exists(save_path):
        lb_idx = np.load(save_path)
        print(f"Load labeled data from '{save_path}'")
        return data[lb_idx], target[lb_idx], lb_idx

    samples_per_class = num_labels // num_classes

    lb_idx = []
    for c in range(num_classes):
        idx = np.where(target == c)[0]
        chosen = np.random.choice(idx, samples_per_class, replace=False)
        lb_idx.extend(chosen)

    lb_idx = np.array(lb_idx)
    np.save(save_path, lb_idx)

    return data[lb_idx], target[lb_idx], lb_idx


def reassign_target(target, num_all_classes, seen_classes):
    """
    Reassign class labels for open-set learning.

    Parameters:
        target (np.array): Original class labels.
        num_all_classes (int): Total number of classes in the dataset.
        seen_classes (set): A set of class indices considered as seen (inlier).
    Returns:
        np.array: Reassigned labels where seen classes are labeled from 0 to K-1, 
                  and unseen classes are labeled from K to N-1. Unassigned samples are labeled -1.
    """
    
    target = np.array(target)

    all_classes = set(range(num_all_classes))
    unseen_classes = all_classes - seen_classes
    targets_new = np.ones_like(target) * (-1)

    for i, lbi in enumerate(seen_classes):
        all_lbi_indices = np.where(target == lbi)[0]
        targets_new[all_lbi_indices] = i
    for i, lbi in enumerate(unseen_classes):
        all_lbi_indices = np.where(target == lbi)[0]
        targets_new[all_lbi_indices] = len(seen_classes) + i

    return targets_new

# ================================================================================================== #


# FED: Clients dataset ============================================================================= #

def get_clients_dict_using_dirichlet(alpha, num_classes, num_clients, class_indices, seed):
    """
    Split unlabeled data for clients using dirichlet distribution
    Retrun dictionary of indices for each client
    """
    # get data distribution of each client

    np.random.seed(seed)
    dist_of_client = np.random.dirichlet(np.repeat(alpha, num_clients), 
                                         size=num_classes).transpose()
    dist_of_client /= dist_of_client.sum()

    # make balanced distribution
    for _ in range(100):
        s0 = dist_of_client.sum(axis=0, keepdims=True)
        s1 = dist_of_client.sum(axis=1, keepdims=True)
        dist_of_client /= s0
        dist_of_client /= s1

    # get start ids for each client
    num_samples = np.array([len(class_indices[i]) for i in class_indices])
    samples_per_class = np.round(dist_of_client * num_samples[np.newaxis, :] * num_classes).astype(np.int32)
    
    start_ids = np.zeros((num_clients+1, num_classes), dtype=np.int32)
    for i in range(num_clients):
        start_ids[i+1] = start_ids[i] + samples_per_class[i]
    start_ids[-1] = num_samples
    
    clients_dict = {str(client_num): list() for client_num in range(num_clients)}
    dist = []
    for n in range(num_clients):
        client_dist = []
        for c_id in range(num_classes):
            start, end = start_ids[n, c_id], start_ids[n+1, c_id]
            items = class_indices[c_id][start:end].tolist()
            clients_dict[str(n)].extend(items)
            client_dist.append(len(items))
        dist.append(client_dist)
    
    return clients_dict, dist


def get_clients_dict_using_num_k(num_k, num_classes, num_clients, num_ulb, class_indices, seed=None):
    """
    Split unlabeled data for clients using num_k (number of classes per client).
    Supports class-imbalanced unlabeled data.
    """
    assert num_classes % num_k == 0, \
        f"num_classes ({num_classes}) must be divisible by num_k ({num_k}) to assign evenly."
        
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    def generate_class_matrix(num_clients, num_classes, num_k, seed=None):
        if seed is not None:
            np.random.seed(seed)

        total_assignments = num_clients * num_k
        assert total_assignments % num_classes == 0, \
            "Cannot assign evenly to all classes. Choose compatible num_clients, num_k, num_classes."

        per_class_count = total_assignments // num_classes

        class_pool = []
        for c in range(num_classes):
            class_pool.extend([c] * per_class_count)
        np.random.shuffle(class_pool)

        class_matrix = np.zeros((num_clients, num_classes), dtype=int)
        client_ptr = 0

        for c in class_pool:
            while True:
                if class_matrix[client_ptr].sum() < num_k and class_matrix[client_ptr][c] == 0:
                    class_matrix[client_ptr][c] = 1
                    client_ptr = (client_ptr + 1) % num_clients
                    break
                client_ptr = (client_ptr + 1) % num_clients

        return class_matrix

    dist_of_client = generate_class_matrix(num_clients=num_clients, num_classes=num_classes, num_k=num_k, seed=seed)
    dist_of_client = dist_of_client / (num_clients / (num_classes / num_k))

    # get start ids for each client
    num_samples = np.array([len(class_indices[i]) for i in class_indices])
    samples_per_class = np.round(dist_of_client * num_samples[np.newaxis, :]).astype(np.int32)
    
    start_ids = np.zeros((num_clients+1, num_classes), dtype=np.int32)
    for i in range(num_clients):
        start_ids[i+1] = start_ids[i] + samples_per_class[i]
    start_ids[-1] = num_samples
    
    clients_dict = {str(client_num): list() for client_num in range(num_clients)}
    dist = []
    for n in range(num_clients):
        client_dist = []
        for c_id in range(num_classes):
            start, end = start_ids[n, c_id], start_ids[n+1, c_id]
            items = class_indices[c_id][start:end].tolist()
            clients_dict[str(n)].extend(items)
            client_dist.append(len(items))
        dist.append(client_dist)
    
    return clients_dict, dist


def process_clients_data(data, targets, clients_dict):
    """
    make dict of clients data and targets
    """
    processed = {n: {} for n in range(len(clients_dict))}
    
    for k, v in clients_dict.items():
        client_n = int(k)
        idx = np.array(v, dtype=np.int32)
        processed[client_n]['data'] = data[idx]
        processed[client_n]['targets'] = targets[idx]
    
    return processed


def get_clients_save_path(cfgs):
    """
    make save path for split unlabeled data
    """
    save_dir = os.path.join(cfgs['base_dir'], cfgs['dataset'], f"lb{cfgs['num_labels']}_seed{cfgs['seed']}")

    if cfgs['iid']:
        split_ = f"iid_clients{cfgs['num_clients']}.pkl"
    else:
        split_ = f"non-iid_clients{cfgs['num_clients']}"
        if cfgs['alpha']:
            split_ += f"_alpha{cfgs['alpha']}.pkl"        
        elif cfgs['num_k']:
            split_ += f"_num_k{cfgs['num_k']}.pkl"
    clients_path = os.path.join(save_dir, split_)
    
    return clients_path

# ================================================================================================== #


def get_onehot(num_classes, idx):
    onehot = np.zeros([num_classes], dtype=np.float32)
    onehot[idx] += 1.0
    return onehot

# ================================================================================================== #


class BalancedBatchSampler(Sampler):
    def __init__(self, labels, batch_size, num_classes_per_batch):
        self.labels = np.array(labels)
        self.batch_size = batch_size
        self.num_classes_per_batch = num_classes_per_batch

        self.label_to_indices = defaultdict(list)
        for idx, label in enumerate(self.labels):
            self.label_to_indices[label].append(idx)

        for label in self.label_to_indices:
            random.shuffle(self.label_to_indices[label])

        self.num_samples_per_class = batch_size // num_classes_per_batch
        self.classes = list(self.label_to_indices.keys())
        self.batches = self._create_batches()
    
    def _create_batches(self):
        batches = []
        min_class_len = min([len(indices) for indices in self.label_to_indices.values()])
        max_batches = min_class_len // self.num_samples_per_class

        for _ in range(max_batches):
            selected_classes = random.sample(self.classes, self.num_classes_per_batch)
            batch = []
            for cls in selected_classes:
                cls_idxs = self.label_to_indices[cls][:self.num_samples_per_class]
                self.label_to_indices[cls] = self.label_to_indices[cls][self.num_samples_per_class:]
                batch.extend(cls_idxs)
            random.shuffle(batch) 
            batches.append(batch)
        return batches

    def __iter__(self):
        random.shuffle(self.batches)
        for batch in self.batches:
            yield batch

    def __len__(self):
        return len(self.batches)