'''
Creates binary classification tasks between two random ImageNet classes
Each task has balanced classes (equal samples per class)
Uses standard ImageNet transforms and normalization
Maintains reproducibility through seeding based on task_id
Allows configurable samples per class
Converts the multi-class ImageNet labels to binary (0/1) labels
Uses a smaller test set (1/5 of training size, which is standard practice)
'''

import torch
from torchvision import datasets, transforms
from torch.utils.data import Dataset, Subset
import random
import numpy as np

class ContinualImageNetDataset(Dataset):
    def __init__(self, base_dataset, class_pair, samples_per_class):
        """
        Args:
            base_dataset: The base ImageNet dataset
            class_pair: Tuple of two class indices to use
            samples_per_class: Number of samples to use per class
        """
        self.base_dataset = base_dataset
        self.class_pair = class_pair
        self.samples_per_class = samples_per_class
        
        # Get indices for each class
        class1_indices = [i for i, (_, label) in enumerate(base_dataset) 
                         if label == class_pair[0]]
        class2_indices = [i for i, (_, label) in enumerate(base_dataset) 
                         if label == class_pair[1]]
        
        # Randomly sample from each class
        random.seed(42)  # for reproducibility
        self.class1_indices = random.sample(class1_indices, samples_per_class)
        self.class2_indices = random.sample(class2_indices, samples_per_class)
        
        # Combine indices
        self.selected_indices = self.class1_indices + self.class2_indices
        
    def __len__(self):
        return len(self.selected_indices)
    
    def __getitem__(self, idx):
        img, label = self.base_dataset[self.selected_indices[idx]]
        # Convert to binary classification (0 for first class, 1 for second class)
        binary_label = 0 if label == self.class_pair[0] else 1
        return img, binary_label

def get_imagenet_transforms():
    """Define standard ImageNet transforms"""
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, test_transform

def get_task_dataset(task_id, imagenet_path, samples_per_class=600):
    """
    Creates train and test datasets for a specific task
    
    Args:
        task_id: Integer identifying the task
        imagenet_path: Path to ImageNet dataset
        samples_per_class: Number of samples per class (default 600)
    
    Returns:
        train_dataset, test_dataset for the specified task
    """
    # Set random seed based on task_id for reproducible class selection
    random.seed(task_id)
    
    # Get transforms
    train_transform, test_transform = get_imagenet_transforms()
    
    # Create base datasets
    train_dataset = datasets.ImageNet(
        root=imagenet_path,
        split='train',
        transform=train_transform
    )
    
    test_dataset = datasets.ImageNet(
        root=imagenet_path,
        split='val',
        transform=test_transform
    )
    
    # Select two random classes for this task
    all_classes = list(range(1000))  # ImageNet has 1000 classes
    class_pair = tuple(random.sample(all_classes, 2))
    
    # Create train and test datasets for these classes
    train_task_dataset = ContinualImageNetDataset(
        train_dataset, 
        class_pair, 
        samples_per_class
    )
    
    test_task_dataset = ContinualImageNetDataset(
        test_dataset, 
        class_pair, 
        samples_per_class // 5  # Typically validation set is smaller
    )
    
    return train_task_dataset, test_task_dataset