import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import os
from PIL import Image

class CustomDataset(Dataset):
\
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.labels[idx]
        
        # For datasets like CIFAR, data is numpy array. For Tiny-Imagenet, it might be a path.
        if isinstance(image, np.ndarray):
            # If the image is a numpy array, it's likely in (H, W, C) format.
            # PIL Image expects (H, W, C) for RGB or (H, W) for grayscale.
            if image.shape[-1] != 3 and image.shape[-1] != 1: # channels last
                 image = image.transpose(1, 2, 0) # from (C,H,W) to (H,W,C)
            
            # Ensure it's in a format PIL can handle
            if image.squeeze().ndim == 2:
                image = Image.fromarray(image.squeeze(), mode='L')
            else:
                image = Image.fromarray(image)

        if self.transform:
            image = self.transform(image)
            
        return image, label

def get_tiny_imagenet(path, transform):

    train_data, train_labels = [], []
    val_data, val_labels = [], []
    
    # Create a mapping from wnid to an integer label
    with open(os.path.join(path, 'wnids.txt'), 'r') as f:
        wnids = [line.strip() for line in f.readlines()]
    wnid_to_label = {wnid: i for i, wnid in enumerate(wnids)}
    
    # Load training data
    train_path = os.path.join(path, 'train')
    for wnid in os.listdir(train_path):
        if wnid not in wnid_to_label:
            continue
        label = wnid_to_label[wnid]
        class_path = os.path.join(train_path, wnid, 'images')
        for img_name in os.listdir(class_path):
            img_path = os.path.join(class_path, img_name)
            img = Image.open(img_path).convert('RGB')
            train_data.append(np.array(img))
            train_labels.append(label)

    # Load validation data
    val_images_path = os.path.join(path, 'val', 'images')
    with open(os.path.join(path, 'val', 'val_annotations.txt'), 'r') as f:
        for line in f.readlines():
            parts = line.strip().split('\t')
            img_name, wnid = parts[0], parts[1]
            if wnid not in wnid_to_label:
                continue
            label = wnid_to_label[wnid]
            img_path = os.path.join(val_images_path, img_name)
            img = Image.open(img_path).convert('RGB')
            val_data.append(np.array(img))
            val_labels.append(label)
            
    return np.array(train_data), np.array(train_labels), np.array(val_data), np.array(val_labels)

def get_dataset(dataset_name, root='./data'):
    """
    Downloads and loads the specified dataset.
    """
    transform_cifar10 = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    transform_cifar100 = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
    ])
    transform_fashion = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    transform_tiny = transforms.Compose([
        transforms.Resize(64),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    if dataset_name == 'FashionMNIST':
        train_dataset = torchvision.datasets.FashionMNIST(root=root, train=True, download=True)
        test_dataset = torchvision.datasets.FashionMNIST(root=root, train=False, download=True)
        return train_dataset.data.numpy(), train_dataset.targets.numpy(), test_dataset.data.numpy(), test_dataset.targets.numpy(), transform_fashion
    elif dataset_name == 'CIFAR10':
        train_dataset = torchvision.datasets.CIFAR10(root=root, train=True, download=True)
        test_dataset = torchvision.datasets.CIFAR10(root=root, train=False, download=True)
        return train_dataset.data, train_dataset.targets, test_dataset.data, np.array(test_dataset.targets), transform_cifar10
    elif dataset_name == 'CIFAR100':
        train_dataset = torchvision.datasets.CIFAR100(root=root, train=True, download=True)
        test_dataset = torchvision.datasets.CIFAR100(root=root, train=False, download=True)
        return train_dataset.data, train_dataset.targets, test_dataset.data, np.array(test_dataset.targets), transform_cifar100
    elif dataset_name == 'Tiny-Imagenet':
        path = os.path.join(root, 'tiny-imagenet-200')
        if not os.path.exists(path):
            raise RuntimeError("Tiny-Imagenet not found. Please download and place it in './data/tiny-imagenet-200'")
        train_data, train_labels, test_data, test_labels = get_tiny_imagenet(path, transform_tiny)
        return train_data, train_labels, test_data, test_labels, transform_tiny
    else:
        raise ValueError(f"Unknown dataset: {dataset_name}")


def generate_federated_incremental_dataloader(dataset_name, n_tasks, n_clients, alpha, batch_size, root='./data', min_samples_per_class=2):

    train_data, train_labels, test_data, test_labels, transform = get_dataset(dataset_name, root)
    
    train_labels = np.array(train_labels)
    test_labels = np.array(test_labels)

    n_classes = len(np.unique(train_labels))
    if n_classes % n_tasks != 0:
        raise ValueError("Number of classes must be divisible by the number of tasks.")
    
    classes_per_task = n_classes // n_tasks
    class_order = np.random.permutation(n_classes)
    
    tasks_data = []

    for task_id in range(n_tasks):
        # 1. Select classes for the current task
        start_class = task_id * classes_per_task
        end_class = (task_id + 1) * classes_per_task
        task_classes = class_order[start_class:end_class]
        
        print(f"Task {task_id+1}/{n_tasks}: classes {task_classes}")

        # 2. Filter dataset for the classes in the current task
        train_mask = np.isin(train_labels, task_classes)
        task_train_data = train_data[train_mask]
        task_train_labels = train_labels[train_mask]

        test_mask = np.isin(test_labels, task_classes)
        task_test_data = test_data[test_mask]
        task_test_labels = test_labels[test_mask]
        
        # 3. Distribute the task's training data among clients using Dirichlet distribution
        client_train_loaders = []
        
        # We need to work with indices relative to the task's subset of data
        task_class_indices = {cls: np.where(task_train_labels == cls)[0] for cls in task_classes}
        
        # For each class, distribute its samples among clients
        client_indices = [[] for _ in range(n_clients)]
        for cls in task_classes:
            class_idx = task_class_indices[cls]
            num_class_samples = len(class_idx)
            np.random.shuffle(class_idx)
            
            # Check if there are enough samples for the minimum guarantee
            if num_class_samples < n_clients * min_samples_per_class:
                print(f"Warning: Class {cls} has only {num_class_samples} samples, less than the required minimum "
                      f"({n_clients * min_samples_per_class}). Distributing samples one-by-one to clients.")
                for i, sample_idx in enumerate(class_idx):
                    client_indices[i % n_clients].append(sample_idx)
                continue

            # 1. Assign `min_samples_per_class` to each client first
            guaranteed_samples_count = n_clients * min_samples_per_class
            guaranteed_indices = class_idx[:guaranteed_samples_count]
            remaining_indices = class_idx[guaranteed_samples_count:]
            
            for client_id in range(n_clients):
                start = client_id * min_samples_per_class
                end = (client_id + 1) * min_samples_per_class
                client_indices[client_id].extend(guaranteed_indices[start:end])

            # 2. Distribute the rest using Dirichlet
            if len(remaining_indices) > 0:
                proportions = np.random.dirichlet(np.repeat(alpha, n_clients))
                # Ensure all remaining samples are distributed
                proportions_cumsum = np.cumsum(proportions)
                proportions_cumsum[-1] = 1.0
                
                remaining_sample_counts = (proportions_cumsum * len(remaining_indices)).astype(int)
                remaining_sample_counts = np.diff(np.insert(remaining_sample_counts, 0, 0))

                start = 0
                for client_id, count in enumerate(remaining_sample_counts):
                    end = start + count
                    client_indices[client_id].extend(remaining_indices[start:end])
                    start = end

        for client_id in range(n_clients):
            indices = client_indices[client_id]
            if len(indices) == 0:
                # Handle cases where a client gets no data
                client_train_loaders.append(None)
                continue

            client_data = task_train_data[indices]
            client_labels = task_train_labels[indices]
            
            dataset = CustomDataset(client_data, client_labels, transform=transform)
            # Use drop_last=True to avoid BatchNorm errors with batch_size=1
            # This ensures all batches have at least 2 samples for BatchNorm training
            loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
            client_train_loaders.append(loader)

        # 4. Create a single test loader for the current task
        test_dataset = CustomDataset(task_test_data, task_test_labels, transform=transform)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
        
        tasks_data.append({
            'train_loaders': client_train_loaders,
            'test_loader': test_loader
        })
        
    return tasks_data


if __name__ == '__main__':
    # --- Example Usage ---
    DATASET = 'CIFAR10' # Choose from 'FashionMNIST', 'CIFAR10', 'CIFAR100', 'Tiny-Imagenet'
    N_TASKS = 5
    N_CLIENTS = 10
    ALPHA = 0.5  # Lower alpha means more heterogeneity
    BATCH_SIZE = 32
    MIN_SAMPLES = 2 # Minimum samples per class for each client
    
    print(f"Generating data for {DATASET}...")
    print(f" - Tasks: {N_TASKS}, Clients: {N_CLIENTS}, Alpha: {ALPHA}, Min Samples/Class: {MIN_SAMPLES}")
    
    # To run with Tiny-Imagenet, first download and extract it to './data/tiny-imagenet-200'
    # wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
    # unzip tiny-imagenet-200.zip -d ./data/
    
    task_dataloaders = generate_federated_incremental_dataloader(
        dataset_name=DATASET,
        n_tasks=N_TASKS,
        n_clients=N_CLIENTS,
        alpha=ALPHA,
        batch_size=BATCH_SIZE,
        min_samples_per_class=MIN_SAMPLES
    )

    print(f"\nGenerated {len(task_dataloaders)} tasks.")
    
    # Inspect the first task
    first_task = task_dataloaders[0]
    print(f"\n--- Task 1 Details ---")
    print(f"Number of clients with data: {len([l for l in first_task['train_loaders'] if l is not None])}/{N_CLIENTS}")

    # Print number of samples per client in the first task
    for i, loader in enumerate(first_task['train_loaders']):
        if loader:
            print(f"Client {i+1}: {len(loader.dataset)} samples")
        else:
            print(f"Client {i+1}: 0 samples")

    # Inspect the test loader for the first task
    test_loader = first_task['test_loader']
    print(f"Test set for Task 1 has {len(test_loader.dataset)} samples.")
    
    # Verify classes in the first batch of the test loader
    first_batch_labels = next(iter(test_loader))[1]
    print(f"Labels in first test batch: {torch.unique(first_batch_labels)}")
