import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset, Dataset
from collections import defaultdict
import pickle
import os
import requests
import zipfile
from PIL import Image


class TinyImageNetDataset(Dataset):
    """Custom Dataset for TinyImageNet"""
    
    def __init__(self, root, train=True, transform=None, download=False):
        self.root = root
        self.train = train
        self.transform = transform
        
        if download:
            self.download()
        
        # Load data
        self.data, self.targets = self._load_data()
        
    def download(self):
        """Download TinyImageNet dataset"""
        url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
        filename = "tiny-imagenet-200.zip"
        
        os.makedirs(self.root, exist_ok=True)
        filepath = os.path.join(self.root, filename)
        
        # Check if already downloaded
        if os.path.exists(os.path.join(self.root, "tiny-imagenet-200")):
            print("TinyImageNet dataset already exists.")
            return
            
        print("Downloading TinyImageNet dataset...")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(filepath, 'wb') as f:
            downloaded = 0
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:
                    f.write(chunk)
                    downloaded += len(chunk)
                    if total_size > 0:
                        percent = (downloaded / total_size) * 100
                        print(f"\rDownload progress: {percent:.1f}%", end="")
        
        print("\nExtracting dataset...")
        with zipfile.ZipFile(filepath, 'r') as zip_ref:
            zip_ref.extractall(self.root)
        
        # Clean up zip file
        os.remove(filepath)
        print("Download and extraction complete!")
        
    def _load_data(self):
        """Load TinyImageNet data"""
        dataset_path = os.path.join(self.root, "tiny-imagenet-200")
        
        # Load class names and create mapping
        with open(os.path.join(dataset_path, "wnids.txt"), 'r') as f:
            class_names = [line.strip() for line in f.readlines()]
        
        class_to_idx = {cls_name: idx for idx, cls_name in enumerate(class_names)}
        
        data = []
        targets = []
        
        if self.train:
            # Load training data
            train_path = os.path.join(dataset_path, "train")
            for class_name in class_names:
                class_path = os.path.join(train_path, class_name, "images")
                if os.path.exists(class_path):
                    for img_file in os.listdir(class_path):
                        if img_file.endswith('.JPEG'):
                            img_path = os.path.join(class_path, img_file)
                            data.append(img_path)
                            targets.append(class_to_idx[class_name])
        else:
            # Load validation data
            val_path = os.path.join(dataset_path, "val")
            
            # Read validation annotations
            val_annotations = {}
            with open(os.path.join(val_path, "val_annotations.txt"), 'r') as f:
                for line in f:
                    parts = line.strip().split('\t')
                    img_name = parts[0]
                    class_name = parts[1]
                    val_annotations[img_name] = class_to_idx[class_name]
            
            # Load validation images
            val_images_path = os.path.join(val_path, "images")
            for img_file in os.listdir(val_images_path):
                if img_file.endswith('.JPEG') and img_file in val_annotations:
                    img_path = os.path.join(val_images_path, img_file)
                    data.append(img_path)
                    targets.append(val_annotations[img_file])
        
        return data, targets
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]
        target = self.targets[idx]
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        return image, target


def setup_tinyimagenet_loaders(opt):
    """Create non-IID TinyImageNet data loaders with class sharing between clients.
    Each data loader will have samples distributed across assigned classes."""
    client_num = opt.num_clients  # 10 clients
    tasks_per_client = opt.num_task  # 5 tasks
    classes_per_task = opt.class_per_task  # 8 classes
    batch_size = opt.batch_size
    data_dir = opt.data_dir
    
    os.makedirs(os.path.dirname('./dump/'), exist_ok=True)

    # 1. Load TinyImageNet dataset
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # ImageNet normalization
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    train_set = TinyImageNetDataset(data_dir, train=True, download=True, transform=transform_train)
    test_set = TinyImageNetDataset(data_dir, train=False, download=True, transform=transform_test)
    
    # 2. Create client-task-class assignments with class sharing
    np.random.seed(opt.seed)
    all_classes = np.arange(200)  # 200 classes in TinyImageNet
    
    client_task_classes = {}
    
    for client_id in range(client_num):
        # Sample without replacement within a client 
        selected = np.random.choice(all_classes, 
                                  size=tasks_per_client*classes_per_task, 
                                  replace=False)
        
        # Split into tasks ensuring no class repeats within client
        task_classes = []
        for i in range(tasks_per_client):
            start_idx = i * classes_per_task
            end_idx = (i + 1) * classes_per_task
            task_classes.append(selected[start_idx:end_idx])
        
        client_task_classes[client_id] = task_classes
    
    # 3. Create data loaders
    client_loaders = defaultdict(dict)
    
    # Get indices for each class
    y_ind_dict = {}
    for y in range(200):
        y_ind_dict[y] = np.where(np.array(train_set.targets) == y)[0]
        
    y_test_ind_dict = {}
    for y in range(200):
        y_test_ind_dict[y] = np.where(np.array(test_set.targets) == y)[0]
    
    # Define target samples per task
    target_samples = 800  # 800 samples per task (100 per class for 8 classes)

    for client_id, task_list in client_task_classes.items():
        for task_id, class_ids in enumerate(task_list):
            # Calculate how many samples to take per class
            samples_per_class = target_samples // len(class_ids)
            remaining_samples = target_samples % len(class_ids)
            
            train_indices = []
            for i, class_id in enumerate(class_ids):
                # Add extra sample to early classes if needed
                extra = 1 if i < remaining_samples else 0
                class_indices = y_ind_dict[class_id]
                np.random.shuffle(class_indices)
                train_indices.extend(class_indices[:samples_per_class + extra])
            
            # Similar approach for test set
            test_target_samples = 1000  # 200 test samples per task
            test_samples_per_class = test_target_samples // len(class_ids)
            test_remaining_samples = test_target_samples % len(class_ids)
            
            test_indices = []
            for i, class_id in enumerate(class_ids):
                extra = 1 if i < test_remaining_samples else 0
                class_indices = y_test_ind_dict[class_id]
                np.random.shuffle(class_indices)
                test_indices.extend(class_indices[:test_samples_per_class + extra])
            
            # Ensure we have the target number of samples
            train_indices = train_indices[:target_samples]
            test_indices = test_indices[:test_target_samples]
            
            client_loaders[client_id][task_id] = {
                'train': DataLoader(
                    Subset(train_set, train_indices), 
                    batch_size=batch_size, 
                    shuffle=True,
                    num_workers=opt.num_workers,
                    pin_memory=opt.pin_memory,
                ),
                'test': DataLoader(
                    Subset(test_set, test_indices),
                    batch_size=batch_size, 
                    shuffle=False,
                    num_workers=opt.num_workers,
                    pin_memory=opt.pin_memory,
                )
            }
    
    # 4. Save partitioning
    partitioning = {
        'client_task_classes': client_task_classes,
        'num_clients': client_num,
        'tasks_per_client': tasks_per_client,
        'classes_per_task': classes_per_task,
        'samples_per_task': target_samples,
        'seed': opt.seed,
        'dataset': 'TinyImageNet'
    }
    
    with open(f'./dump/tinyimagenet_partitioning_seed{opt.seed}.pkl', 'wb') as f:
        pickle.dump(partitioning, f)
    
    print(f"Created TinyImageNet dataloaders:")
    print(f"  - {client_num} clients")
    print(f"  - {tasks_per_client} tasks per client")
    print(f"  - {classes_per_task} classes per task")
    print(f"  - {target_samples} training samples per task")
    print(f"  - {test_target_samples} test samples per task")
    
    return client_loaders


def read_pickle(name):
    with open(name, "rb") as f:
        data = pickle.load(f)
    return data

def write_pickle(data, name):
    with open(name, "wb") as f:
        pickle.dump(data, f)