import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, ConcatDataset
from collections import defaultdict
import pickle
import os

def setup_emnist_loaders(opt):
    """Create non-IID EMNIST-letter data loaders with class sharing between clients.
    Each data loader will have exactly 200 samples per task."""
    client_num = opt.num_clients  # 8 clients
    tasks_per_client = opt.num_task  # 6 tasks
    classes_per_task = opt.class_per_task  # 2 classes
    batch_size = opt.batch_size
    data_dir = opt.data_dir
    
    os.makedirs(os.path.dirname('./dump/'), exist_ok=True)

    # 1. Load EMNIST-letter dataset
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1722,), (0.3309,))  # EMNIST-letter normalization values
    ])
    
    # Note: EMNIST letters are indexed 1-26, we'll remap to 0-25
    train_set = datasets.EMNIST(data_dir, split='letters', train=True, download=True, transform=transform)
    test_set = datasets.EMNIST(data_dir, split='letters', train=False, download=True, transform=transform)
    
    # Remap labels from 1-26 to 0-25
    train_set.targets = train_set.targets - 1  # Convert to 0-based indexing
    test_set.targets = test_set.targets - 1    # Convert to 0-based indexing
    
    # 2. Create client-task-class assignments with class sharing
    np.random.seed(opt.seed)
    all_classes = np.arange(26)  # 26 letters: A-Z
    
    client_task_classes = {}
    
    for client_id in range(client_num):
        # Sample without replacement within a client 
        selected = np.random.choice(all_classes, 
                                  size=tasks_per_client*classes_per_task, 
                                  replace=False)
        
        # Split into tasks ensuring no class repeats within client
        task_classes = []
        for i in range(tasks_per_client):
            start_idx = i * classes_per_task
            end_idx = (i + 1) * classes_per_task
            task_classes.append(selected[start_idx:end_idx])
        
        client_task_classes[client_id] = task_classes
    
    # 3. Create data loaders with exactly 200 samples per task
    client_loaders = defaultdict(dict)
    
    # Get indices for each class
    y_ind_dict = {}
    for y in range(26):
        y_ind_dict[y] = np.where(np.array(train_set.targets) == y)[0]
        
    y_test_ind_dict = {}
    for y in range(26):
        y_test_ind_dict[y] = np.where(np.array(test_set.targets) == y)[0]
    
    # Define target samples per task (reduced for EMNIST as it's smaller than CIFAR100)
    target_samples = 500  # Reduced from 500 to 200

    for client_id, task_list in client_task_classes.items():
        for task_id, class_ids in enumerate(task_list):
            # Calculate how many samples to take per class to get exactly target_samples
            samples_per_class = target_samples // len(class_ids)
            remaining_samples = target_samples % len(class_ids)
            
            train_indices = []
            for i, class_id in enumerate(class_ids):
                # Add extra sample to early classes if needed to reach exactly target_samples
                extra = 1 if i < remaining_samples else 0
                class_indices = y_ind_dict[class_id]
                np.random.shuffle(class_indices)
                train_indices.extend(class_indices[:samples_per_class + extra])
            
            # Similar approach for test set (using fewer samples for test)
            test_target_samples = 500  # Smaller test set
            test_samples_per_class = test_target_samples // len(class_ids)
            test_remaining_samples = test_target_samples % len(class_ids)
            
            test_indices = []
            for i, class_id in enumerate(class_ids):
                extra = 1 if i < test_remaining_samples else 0
                class_indices = y_test_ind_dict[class_id]
                np.random.shuffle(class_indices)
                test_indices.extend(class_indices[:test_samples_per_class + extra])
            
            # Double check we got exactly the target number of samples
            assert len(train_indices) == target_samples, f"Got {len(train_indices)} samples instead of {target_samples}"
            assert len(test_indices) == test_target_samples, f"Got {len(test_indices)} test samples instead of {test_target_samples}"
            
            client_loaders[client_id][task_id] = {
                'train': DataLoader(
                    Subset(train_set, train_indices), 
                    batch_size=batch_size, 
                    shuffle=True,
                    num_workers=2*opt.num_workers,
                    pin_memory=opt.pin_memory,
                ),
                'test': DataLoader(
                    Subset(test_set, test_indices),
                    batch_size=batch_size, 
                    shuffle=False,
                    num_workers=2*opt.num_workers,
                    pin_memory=opt.pin_memory,
                )
            }
    
    # 4. Save partitioning
    partitioning = {
        'client_task_classes': client_task_classes,
        'num_clients': client_num,
        'tasks_per_client': tasks_per_client,
        'classes_per_task': classes_per_task,
        'samples_per_task': target_samples,
        'seed': opt.seed,
        'dataset': 'EMNIST-letter'
    }
    
    with open(f'./dump/emnist_partitioning_seed{opt.seed}.pkl', 'wb') as f:
        pickle.dump(partitioning, f)
    
    return client_loaders

def read_pickle(name):
    with open(name, "rb") as f:
        data = pickle.load(f)
    return data

def write_pickle(data, name):
    with open(name, "wb") as f:
        pickle.dump(data, f)