""" """

import glob
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any
)
from tasker import BaseTask
from rapidfuzz.distance import Levenshtein
from langchain_interface.steps import (
    AnswerShorteningStep,
    AnchoredClusteringStep
)
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_interface.steps import Step
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache

from ..data_readers import HyperClusteredDataReader


class IterativeClusterRunner:
    def __init__(
        self,
        step_size: int,
        max_num_step: int,
        llm: BaseLanguageModel,
        exact_match,
        runnable_config,
    ):
        self._step_size = step_size
        self._max_num_step = max_num_step
        self._llm = llm
        self._exact_match = exact_match
        self._runnable_config = runnable_config

        self._anchored_clustering_chain = AnchoredClusteringStep().chain_llm(self._llm)
        
    def run(self, inputs: List[List[Dict[Text, Any]]]) -> List[List[Dict[Text, Any]]]:
        """ Given a set of inputs as described in the
        processed items, run the iteraive clustering to
        form configurations.
        """
        
        def _convert_to_state(lidct: List[Dict[Text, Any]]) -> Dict[Text, Any]:
            return {
                "selected": [lidct[0]["shortened_answer"]],
                "candidates": [item["shortened_answer"] for item in lidct[1:]],
                "candidates_multiplicity": [item["cluster_size"] for item in lidct[1:]],
                "selected_multiplicity": [lidct[0]["cluster_size"]],
                "num_select": self._step_size
            }
        
        states_memory = []
            
        states = list(map(_convert_to_state, inputs))
        states_memory.append(copy.deepcopy(states))
        
        for _ in range(self._max_num_step):
            # filter out the states that needs to be processed.
            filtered_state_ids = [idx for idx, state in enumerate(states) if len(state["candidates"]) > 0]
            if len(filtered_state_ids) == 0:
                break
            
            filtered_states = [
                {"selected": states[idx]["selected"], "candidates": states[idx]["candidates"], "num_select": states[idx]["num_select"]} 
                for idx in filtered_state_ids
            ]
            filtered_selection = [
                {
                    "selected_multiplicity": states[idx]["selected_multiplicity"],
                    "candidates_multiplicity": states[idx]["candidates_multiplicity"]
                }
                for idx in filtered_state_ids
            ]
            filtered_state_updates = []
            response_on_filtered = self._anchored_clustering_chain.batch(filtered_states, config=self._runnable_config)

            for rsf, fstate, fsel in zip(response_on_filtered, filtered_states, filtered_selection):
                candidates = fstate["candidates"]
                increments = rsf.increments

                if increments is None:
                    # Exceptional case
                    increments = candidates[:1]

                if self._exact_match:
                    sindices = set([idx for idx, c in enumerate(candidates) if c in increments])
                    # new_candidates = [item for item in candidates if item not in increments]
                    
                else:
                    # fuzzy matching
                    sindices = set()
                    for s in increments:
                        s_index = np.argmin([Levenshtein.distance(s, c) for c in candidates]).item()
                        sindices.add(s_index)
                        
                new_candidates = [candidates[idx] for idx in range(len(candidates)) if idx not in sindices]
                    
                new_selected = fstate["selected"] + increments
                filtered_state_updates.append({
                    "selected": new_selected,
                    "candidates": new_candidates,
                    "num_select": self._step_size,
                    "selected_multiplicity": fsel["selected_multiplicity"] + [fsel["candidates_multiplicity"][idx] for idx in sindices],
                    "candidates_multiplicity": [fsel["candidates_multiplicity"][idx] for idx in range(len(candidates)) if idx not in sindices]
                })
                
            for idx, fidx in enumerate(filtered_state_ids):
                states[fidx] = filtered_state_updates[idx]
                
            states_memory.append(copy.deepcopy(states))
                
        return states_memory
        

@BaseTask.register("cluster-ordering-task")
class ClusterOrderingTask(BaseTask):
    """Given a set of clusters, order them in a way
    that the cluster is ordered based on similarity.
    """
    
    __VERSION__ = "0.1.2"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        step_size: int,
        max_num_step: int,
        num_samples: int,
        exact_match: bool,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        self._llm = ChatOpenAI(
            temperature=0,
            top_p=1,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )
        self._runnable_config = RunnableConfig(max_concurrency=1)

        self._input_dir = input_dir
        self._num_samples = num_samples
        self._step_size = step_size
        self._max_num_step = max_num_step
        self._exact_match = exact_match

        self._answer_shortening_chain = AnswerShorteningStep().chain_llm(self._llm)
        
        set_llm_cache(SQLiteCache(cache_path))

    @overrides
    def _run(self):
        """ """
        # TODO: taking the partial entailment judgment into account.
        iterator = HyperClusteredDataReader([filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))])
        
        inputs_ = [
            {
                "question": item.question,
                "iter_idx": idx,
                "_id": item._id,
                "cluster_id": cluster._id,
                "answer": cluster.claim,
                "cluster_size": cluster.size
            }
            for idx, item in enumerate(list(filter(lambda x: len(x.clusters) >= 6 and len(x.clusters) <= 10, iterator))[:self._num_samples]) for cluster in item.clusters
        ]  # Making this deterministic for task consistency.
        
        # Step 1: Shorten the answers
        shortening_results = self._answer_shortening_chain.batch([{"question": ipt['question'], "answer": ipt["answer"]} for ipt in inputs_], config=self._runnable_config)
        
        _id_to_shortened = {}
        
        for response, ipt in zip(shortening_results, inputs_):
            shortened = {
                "question": ipt['question'],
                "cluster_id": ipt['cluster_id'],
                "answer": ipt['answer'],
                "shortened_answer": response.short_answer,
                "cluster_size": ipt['cluster_size']
            }
            
            if ipt['_id'] not in _id_to_shortened:
                _id_to_shortened[ipt['_id']] = []

            _id_to_shortened[ipt['_id']].append(shortened)

        # Step 2: incremental_clustering
        # TODO: extend to more general filtering.

        def _shrink(v):
            """ remove literal repetitions """
            visited = {}

            for item in v:
                if item["shortened_answer"].lower() not in visited:
                    visited[item["shortened_answer"].lower()] = item
                else:
                    visited[item["shortened_answer"].lower()]["cluster_size"] += item["cluster_size"]
                    
            return list(visited.values())

        list_of_shortened = [_shrink(v) for v in _id_to_shortened.values()]
        list_of_shortened = [v for v in list_of_shortened if len(v) > 3]
        
        runner = IterativeClusterRunner(
            step_size=self._step_size,
            max_num_step=self._max_num_step,
            llm=self._llm,
            exact_match=self._exact_match,
            runnable_config=self._runnable_config
        )
        
        clustering_result = runner.run(list_of_shortened)
        return clustering_result, list_of_shortened
    
    @overrides
    def _write(self, outputs):
        """ """
        clustering_result, list_of_shortened = outputs

        # Step 3: Save the results
        with open(
            os.path.join(self._output_dir, "iterative_clustering.jsonl"),
            'w', encoding='utf-8'
        ) as file_:
            for lidx, lsrt in enumerate(list_of_shortened):
                cr = [crr[lidx] for crr in clustering_result]
                
                file_.write(json.dumps({
                    "question": lsrt[0]["question"],
                    "answers": [srt['answer'] for srt in lsrt],
                    "rounds": [
                        {
                            "round_idx": ridx,
                            "round_result": round_result
                        } for ridx, round_result in enumerate(cr)
                    ]
                }, ensure_ascii=False) + "\n")