import argparse
from data import get_data_loader
from model import get_model


def get_defalt_data_stats():
    data_stats = {'max_command_length': 7, 'max_world_length': 12,
                  'grid_size': 6,
                  'word_index': {'walk': 0, 'to': 1, 'a': 2, 'yellow': 3,
                                 'small': 4, 'cylinder': 5, 'hesitantly': 6,
                                 'while spinning': 7, 'while zigzagging': 8,
                                 'circle': 9, 'big': 10, 'green': 11,
                                 'square': 12, 'red': 13, 'blue': 14,
                                 'push': 15, 'pull': 16, 'cautiously': 17,
                                 '<PAD>': 18},
                  'action_index': {'walk': 0, 'end': 1, 'push': 2, 'pull': 3},
                  'manner_index': {'hesitantly': 0, 'while spinning': 1, '': 2,
                                   'cautiously': 3}}
    return data_stats


def get_configs(data_stats):
    common = ['max_command_length', 'max_world_length', 'grid_size']
    configs = {x: data_stats[x] for x in common}
    configs['command_voc'] = len(data_stats['word_index'])
    configs['n_directions'] = 4
    configs['n_actions'] = len(data_stats['action_index'])
    configs['n_manners'] = len(data_stats['manner_index'])
    configs['position_vec_size'] = 2
    configs['attribute_vec_size'] = 11
    return configs


def experiment(args, configs, dl):
    model = get_model(args.model_name)
    model.initialize(args, configs)
    model.training(dl)


def main(args):
    # data_stats = get_defalt_data_stats()
    # model = get_model(args.model_name)
    # model.initialize(args, configs)
    # return

    if args.length:
        input_file = 'processed_data/target_length.txt'
    else:
        input_file = 'processed_data/compositional.txt'

    dl = get_data_loader('rotation', args)
    dl.load_saved(input_file)

    data_stats = dl.get_data_stats()
    print(data_stats)
    configs = get_configs(data_stats)

    for i in range(args.repeat_experiments):
        print('Running experiment', i)
        experiment(args, configs, dl)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--repeat_experiments', type=int, default=1,
                        help='number of repeating experiments.')
    parser.add_argument('--max_gradient_norm', type=float, default=1.0,
                        help='max gradient norm')
    parser.add_argument('--model_name', type=str, default='proposed',
                        help='model name')
    parser.add_argument('--n_queries', type=int, default=3,
                        help='number of queries for grounding network.')
    parser.add_argument('--remove_entreg', action='store_true', default=False,
                        help='Linear max.')
    parser.add_argument('--length', action='store_true', default=False,
                        help='Use target length split.')
    parser.add_argument('--k_shot', type=int, default=1,
                        help='k-shot.')
    args = parser.parse_args()
    main(args)
