

from jsonargparse import ArgumentParser
import numpy as np
import torch
from tqdm import tqdm

from utils import get_model, get_image_data, save_client_data, get_seeds, set_seeds, CustomDataset

def main(args):

    set_seeds(args.init_seed)

    all_dataset_names = args.image_datasets
    for i_data, dataset_name in enumerate(all_dataset_names):
        print(f'\nProcessing dataset: {dataset_name}')
        if dataset_name == 'cifar10-pretrained':
            dataset_name = 'cifar10'
            assert args.use_pretrained_features, 'Pretrained features must be used with cifar10-pretrained dataset'

        seed = get_seeds(n_rng=1, initial_seed=args.init_seed)[0]

        if dataset_name in ['cifar10', 'fashion_mnist']:
            client_data, N_is, data_props = get_image_data(num_clients=args.num_clients+args.num_test_clients, dataset_name=dataset_name, data_folder=args.dataset_folder, random_seed=seed)
        elif dataset_name == 'income':
            raise ValueError('Income data has inherent splits!')
        else:
            raise ValueError(f'Unknown dataset name: {dataset_name}')
        
        client_ids = np.arange(len(client_data),dtype=int)
        np.random.shuffle(client_ids)
        train_client_ids = client_ids[:args.num_clients]
        test_client_ids = client_ids[args.num_clients:]
        if args.use_pretrained_features:
            model = get_model(model_name=args.pretrained_model_name, data_dims=args.data_dims, add_ones_to_data=args.add_ones_to_data)
            client_data = create_pretrained_features(client_data, model)
        
        if args.use_pretrained_features:
            tmp = f"{args.dataset_folder}/{dataset_name}/federated/iid-pretrained-clients{args.num_clients}-testclients{args.num_test_clients}"
        else:
            tmp = f"{args.dataset_folder}/{dataset_name}/federated/iid-clients{args.num_clients}-testclients{args.num_test_clients}"

        # save clients
        save_client_data([ client_data[i] for i in train_client_ids], folder_name=tmp, data_props=
                         [ data_props[i] for i in train_client_ids], N_is=[N_is[i] for i in train_client_ids])
        # save test clients
        save_client_data([ client_data[i] for i in test_client_ids], folder_name=tmp+'-testset', data_props=
                         [ data_props[i] for i in test_client_ids], N_is=[N_is[i] for i in test_client_ids])
        print(f'Done processing dataset: {dataset_name}.')

def create_pretrained_features(client_data, model):
    device = 'cuda' if torch.cuda.is_available() else 'cpu' # gpu is only used for extracting features from pretrained model
    model = model.to(device)
    client_features = []
    with torch.no_grad():
        for i_client, data in tqdm(enumerate(client_data)):
            client_features.append([])
            dataset = CustomDataset(x=torch.from_numpy(np.array(data['x'], dtype=np.float32)), y=torch.from_numpy(np.array(data['y'], dtype=np.int64)), data_transforms=None)
            trainloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers=1)
            for batch_idx, (x, y) in enumerate(trainloader):
                x, y = x.to(device), y.to(device)
                f = model(x).cpu().numpy()
                client_features[-1].append(f)
            client_features[-1] = {'x' : np.concatenate(client_features[-1], axis=0), 'y': data['y']}

    print('All client data transformed')
    return client_features
    
if __name__ == '__main__':
    parser = ArgumentParser(description="parse args")
    parser.add_argument('--pretrained_model_name', default='resnext29', help="Model when creating splits with pretrained model.")
    parser.add_argument('--data_dims', default=(3,32,32), type=tuple, help="Number of data dimensions, e.g. (3,32,32) for CIfAR-10. Used only with pretrained model.")
    parser.add_argument('--add_ones_to_data', default=False, action='store_true', help="Add ones to data dims, only used with pretrained model.")
    parser.add_argument('--use_pretrained_features', default=False, action='store_true', help="Use pretrained model to extract features from image data sets.")
    parser.add_argument('--image_datasets', default=["fashion_mnist"], help="List of image data sets to generate. Possible data sets 'fashion_mnist', 'cifar10', 'cifar10-pretrained', 'income'.")
    parser.add_argument('--dataset_folder', default='data', type=str, help='Main data folder, individual data sets will create subfolders under the main folder.')
    parser.add_argument('--num_clients', default=10, type=int, help='Number of clients for training/val, should be even.')
    parser.add_argument('--num_test_clients', default=0, type=int, help='Number of clients for testing. Should be 0.')
    parser.add_argument('--init_seed', default=42, type=int, help='Random seed for data splitting.')
    args = parser.parse_args()
    if type(args.image_datasets) == str:
        args.image_datasets = [args.image_datasets]
    main(args)