import os
import json
import logging
from pathlib import Path
from typing import Any
from time import time
from collections import defaultdict

from mteb.abstasks.task_metadata import TaskMetadata
from aoeb.abstasks.LocalRetrieval import LocalRetrieval
from mteb.types import HFSubset, ScoresDict
from mteb._evaluators import RetrievalEvaluator
from mteb._evaluators.retrieval_metrics import make_score_dict
from mteb.models import (
    CrossEncoderProtocol,
    EncoderProtocol,
    MTEBModels,
    SearchCrossEncoderWrapper,
    SearchEncoderWrapper,
    SearchProtocol,
)
from mteb.abstasks.retrieval_dataset_loaders import RetrievalSplitData

logger = logging.getLogger(__name__)

PREFIX = os.environ.get("LOCAL_DATA_PREFIX", "local-data")


class SubsetRetrieval(LocalRetrieval):
    metadata = TaskMetadata(
        name="UselessTaskName",
        description=(
            "Instruction retrieval benchmark: queries include an instruction + query, "
            "corpus contains tool documentation passages."
        ),
        reference="https://example.com/toolret",
        type="Retrieval",
        category="t2t",
        modalities=["text"],
        eval_splits=["test"],
        eval_langs={
            "web": ["eng-Latn"],
            "code": ["eng-Latn"],
            "customized": ["eng-Latn"]
        },
        main_score="ndcg_at_10",
        date=("2024-01-01", "2024-12-31"),
        domains=["Programming", "Web"],
        task_subtypes=[],
        license="cc-by-nc-sa-4.0",
        annotations_creators="derived",
        dialect=[],
        sample_creation="found",
        bibtex_citation="""@inproceedings{your_citation_2024,
title={ToolRet: A Benchmark for Tool Retrieval},
author={...},
booktitle={...},
year={2024}
}""",
        dataset={
            "path": "../datasets/mangopy_toolret1",
            "revision": "1.0"
        },
    )

    def _evaluate_subset(
        self,
        model: MTEBModels,
        data_split: RetrievalSplitData,
        encode_kwargs: dict[str, Any],
        hf_split: str,
        hf_subset: str,
        prediction_folder: Path | None = None,
        **kwargs,
    ) -> ScoresDict:
        corpus = data_split["corpus"]
        queries = data_split["queries"]
        relevant_docs = data_split["relevant_docs"]
        
        original_k_values = list(self.k_values)
        extended_k_values = original_k_values + [len(corpus)]
        
        retriever = RetrievalEvaluator(
            corpus=corpus,
            queries=queries,
            task_metadata=self.metadata,
            hf_split=hf_split,
            hf_subset=hf_subset,
            top_ranked=data_split["top_ranked"],
            top_k=extended_k_values[-1], 
            **kwargs,
        )
        
        if isinstance(model, EncoderProtocol) and not isinstance(model, SearchProtocol):
            search_model = SearchEncoderWrapper(model)
        elif isinstance(model, CrossEncoderProtocol):
            search_model = SearchCrossEncoderWrapper(model)
        elif isinstance(model, SearchProtocol):
            search_model = model
        else:
            raise TypeError(
                f"RetrievalEvaluator expects a SearchInterface, Encoder, or CrossEncoder, got {type(model)}"
            )
        
        logger.info(f"Size of corpus: {len(corpus)}")
        logger.info(f"Size of queries: {len(queries)}")
        
        start_time = time()
        results = retriever(search_model, encode_kwargs=encode_kwargs)
        end_time = time()
        logger.info(f"Time taken to retrieve: {end_time - start_time:.2f} seconds")
        
        qid_docids_map = self._load_query_docids_mapping(hf_subset)
        
        filtered_results = self._filter_results_by_mapping(results, qid_docids_map)
        
        if prediction_folder:
            self._save_task_predictions(
                filtered_results,
                model,
                prediction_folder,
                hf_subset=hf_subset,
                hf_split=hf_split,
            )
        
        logger.info("Running retrieval task - Evaluating retrieval scores...")
        (
            all_scores,
            ndcg,
            _map,
            recall,
            precision,
            naucs,
            mrr,
            naucs_mrr,
            cv_recall,
        ) = retriever.evaluate(
            relevant_docs,
            filtered_results,
            original_k_values, 
            ignore_identical_ids=self.ignore_identical_ids,
            skip_first_result=self.skip_first_result,
        )
        
        task_specific_scores = self.task_specific_scores(
            all_scores,
            relevant_docs,
            filtered_results,
            hf_split=hf_split,
            hf_subset=hf_subset,
        )
        
        logger.info("Running retrieval task - Finished.")
        
        return make_score_dict(
            ndcg,
            _map,
            recall,
            precision,
            mrr,
            naucs,
            naucs_mrr,
            cv_recall,
            task_specific_scores,
            self._previous_results_model_meta,
        )
    
    def _load_query_docids_mapping(self, hf_subset: str) -> dict[str, list[str]]:
        qid_docids_map = {}
        # origin_data_file = self.metadata["dataset"].get("origin_data_file")
        origin_data_file = self.origin_data_file
        
        if not origin_data_file:
            logger.warning("No origin_data_file specified in metadata, skipping filtering")
            return qid_docids_map
        
        if origin_data_file.endswith(".jsonl"):
            file_path = os.path.join(
                PREFIX,
                self.metadata.dataset["path"],
                hf_subset,
                origin_data_file
            )
            
            with open(file_path, "r", encoding="utf-8") as f:
                for line in f:
                    data_point = json.loads(line)
                    qid_docids_map[data_point["query_id"]] = data_point["candidate_doc_ids"]
        else:
            file_path = os.path.join(
                PREFIX,
                self.metadata.dataset["path"],
                origin_data_file
            )
            
            with open(file_path, "r", encoding="utf-8") as f:
                origin_data = json.load(f)
                for data_point in origin_data:
                    qid_docids_map[data_point["question_id"]] = data_point["haystack_session_ids"]
        
        logger.info(f"Loaded {len(qid_docids_map)} query-to-docs mappings from {file_path}")
        return qid_docids_map
    
    def _filter_results_by_mapping(
        self,
        results: dict[str, dict[str, float]],
        qid_docids_map: dict[str, list[str]]
    ) -> dict[str, dict[str, float]]:
        if not qid_docids_map:
            return results
        
        filtered_results = {}
        empty_count = 0
        
        for qid, doc_scores in results.items():
            if qid not in qid_docids_map:
                logger.warning(f"Query ID {qid} not found in mapping, keeping all results")
                filtered_results[qid] = doc_scores
                continue
            
            candidate_docids = set(qid_docids_map[qid])
            
            if not candidate_docids:
                logger.warning(f"Query ID {qid} has empty candidate doc list")
                empty_count += 1
            
            filtered_results[qid] = {
                docid: score
                for docid, score in doc_scores.items()
                if docid in candidate_docids
            }
            
            if not filtered_results[qid]:
                logger.warning(f"Query ID {qid} has no matching documents after filtering")
        
        if empty_count > 0:
            logger.warning(f"Found {empty_count} queries with empty candidate lists")
        
        logger.info(f"Filtered results: {len(filtered_results)} queries")
        return filtered_results