""" Score a claim using FActScore Agg. """

from tasker import BaseTask
import numpy as np
import os
from typing import Text, Union, List
try:
    import ujson as json
except ImportError:
    import json
from overrides import overrides
from ..factual_scorer import LLMSupportScorer
from ..retriever import WikiDocRetriever
from ..utils.instances import ScorerInstance


@BaseTask.register("factscore-claim")
class FactScoreClaimTask(BaseTask):
    __VERSION__ = "0.0.1"

    def __init__(
        self,
        input_file_paths: Union[Text, List[Text]],
        output_dir: Text
    ):
        super().__init__(output_dir=output_dir)
        
        self._input_file_paths = input_file_paths
        if isinstance(self._input_file_paths, str):
            self._input_file_paths = [self._input_file_paths]
        self._scorer = LLMSupportScorer(
            model_name="meta-llama/Meta-Llama-3-8B-Instruct",
            retriever=WikiDocRetriever(
                db_path="db/enwiki-20230401.db",
                cache_path="db/.cache/bios.cache",
                embed_cache_path="db/.cache/bios.embed.cache",
                batch_size=128
            ),
            base_url="http://localhost:22659/v1",
            api_key="token-abc123"
        )

    @overrides
    def _run(self):

        data = []
        
        num_claims = []
        instances = []

        for input_file_path in self._input_file_paths:
            with open(input_file_path, 'r', encoding='utf-8') as file_:
                data.extend([json.loads(line) for line in file_])

        for didx, dp in enumerate(data):
            topic = dp['topic']
            source_text = dp['response']
            claims = dp['all_claims']

            # This is because we know that the number of claims will stay the same
            if not claims:
                continue
            
            num_claims.append((didx, len(claims)))
            instances.extend([
                ScorerInstance(
                    text=claim,
                    source_text=source_text,
                    topic=topic
                )
                for claim in claims
            ])
                
        scoring = self._scorer(instance=instances, return_raw=True)
        
        scored = []

        for idx, nc in num_claims:
            all_scores_instance = [s['parsed'] for s in scoring[:nc]]
            score = np.mean(all_scores_instance)
            scoring = scoring[nc:]
            
            scored.append({
                "topic": data[idx]['topic'],
                "fs": (score if len(all_scores_instance) >= 10 else np.exp(1 - 10 / len(all_scores_instance)) * score).item(),
                "claims": [
                    {
                        "claim": claim,
                        "score": claim_score
                    }
                    for claim, claim_score in zip(data[idx]['all_claims'], all_scores_instance)
                ]
            })
            
        return scored
    
    @overrides
    def _write(self, outputs):

        with open(os.path.join(self._output_dir, 'factscore.json'), 'w', encoding='utf-8') as file_:
            json.dump({
                "total_score": np.mean([s['fs'] for s in outputs]),
                "instances": outputs
            }, file_, ensure_ascii=False, indent=2)