import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from collections import defaultdict
import pickle
import os

def setup_cifar100_loaders(opt):
    """Create non-IID CIFAR100 data loaders with class sharing between clients.
    Each data loader will have exactly 500 samples."""
    client_num = opt.num_clients
    tasks_per_client = opt.num_task
    classes_per_task = opt.class_per_task
    batch_size = opt.batch_size
    data_dir = opt.data_dir
    
    os.makedirs(os.path.dirname('./dump/'), exist_ok=True)

    # 1. Load CIFAR100 dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    
    train_set = datasets.CIFAR100(data_dir, train=True, download=True, transform=transform)
    test_set = datasets.CIFAR100(data_dir, train=False, download=True, transform=transform)
    
    # 2. Create client-task-class assignments with class sharing
    np.random.seed(opt.seed)
    all_classes = np.arange(100)
    
    client_task_classes = {}
    
    for client_id in range(client_num):
        # Sample without replacement within a client 
        selected = np.random.choice(all_classes, 
                                  size=tasks_per_client*classes_per_task, 
                                  replace=False)
        
        # Split into tasks ensuring no class repeats within client
        task_classes = []
        for i in range(tasks_per_client):
            start_idx = i * classes_per_task
            end_idx = (i + 1) * classes_per_task
            task_classes.append(selected[start_idx:end_idx])
        
        client_task_classes[client_id] = task_classes
    
    # 3. Create data loaders with exactly 500 samples
    client_loaders = defaultdict(dict)
    
    # Get indices for each class
    y_ind_dict = {}
    for y in range(100):
        y_ind_dict[y] = np.where(np.array(train_set.targets) == y)[0]
        
    y_test_ind_dict = {}
    for y in range(100):
        y_test_ind_dict[y] = np.where(np.array(test_set.targets) == y)[0]
    
    # Define target samples per task
    target_samples = 200
    
    # More efficient data loading
    persistent_workers = True if opt.num_workers > 0 else False
    prefetch_factor = 2  # Increased from default

    for client_id, task_list in client_task_classes.items():
        for task_id, class_ids in enumerate(task_list):
            # Calculate how many samples to take per class to get exactly 500 samples
            samples_per_class = target_samples // len(class_ids)
            remaining_samples = target_samples % len(class_ids)
            
            train_indices = []
            for i, class_id in enumerate(class_ids):
                # Add extra sample to early classes if needed to reach exactly 500
                extra = 1 if i < remaining_samples else 0
                class_indices = y_ind_dict[class_id]
                np.random.shuffle(class_indices)
                train_indices.extend(class_indices[:samples_per_class + extra])
            
            # Similar approach for test set (using fewer samples for test)
            test_target_samples = 2000  # Smaller test set
            test_samples_per_class = test_target_samples // len(class_ids)
            test_remaining_samples = test_target_samples % len(class_ids)
            
            test_indices = []
            for i, class_id in enumerate(class_ids):
                extra = 1 if i < test_remaining_samples else 0
                class_indices = y_test_ind_dict[class_id]
                np.random.shuffle(class_indices)
                test_indices.extend(class_indices[:test_samples_per_class + extra])
            
            # Double check we got exactly the target number of samples
            assert len(train_indices) == target_samples, f"Got {len(train_indices)} samples instead of {target_samples}"
            assert len(test_indices) == test_target_samples, f"Got {len(test_indices)} test samples instead of {test_target_samples}"
            
            client_loaders[client_id][task_id] = {
                'train': DataLoader(
                    Subset(train_set, train_indices), 
                    batch_size=batch_size, 
                    shuffle=True,
                    num_workers=2*opt.num_workers,
                    pin_memory=opt.pin_memory,
                    # persistent_workers=persistent_workers,
                    # prefetch_factor=prefetch_factor
                ),
                'test': DataLoader(
                    Subset(test_set, test_indices),
                    batch_size=batch_size, 
                    shuffle=False,
                    num_workers=2*opt.num_workers,
                    pin_memory=opt.pin_memory,
                    # persistent_workers=persistent_workers,
                    # prefetch_factor=prefetch_factor
                )
            }
    
    # 4. Save partitioning
    partitioning = {
        'client_task_classes': client_task_classes,
        'num_clients': client_num,
        'tasks_per_client': tasks_per_client,
        'classes_per_task': classes_per_task,
        'samples_per_task': target_samples,
        'seed': opt.seed
    }
    
    with open(f'./dump/cifar100_partitioning_seed{opt.seed}.pkl', 'wb') as f:
        pickle.dump(partitioning, f)
    
    return client_loaders

def read_pickle(name):
    with open(name, "rb") as f:
        data = pickle.load(f)
    return data

def write_pickle(data, name):
    with open(name, "wb") as f:
        pickle.dump(data, f)