import torch
import numpy as np
import torchvision
from PIL import Image
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision import datasets
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split


class PACS(Dataset):
    def __init__(self, base_path, site, transform=None):
        self.dir_path = base_path + "/{}".format(site)
        # Prepare Pytorch train/test Datasets
        self.dataset = ImageFolder(self.dir_path, transform=transform)

        self.transform = transform

        # img_path, label = self.dataset.imgs[666]
        # image = Image.open(img_path)
        # image = self.transform(image)

        # plt.imshow(image.permute(2, 1, 0))
        # plt.savefig("11.png")
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_path, label = self.dataset.imgs[idx]
        image = Image.open(img_path)

        image = self.transform(image)

        return image, label


def prepare_PACS_multi(data_base_path, batch, train_data_ratio=0.2, clients_per_domain=2):
    """
    Prepare PACS data for multi-client setup: split each domain into specified number of clients
    
    Args:
        data_base_path: data path
        batch: batch size
        train_data_ratio: training data ratio
        clients_per_domain: number of clients per domain, default is 2
    """
    trans = transforms.Compose([
            transforms.Resize([224, 224]),            
            transforms.ToTensor(),
    ])
    
    # Prepare datasets for all domains
    domain_names = ["art_painting", "cartoon", "sketch", "photo"]
    domain_datasets = []
    domain_trains = []
    domain_testsets = []
    
    for domain in domain_names:
        dataset = PACS(data_base_path, domain, trans)
        train_data, test_data = split_train_test_dataset(dataset)
        domain_datasets.append(dataset)
        domain_trains.append(train_data)
        domain_testsets.append(test_data)

    # Calculate minimum training data length
    min_data_len = min([len(train_data) for train_data in domain_trains])
    train_len = int(min_data_len * train_data_ratio)
    
    print(f"Train len per client: {train_len // clients_per_domain}")
    print(f"Clients per domain: {clients_per_domain}")
    print(f"Total clients: {len(domain_names) * clients_per_domain}")

    train_loaders = []
    test_loaders = []
    
    # Create specified number of client data loaders for each domain
    for i, (domain, train_data, test_data) in enumerate(zip(domain_names, domain_trains, domain_testsets)):
        # Split training data into specified number of clients
        indices = list(range(train_len))
        torch.manual_seed(42 + i)  # Ensure reproducibility
        shuffled_indices = torch.randperm(len(indices)).tolist()
        
        # Calculate data amount per client
        data_per_client = train_len // clients_per_domain
        
        # Create multiple clients for current domain
        domain_client_sizes = []
        for client_idx in range(clients_per_domain):
            start_idx = client_idx * data_per_client
            if client_idx == clients_per_domain - 1:  # Last client gets all remaining data
                end_idx = train_len
            else:
                end_idx = (client_idx + 1) * data_per_client
            
            client_indices = shuffled_indices[start_idx:end_idx]
            client_trainset = torch.utils.data.Subset(train_data, client_indices)
            
            # Create data loader
            client_train_loader = torch.utils.data.DataLoader(client_trainset, batch_size=batch, shuffle=True)
            test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch, shuffle=False)
            
            # Add to list
            train_loaders.append(client_train_loader)
            test_loaders.append(test_loader)  # Test set can be shared
            
            domain_client_sizes.append(len(client_trainset))
        
        print(f"Domain {domain}: {clients_per_domain} clients with sizes {domain_client_sizes}, Test={len(test_data)}")

    return train_loaders, test_loaders



def split_train_test_dataset(dataset):
    data_len = len(dataset)
    test_len = int(data_len * 0.2)
    train_len = len(dataset) - test_len

    # split dataset  (train and test)
    trainset, testset = torch.utils.data.random_split(dataset, [train_len, test_len])
    return trainset, testset


class VLCS(Dataset):
    def __init__(self, base_path, site, transform=None):
        self.dir_path = base_path + "/{}".format(site)
        # Prepare Pytorch train/test Datasets
        self.dataset = ImageFolder(self.dir_path, transform=transform)

        self.transform = transform

        # img_path, label = self.dataset.imgs[666]
        # image = Image.open(img_path)
        # image = self.transform(image)

        # plt.imshow(image.permute(2, 1, 0))
        # plt.savefig("11.png")
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx): 
        img_path, label = self.dataset.imgs[idx]
        image = Image.open(img_path)
        try:
            image = Image.open(img_path).convert("RGB")
        except OSError as e:
            print(f"Cannot load image {img_path}: {e}")
            return self.__getitem__((idx + 1))
        if len(image.split()) != 3:
            image = transforms.Grayscale(num_output_channels=3)(image)
        image = self.transform(image)

        return image, label
        



def prepare_VLCS_multi(data_base_path, batch, train_data_ratio=0.2, clients_per_domain=2):
    mean=[0.5, 0.5, 0.5]
    std = [0.5,0.5,0.5] 

    transform_train = transforms.Compose([
                    transforms.Resize([256, 256]),            
                    transforms.RandomHorizontalFlip(),
                    transforms.CenterCrop([224,224]),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=mean, std=std)
            ])

    # Prepare all domain datasets
    domain_names = ["Caltech101", "LabelMe", "SUN09", "VOC2007"]
    domain_datasets = []
    domain_trains = []
    domain_testsets = []
    
    for domain in domain_names:
        dataset = VLCS(data_base_path, domain, transform_train)
        train_data, test_data = split_train_test_dataset(dataset)
        domain_datasets.append(dataset)
        domain_trains.append(train_data)
        domain_testsets.append(test_data)

    # Calculate minimum training data length
    min_data_len = min([len(train_data) for train_data in domain_trains])
    train_len = int(min_data_len * train_data_ratio)
    
    print(f"Train len per client: {train_len // clients_per_domain}")
    print(f"Clients per domain: {clients_per_domain}")
    print(f"Total clients: {len(domain_names) * clients_per_domain}")

    train_loaders = []
    test_loaders = []
    
    # Create specified number of client data loaders for each domain
    for i, (domain, train_data, test_data) in enumerate(zip(domain_names, domain_trains, domain_testsets)):
        # Split training data into specified number of clients
        indices = list(range(train_len))
        torch.manual_seed(42 + i)  # Ensure reproducibility
        shuffled_indices = torch.randperm(len(indices)).tolist()
        
        # Calculate data amount per client
        data_per_client = train_len // clients_per_domain
        
        # Create multiple clients for current domain
        domain_client_sizes = []
        for client_idx in range(clients_per_domain):
            start_idx = client_idx * data_per_client
            if client_idx == clients_per_domain - 1:  # Last client gets all remaining data
                end_idx = train_len
            else:
                end_idx = (client_idx + 1) * data_per_client
            
            client_indices = shuffled_indices[start_idx:end_idx]
            client_trainset = torch.utils.data.Subset(train_data, client_indices)
            
            # Create data loader
            client_train_loader = torch.utils.data.DataLoader(client_trainset, batch_size=batch, shuffle=True)
            test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch, shuffle=False)
            
            # Add to list
            train_loaders.append(client_train_loader)
            test_loaders.append(test_loader)  # Test set can be shared
            
            domain_client_sizes.append(len(client_trainset))
        
        print(f"Domain {domain}: {clients_per_domain} clients with sizes {domain_client_sizes}, Test={len(test_data)}")

    return train_loaders, test_loaders
