""" Attach answer to clusters with the highest similarity. """

import glob
import os
from tqdm import tqdm
import numpy as np
import ujson as json
from tqdm import tqdm
from overrides import overrides
from typing import Text, Dict, Any
from tasker import BaseTask
from ..data_readers import SimpleQAAnswerClusterDataReader
from sentence_transformers import SentenceTransformer


@BaseTask.register("simpleqa-attach-to-clusters")
class SimpleQAAttachToClusters(BaseTask):
    """ """
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
    ):
        super().__init__(output_dir=output_dir)
        self._input_dir = input_dir

        # sentence_lm
        self._model = SentenceTransformer('paraphrase-MiniLM-L6-v2', device='cuda', local_files_only=False)

    @overrides
    def _run(self):
        """ Run the task """
        iterator = SimpleQAAnswerClusterDataReader(
            [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
        )

        results = []
        
        for item in tqdm(list(iterator)):
            answers = item.filtered_answers
            clusters = item.clusters

            if not clusters:
                results.append({
                    "question": item.question,
                    "answer_type": item.answer_type,
                    "gold_answer": item.gold_answer,
                    # "filtered_answers": item.filtered_answers,
                    "clusters": []
                })
                continue
            
            embeddings = self._model.encode(clusters + answers, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)
            cluster_embeddings = embeddings[:len(clusters)]  # (n_clusters, embedding_dim)
            answer_embeddings = embeddings[len(clusters):]  # (n_answers, embedding_dim)
            
            # find the most similar cluster for each answer
            sim = np.matmul(answer_embeddings, cluster_embeddings.T) # (n_answers, n_clusters)
            cluster_indices = np.argmax(sim, axis=1).tolist() # (n_answers, )
            
            # init cluster dict
            cluster_dicts = [{"cluster": cluster, "answers": []} for cluster in clusters]

            # attach answers to clusters
            for aidx, (answer, cluster_idx) in enumerate(zip(answers, cluster_indices)):
                cluster_dicts[cluster_idx]["answers"].append({"answer": answer, "aidx": aidx})

            results.append({
                "question": item.question,
                "answer_type": item.answer_type,
                "gold_answer": item.gold_answer,
                # "sampled_answers": item.sampled_answers,
                "clusters": cluster_dicts
            })

        return results
    
    @overrides
    def _write(self, outputs):
        """ Write the outputs """
        with open(os.path.join(self._output_dir, "output.jsonl"), "w") as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")