""" """

import glob
import numpy as np
import ujson as json
import os
import copy
from overrides import overrides
from typing import (
    Text, List, Dict, Any, Optional
)
from tasker import BaseTask
from rapidfuzz.distance import Levenshtein
from langchain_interface.steps import (
    AnchoredClusteringStep
)
from ..cluster_runner import LLMIterativeClusterRunner
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.runnables.config import RunnableConfig
# from langchain_openai import ChatOpenAI
from langchain_interface.models import ChatOpenAIWithBatchAPI
from langchain_interface.steps import Step
from langchain_core.globals import set_llm_cache
from langchain_community.cache import SQLiteCache
from ..data_readers import SimpleQAAttachedDataReader


@BaseTask.register("simpleqa-iterative-clustering")
class SimpleQAIterativeClusteringTask(BaseTask):
    """Given a set of clusters, order them in a way
    that the cluster is ordered based on similarity.
    """
    
    __VERSION__ = "0.2.5"
    
    def __init__(
        self,
        input_dir: Text,
        output_dir: Text,
        max_step_size: int,
        threshold_multiplicities: List[int],
        exact_match: bool,
        cache_path: Text
    ):
        super().__init__(output_dir=output_dir)
        
        # self._num_samples = 20
        
        self._llm = ChatOpenAIWithBatchAPI(
            temperature=0,
            top_p=1,
            model="gpt-4o",
            max_tokens=None,
            verbose=True,
        )
        self._runnable_config = RunnableConfig(max_concurrency=128)
        self._max_step_size = max_step_size

        self._input_dir = input_dir
        self._exact_match = exact_match
        self._threshold_multiplicities = threshold_multiplicities

        set_llm_cache(SQLiteCache(cache_path))

    @overrides
    def _run(self):
        """ """
        
        selections = list(filter(
            lambda x: x.clusters,  # filter out empty clusters
            SimpleQAAttachedDataReader(
                [filepath for filepath in glob.glob(os.path.join(self._input_dir, "*.jsonl"))]
            )
        ))
        
        init_states = []
        for item in selections:
            sorted_clusters = sorted(item.clusters, key=lambda x: x.multiplicity, reverse=True)
            init_states.append({
                # TODO: also include answer type and question
                "selected": [sorted_clusters[0].cluster_name],
                "candidates": [cluster.cluster_name for cluster in sorted_clusters[1:]] if len(sorted_clusters) > 1 else [],
                "num_select": 1,
                "selected_multiplicity": [sorted_clusters[0].multiplicity],
                "candidates_multiplicity": [cluster.multiplicity for cluster in sorted_clusters[1:]] if len(sorted_clusters) > 1 else []
            })
        
        runner = LLMIterativeClusterRunner(
            targets=self._threshold_multiplicities,
            max_step_size=self._max_step_size,
            llm=self._llm,
            exact_match=self._exact_match,
            runnable_config=self._runnable_config
        )
        
        clustering_result = runner.run(init_states)
        
        # combine clustering_result with the selections
        outputs = []
        for idx, item in enumerate(selections):
            
            appends = []
            
            if clustering_result[idx][-1]['candidates']:
                # still have candidates
                appends = [{
                    "selected": clustering_result[idx][-1]['selected'] + clustering_result[idx][-1]['candidates'],
                    "candidates": [],
                    "num_select": len(clustering_result[idx][-1]['candidates']),
                    "selected_multiplicity": clustering_result[idx][-1]['selected_multiplicity'] + clustering_result[idx][-1]['candidates_multiplicity'],
                    "candidates_multiplicity": []
                }]
                
            # if len(clustering_result[idx][0]['selected']) > 1:
            #     prepend = [
            #         {
            #             "selected": [init_states[idx]['selected'][0]],
            #             "candidates": init_states[idx]['candidates'],
            #             "num_select": 1,
            #             "selected_multiplicity": init_states[idx]['selected_multiplicity'],
            #             "candidates_multiplicity": init_states[idx]['candidates_multiplicity']
            #         }
            #     ]

            item_writing_dict = {
                "question": item.question,
                "answer_type": item.answer_type,
                "gold_answer": item.gold_answer,
                "rounds": clustering_result[idx] + appends
            }
                
            outputs.append(item_writing_dict)
        
        # return clustering_result
        return outputs
    
    @overrides
    def _write(self, outputs):
        """ """
        # Step 3: Save the results
        with open(
            os.path.join(self._output_dir, "iterative_clustering.jsonl"),
            'w', encoding='utf-8'
        ) as file_:
            for item in outputs:
                file_.write(json.dumps(item) + "\n")