""" """

import asyncio
import glob
import copy
import numpy as np
import ujson as json
from overrides import overrides
from typing import (
    Text, List, Dict, Any, Optional
)
from rapidfuzz.distance import Levenshtein
from langchain_interface.steps import (
    AnchoredClusteringStep
)
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.runnables.config import RunnableConfig
# from langchain_core.globals import set_llm_cache
# from langchain_community.cache import SQLiteCache
from .base_cluster_runner import BaseClusterRunner
from ..data_readers import SimpleQAAttachedDataReader


@BaseClusterRunner.register("llm-iterative-clusterer")
class LLMIterativeClusterRunner(BaseClusterRunner):
    def __init__(
        self,
        targets: List[int],
        max_step_size: int,
        llm: BaseLanguageModel,
        exact_match: bool,
        runnable_config: RunnableConfig,
    ):
        super().__init__(targets=sorted(targets, reverse=False))
        self._max_step_size = max_step_size
        self._llm = llm
        self._exact_match = exact_match
        self._runnable_config = runnable_config
        self._anchored_clustering_chain = AnchoredClusteringStep().chain_llm(self._llm)
        
    @overrides
    def run(self, inputs: List[List[Dict[Text, Any]]]) -> List[List[Dict[Text, Any]]]:
        """ Given a set of inputs as described in the
        processed items, run the iteraive clustering to
        form configurations.
        
        {
            "selected": List[Text],
            "candidates": List[Text],
            "selected_multiplicity": List[int],
            "candidates_multiplicity": List[int],
        }
        """

        # available: List[int] = np.arange(len(inputs), dtype=np.int32).tolist()
        
        def _calculate_targ_step_size(state: Dict[Text, Any]) -> Optional[int]:
            """ """
            selected_total_multiplicity = sum(state[-1]['selected_multiplicity'])
            # find the first non-satisfied target
            for target in self._targets:
                if selected_total_multiplicity < target:
                    targ_mul = target - selected_total_multiplicity
                    if state[-1]["candidates_multiplicity"]:
                        exp_single_selection_mul = np.mean(state[-1]["candidates_multiplicity"]).item()
                        proposal = int(targ_mul // exp_single_selection_mul) + (1 if (targ_mul % exp_single_selection_mul) > 1e-9 else 0)
                        return min(proposal, self._max_step_size, len(state[-1]['candidates']))
                    else:
                        return None
                
            return None
        
        states = [[ipt] for ipt in inputs]
        
        while True:
            targ_indices = []
            prepared_inputs = []
            for tidx, (maybe_size, state) in enumerate(zip(map(_calculate_targ_step_size, states), states)):
                if maybe_size is not None:
                    targ_indices.append(tidx)
                    prepared_inputs.append({
                        "selected": state[-1]['selected'],
                        "num_select": maybe_size,
                        "candidates": state[-1]['candidates'],
                    })

            if not targ_indices:
                break
            
            results = asyncio.run(self._anchored_clustering_chain.abatch(prepared_inputs, config=self._runnable_config))
            # results = self._anchored_clustering_chain.batch(prepared_inputs, config=self._runnable_config)
            
            for tidx, pipt, result in zip(targ_indices, prepared_inputs, results):
                print("-" * 20)
                print(states[tidx][-1]['candidates'], len(states[tidx][-1]['candidates']))
                candidates = pipt["candidates"]
                increments = result.increments
                
                print(result.messages)
                print(increments)
                
                if result.increments is None:
                    increments = state[-1]['candidates'][:pipt["num_select"]]
                    
                if len(increments) > pipt['num_select']:
                    print("Warning: More increments than expected")
                    increments = increments[:pipt['num_select']]
                
                    
                if self._exact_match:
                    sindices = list(set([idx for idx, c in enumerate(candidates) if c in increments]))
                    if len(sindices) < pipt["num_select"]:
                        not_sindices = [idx for idx in range(len(candidates)) if idx not in sindices][:pipt["num_select"] - len(sindices)]
                        sindices.extend(not_sindices)

                else:
                    sindices = set()
                    processing: List[int] = np.arange(len(candidates), dtype=np.int32).tolist()
                    for s in increments:
                        s_index = np.argmin([Levenshtein.distance(s, candidates[p]) for p in processing]).item()
                        processing.pop(s_index)
                        sindices.add(s_index)
                        
                    print(sindices)
                    print([candidates[s] for s in sindices])
                        
                    if len(sindices) < pipt["num_select"]:
                        sindices = list(sindices) + processing[:pipt["num_select"] - len(sindices)]
                    else:
                        sindices = list(sindices)
                        
                    sindices = sorted(sindices, reverse=False)
                        
                assert len(sindices) == pipt["num_select"], f"Expected {pipt['num_select']} selections, but got {len(sindices)}"
                
                print(sindices)
                print(states[tidx][-1]['candidates'], len(states[tidx][-1]['candidates']))
                print(states[tidx][-1]['candidates_multiplicity'], len(states[tidx][-1]['candidates_multiplicity']))
                print(candidates, len(candidates))
                print([candidates[idx] for idx in range(len(candidates)) if idx not in sindices])
                print("-" * 20)
                
                # update the state
                states[tidx].append({
                    "selected": pipt["selected"] + [candidates[idx] for idx in sindices],
                    "candidates": [candidates[idx] for idx in range(len(candidates)) if idx not in sindices],
                    "num_select": pipt["num_select"],
                    "selected_multiplicity": states[tidx][-1]["selected_multiplicity"] + [states[tidx][-1]["candidates_multiplicity"][idx] for idx in sindices],
                    "candidates_multiplicity": [states[tidx][-1]["candidates_multiplicity"][idx] for idx in range(len(candidates)) if idx not in sindices]
                })
                
        return states