from torchvision import transforms
from utils.sampling import dirichlet_noniid
import numpy as np
import os
import pickle
from pathlib import Path
from typing import Any, Callable, List, Optional, Tuple, Union
from PIL import Image
from datasets import load_dataset
from torch.utils.data import Dataset, Subset
from collections import Counter

class TinyImageNet(Dataset):
    def __init__(self, transform=None, is_train=True):
        self.data_dir = r"/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/Dataset/tiny-imagenet-200/"
        wnids = []
        with open(self.data_dir + 'words200.txt', 'r') as f:
            for line in f:
                parts = line.split('\t') 
                if len(parts) > 1:
                    wnids.append(parts[0].strip())

        # Map wnids to integer labels
        wnid_to_label = {wnid: i for i, wnid in enumerate(wnids)}

        self.img_files = []
        self.image = []
        self.targets = []
        if is_train:
            for k, v in wnid_to_label.items():
                images_path = self.data_dir + 'train' + '/' + str(k) + '/images/'
                images_name = os.listdir(images_path)
                for name in images_name:
                    self.img_files.append(images_path + name)
                    self.targets.append(v)
                    with open(self.img_files[-1], 'rb') as f:
                        img = Image.open(f)
                        img = img.convert('RGB')
                        self.image.append(img)
        else:
            with open(os.path.join(self.data_dir, 'val', 'val_annotations.txt'), 'r') as f:
                img_files = []
                val_wnids = []
                for line in f:
                    img_file, wnid = line.split('\t')[:2]
                    img_files.append(img_file)
                    val_wnids.append(wnid)
                self.img_files = [self.data_dir + 'val/images/' + item for item in img_files]
                self.targets = [wnid_to_label[wnid] for wnid in val_wnids]
                for idx in range(len(self.img_files)):
                    with open(self.img_files[idx], 'rb') as f:
                        img = Image.open(f)
                        img = img.convert('RGB')
                        self.image.append(img)

        self.transform = transform

    def __getitem__(self, index):
        img = self.image[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, self.targets[index]

    def __len__(self):
        return len(self.img_files)
    
    def add_data(self, img, label):
        self.image.append(img)
        self.targets.append(label)
        self.img_files.append(None)


class CINIC10(Dataset):

    def __init__(
        self, 
        root: Union[str, Path], 
        train: bool = True,
        transform: Optional[Callable] = None,
    ) -> None:
        self.root = Path(root)
        self.transform = transform
        self.train = train
        default_map = {
            "train": self.root / "data" / "train-00000-of-00001.parquet",
            "test": self.root / "data" / "test-00000-of-00001.parquet",
        }
        self.key = 'train' if self.train else 'test'
        self.parquet_path = Path(default_map[self.key])

        self._backend = "datasets"
        self._load_with_datasets()
        imgs = []
        for i in range(len(self.targets)):
            img = self._ensure_pil(self.data[i])
            img = img.resize((32, 32), Image.BILINEAR)
            arr = np.asarray(img, dtype=np.uint8) 
            if arr.ndim == 2:              # 灰度 -> 3通道
                arr = np.repeat(arr[..., None], 3, axis=2)
            elif arr.ndim == 3 and arr.shape[2] == 4:  # RGBA -> 丢弃A
                arr = arr[..., :3]
            elif arr.ndim != 3 or arr.shape[2] != 3:
                raise ValueError(f"Unexpected shape {arr.shape}")
            imgs.append(arr)
        arr = np.stack(imgs, axis=0)
        self.data = arr

        if self.data.ndim == 3:
            self.data = self.data.reshape(-1, 32, 32, 3)
        self.classes = [str(i) for i in range(10)]
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
    
    def _ensure_pil(self, item: Any) -> Image.Image:
        if isinstance(item, Image.Image):
            return item
        if isinstance(item, dict):
            b = item.get("bytes", None)
            if b is not None:
                from io import BytesIO
                return Image.open(BytesIO(b)).convert("RGB")
            path = item.get("path", None)
            if path is not None and os.path.isfile(path):
                return Image.open(path).convert("RGB")
        if isinstance(item, (bytes, bytearray, memoryview)):
            from io import BytesIO
            return Image.open(BytesIO(item)).convert("RGB")
        if isinstance(item, np.ndarray):
            if item.ndim == 3 and item.shape[2] in (1, 3, 4):
                if item.shape[2] == 1:
                    return Image.fromarray(item[:, :, 0], mode="L").convert("RGB")
                if item.shape[2] == 4:
                    return Image.fromarray(item, mode="RGBA").convert("RGB")
                return Image.fromarray(item, mode="RGB")
            if item.ndim == 2:
                return Image.fromarray(item, mode="L").convert("RGB")
        if isinstance(item, str) and os.path.isfile(item):
            return Image.open(item).convert("RGB")
        raise ValueError("Unsupported image payload")

    def _load_with_datasets(self) -> None:
        split_name = "train" if self.train else 'test'
        ds = load_dataset(
            "parquet",
            data_files={split_name: str(self.parquet_path)},
            split=split_name
        )
        images = ds["image"]
        labels = ds["label"]
        self.targets = list(map(int, labels))
        self.data = images

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img = self.data[index]
        if isinstance(img, np.ndarray):
            img = Image.fromarray(img)
        else:
            img = self._ensure_pil(img)
        target = self.targets[index]
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self) -> int:
        return len(self.targets)

    def add_data(self, image: Image.Image, label: int):
        arr = np.array(image)
        if arr.shape != (32, 32, 3):
            raise ValueError("Image must be of shape (32, 32, 3)")
        if isinstance(self.data, np.ndarray):
            self.data = np.vstack([self.data, arr[np.newaxis, ...]])
        else:
            self.data.append(arr)
        self.targets.append(int(label))




class CIFAR10(Dataset):
    base_folder = "cifar-10-batches-py"
    filename = "cifar-10-python.tar.gz"
    train_list = [
        ["data_batch_1", "c99cafc152244af753f735de768cd75f"],
        ["data_batch_2", "d4bba439e000b95fd0a9bffe97cbabec"],
        ["data_batch_3", "54ebc095f3ab1f0389bbae665268c751"],
        ["data_batch_4", "634d18415352ddfa80567beed471001a"],
        ["data_batch_5", "482c414d41f54cd18b22e5b47cb7c3cb"],
    ]

    test_list = [
        ["test_batch", "40351d587109b95175f43aff81a1287e"],
    ]
    meta = {
        "filename": "batches.meta",
        "key": "label_names",
        "md5": "5ff9c542aee3614f3951f8cda6e48888",
    }
    def __init__(
        self,
        root: Union[str, Path],
        train: bool = True,
        transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        self.transform = transform
        self.root = root
        self.train = train  # training set or test set
        if self.train:
            downloaded_list = self.train_list
        else:
            downloaded_list = self.test_list
        self.data: Any = []
        self.targets = []
        # now load the picked numpy arrays
        for file_name, checksum in downloaded_list:
            file_path = os.path.join(self.root, self.base_folder, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
                self.data.append(entry["data"])
                if "labels" in entry:
                    self.targets.extend(entry["labels"])
                else:
                    self.targets.extend(entry["fine_labels"])

        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        self.data = self.data.transpose((0, 2, 3, 1))  # convert to HWC
        self._load_meta()

    def _load_meta(self) -> None:
        path = os.path.join(self.root, self.base_folder, self.meta["filename"])
        with open(path, "rb") as infile:
            data = pickle.load(infile, encoding="latin1")
            self.classes = data[self.meta["key"]]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)
        if self.transform is not None:
            img = self.transform(img)
        return img, target

    def __len__(self) -> int:
        return len(self.data)

    def extra_repr(self) -> str:
        split = "Train" if self.train is True else "Test"
        return f"Split: {split}"
    
    def add_data(self, image, label):
        image_array = np.array(image)
        if image_array.shape != (32, 32, 3):
            raise ValueError("Image must be of shape (32, 32, 3)")
        self.data = np.vstack([self.data, image_array[np.newaxis, ...]])
        self.targets.append(label)


class CIFAR100(CIFAR10):
    """`CIFAR100 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.

    This is a subclass of the `CIFAR10` Dataset.
    """

    base_folder = "cifar-100-python"
    filename = "cifar-100-python.tar.gz"
    tgz_md5 = "eb9058c3a382ffc7106e4002c42a8d85"
    train_list = [
        ["train", "16019d7e3df5f24257cddd939b257f8d"],
    ]

    test_list = [
        ["test", "f0ef6b0ae62326f3e7ffdfab6717acfc"],
    ]
    meta = {
        "filename": "meta",
        "key": "fine_label_names",
        "md5": "7973b15100ade9c7d40fb424638fde48",
    }


def print_label_distribution(dataset):
    labels = [dataset.dataset.targets[idx] for idx in dataset.indices]

    label_counts = Counter(labels)
    total_samples = sum(label_counts.values())

    print("Label Distribution for new_test_dataset:")
    for label, count in label_counts.items():
        print(f"Label {label}: {count} samples, Proportion: {count / total_samples:.4f}")



def remake_test_dataset(train_user_groups, train_dataset, test_dataset, name):
    train_labels = []
    for user in train_user_groups.values():
        train_labels.extend([train_dataset.targets[idx] for idx in user])
    
    label_distribution = Counter(train_labels)
    test_labels = np.array(test_dataset.targets)
    new_test_indices = []

    for label, count in label_distribution.items():
        if name == 'cifar-10':
            adjusted_count = int(count / 3)
        elif name == 'cifar-100':
            adjusted_count = int(count / 6)
        elif name == 'tiny-imagenet':
            adjusted_count = int(count / 6)
        label_indices = np.where(test_labels == label)[0]

        if len(label_indices) < adjusted_count:
            print(f"Warning: Not enough samples for label {label}. Selecting all available samples.")
            selected_indices = label_indices
        else:
            selected_indices = np.random.choice(label_indices, size=adjusted_count, replace=False)

        new_test_indices.extend(selected_indices)

    new_test_dataset = Subset(test_dataset, new_test_indices)
    print_label_distribution(new_test_dataset)
    return new_test_dataset


def get_dataset(args):
    """ Returns train and test datasets and a user group which is a dict where
    the keys are the user index and the values are the corresponding data for
    each of those users.
    """
    if args.dataset == 'cifar-10':
        data_dir = '/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/Dataset/cifar'
        apply_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        train_dataset = CIFAR10(data_dir, train=True, download=True,
                                       transform=apply_transform)
        test_dataset = CIFAR10(data_dir, train=False, download=True,
                                      transform=apply_transform)
        train_user_groups, test_user_groups = dirichlet_noniid(train_dataset, test_dataset, args.num_users,
                                                                         args.alpha, args.seed, args)
    elif args.dataset == 'cifar-100':
        data_dir = '/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/Dataset/cifar100'
        apply_transform1 = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        apply_transform2 = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = CIFAR100(data_dir, train=True, transform=apply_transform1)
        test_dataset = CIFAR100(data_dir, train=False, transform=apply_transform2)
        train_user_groups, test_user_groups = dirichlet_noniid(train_dataset, test_dataset, args.num_users,
                                                                       args.alpha, args.seed, args)
    elif args.dataset == 'tiny-imagenet':
        apply_transform1 = transforms.Compose([
            # for resnet50
            # transforms.RandomResizedCrop(64),
            # for vgg
            # transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(15),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        ])

        apply_transform2 = transforms.Compose([
            # for resnet50
            # transforms.Resize(64),
            # for vgg
            # transforms.Resize(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
        ])

        train_dataset = TinyImageNet(transform=apply_transform1, is_train=True)
        test_dataset = TinyImageNet(transform=apply_transform2, is_train=False)
        train_user_groups, test_user_groups = dirichlet_noniid(train_dataset, test_dataset, args.num_users,
                                                                       args.alpha, args.seed, args)
    elif args.dataset == 'cinic-10':
        data_dir = "/home/hadoop-mtaigc-live/dolphinfs_hdd_hadoop-mtaigc-live/wangyuqi/Dataset/cinic10"
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        train_dataset = CINIC10(data_dir, train=True, transform=transform)
        test_dataset = CINIC10(data_dir, train=False, transform=transform)
        train_user_groups, test_user_groups = dirichlet_noniid(train_dataset, test_dataset, args.num_users, args.alpha, args.seed, args)
    else:
        raise  NotImplementedError
    new_test_dataset = remake_test_dataset(train_user_groups, train_dataset, test_dataset, args.dataset)
    return train_dataset, new_test_dataset, train_user_groups



