from __future__ import annotations

import json
import numpy as np
import numpy.random as random
# import pickle
import torch
import torch.utils.data as data
import torchvision.datasets as datasets
import torchvision.transforms as transforms

# from multiprocessing import Pool
# from pathos.multiprocessing import ProcessingPool as Pool
from PIL import Image
from os import listdir
from torchvision.io import read_image


def load_data(dataset: str, root: str, args):
    if dataset == 'cifar10':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        # load dataset
        tr_dataset = datasets.CIFAR10(root=root, train=True, download=True, transform=train_transform)
        te_dataset = datasets.CIFAR10(root=root, train=False, download=True, transform=train_transform)
        
        tr_data_tensor, tr_label_tensor = concat_dataset(dataset=tr_dataset)
        te_data_tensor, te_label_tensor = concat_dataset(dataset=te_dataset)
    elif dataset == 'cifar100':
        train_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        # load dataset
        tr_dataset = datasets.CIFAR100(root=root, train=True, download=True, transform=train_transform)
        te_dataset = datasets.CIFAR100(root=root, train=False, download=True, transform=train_transform)
        
        tr_data_tensor, tr_label_tensor = concat_dataset(dataset=tr_dataset)
        te_data_tensor, te_label_tensor = concat_dataset(dataset=te_dataset)
    elif dataset == 'imagenet12':
        tr_data_tensor, tr_label_tensor = load_imagenet(root=root, train=True, n_classes=args.n_classes)
        te_data_tensor, te_label_tensor = load_imagenet(root=root, train=False, n_classes=args.n_classes)
    elif dataset == 'femnist':
        user_tr_data_tensors, user_tr_label_tensors = load_femnist(root=root, train=True, )
        te_data_tensor, te_label_tensor = load_femnist(root=root, train=False, )

    print('data loaded')
    
    if dataset != 'femnist':
        user_tr_data_tensors, user_tr_label_tensors = partition_data(tr_data_tensor, tr_label_tensor, args)
    
    print('datalens:', [len(i) for i in user_tr_data_tensors])
    
    return user_tr_data_tensors, user_tr_label_tensors, te_data_tensor, te_label_tensor

def partition_data(X: torch.Tensor, Y: torch.Tensor, args)->tuple[list[torch.Tensor],list[torch.Tensor]]:
    n_samples = len(X)
    # shuffle training set
    all_indices = np.arange(len(X))
    random.shuffle(all_indices)
    X=X[all_indices]
    Y=Y[all_indices]

    if args.partition == 'iid':
        user_tr_len = n_samples // args.n_clients

        user_tr_data_tensors=[]
        user_tr_label_tensors=[]

        for i in range(args.n_clients):
            user_tr_data_tensor=X[user_tr_len*i:user_tr_len*(i+1)]
            user_tr_label_tensor=Y[user_tr_len*i:user_tr_len*(i+1)]

            user_tr_data_tensors.append(user_tr_data_tensor)
            user_tr_label_tensors.append(user_tr_label_tensor)
    
    if args.partition == 'dirichlet':
        min_size = 0
        while min_size < args.min_n_samples:
            idx_batch = [[] for _ in range(args.n_clients)]
            for k in range(args.n_classes):
                idx_k = torch.nonzero(Y == k).squeeze().numpy()
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(args.beta, args.n_clients))
                # Balance
                proportions = np.array([p * (len(idx_j) < n_samples / args.n_clients) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()

                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]

                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

        user_tr_data_tensors = [X[idx] for idx in idx_batch]
        user_tr_label_tensors = [Y[idx].squeeze() for idx in idx_batch]

    return user_tr_data_tensors, user_tr_label_tensors

class TensorDataset(data.Dataset):
    def __init__(self, data_tensor, label_tensor):
        # super().__init__()
        self.data = data_tensor
        self.target = label_tensor
    
    def __getitem__(self, index):
        return (self.data[index], self.target[index])
    
    def __len__(self):
        return len(self.data)

def load_femnist(root: str, train: bool, ):
    user_data = []
    user_labels = []

    if train:
        relative_folder = 'train'
    else:
        relative_folder = 'test'
    folder = '/'.join([root, relative_folder])
    json_paths = listdir(folder)

    for json_path in json_paths:
        f = '/'.join([folder, json_path])
        with open(f, 'r') as myfile:
            data=myfile.read()
        obj = json.loads(data)
        
        for user in obj['users']:
            list_x = obj['user_data'][user]['x']
            tensor_x = torch.Tensor(list_x).float()
            list_y = obj['user_data'][user]['y']
            tensor_y = torch.Tensor(list_y).long()
            user_data.append(tensor_x)
            user_labels.append(tensor_y)
        
    if train:
        return user_data, user_labels
    else:
        te_data_tensor = torch.cat(user_data)
        te_label_tensor = torch.cat(user_labels)
        return te_data_tensor, te_label_tensor


def load_imagenet(root: str, train: bool, n_classes: int):
    datas = []
    labels = []
    for class_idx in range(n_classes):
        data_i, label_i = load_imagenet_class(class_idx=class_idx, train=train, root=root)
        datas.append(data_i)
        labels.append(label_i)
    data_tensor = torch.cat(datas)
    label_tensor = torch.cat(labels)

    return data_tensor, label_tensor

def load_imagenet_class(class_idx: int, root: str, train: bool, ):
    if train:
        sub_folder = 'train'
    else:
        sub_folder = 'val'
    class_path = '/'.join([root, sub_folder, str(class_idx)])
    image_relative_paths = listdir(path=class_path)
        
    n_datas=len(image_relative_paths)
    imagenet_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    X = []
    for i in range(n_datas):
        image_path = '/'.join([class_path, image_relative_paths[i]])
        image = Image.open(image_path)
        image = imagenet_transform(image)
        X.append(image)
    data = torch.stack(X)

    label = torch.Tensor([class_idx, ] * n_datas).long()
    return data, label

def concat_dataset(dataset):
    X=[]
    Y=[]
    for i in range(len(dataset)):
        X.append(dataset[i][0])
        Y.append(dataset[i][1])
    X=torch.stack(X)
    Y=torch.Tensor(Y,).long()
    return X, Y
