# Modified from: https://github.com/pliang279/LG-FedAvg/blob/master/utils/train_utils.py
# credit goes to: Paul Pu Liang

from torchvision import datasets, transforms
from models.Nets import CNN_FEMNIST, CNN_FAMNIST
from models.Transformer import TransformerEncoder
from utils.sampling import noniid
import os
import json
import numpy as np
import pandas as pd
import torch.utils.data as data

trans_mnist = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize((0.1307,), (0.3081,))])
trans_cifar10_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                               std=[0.229, 0.224, 0.225])])
trans_cifar10_val = transforms.Compose([transforms.ToTensor(),
                                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                             std=[0.229, 0.224, 0.225])])
trans_cifar100_train = transforms.Compose([transforms.RandomCrop(32, padding=4),
                                          transforms.RandomHorizontalFlip(),
                                          transforms.ToTensor(),
                                          transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                               std=[0.267, 0.256, 0.276])])
trans_cifar100_val = transforms.Compose([transforms.ToTensor(),
                                         transforms.Normalize(mean=[0.507, 0.487, 0.441],
                                                              std=[0.267, 0.256, 0.276])])
def get_data(args):
    if args.dataset == 'cifar10':
        dataset_train = datasets.CIFAR10('data/cifar10', train=True, download=True, transform=trans_cifar10_train)
        dataset_test = datasets.CIFAR10('data/cifar10', train=False, download=True, transform=trans_cifar10_val)
        dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user, args.num_classes)
        dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, args.num_classes, rand_set_all=rand_set_all)
    elif args.dataset == 'cifar100':
        dataset_train = datasets.CIFAR100('data/cifar100', train=True, download=True, transform=trans_cifar100_train)
        dataset_test = datasets.CIFAR100('data/cifar100', train=False, download=True, transform=trans_cifar100_val)
        dict_users_train, rand_set_all = noniid(dataset_train, args.num_users, args.shard_per_user, args.num_classes)
        dict_users_test, rand_set_all = noniid(dataset_test, args.num_users, args.shard_per_user, args.num_classes, rand_set_all=rand_set_all)
    elif args.dataset == 'femnist':
        # from utils import femnist
        # from utils import GrayscaleToRGB as grgb
        # apply_transform = transforms.Compose([
        #     grgb.GrayscaleToRgb(),
        #     transforms.Resize(28),
        #     transforms.ToTensor()])
        # apply_transform_c = transforms.Compose([
        #     transforms.Resize(28),
        #     transforms.ToTensor()])
        #
        # data_dir = '../Datasets/images/FeMNIST/Gray/'
        # cdata_dir = '../Datasets/images/FeMNIST/Color/'
        #
        # dataset_train_g = femnist.FEMNIST(data_dir, train=True, download=False,
        #                                   transform=apply_transform)
        # dataset_test_g = femnist.FEMNIST(data_dir, train=False, download=False,
        #                                  transform=apply_transform)
        #
        # dataset_train_c = femnist.FEMNISTC(cdata_dir, train=True, download=False,
        #                                    transform=apply_transform_c)
        # dataset_test_c = femnist.FEMNISTC(cdata_dir, train=False, download=False,
        #                                   transform=apply_transform_c)
        #
        # dict_users_train_g, rand_set_all_g = noniid(dataset_train_g, args.num_users, args.shard_per_user,
        #                                             args.num_classes, color=False)
        # dict_users_test_g, rand_set_all_g = noniid(dataset_test_g, args.num_users, args.shard_per_user,
        #                                            args.num_classes, color=False, rand_set_all=rand_set_all_g)
        #
        # dict_users_train_c, rand_set_all_c = noniid(dataset_train_c, args.num_users, args.shard_per_user,
        #                                             args.num_classes, color=True)
        # dict_users_test_c, rand_set_all_c = noniid(dataset_test_c, args.num_users, args.shard_per_user,
        #                                            args.num_classes, color=True, rand_set_all=rand_set_all_c)

        # For fashionMNIST
        from utils import femnist
        from utils import GrayscaleToRGB as grgb
        apply_transform = transforms.Compose([
            grgb.GrayscaleToRgb(),
            transforms.ToTensor()])
        apply_transform_c = transforms.Compose([
            transforms.ToTensor()])

        data_dir = '../Datasets/images/FashionMNIST/Edge'
        cdata_dir = '../Datasets/images/FashionMNIST/Color'

        dataset_train_g = femnist.FAMNISTGC(data_dir, train=True, download=False,
                                            transform=apply_transform, color=False)
        dataset_test_g = femnist.FAMNISTGC(data_dir, train=False, download=False,
                                           transform=apply_transform, color=False)

        dataset_train_c = femnist.FAMNISTGC(cdata_dir, train=True, download=False,
                                            transform=apply_transform_c, color=True)
        dataset_test_c = femnist.FAMNISTGC(cdata_dir, train=False, download=False,
                                           transform=apply_transform_c, color=True)

        dict_users_train_g, rand_set_all_g = noniid(dataset_train_g, args.num_users, args.shard_per_user,
                                                    args.num_classes, color=False)
        dict_users_test_g, rand_set_all_g = noniid(dataset_test_g, args.num_users, args.shard_per_user,
                                                   args.num_classes, color=False, rand_set_all=rand_set_all_g)

        dict_users_train_c, rand_set_all_c = noniid(dataset_train_c, args.num_users, args.shard_per_user,
                                                    args.num_classes, color=True)
        dict_users_test_c, rand_set_all_c = noniid(dataset_test_c, args.num_users, args.shard_per_user,
                                                   args.num_classes, color=True, rand_set_all=rand_set_all_c)

        return dataset_train_g, dataset_test_g, dict_users_train_g, dict_users_test_g, \
               dataset_train_c, dataset_test_c, dict_users_train_c, dict_users_test_c
    else:
        exit('Error: unrecognized dataset')

    return dataset_train, dataset_test, dict_users_train, dict_users_test

def read_data(train_data_dir, test_data_dir):
    '''parses data in given train and test data directories
    assumes:
    - the data in the input directories are .json files with 
        keys 'users' and 'user_data'
    - the set of train set users is the same as the set of test set users
    
    Return:
        clients: list of client ids
        groups: list of group ids; empty list if none found
        train_data: dictionary of train data
        test_data: dictionary of test data
    '''
    clients = []
    groups = []
    train_data = {}
    test_data = {}

    train_files = os.listdir(train_data_dir)
    train_files = [f for f in train_files if f.endswith('.json')]
    for f in train_files:
        file_path = os.path.join(train_data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        clients.extend(cdata['users'])
        if 'hierarchies' in cdata:
            groups.extend(cdata['hierarchies'])
        train_data.update(cdata['user_data'])

    test_files = os.listdir(test_data_dir)
    test_files = [f for f in test_files if f.endswith('.json')]
    for f in test_files:
        file_path = os.path.join(test_data_dir,f)
        with open(file_path, 'r') as inf:
            cdata = json.load(inf)
        test_data.update(cdata['user_data'])

    clients = list(train_data.keys())

    return clients, groups, train_data, test_data


def get_model(args):

    if args.model == 'cnn' and 'femnist' in args.dataset:
        #For FEMNIST
        net_glob = CNN_FEMNIST(args=args).to(args.device)
        net_trans = TransformerEncoder(128, 512, 8).to(args.device)

        #For FashionMNIST
        net_glob = CNN_FAMNIST(args=args).to(args.device)
        net_trans = TransformerEncoder(128, 512, 8).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

    return net_glob, net_trans
