"""
"""

from accelerate import Accelerator, PartialState
import os
import torch
import numpy as np
import ujson as json
import click
from tqdm import tqdm
from src.base_clustering import cluster, get_small_entailment_model


accelerator = Accelerator()
partial_state = PartialState()


@click.command()
@click.option("--run-name", type=click.STRING, help="Name of the run.", required=True)
@click.option("--input-data-path", type=click.STRING, help="Path to the input data.", required=True)
@click.option("--output-dir", type=click.Path(exists=False, dir_okay=True, file_okay=False), help="Path to the output directory.", required=True)
def main(
    run_name,
    input_data_path,
    output_dir
):
    """
    """
    model, tokenizer = get_small_entailment_model()
    model = accelerator.prepare(model)
    
    with open(input_data_path, 'r', encoding='utf-8') as file_:
        data = [json.loads(line) for line in file_]
        data = [{"sentences": item['answers'], "_id": item["example_id"]} for item in data]

    with partial_state.split_between_processes(data) as splitted_data:
        with torch.no_grad():
            entailment_mat_store = []
            _id_to_index = {}
            results = []
            
            for idx, instance in enumerate(tqdm(splitted_data)):
                sentences = instance['sentences']
                _id = instance['_id']
                result, entailment_mat = cluster(sentences, model, tokenizer, accelerator.device)
                results.append(result)
                entailment_mat_store.append(entailment_mat)
                _id_to_index[_id] = idx

    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, f"result-{run_name}-{partial_state.process_index:02d}.jsonl"), 'w', encoding='utf-8') as file_:
        for r in results:
            file_.write(json.dumps(r) + "\n")
    
    with open(os.path.join(output_dir, f"idmap-{run_name}-{partial_state.process_index:02d}.json"), 'w', encoding='utf-8') as file_:
        json.dump(_id_to_index, file_)
        
    # save the entailment matrices with npy
    np.save(os.path.join(output_dir, f"entailment-{run_name}-{partial_state.process_index:02d}.npy"), np.vstack(entailment_mat_store))


if __name__ == '__main__':
    main()