import os
import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision import datasets


class OfficeDataset(Dataset):
    def __init__(self, base_path, site, train=True, transform=None):
        if train:
            # self.paths, self.text_labels = np.load('../data/office_caltech_10/{}_train.pkl'.format(site), allow_pickle=True)
            self.paths, self.text_labels = np.load(f'{base_path}/{site}_train.pkl', allow_pickle=True)
        else:
            # self.paths, self.text_labels = np.load('../data/office_caltech_10/{}_test.pkl'.format(site), allow_pickle=True)
            self.paths, self.text_labels = np.load(f'{base_path}/{site}_test.pkl', allow_pickle=True)
        
        label_dict={'back_pack':0, 'bike':1, 'calculator':2, 'headphones':3, 'keyboard':4, 'laptop_computer':5, 'monitor':6, 'mouse':7, 'mug':8, 'projector':9}
        self.labels = [label_dict[text] for text in self.text_labels]
        self.transform = transform
        self.base_path = os.path.dirname(base_path)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.base_path, self.paths[idx])
        label = self.labels[idx]
        image = Image.open(img_path)

        if len(image.split()) != 3:
            image = transforms.Grayscale(num_output_channels=3)(image)

        if self.transform is not None:
            image = self.transform(image)

        return image, label


def prepare_data_multi(data_base_path, batchsize, train_ratio=0.5, clients_per_domain=2):
    """
    Prepare Office-Caltech data for multi-client setup: split each domain into specified number of clients
    
    Args:
        data_base_path: data path
        batchsize: batch size
        train_ratio: training data ratio
        clients_per_domain: number of clients per domain, default is 2
    """
    mean=[0.5, 0.5, 0.5]
    std=[0.5, 0.5, 0.5]
    transform_office = transforms.Compose([
            transforms.Resize([256, 256]),            
            transforms.RandomHorizontalFlip(),
            transforms.CenterCrop([224,224]),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
    ])

    transform_test = transforms.Compose([
            transforms.Resize([256, 256]),   
            transforms.CenterCrop([224,224]),         
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std)
    ])
    
    # Create datasets for each domain
    domain_names = ['amazon', 'caltech', 'dslr', 'webcam']
    domain_trainsets = []
    domain_testsets = []
    
    for domain in domain_names:
        trainset = OfficeDataset(data_base_path, domain, transform=transform_office)
        testset = OfficeDataset(data_base_path, domain, transform=transform_test, train=False)
        domain_trainsets.append(trainset)
        domain_testsets.append(testset)

    # Calculate minimum data length
    min_data_len = min([len(trainset) for trainset in domain_trainsets])
    val_len = int(min_data_len * train_ratio)
    min_data_len = int(min_data_len * train_ratio)
    
    print(f"Train len per client: {min_data_len // clients_per_domain}")
    print(f"Clients per domain: {clients_per_domain}")
    print(f"Total clients: {len(domain_names) * clients_per_domain}")

    train_loaders = []
    test_loaders = []
    
    # Create specified number of client data loaders for each domain
    for i, (trainset, testset) in enumerate(zip(domain_trainsets, domain_testsets)):
        domain_name = domain_names[i]
        
        # Split training data into specified number of clients
        indices = list(range(min_data_len))
        torch.manual_seed(42 + i)  # Ensure reproducibility
        shuffled_indices = torch.randperm(len(indices)).tolist()
        
        # Calculate data amount per client
        data_per_client = min_data_len // clients_per_domain
        
        # Create multiple clients for current domain
        domain_client_sizes = []
        for client_idx in range(clients_per_domain):
            start_idx = client_idx * data_per_client
            if client_idx == clients_per_domain - 1:  # Last client gets all remaining data
                end_idx = min_data_len
            else:
                end_idx = (client_idx + 1) * data_per_client
            
            client_indices = shuffled_indices[start_idx:end_idx]
            client_trainset = torch.utils.data.Subset(trainset, client_indices)
            
            # Create data loader
            client_train_loader = torch.utils.data.DataLoader(client_trainset, batch_size=batchsize, shuffle=True, pin_memory=True)
            test_loader = torch.utils.data.DataLoader(testset, batch_size=batchsize, shuffle=False, pin_memory=True)
            
            # Add to list
            train_loaders.append(client_train_loader)
            test_loaders.append(test_loader)  # Test set can be shared
            
            domain_client_sizes.append(len(client_trainset))
        
        print(f"Domain {domain_name}: {clients_per_domain} clients with sizes {domain_client_sizes}, Test={len(testset)}")

    return train_loaders, test_loaders


