from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import os
from torchvision import datasets, transforms
from collections import defaultdict
import random
import numpy as np
import torch
# from .so_tag_utils import *
import math
import sys
import six
from PIL import Image
from scipy import special

VOCAB_DIR = 0
emb_array = 0
vocab = 0
embed_dim = 0
IMAGES_DIR = "data/celeba/data/raw/img_align_celeba"
IMAGE_SIZE = 84


def batch_data(data, batch_size, rng=None, shuffle=True, eval_mode=False, full=False, malicious=False, dataset='celeba'):
    """
    data is a dict := {'x': [list], 'y': [list]} with optional fields 'y_true': [list], 'x_true' : [list]
    If eval_mode, use 'x_true' and 'y_true' instead of 'x' and 'y', if such fields exist
    returns x, y, which are both lists of size-batch_size lists
    """
    x = data['x_true'] if eval_mode and 'x_true' in data else data['x']
    y = data['y_true'] if eval_mode and 'y_true' in data else data['y']
    a = data['y_true'] if eval_mode and 'a_true' in data else data['a']
    indices = np.arange(len(x))
    np.random.shuffle(indices)

    x = process_x(x)
    x = torch.tensor((x/1.)).cuda()
    x = torch.transpose(torch.transpose(x,-1,-2),-2,-3)
    y = torch.LongTensor(y).cuda()
    raw_x = x[indices]
    raw_y = y[indices]
    raw_a = a[indices]
    batched_x, batched_y, batched_a = [], [], []
    if not full:
        for i in range(0, len(raw_x), batch_size):
            batched_x.append(raw_x[i:i + batch_size])
            batched_y.append(raw_y[i:i + batch_size])
            batched_a.append(raw_a[i:i + batch_size])
    else:
        batched_x.append(raw_x)
        batched_y.append(raw_y)
        batched_a.append(raw_a)
    return batched_x, batched_y, batched_a

def load_image(img_name):
    img = Image.open(os.path.join(IMAGES_DIR, img_name))
    img = img.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB')
    return np.array(img)

def process_x(raw_x_batch):
    x_batch = [load_image(i) for i in raw_x_batch]
    x_batch = np.array(x_batch)
    return x_batch

def process_y(raw_y_batch):
    return raw_y_batch


def read_so_data():
    groups = []
    train_data, test_data = get_centralized_datasets()
    clients = {
        'train_users': list(train_data.keys()),
        'test_users': list(test_data.keys())
    }
    return clients, groups, train_data, test_data

def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    clients = []
    groups = []
    train_data = {}
    test_data = {}

    train_files = os.listdir(train_data_dir)
    train_files = [f for f in train_files if f.endswith('.json')]
    for f in train_files:
        file_path = os.path.join(train_data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        train_data.update(cdata['user_data'])

    test_files = os.listdir(test_data_dir)
    test_files = [f for f in test_files if f.endswith('.json')]
    for f in test_files:
        file_path = os.path.join(test_data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        test_data.update(cdata['user_data'])

    clients = list(train_data.keys())

    return clients, groups, train_data, test_data