import torch
import numpy as np
import torchvision
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision import datasets
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from utils.PACS_utils import split_train_test_dataset


class OfficeHome(Dataset):
    def __init__(self, base_path, site, transform=None):
        self.dir_path = base_path + "/{}".format(site)
        # Prepare Pytorch train/test Datasets
        self.dataset = ImageFolder(self.dir_path, transform=transform)

        # print(len(self.dataset.imgs))
        self.transform = transform
        # img_path, label = self.dataset.imgs[666]
        # image = Image.open(img_path)
        # image = self.transform(image)

        # plt.imshow(image.permute(2, 1, 0))
        # plt.savefig("123.png")
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_path, label = self.dataset.imgs[idx]
        image = Image.open(img_path)

        image = self.transform(image)

        return image, label

def prepare_OH_multi(data_base_path, batch, train_data_ratio=0.2, clients_per_domain=2):
    """
    Prepare OfficeHome data for multi-client setup: split each domain into specified number of clients
    
    Args:
        data_base_path: data path
        batch: batch size
        train_data_ratio: training data ratio
        clients_per_domain: number of clients per domain, default is 2
    """
    trans = transforms.Compose([
            transforms.Resize([224, 224]),            
            transforms.ToTensor(),
    ])
    
    # Prepare datasets for all domains
    domain_names = ["Art", "Clipart", "Product", "RealWorld"]
    domain_datasets = []
    domain_trains = []
    domain_testsets = []
    
    for domain in domain_names:
        dataset = OfficeHome(data_base_path, domain, trans)
        train_data, test_data = split_train_test_dataset(dataset)
        domain_datasets.append(dataset)
        domain_trains.append(train_data)
        domain_testsets.append(test_data)

    # Calculate minimum training data length
    min_data_len = min([len(train_data) for train_data in domain_trains])
    train_len = int(min_data_len * train_data_ratio)
    
    print(f"Train len per client: {train_len // clients_per_domain}")
    print(f"Clients per domain: {clients_per_domain}")
    print(f"Total clients: {len(domain_names) * clients_per_domain}")

    train_loaders = []
    test_loaders = []
    
    # Create specified number of client data loaders for each domain
    for i, (domain, train_data, test_data) in enumerate(zip(domain_names, domain_trains, domain_testsets)):
        # Split training data into specified number of clients
        indices = list(range(train_len))
        torch.manual_seed(42 + i)  # Ensure reproducibility
        shuffled_indices = torch.randperm(len(indices)).tolist()
        
        # Calculate data amount per client
        data_per_client = train_len // clients_per_domain
        
        # Create multiple clients for current domain
        domain_client_sizes = []
        for client_idx in range(clients_per_domain):
            start_idx = client_idx * data_per_client
            if client_idx == clients_per_domain - 1:  # Last client gets all remaining data
                end_idx = train_len
            else:
                end_idx = (client_idx + 1) * data_per_client
            
            client_indices = shuffled_indices[start_idx:end_idx]
            client_trainset = torch.utils.data.Subset(train_data, client_indices)
            
            # Create data loader
            client_train_loader = torch.utils.data.DataLoader(client_trainset, batch_size=batch, shuffle=True)
            test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch, shuffle=False)
            
            # Add to list
            train_loaders.append(client_train_loader)
            test_loaders.append(test_loader)  # Test set can be shared
            
            domain_client_sizes.append(len(client_trainset))
        
        print(f"Domain {domain}: {clients_per_domain} clients with sizes {domain_client_sizes}, Test={len(test_data)}")

    return train_loaders, test_loaders


# # Split dataset
# def id_not_target(dataset):
#     classes = {}
#     for index, x in enumerate(dataset):
#         _, label = x
#         if label in dataset:
#             dataset[label].append(index)
#         else:
#             dataset[label] = [index]

#     range_no_id = list(range(0, len(test_dataset)))
#     for image_ind in test_classes[5]:  # 5 target label
#         if image_ind in range_no_id:
#             range_no_id.remove(image_ind)
#     # poison_label_inds = test_classes[0]  
#     # print(range_no_id)
#     return range_no_id

