""" Use back_ref_id etc. to propagate topic to all related objects. """

import click
from time import time
import os
from typing import List, Dict, Any, Text
import ujson as json
from tasker.caching.task_state_cache import TaskStateCache


def _check_downstream_validity(downstream: List[Dict[Text, Any]]) -> bool:
    return all(["back_ref_id" in item for item in downstream])


def _check_upstream_validity(upstream: List[Dict[Text, Any]]) -> bool:
    return all(["topic" in item for item in upstream])


# update the downstream
__TASK_STATE_CACHE__ = TaskStateCache(cache_dir=os.environ.get("TASK_STATE_CACHE_DIR", "./"))


@click.command()
@click.option('--input-path', type=click.Path(exists=True), help='input-path.', required=True)
@click.option('--output-path', type=click.Path(exists=True), help='output-path.', required=True)
@click.option('--config-path', type=click.Path(exists=True), help='config-path.', required=True)
def main(
    input_path,
    output_path,
    config_path
):
    with open(input_path, 'r', encoding='utf-8') as file_:
        upstream = [json.loads(line) for line in file_]
        
    with open(output_path, 'r', encoding='utf-8') as file_:
        downstream = [json.loads(line) for line in file_]
        
    # query the task state cache
    task_name, last_accessed, version = __TASK_STATE_CACHE__.query(config_path)
        
    if (
        _check_upstream_validity(upstream) and
        _check_downstream_validity(downstream)
    ):
        
        topic_enhanced_downstream = [
            {
                **item,
                "topic": upstream[item['back_ref_id']]["topic"],
            } for item in downstream
        ]
        
        # overwrite the output file
        with open(output_path, 'w', encoding='utf-8') as file_:
            for item in topic_enhanced_downstream:
                file_.write(json.dumps(item) + "\n")
                
    elif len(upstream) == len(downstream):
        # if the upstream is the same length as the downstream, then we can assume that the downstream is in the same order as the upstream
        topic_enhanced_downstream = [
            {
                **downstream[i],
                "topic": upstream[i]["topic"],
            } for i in range(len(downstream))
        ]
        
        with open(output_path, 'w', encoding='utf-8') as file_:
            for item in topic_enhanced_downstream:
                file_.write(json.dumps(item) + "\n")
                
    else:
        raise ValueError("The downstream and upstream are not compatible.")
                
    # refresh the task state cache
    if task_name is not None:
        # refresh the cache
        print("Refreshing the cache.")
        __TASK_STATE_CACHE__.refresh_cache(
            config_path=config_path,
            task_name=task_name,
            # change the last_accessd time to the current time (propagation as pseudo-access)
            last_accessed=time(),
            version=version
        )
            
            
if __name__ == "__main__":
    main()