""" Instead of cluster answer clusters, we cluster cluster representatives. """

import click

# from accelerate import Accelerator, PartialState
import pickle
import os
import torch
import numpy as np
import ujson as json
import click
from tqdm import tqdm
from src.base_clustering import cluster, get_small_entailment_model


# accelerator = Accelerator()
# partial_state = PartialState()


@click.command()
@click.option("--run-name", type=click.STRING, help="Name of the run.", required=True)
@click.option("--input-data-path", type=click.STRING, help="Path to the input data.", required=True)
@click.option("--output-dir", type=click.Path(exists=False, dir_okay=True, file_okay=False), help="Path to the output directory.", required=True)
def main(
    run_name,
    input_data_path,
    output_dir
):
    """
    """
    model, tokenizer = get_small_entailment_model()
    model.to(torch.device("cuda"))
    # model = accelerator.prepare(model)
    
    with open(input_data_path, 'r', encoding='utf-8') as file_:
        data = [json.loads(line) for line in file_]
        # data = [{"sentences": item['answers'], "_id": item["example_id"]} for item in data]
        data = [{"_id": item['example_id'], "claims": [cluster['claim'] for cluster in item['clusters']]} for item in data]

    # with partial_state.split_between_processes(data) as splitted_data:
    with torch.no_grad():
        entailment_mat_store = []
        _id_to_index = {}
        results = []
        
        for idx, instance in enumerate(tqdm(data)):
            # sentences = instance['sentences']
            claims = instance['claims']
            _id = instance['_id']
            result, entailment_mat = cluster(claims, model, tokenizer, torch.device("cuda"))
            results.append(result)
            entailment_mat_store.append(entailment_mat)
            _id_to_index[_id] = idx

    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, f"result-{run_name}.jsonl"), 'w', encoding='utf-8') as file_:
        for r in results:
            file_.write(json.dumps(r) + "\n")
    
    with open(os.path.join(output_dir, f"idmap-{run_name}.json"), 'w', encoding='utf-8') as file_:
        json.dump(_id_to_index, file_)
        
    # save the entailment matrices with npy
    # np.save(os.path.join(output_dir, f"entailment-{run_name}.npy"), np.vstack(entailment_mat_store))
    pickle.dump(entailment_mat_store, open(os.path.join(output_dir, f"entailment-{run_name}.pkl"), 'wb'))


if __name__ == '__main__':
    main()