import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision import datasets


class DomainNetDataset(Dataset):
    def __init__(self, base_path, site, train=True, transform=None):
        if train:
            self.paths, self.text_labels = np.load(base_path+'/{}_train.pkl'.format(site), allow_pickle=True)
        else:
            self.paths, self.text_labels = np.load(base_path+'/{}_test.pkl'.format(site), allow_pickle=True)
            
        label_dict = {'bird':0, 'feather':1, 'headphones':2, 'ice_cream':3, 'teapot':4, 'tiger':5, 'whale':6, 'windmill':7, 'wine_glass':8, 'zebra':9}     
        
        self.labels = [label_dict[text] for text in self.text_labels]
        self.transform = transform
        self.base_path = base_path if base_path is not None else '../data'

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.base_path, self.paths[idx][10:])
        # img_path = os.path.join(self.base_path, self.paths[idx])
        label = self.labels[idx]
        image = Image.open(img_path)
        
        if len(image.split()) != 3:
            image = transforms.Grayscale(num_output_channels=3)(image)

        if self.transform is not None:
            image = self.transform(image)

        return image, label



def prepare_data_multi(data_base_path, batch, min_data_ratio=0.1, clients_per_domain=2):
    """
    Prepare DomainNet data for multi-client setup: split each domain into specified number of clients
    
    Args:
        data_base_path: data path
        batch: batch size
        min_data_ratio: minimum data ratio
        clients_per_domain: number of clients per domain, default is 2
    """
    transform_train = transforms.Compose([
            transforms.Resize([224, 224]),            
            transforms.ToTensor(),
    ])

    transform_test = transforms.Compose([
            transforms.Resize([224, 224]),            
            transforms.ToTensor(),
    ])
    
    # Create datasets for each domain
    domain_names = ['clipart', 'infograph', 'painting', 'quickdraw', 'real', 'sketch']
    domain_trainsets = []
    domain_testsets = []
    
    for domain in domain_names:
        trainset = DomainNetDataset(data_base_path, domain, transform=transform_train)
        testset = DomainNetDataset(data_base_path, domain, transform=transform_test, train=False)
        domain_trainsets.append(trainset)
        domain_testsets.append(testset)

    # Calculate minimum data length
    min_data_len = min([len(trainset) for trainset in domain_trainsets])
    val_len = int(min_data_len * 0.25)
    min_data_len = int(min_data_len * min_data_ratio)
    
    print(f"Train len per client: {min_data_len // clients_per_domain}")
    print(f"Clients per domain: {clients_per_domain}")
    print(f"Total clients: {len(domain_names) * clients_per_domain}")
    
    train_loaders = []
    val_loaders = []
    test_loaders = []
    
    # Create specified number of client data loaders for each domain
    for i, (trainset, testset) in enumerate(zip(domain_trainsets, domain_testsets)):
        domain_name = domain_names[i]
        
        # Split training data into specified number of clients
        indices = list(range(min_data_len))
        torch.manual_seed(42 + i)  # Ensure reproducibility
        shuffled_indices = torch.randperm(len(indices)).tolist()
        
        # Calculate data amount per client
        data_per_client = min_data_len // clients_per_domain
        
        # Validation set: take from the last part of original dataset
        val_indices = list(range(len(trainset)))[-val_len:]
        valset = torch.utils.data.Subset(trainset, val_indices)
        
        val_loader = torch.utils.data.DataLoader(valset, batch_size=batch, shuffle=False)
        test_loader = torch.utils.data.DataLoader(testset, batch_size=batch, shuffle=False)
        
        # Create multiple clients for current domain
        domain_client_sizes = []
        for client_idx in range(clients_per_domain):
            start_idx = client_idx * data_per_client
            if client_idx == clients_per_domain - 1:  # Last client gets all remaining data
                end_idx = min_data_len
            else:
                end_idx = (client_idx + 1) * data_per_client
            
            client_indices = shuffled_indices[start_idx:end_idx]
            client_trainset = torch.utils.data.Subset(trainset, client_indices)
            
            # Create data loader
            client_train_loader = torch.utils.data.DataLoader(client_trainset, batch_size=batch, shuffle=True)
            
            # Add to list
            train_loaders.append(client_train_loader)
            val_loaders.append(val_loader)  # Validation set can be shared
            test_loaders.append(test_loader)  # Test set can be shared
            
            domain_client_sizes.append(len(client_trainset))
        
        print(f"Domain {domain_name}: {clients_per_domain} clients with sizes {domain_client_sizes}, Val={len(valset)}, Test={len(testset)}")

    return train_loaders, val_loaders, test_loaders
