from libs import *


class TrainTinyImageNet(data.Dataset):
    def __init__(self, root, id, transform=None) -> None:
        super().__init__()
        self.filenames = glob.glob(root + "train/*/*/*.JPEG")
        self.transform = transform
        self.id_dict = id

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx: Any) -> Any:
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == "L":
            image = image.convert("RGB")
        label = self.id_dict[img_path.split("/")[-3]]
        if self.transform:
            image = self.transform(image)
        return image, label


class ValTinyImageNet(data.Dataset):
    def __init__(self, root, id, transform=None):
        self.filenames = glob.glob(root + "val/images/*.JPEG")
        self.transform = transform
        self.id_dict = id
        self.cls_dic = {}
        for i, line in enumerate(open(root + "val/val_annotations.txt", "r")):
            a = line.split("\t")
            img, cls_id = a[0], a[1]
            self.cls_dic[img] = self.id_dict[cls_id]

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        img_path = self.filenames[idx]
        image = Image.open(img_path)
        if image.mode == "L":
            image = image.convert("RGB")
        label = self.cls_dic[img_path.split("/")[-1]]
        if self.transform:
            image = self.transform(image)
        return image, label


class DatasetObject:
    def __init__(
        self, dataset, n_client, seed, rule, unbalanced_sgm=0, rule_arg="", data_path=""
    ):
        self.dataset = dataset
        self.n_client = n_client
        self.rule = rule
        self.rule_arg = rule_arg
        self.seed = seed
        rule_arg_str = rule_arg if isinstance(rule_arg, str) else "%.3f" % rule_arg
        self.name = "%s_%d_%d_%s_%s" % (
            self.dataset,
            self.n_client,
            self.seed,
            self.rule,
            rule_arg_str,
        )
        self.name += "_%f" % unbalanced_sgm if unbalanced_sgm != 0 else ""
        self.unbalanced_sgm = unbalanced_sgm
        self.data_path = data_path
        self.set_data()

    def set_data(self):
        # Prepare data if not ready
        if not os.path.exists("%sData/%s" % (self.data_path, self.name)):
            # Get Raw data
            if self.dataset == "mnist":
                transform = transforms.Compose(
                    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
                )
                trainset = torchvision.datasets.MNIST(
                    root="%sData/Raw" % self.data_path,
                    train=True,
                    download=True,
                    transform=transform,
                )
                testset = torchvision.datasets.MNIST(
                    root="%sData/Raw" % self.data_path,
                    train=False,
                    download=True,
                    transform=transform,
                )

                train_load = torch.utils.data.DataLoader(
                    trainset, batch_size=60000, shuffle=False, num_workers=1
                )
                test_load = torch.utils.data.DataLoader(
                    testset, batch_size=10000, shuffle=False, num_workers=1
                )
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10

            if self.dataset == "CIFAR10":
                transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize(
                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                        ),
                    ]
                )

                trainset = torchvision.datasets.CIFAR10(
                    root="%sData/Raw" % self.data_path,
                    train=True,
                    download=True,
                    transform=transform,
                )
                testset = torchvision.datasets.CIFAR10(
                    root="%sData/Raw" % self.data_path,
                    train=False,
                    download=True,
                    transform=transform,
                )

                train_load = torch.utils.data.DataLoader(
                    trainset, batch_size=50000, shuffle=False, num_workers=1
                )
                test_load = torch.utils.data.DataLoader(
                    testset, batch_size=10000, shuffle=False, num_workers=1
                )
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 10

            if self.dataset == "CIFAR100":
                transform = transforms.Compose(
                    [
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                        transforms.Normalize(
                            (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
                        ),
                    ]
                )

                trainset = torchvision.datasets.CIFAR100(
                    root="%sData/Raw" % self.data_path,
                    train=True,
                    download=True,
                    transform=transform,
                )
                testset = torchvision.datasets.CIFAR100(
                    root="%sData/Raw" % self.data_path,
                    train=False,
                    download=True,
                    transform=transform,
                )

                train_load = torch.utils.data.DataLoader(
                    trainset, batch_size=50000, shuffle=False, num_workers=1
                )
                test_load = torch.utils.data.DataLoader(
                    testset, batch_size=10000, shuffle=False, num_workers=1
                )
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 100

            if self.dataset == "tinyimagenet":
                transform = transforms.Compose(
                    [
                        transforms.ToTensor(),
                        transforms.Normalize(
                            mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                        ),
                    ]
                )

                """""" """""" """
                
                需要自行下载，太大了
                    
                """ """""" """"""

                root_dir = self.data_path + "/Data/Raw/tiny-imagenet-200/"
                id_dic = {}
                for i, line in enumerate(open(root_dir + "wnids.txt", "r")):
                    id_dic[line.replace("\n", "")] = i
                num_classes = len(id_dic)
                data_transform = {
                    "train": transforms.Compose(
                        [
                            transforms.RandomHorizontalFlip(),
                            transforms.RandomCrop(64, padding=4),
                            transforms.ToTensor(),
                            transforms.Normalize(
                                [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                            ),
                        ]
                    ),
                    "val": transforms.Compose(
                        [
                            transforms.ToTensor(),
                            transforms.Normalize(
                                [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
                            ),
                        ]
                    ),
                }
                train_dataset = TrainTinyImageNet(
                    root_dir, id=id_dic, transform=data_transform["train"]
                )
                test_dataset = ValTinyImageNet(
                    root_dir, id=id_dic, transform=data_transform["val"]
                )

                train_load = torch.utils.data.DataLoader(
                    train_dataset,
                    batch_size=100000,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=4,
                )
                test_load = torch.utils.data.DataLoader(
                    test_dataset,
                    batch_size=10000,
                    shuffle=False,
                    pin_memory=True,
                    num_workers=1,
                )

                print("trainset batchsize is : ", len(train_dataset))
                print("testset batchsize is : ", len(test_dataset))
                self.channels = 3
                self.width = 64
                self.height = 64
                self.n_cls = 200

            train_itr = train_load.__iter__()
            test_itr = test_load.__iter__()
            train_x, train_y = train_itr.__next__()
            test_x, test_y = test_itr.__next__()

            if self.dataset == "tinyimagenet":
                train_y = train_y.reshape(-1, 1)
                test_y = test_y.reshape(-1, 1)
                rand_perm = torch.randperm(len(train_y))
                train_x = train_x[rand_perm]
                train_y = train_y[rand_perm]
            else:
                train_x = np.array(train_x)
                train_y = np.array(train_y).reshape(-1, 1)
                test_x = np.array(test_x)
                test_y = np.array(test_y).reshape(-1, 1)

                # Shuffle Data
                # np.random.seed(self.seed)
                rand_perm = np.random.permutation(len(train_y))
                train_x = train_x[rand_perm]
                train_y = train_y[rand_perm]

            ###
            n_data_per_client = int((len(train_y)) / self.n_client)
            # Draw from lognormal distribution
            # 每个client拥有的数据数量(log)
            client_data_list = np.random.lognormal(
                mean=np.log(n_data_per_client),
                sigma=self.unbalanced_sgm,
                size=self.n_client,
            )
            # 还原
            client_data_list = (
                client_data_list / np.sum(client_data_list) * len(train_y)
            ).astype(int)
            diff = int((np.sum(client_data_list) - len(train_y)) / self.n_client)

            # train_itr = train_load.__iter__(); test_itr = test_load.__iter__()
            # # labels are of shape (n_data,)
            # train_x, train_y = train_itr.__next__()
            # test_x, test_y = test_itr.__next__()

            # train_x = train_x.numpy(); train_y = train_y.numpy().reshape(-1,1)
            # test_x = test_x.numpy(); test_y = test_y.numpy().reshape(-1,1)

            # # Shuffle Data
            # np.random.seed(self.seed)
            # rand_perm = np.random.permutation(len(train_y))
            # train_x = train_x[rand_perm]
            # train_y = train_y[rand_perm]

            # self.train_x = train_x
            # self.train_y = train_y
            # self.test_x = test_x
            # self.test_y = test_y

            # ###
            # n_data_per_client = int((len(train_y)) / self.n_client)
            # # Draw from lognormal distribution
            # client_data_list = (np.random.lognormal(mean=np.log(n_data_per_client), sigma=self.unbalanced_sgm, size=self.n_client))
            # client_data_list = (client_data_list/np.sum(client_data_list)*len(train_y)).astype(int)
            # diff = int((np.sum(client_data_list) - len(train_y))/self.n_client)

            # Add/Subtract the excess number starting from first client
            if diff != 0:
                for client_i in range(self.n_client):
                    if client_data_list[client_i] > diff:
                        client_data_list[client_i] -= diff
            ###

            if self.rule == "Drichlet":
                cls_priors = np.random.dirichlet(
                    alpha=[self.rule_arg] * self.n_cls, size=self.n_client
                )
                np.save(
                    "results/heterogeneity_distribution_{:s}.npy".format(self.dataset),
                    cls_priors,
                )
                prior_cumsum = np.cumsum(cls_priors, axis=1)
                idx_list = [np.where(train_y == i)[0] for i in range(self.n_cls)]
                cls_amount = [len(idx_list[i]) for i in range(self.n_cls)]

                client_x = [
                    np.zeros(
                        (
                            client_data_list[client__],
                            self.channels,
                            self.height,
                            self.width,
                        )
                    ).astype(np.float32)
                    for client__ in range(self.n_client)
                ]
                client_y = [
                    np.zeros((client_data_list[client__], 1)).astype(np.int64)
                    for client__ in range(self.n_client)
                ]

                while np.sum(client_data_list) != 0:
                    curr_client = np.random.randint(self.n_client)
                    # If current node is full resample a client
                    # print('Remaining Data: %d' %np.sum(client_data_list))
                    if client_data_list[curr_client] <= 0:
                        continue
                    client_data_list[curr_client] -= 1
                    curr_prior = prior_cumsum[curr_client]
                    while True:
                        # selected_times = np.zeros(self.n_cls)
                        if max(cls_amount) != 0 and max(cls_amount) == np.sum(
                            cls_amount
                        ):
                            # print(cls_amount)
                            cls_label = np.argmax(cls_amount)
                        else:
                            cls_label = np.argmax(np.random.uniform() <= curr_prior)
                            # print("prior:", curr_prior)
                            # Redraw class label if train_y is out of that class
                        if cls_amount[cls_label] <= 0:
                            # print("cls_amount:", cls_amount)
                            continue
                        cls_amount[cls_label] -= 1

                        client_x[curr_client][client_data_list[curr_client]] = train_x[
                            idx_list[cls_label][cls_amount[cls_label]]
                        ]
                        client_y[curr_client][client_data_list[curr_client]] = train_y[
                            idx_list[cls_label][cls_amount[cls_label]]
                        ]
                        break

                client_x = np.asarray(client_x)
                client_y = np.asarray(client_y)

                cls_means = np.zeros((self.n_client, self.n_cls))
                for client in range(self.n_client):
                    for cls in range(self.n_cls):
                        cls_means[client, cls] = np.mean(client_y[client] == cls)
                prior_real_diff = np.abs(cls_means - cls_priors)
                print("--- Max deviation from prior: %.4f" % np.max(prior_real_diff))
                print("--- Min deviation from prior: %.4f" % np.min(prior_real_diff))

            elif self.rule == "iid":

                client_x = [
                    np.zeros(
                        (
                            client_data_list[client__],
                            self.channels,
                            self.height,
                            self.width,
                        )
                    ).astype(np.float32)
                    for client__ in range(self.n_client)
                ]
                client_y = [
                    np.zeros((client_data_list[client__], 1)).astype(np.int64)
                    for client__ in range(self.n_client)
                ]

                client_data_list_cum_sum = np.concatenate(
                    ([0], np.cumsum(client_data_list))
                )
                for client_idx_ in range(self.n_client):
                    client_x[client_idx_] = train_x[
                        client_data_list_cum_sum[
                            client_idx_
                        ] : client_data_list_cum_sum[client_idx_ + 1]
                    ]
                    client_y[client_idx_] = train_y[
                        client_data_list_cum_sum[
                            client_idx_
                        ] : client_data_list_cum_sum[client_idx_ + 1]
                    ]

                client_x = np.asarray(client_x)
                client_y = np.asarray(client_y)

            self.client_x = client_x
            self.client_y = client_y

            self.test_x = test_x
            self.test_y = test_y

            # Save data
            os.mkdir("%sData/%s" % (self.data_path, self.name))

            np.save("%sData/%s/client_x.npy" % (self.data_path, self.name), client_x)
            np.save("%sData/%s/client_y.npy" % (self.data_path, self.name), client_y)

            np.save("%sData/%s/test_x.npy" % (self.data_path, self.name), test_x)
            np.save("%sData/%s/test_y.npy" % (self.data_path, self.name), test_y)

        else:
            print("Data is already downloaded")
            self.client_x = np.load(
                "%sData/%s/client_x.npy" % (self.data_path, self.name),
                allow_pickle=True,
            )
            self.client_y = np.load(
                "%sData/%s/client_y.npy" % (self.data_path, self.name),
                allow_pickle=True,
            )
            self.n_client = len(self.client_x)

            self.test_x = np.load(
                "%sData/%s/test_x.npy" % (self.data_path, self.name), allow_pickle=True
            )
            self.test_y = np.load(
                "%sData/%s/test_y.npy" % (self.data_path, self.name), allow_pickle=True
            )

            if self.dataset == "mnist":
                self.channels = 1
                self.width = 28
                self.height = 28
                self.n_cls = 10
            if self.dataset == "CIFAR10":
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 10
            if self.dataset == "CIFAR100":
                self.channels = 3
                self.width = 32
                self.height = 32
                self.n_cls = 100
            if self.dataset == "tinyimagenet":
                self.channels = 3
                self.width = 64
                self.height = 64
                self.n_cls = 200


class Dataset(torch.utils.data.Dataset):

    def __init__(self, data_x, data_y=True, train=False, dataset_name=""):
        self.name = dataset_name

        if self.name == "mnist":
            self.X_data = torch.as_tensor(data_x, device=device)
            self.y_data = data_y

            if not isinstance(data_y, bool):
                self.y_data = torch.as_tensor(data_y, device=device)

        # elif self.name == "CIFAR10":
        #     if self.name == "CIFAR10":
        #         self.image_size = 32
        #         self.pad = 4
        #     else:
        #         self.image_size = 64
        #         self.pad = 8

        #     self.train = train
        #     self.transform = transforms.Compose([])

        #     self.X_data = torch.as_tensor(data_x, device=device)
        #     self.y_data = data_y
        #     if not isinstance(data_y, bool):
        #         self.y_data = torch.as_tensor(data_y, device=device)
        #     if self.train:
        #         self.X_data = self.preprocess_data(self.X_data)

        elif self.name == "tinyimagenet" or self.name == "CIFAR100" or self.name == "CIFAR10": 
            resized_data_x = F.interpolate(
                torch.as_tensor(data_x, device=device),
                size=(224, 224),
                mode="bilinear",
                align_corners=False,
            )
            self.X_data = torch.as_tensor(resized_data_x, device=device)
            self.y_data = data_y

            if not isinstance(data_y, bool):
                self.y_data = torch.as_tensor(data_y, device=device)

    def __len__(self):
        return len(self.X_data)

    def __getitem__(self, idx):
        if (
            self.name == "mnist"
            or self.name == "CIFAR100"
            or self.name == "tinyimagenet"
            or self.name == "CIFAR10"
        ):
            X = self.X_data[idx, :]
            if isinstance(self.y_data, bool):
                return X
            else:
                y = self.y_data[idx]
                return X, y

        # elif self.name == "CIFAR10":
        #     img = self.X_data[idx]
        #     img = self.transform(img)
        #     if isinstance(self.y_data, bool):
        #         return img
        #     else:
        #         y = self.y_data[idx]
        #         return img, y

    def preprocess_data(self, data):
        batch_size = data.shape[0]

        # Horizontal flip with 50% probability
        flip_mask = torch.rand(batch_size) > 0.5
        data[flip_mask] = torch.flip(data[flip_mask], dims=[2])

        # Random cropping with 50% probability
        extended_data = torch.zeros(
            (
                batch_size,
                3,
                self.image_size + self.pad * 2,
                self.image_size + self.pad * 2,
            ),
            dtype=torch.float32,
            device=data.device,
        )
        extended_data[:, :, self.pad : -self.pad, self.pad : -self.pad] = data
        crop_mask = (torch.rand(batch_size) > 0.5).to(device)

        # Generate random crop coordinates
        dim_1 = torch.randint(self.pad * 2 + 1, (batch_size,), device=device)
        dim_2 = torch.randint(self.pad * 2 + 1, (batch_size,), device=device)
        add_dim = torch.tensor(range(self.image_size)).to(device)
        dim_1 = dim_1.unsqueeze(1) + add_dim
        dim_2 = dim_2.unsqueeze(1) + add_dim
        dim_1 = dim_1.unsqueeze(1).unsqueeze(-1).expand(-1, 3, -1, self.image_size)
        dim_2 = (
            dim_2.unsqueeze(1)
            .unsqueeze(2)
            .expand(-1, 3, self.image_size + 2 * self.pad, -1)
        )

        extended_data = torch.gather(extended_data, 3, dim_2)
        extended_data = torch.gather(extended_data, 2, dim_1)
        data = torch.where(crop_mask.view(-1, 1, 1, 1), extended_data, data)
        return data
