import numpy as np
import os
import sys
import random
import torch
import torchvision
import torchvision.transforms as transforms
from utils.dataset_utils import check, separate_data, split_data, save_file
from torchvision.datasets import ImageFolder, DatasetFolder
import requests
import zipfile

random.seed(1)
np.random.seed(1)
num_clients = 10
dir_path = "../../../PycharmProjects/HtFLlib/dataset/TinyImagenet/"

# https://github.com/QinbinLi/MOON/blob/6c7a4ed1b1a8c0724fa2976292a667a828e3ff5d/datasets.py#L148
class ImageFolder_custom(DatasetFolder):
    def __init__(self, root, dataidxs=None, train=True, transform=None, target_transform=None):
        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform

        imagefolder_obj = ImageFolder(self.root, self.transform, self.target_transform)
        self.loader = imagefolder_obj.loader
        if self.dataidxs is not None:
            self.samples = np.array(imagefolder_obj.samples)[self.dataidxs]
        else:
            self.samples = np.array(imagefolder_obj.samples)

    def __getitem__(self, index):
        path = self.samples[index][0]
        target = self.samples[index][1]
        target = int(target)
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return sample, target

    def __len__(self):
        if self.dataidxs is None:
            return len(self.samples)
        else:
            return len(self.dataidxs)


# # Allocate data to users
# def generate_dataset(dir_path, num_clients, niid, balance, partition):
#     if not os.path.exists(dir_path):
#         os.makedirs(dir_path)
#
#     # Setup directory for train/test data
#     config_path = dir_path + "config.json"
#     train_path = dir_path + "train/"
#     test_path = dir_path + "test/"
#
#     if check(config_path, train_path, test_path, num_clients, niid, balance, partition):
#         return
#
#     # Get data
#     if not os.path.exists(f'{dir_path}/rawdata/'):
#         os.system(f'wget --directory-prefix {dir_path}/rawdata/ http://cs231n.stanford.edu/tiny-imagenet-200.zip')
#         os.system(f'unzip {dir_path}/rawdata/tiny-imagenet-200.zip -d {dir_path}/rawdata/')
#     else:
#         print('rawdata already exists.\n')
#
#     transform = transforms.Compose(
#         [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
#
#     trainset = ImageFolder_custom(root=dir_path+'rawdata/tiny-imagenet-200/train/', transform=transform)
#     testset = ImageFolder_custom(root=dir_path+'rawdata/tiny-imagenet-200/val/', transform=transform)
#     trainloader = torch.utils.data.DataLoader(
#         trainset, batch_size=len(trainset), shuffle=False)
#     testloader = torch.utils.data.DataLoader(
#         testset, batch_size=len(testset), shuffle=False)
#
#     for _, train_data in enumerate(trainloader, 0):
#         trainset.data, trainset.targets = train_data
#     for _, test_data in enumerate(testloader, 0):
#         testset.data, testset.targets = test_data
#
#     dataset_image = []
#     dataset_label = []
#
#     dataset_image.extend(trainset.data.cpu().detach().numpy())
#     dataset_image.extend(testset.data.cpu().detach().numpy())
#     dataset_label.extend(trainset.targets.cpu().detach().numpy())
#     dataset_label.extend(testset.targets.cpu().detach().numpy())
#     dataset_image = np.array(dataset_image)
#     dataset_label = np.array(dataset_label)
#
#     num_classes = len(set(dataset_label))
#     print(f'Number of classes: {num_classes}')
#
#     # dataset = []
#     # for i in range(num_classes):
#     #     idx = dataset_label == i
#     #     dataset.append(dataset_image[idx])
#
#     X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes,
#                                     niid, balance, partition, class_per_client=20)
#     train_data, test_data = split_data(X, y)
#     save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes,
#         statistic, niid, balance, partition)


def download_and_extract(url, download_path, extract_path):
    zip_path = os.path.join(download_path, "tiny-imagenet-200.zip")

    # 下载文件
    if not os.path.exists(zip_path):
        print("Downloading Tiny-ImageNet dataset...")
        with requests.get(url, stream=True) as r:
            r.raise_for_status()
            with open(zip_path, 'wb') as f:
                for chunk in r.iter_content(chunk_size=8192):
                    f.write(chunk)
        print("Download completed.")
    else:
        print("Dataset zip file already exists.")

    # 解压文件
    if not os.path.exists(extract_path):
        print("Extracting dataset...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(download_path)
        print("Extraction completed.")
    else:
        print("Dataset already extracted.")
# 修改 generate_dataset 函数
def generate_dataset(dir_path, num_clients, niid, balance, partition):
    if not os.path.exists(dir_path):
        os.makedirs(dir_path)

    # Setup directory for train/test data
    config_path = dir_path + "config.json"
    train_path = dir_path + "train/"
    test_path = dir_path + "test/"

    if check(config_path, train_path, test_path, num_clients, niid, balance, partition):
        return

    # Get data
    rawdata_path = os.path.join(dir_path, "rawdata")
    tiny_imagenet_url = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
    if not os.path.exists(rawdata_path):
        os.makedirs(rawdata_path)
        download_and_extract(tiny_imagenet_url, rawdata_path, os.path.join(rawdata_path, "tiny-imagenet-200"))
    else:
        print('rawdata already exists.\n')

    transform = transforms.Compose(
        [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = ImageFolder_custom(root=dir_path + 'rawdata/tiny-imagenet-200/train/', transform=transform)
    testset = ImageFolder_custom(root=dir_path + 'rawdata/tiny-imagenet-200/val/', transform=transform)
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=len(trainset), shuffle=False)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=len(testset), shuffle=False)

    for _, train_data in enumerate(trainloader, 0):
        trainset.data, trainset.targets = train_data
    for _, test_data in enumerate(testloader, 0):
        testset.data, testset.targets = test_data

    dataset_image = []
    dataset_label = []

    dataset_image.extend(trainset.data.cpu().detach().numpy())
    dataset_image.extend(testset.data.cpu().detach().numpy())
    dataset_label.extend(trainset.targets.cpu().detach().numpy())
    dataset_label.extend(testset.targets.cpu().detach().numpy())
    dataset_image = np.array(dataset_image)
    dataset_label = np.array(dataset_label)

    num_classes = len(set(dataset_label))
    print(f'Number of classes: {num_classes}')

    X, y, statistic = separate_data((dataset_image, dataset_label), num_clients, num_classes,
                                    niid, balance, partition, class_per_client=20)
    train_data, test_data = split_data(X, y)
    save_file(config_path, train_path, test_path, train_data, test_data, num_clients, num_classes,
              statistic, niid, balance, partition)

if __name__ == "__main__":
    niid = True if sys.argv[1] == "noniid" else False
    balance = True if sys.argv[2] == "balance" else False
    partition = sys.argv[3] if sys.argv[3] != "-" else None

    generate_dataset(dir_path, num_clients, niid, balance, partition)
