""" Iterative Clustering on the  FActScore dataset """

import glob
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any
)
from tasker import BaseTask
from rapidfuzz.distance import Levenshtein
from langchain_interface.steps import (
    AnswerShorteningStep,
    AnchoredClusteringStep
)
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.runnables.config import RunnableConfig
from langchain_openai import ChatOpenAI
from langchain_interface.steps import Step
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache

from ..data_readers import AnswerClusterAttachedDataReader


class IterativeClusterRunner:
    def __init__(
        self,
        step_size: int,
        max_num_step: int,
        llm: BaseLanguageModel,
        exact_match,
        runnable_config,
    ):
        self._step_size = step_size
        self._max_num_step = max_num_step
        self._llm = llm
        self._exact_match = exact_match
        self._runnable_config = runnable_config

        self._anchored_clustering_chain = AnchoredClusteringStep().chain_llm(self._llm)
        
    def run(self, inputs: List[List[Dict[Text, Any]]]) -> List[List[Dict[Text, Any]]]:
        """ Given a set of inputs as described in the
        processed items, run the iteraive clustering to
        form configurations.
        """
        
        states_memory = []
        states = inputs
        states_memory.append(copy.deepcopy(states))
        
        for _ in range(self._max_num_step):
            # filter out the states that needs to be processed.
            filtered_state_ids = [idx for idx, state in enumerate(states) if len(state["candidates"]) > 0]
            if len(filtered_state_ids) == 0:
                break
            
            filtered_states = [
                {"selected": states[idx]["selected"], "candidates": states[idx]["candidates"], "num_select": states[idx]["num_select"]} 
                for idx in filtered_state_ids
            ]
            filtered_selection = [
                {
                    "selected_multiplicity": states[idx]["selected_multiplicity"],
                    "candidates_multiplicity": states[idx]["candidates_multiplicity"]
                }
                for idx in filtered_state_ids
            ]
            filtered_state_updates = []
            response_on_filtered = self._anchored_clustering_chain.batch(filtered_states, config=self._runnable_config)

            for rsf, fstate, fsel in zip(response_on_filtered, filtered_states, filtered_selection):
                candidates = fstate["candidates"]
                increments = rsf.increments

                if increments is None:
                    # Exceptional case
                    increments = candidates[:1]

                if self._exact_match:
                    sindices = set([idx for idx, c in enumerate(candidates) if c in increments])
                    # new_candidates = [item for item in candidates if item not in increments]
                    
                else:
                    # fuzzy matching
                    sindices = set()
                    for s in increments:
                        s_index = np.argmin([Levenshtein.distance(s, c) for c in candidates]).item()
                        sindices.add(s_index)
                        
                new_candidates = [candidates[idx] for idx in range(len(candidates)) if idx not in sindices]
                    
                new_selected = fstate["selected"] + increments
                filtered_state_updates.append({
                    "selected": new_selected,
                    "candidates": new_candidates,
                    "num_select": self._step_size,
                    "selected_multiplicity": fsel["selected_multiplicity"] + [fsel["candidates_multiplicity"][idx] for idx in sindices],
                    "candidates_multiplicity": [fsel["candidates_multiplicity"][idx] for idx in range(len(candidates)) if idx not in sindices]
                })
                
            for idx, fidx in enumerate(filtered_state_ids):
                states[fidx] = filtered_state_updates[idx]
                
            states_memory.append(copy.deepcopy(states))
                
        return states_memory
        

@BaseTask.register("iterative-clustering")
class IterativeClusteringTask(BaseTask):
    """Given a set of clusters, order them in a way
    that the cluster is ordered based on similarity.
    """
    
    __VERSION__ = "0.2.1"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        step_size: int,
        max_num_step: int,
        num_samples: int,
        exact_match: bool,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        self._llm = ChatOpenAI(
            temperature=0,
            top_p=1,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )
        self._runnable_config = RunnableConfig(max_concurrency=8)

        self._input_dir = input_dir
        self._num_samples = num_samples
        self._step_size = step_size
        self._max_num_step = max_num_step
        self._exact_match = exact_match

        self._answer_shortening_chain = AnswerShorteningStep().chain_llm(self._llm)
        
        set_llm_cache(SQLiteCache(cache_path))

    @overrides
    def _run(self):
        """ """
        
        selections = []
        for num_clusters in [10, 20]:
            selections.extend([selected_item for _, selected_item in zip(range(self._num_samples), filter(lambda x: len(x.clusters) == num_clusters, AnswerClusterAttachedDataReader(
                [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
            )))])
        
        init_states = []
        for item in selections:
            sorted_clusters = sorted(item.clusters, key=lambda x: x.multiplicity, reverse=True)
            init_states.append({
                "selected": [sorted_clusters[0].cluster_name],
                "candidates": [cluster.cluster_name for cluster in sorted_clusters[1:]],
                "num_select": self._step_size,
                "selected_multiplicity": [sorted_clusters[0].multiplicity],
                "candidates_multiplicity": [cluster.multiplicity for cluster in sorted_clusters[1:]]
            })
        
        runner = IterativeClusterRunner(
            step_size=self._step_size,
            max_num_step=self._max_num_step,
            llm=self._llm,
            exact_match=self._exact_match,
            runnable_config=self._runnable_config
        )
        
        clustering_result = runner.run(init_states)
        
        def _is_same(prev_round, current_round):
            # print(prev_round)
            return all([
                prev_round["selected"] == current_round["selected"],
                prev_round["candidates"] == current_round["candidates"],
                prev_round["num_select"] == current_round["num_select"],
                prev_round["selected_multiplicity"] == current_round["selected_multiplicity"],
                prev_round["candidates_multiplicity"] == current_round["candidates_multiplicity"]
            ])

        # combine clustering_result with the selections
        outputs = []
        for idx, item in enumerate(selections):
            item_writing_dict = {
                "question": item.question,
                "topic": item.topic,
                "answer_template": item.answer_template,
                "filtered_answers": item.filtered_answers,
                "rounds": []
            }
            
            for ridx, round in enumerate(clustering_result):
                # append only if different from the previous round
                
                if len(item_writing_dict["rounds"]) == 0 or not _is_same(item_writing_dict["rounds"][-1]['round_result'], round[idx]):
                    item_writing_dict["rounds"].append({
                        "round_idx": ridx,
                        "round_result": round[idx]
                    })
                else:
                    break
                
            outputs.append(item_writing_dict)
        
        # return clustering_result
        return outputs
    
    @overrides
    def _write(self, outputs):
        """ """
        # Step 3: Save the results
        with open(
            os.path.join(self._output_dir, "iterative_clustering.jsonl"),
            'w', encoding='utf-8'
        ) as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")