import json
import random
import argparse

# Load data from json file
def load_data(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    return data

# Save data to json file
def save_data(data, file_path):
    with open(file_path, 'w') as file:
        json.dump(data, file, indent=2)


if __name__ == '__main__':
    # Parse arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='glue', choices=['glue', 'advglue', 'advgluepp'],
                        help='Dataset mode: glue, advglue or advgluepp')

    args = parser.parse_args()
    dataset = args.dataset

    # Load the data
    if dataset == 'glue' or dataset == 'advglue':
        filepath = 'data/test_ann.json'
    else:
        filepath = 'data/advglue_plus_plus.json'
    data = load_data(filepath)

    # Print number of samples for each task in the dataset
    # print('Number of samples in the dataset:')
    # for key in data.keys():
    #     print(f'\t{key}\t= {len(data[key])}')

    max_samples = 10000

    num_samples = {
        'total': {},
        'train': {},
        'test': {},
        'holdout': {}
    }

    if dataset == 'glue':
        for key in data.keys():
            data[key] = [sample for sample in data[key] if sample['method'] == 'glue']
            # Randomly sample max_samples samples
            random.shuffle(data[key])
            data[key] = data[key][:max_samples]
    elif dataset == 'advglue':
        for key in data.keys():
            data[key] = [sample for sample in data[key] if sample['method'] != 'glue']
            # Randomly sample max_samples samples
            random.shuffle(data[key])
            data[key] = data[key][:max_samples]
    else:
        # Mode is advglue++
        for key in data.keys():
            data[key] = [sample for sample in data[key]]
            # Randomly sample max_samples samples
            random.shuffle(data[key])
            data[key] = data[key][:max_samples]

    # Print the keys of the data
    print('Number of samples in:')
    for key in data.keys():
        print(f'\t{key}\t= {len(data[key])}')
        num_samples['total'][key] = len(data[key])

    # Create train, test and holdout datasets
    train_data = {}
    test_data = {}
    holdout_data = {}

    train_split = 0.4
    holdout_split = 0.4
    test_split = 1 - train_split - holdout_split

    for key in data.keys():
        samples = data[key]
        random.shuffle(samples)
        train_data[key] = samples[:int(len(samples)*train_split)]
        test_data[key] = samples[int(len(samples)*train_split):int(len(samples)*(train_split+test_split))]
        holdout_data[key] = samples[int(len(samples)*(train_split+test_split)):]

    # Save the datasets
    if dataset == 'glue':
        save_data(train_data, 'data/glue_train.json')
        save_data(test_data, 'data/glue_test.json')
        save_data(holdout_data, 'data/glue_holdout.json')
    elif dataset == 'advglue':
        save_data(train_data, 'data/advglue_train.json')
        save_data(test_data, 'data/advglue_test.json')
        save_data(holdout_data, 'data/advglue_holdout.json')
    else:
        save_data(train_data, 'data/advgluepp_train.json')
        save_data(test_data, 'data/advgluepp_test.json')
        save_data(holdout_data, 'data/advgluepp_holdout.json')

    # Print split statistics
    print('\nTrain split:')
    for key in train_data.keys():
        print(f'\t{key}\t= {len(train_data[key])}')
        num_samples['train'][key] = len(train_data[key])

    print('Test split:')
    for key in test_data.keys():
        print(f'\t{key}\t= {len(test_data[key])}')
        num_samples['test'][key] = len(test_data[key])

    print('Holdout split:')
    for key in holdout_data.keys():
        print(f'\t{key}\t= {len(holdout_data[key])}')
        num_samples['holdout'][key] = len(holdout_data[key])

    # Print the fraction of samples per label for each task in the train split as a json string
    print('\nTrain split label distribution:')
    class_wts = {}
    for key in train_data.keys():
        class_wts[key] = {}
        for sample in train_data[key]:
            label = sample['label']
            if label not in class_wts[key]:
                class_wts[key][label] = 0
            class_wts[key][label] += 1

        total_samples = len(train_data[key])
        for label in class_wts[key].keys():
            class_wts[key][label] = round(class_wts[key][label] / total_samples, 4)

    print(json.dumps(class_wts, indent=4))

    # Save dataset info to json file
    if dataset == 'glue':
        info_file = 'data/glue_info.json'
    elif dataset == 'advglue':
        info_file = 'data/advglue_info.json'
    else:
        info_file = 'data/advgluepp_info.json'

    with open(info_file, 'w') as file:
        json.dump({
            'num_samples': num_samples,
            'class_wts': class_wts
        }, file, indent=2)