import click
import logging
import ujson as json
from collections import namedtuple

logger = logging.getLogger(__file__)


def run(args):
    key_list = args.keys.split(",")
    with open(args.base_ds) as fd:
        base_data = json.load(fd)
    
    look_up = {k:{} for k in key_list}
    for item in base_data:
        for key in key_list:
            look_up[key][item['id']] = item[key]


    with open(args.input) as fd:
        path_lines = fd.readlines()
        path_lines = [line.strip() for line in path_lines]
    
    headers = ["file_path"] + key_list

    csv_lines = []
    for path_line in path_lines:
        with open(path_line) as fd:
            data = json.load(fd)
        csv_line = [path_line]
        for key in key_list:
            values = [look_up[key][item['id']] for item in data]
            avg_value = sum(values) / (1.*len(values))
            csv_line.append(avg_value)
        csv_lines.append(csv_line)
    
    with open(args.output, "w") as fd:
        fd.write(",".join(headers) + "\n")
        for csv_line in csv_lines:
            fd.write(",".join([str(item) for item in csv_line]) + "\n")
        


@click.command()
@click.option('-i', '--input', required=True)
@click.option('-k', '--keys', default="input_length,output_length,understandability,naturalness,coherence,reward,first_round_mtld,knn_6")
@click.option('-b', '--base_ds', required=True)
@click.option('-o', '--output', required=True)
def main(**kwargs):
    Arg = namedtuple('Arg', kwargs.keys())
    args = Arg(**kwargs)
    logger.info(kwargs)
    run(args)


if __name__ == '__main__':
    logging.basicConfig(filename=__file__ + '.log',
                        filemode='a',
                        format='[%(levelname)-5.5s][%(asctime)s][%(filename)s %(lineno)d]: %(message)s',
                        datefmt='%d-%m-%Y %H:%M:%S',
                        level=logging.INFO)
    logging.getLogger().addHandler(logging.StreamHandler())
    main()