import json
import numpy as np
import os
from collections import defaultdict
import torch
from PIL import Image
from torchvision import transforms

def load_image(img_name_list):
    transform = transforms.Compose([
        transforms.Resize((84, 84)),
        transforms.ToTensor()
    ])
    raw_dir = './leaf/data/celeba/data/raw/img_align_celeba/'
    batch_x = []
    for img_pth in img_name_list:
        img = Image.open(os.path.join(raw_dir, img_pth))
        img = transform(img)
        batch_x.append(img.unsqueeze(0))
    return torch.cat(batch_x, dim=0)

def batch_data(data, batch_size, seed):
    batch_x = load_image(data['x'])
    device = torch.device("cuda:6" if torch.cuda.is_available() else "cpu")

    data_x = batch_x
    data_y = torch.tensor(data['y'], dtype=torch.int64).to(device)


    np.random.seed(seed)
    rng_state = np.random.get_state()
    np.random.shuffle(data_x)
    np.random.set_state(rng_state)
    np.random.shuffle(data_y)

    for i in range(0, len(data_x), batch_size):
        batched_x = data_x[i:i+batch_size]
        batched_y = data_y[i:i+batch_size]
        yield (batched_x, batched_y)


def read_dir(data_dir):
    clients = []
    groups = []
    data = defaultdict(lambda : None)

    files = os.listdir(data_dir)
    files = [f for f in files if f.endswith('.json')]
    for f in files:
        file_path = os.path.join(data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        data.update(cdata['user_data'])

    clients = list(sorted(data.keys()))
    return clients, groups, data


def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories

    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    train_clients, train_groups, train_data = read_dir(train_data_dir)
    test_clients, test_groups, test_data = read_dir(test_data_dir)

    assert train_clients == test_clients
    assert train_groups == test_groups

    return train_clients, train_groups, train_data, test_data


if __name__ == '__main__':
    train_data_dir = './leaf/data/celeba/data/train'
    test_data_dir = './leaf/data/celeba/data/test'
    train_clients, train_groups, train_data, test_data = read_data(train_data_dir, test_data_dir)

    data = batch_data(test_data[train_clients[0]], 10, 1)
    for images, labels in data:
        print(images.shape, images.dtype)
        assert 1==0
