from __future__ import annotations

import logging
from collections import Counter, defaultdict
from typing import Any

import numpy as np
import tqdm
import json
import logging
import os
from collections import defaultdict
from pathlib import Path
from time import time
import time as time_module
from typing import Any

from mteb.encoder_interface import Encoder

from ..evaluation.evaluators import (
    kNNClassificationEvaluator,
    kNNClassificationEvaluatorPytorch,
    logRegClassificationEvaluator,
)
from ..load_results.mteb_results import HFSubset, ScoresDict
from .AbsTask import AbsTask, DescriptiveStatistics

import random
import numpy as np
from pyserini.search import SimpleSearcher
from pyserini.index.lucene import LuceneIndexer

logger = logging.getLogger(__name__)


class ClassificationDescriptiveStatistics(DescriptiveStatistics):
    """Descriptive statistics for Classification

    Attributes:
      num_samples: number of samples in the dataset.
      average_text_length: Average length of text
      unique_labels: Number of unique labels
      labels: dict of label frequencies
    """

    num_samples: int
    average_text_length: float
    unique_labels: int
    labels: dict[str, dict[str, int]]


class AbsTaskClassification(AbsTask):
    """Abstract class for kNN classification tasks
    The similarity is computed between pairs and the results are ranked.

    self.load_data() must generate a huggingface dataset with a split matching self.metadata_dict["eval_splits"], and assign it to self.dataset. It
    must contain the following columns:
        text: str
        label: int
    """

    def __init__(
        self,
        method: str = "logReg",
        n_experiments: int | None = None,
        samples_per_label: int | None = None,
        k: int = 3,
        index_data_path=None, custom_index_dir=None, random_ic_prompt=False, doc_only=False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.method = method

        # Bootstrap parameters
        self.n_experiments: int = (  # type: ignore
            n_experiments
            if n_experiments is not None
            else self.metadata_dict.get("n_experiments", 10)
        )
        self.samples_per_label: int = (  # type: ignore
            samples_per_label
            if samples_per_label is not None
            else self.metadata_dict.get("samples_per_label", 8)
        )

        # kNN parameters
        self.k = k

        self.index_data_path = index_data_path
        self.custom_index_dir = custom_index_dir
        self.random_ic_prompt = random_ic_prompt
        self.doc_only = doc_only

    def _add_main_score(self, scores: dict[HFSubset, ScoresDict]) -> None:
        scores["main_score"] = scores[self.metadata.main_score]

    def evaluate(
        self,
        model,
        eval_split: str = "test",
        train_split: str = "train",
        *,
        encode_kwargs: dict[str, Any] = {},
        **kwargs,
    ) -> dict[HFSubset, ScoresDict]:
        if not self.data_loaded:
            self.load_data()

        scores = {}
        hf_subsets = list(self.dataset) if self.is_multilingual else ["default"]

        for hf_subset in hf_subsets:
            logger.info(
                f"\nTask: {self.metadata.name}, split: {eval_split}, subset: {hf_subset}. Running..."
            )

            if hf_subset not in self.dataset and hf_subset == "default":
                ds = self.dataset
            else:
                ds = self.dataset[hf_subset]
            scores[hf_subset] = self._evaluate_subset(
                model,
                ds,
                eval_split,
                train_split,
                encode_kwargs=encode_kwargs,
                **kwargs,
            )
            self._add_main_score(scores[hf_subset])

        return scores

    def _evaluate_subset(
        self,
        model: Encoder,
        dataset,
        eval_split: str = "test",
        train_split: str = "train",
        encode_kwargs: dict[str, Any] = {},
        **kwargs,
    ) -> ScoresDict:

        if self.n_ic_examples > 0:
            index_dir = os.path.join(os.environ.get("TRANSFORMERS_CACHE", "temp_index_eval"), "temp_index_eval", self.metadata_dict["dataset"]["path"].split("/")[-1])
            dataset = self.encode_queries_with_ic_data(dataset, index_dir)

        train_split = dataset[train_split]
        eval_split = dataset[eval_split]
        params = {"k": self.k}
        params.update(kwargs)
        
        scores = []
        test_cache, idxs = (
            None,
            None,
        )  # we store idxs to make the shuffling reproducible
        for i in range(self.n_experiments):
            logger.info(
                "=" * 10 + f" Experiment {i+1}/{self.n_experiments} " + "=" * 10
            )
            # Bootstrap `self.samples_per_label` samples per label for each split
            X_sampled, y_sampled, idxs = self._undersample_data(
                train_split["text"],  # type: ignore
                train_split["label"],  # type: ignore
                self.samples_per_label,
                idxs,
            )

            if self.method == "kNN":
                evaluator = kNNClassificationEvaluator(
                    X_sampled,
                    y_sampled,
                    eval_split["text"],  # type: ignore
                    eval_split["label"],  # type: ignore
                    task_name=self.metadata.name,
                    encode_kwargs=encode_kwargs,
                    **params,
                )
            elif self.method == "kNN-pytorch":
                evaluator = kNNClassificationEvaluatorPytorch(
                    X_sampled,
                    y_sampled,
                    eval_split["text"],  # type: ignore
                    eval_split["label"],  # type: ignore
                    task_name=self.metadata.name,
                    encode_kwargs=encode_kwargs,
                    **params,
                )
            elif self.method == "logReg":
                evaluator = logRegClassificationEvaluator(
                    X_sampled,
                    y_sampled,
                    eval_split["text"],  # type: ignore
                    eval_split["label"],  # type: ignore
                    task_name=self.metadata.name,
                    encode_kwargs=encode_kwargs,
                    **params,
                )
            else:
                raise ValueError(f"Method {self.method} not supported")

            scores_exp, test_cache = evaluator(model, test_cache=test_cache)
            scores.append(scores_exp)

        avg_scores: dict[str, Any] = {
            k: np.mean([s[k] for s in scores]) for k in scores[0].keys()
        }
        avg_scores["scores_per_experiment"] = scores
        return avg_scores

    def _undersample_data(self, X, y, samples_per_label: int, idxs=None):
        """Undersample data to have samples_per_label samples of each label"""
        X_sampled = []
        y_sampled = []
        if idxs is None:
            idxs = np.arange(len(y))
        np.random.shuffle(idxs)
        label_counter = defaultdict(int)
        for i in idxs:
            if label_counter[y[i]] < samples_per_label:
                X_sampled.append(X[i])
                y_sampled.append(y[i])
                label_counter[y[i]] += 1
        return X_sampled, y_sampled, idxs

    def calculate_metadata_metrics(
        self,
    ) -> dict[
        str,
        ClassificationDescriptiveStatistics
        | dict[str, ClassificationDescriptiveStatistics],
    ]:
        self.load_data()

        # same function from parent class, but added explicitly train to splits

        all_details = {}
        pbar_split = tqdm.tqdm(
            self.metadata.eval_splits + ["train"], desc="Processing Splits..."
        )
        for split in pbar_split:
            pbar_split.set_postfix_str(f"Split: {split}")
            print(f"Processing metadata for split {split}")
            if self.is_multilingual:
                all_details[split] = self._calculate_metrics_from_split(
                    split, compute_overall=True
                )
                all_details[split]["hf_subset_descriptive_stats"] = {}

                pbar_subset = tqdm.tqdm(
                    self.metadata.eval_langs, desc="Processing Languages..."
                )
                for hf_subset in pbar_subset:
                    pbar_subset.set_postfix_str(f"Language: {hf_subset}")
                    print(f"Processing metadata for language {hf_subset}")
                    split_details = self._calculate_metrics_from_split(split, hf_subset)
                    all_details[split][hf_subset] = split_details
            else:
                split_details = self._calculate_metrics_from_split(split)
                all_details[split] = split_details

        return all_details

    def _calculate_metrics_from_split(
        self, split: str, hf_subset: str | None = None, compute_overall: bool = False
    ) -> ClassificationDescriptiveStatistics:
        if hf_subset:
            text = self.dataset[hf_subset][split]["text"]
            label = self.dataset[hf_subset][split]["label"]
        elif compute_overall:
            text = []
            label = []
            for hf_subset in self.metadata.eval_langs:
                text.extend(self.dataset[hf_subset][split]["text"])
                label.extend(self.dataset[hf_subset][split]["label"])
        else:
            text = self.dataset[split]["text"]
            label = self.dataset[split]["label"]

        total_text_len = sum([len(t) for t in text])
        label_count = Counter(label)
        return ClassificationDescriptiveStatistics(
            num_samples=len(text),
            average_text_length=total_text_len / len(text),
            unique_labels=len(label_count),
            labels={
                str(label): {"count": count} for label, count in label_count.items()
            },
        )
    
    def encode_queries_with_ic_data(self, dataset, index_dir, **kwargs):
        """
        Encode queries with in-context data.

        Args:
        queries: list of queries [str]
        ic_data: list of in-context data [list of dicts]
        searcher: SimpleSearcher instance with pre-encoded index
        top_n: number of similar in-context examples to select

        Returns:
        encoded_queries: list of encoded queries [np.ndarray]
        """
        ic_data = self.construct_ic_dataset(dataset)
        
        if self.random_ic_prompt:
            print("Using random/fixed in-context prompts...")
            encoded_queries = self.process_ic_data(dataset, ic_data, searcher=None, top_n=self.n_ic_examples)
        else:
            searcher = self.encode_ic_data(ic_data, index_dir=index_dir)
            encoded_queries = self.process_ic_data(dataset, ic_data, searcher=searcher, top_n=self.n_ic_examples)

        return encoded_queries
    
    def construct_ic_dataset(self, dataset):
        dataset_train = dataset["train"]
        ic_data = {}
        for idx, row in enumerate(dataset_train):
            query = row['text']
            ic_example = {}
            ic_example["id"] = str(idx)
            ic_example["positive_ctxs"] = []
            max_rel_doc = {}
            max_rel_doc["title"] = ''
            max_rel_doc["text"] = str(row["label"])
            ic_example["positive_ctxs"].append(max_rel_doc)
            ic_example["negative_ctxs"] = []

            min_rel_row = random.choice(dataset_train)
            min_rel_doc = {}
            min_rel_doc["title"] = ''
            min_rel_doc["text"] = str(min_rel_row['label'])
            ic_example["negative_ctxs"].append(min_rel_doc)
            ic_example["question"] = "Query: " + query
            ic_data[str(idx)] = ic_example 
        
        return ic_data
    
    def encode_ic_data(self, ic_data, index_dir='temp_index'):
        """
        Encode the in-context data using pyserini and keep the index in memory.

        Args:
        ic_data: list of in-context data [list of dicts]

        Returns:
        searcher: SimpleSearcher instance with in-memory index
        ic_questions: List of questions from the in-context data
        """
        #ic_questions = [item["question"] for item in ic_data]
        
        documents = [{"id": f"doc{idx}", "contents": item["question"]} for idx, item in ic_data.items()]
        if os.path.exists(index_dir):
            searcher = SimpleSearcher(index_dir)
        else:
            os.makedirs(index_dir, exist_ok=True)
            max_wait_time = 60  # Maximum wait time in seconds

            start_time = time()

            # Loop to handle write lock with timeout
            while True:
                try:
                    # Attempt to create an indexer
                    indexer = LuceneIndexer(index_dir=index_dir, threads=1)
                    
                    # Add documents to the indexer
                    indexer.add_batch_dict(documents)
                    
                    # Close the indexer after adding documents
                    indexer.close()
                    
                    break  # Exit loop if successful
                except Exception as e:
                    # Check if the error is due to the write lock
                    if "Lock obtain timed out" in str(e):
                        elapsed_time = time() - start_time
                        if elapsed_time >= max_wait_time:
                            raise TimeoutError(f"Exceeded maximum wait time of {max_wait_time} seconds for lock release.")
                        print("Write lock detected, waiting for release...")
                        time_module.sleep(1)  # Wait for 1 second before retrying
                    else:
                        raise e  # Raise other exceptions

            # Proceed to search after the index is updated
            searcher = SimpleSearcher(index_dir)
        
        return searcher

    def process_ic_data(self, dataset_dict, ic_data, searcher=None, top_n=5, qrels=None):
        """
        Add in-context data to the test split of a DatasetDict based on similarity using BM25.
        Args:
        dataset_dict: HuggingFace DatasetDict containing 'text' field and 'test' split
        ic_data: dict of in-context data [dict of dicts]
        searcher: SimpleSearcher instance with pre-encoded index
        top_n: number of similar in-context examples to select
        qrels: query relevance judgments (optional)
        Returns:
        new_dataset_dict: DatasetDict with modified text field in test split, other splits unchanged
        """
        def process_example(example):
            query = example['text']
            new_query = ""
            prefix = ""
            
            if top_n > 0:
                if searcher is None:
                    keys = list(ic_data.keys())
                    random.shuffle(keys)
                    top_idxs = keys[:top_n]
                else:
                    hits = searcher.search(query, k=top_n+100)
                    hits = [hit for hit in hits if hit.docid.split('doc')[-1]]
                    top_idxs = [hit.docid.split('doc')[-1] for hit in hits]
                    if qrels is not None:
                        relevant_docs = qrels.get(str(example.get('idx', '')), {})
                        relevant_docs = [doc for doc in relevant_docs.keys()]
                        doc_ids = {idx: ic_data[idx]["id"] for idx in top_idxs}
                        top_idxs_filtered = [idx for idx in top_idxs if doc_ids[idx] not in relevant_docs]
                        print(f"Initial: {len(top_idxs)}")
                        print(f"Diff: {len(top_idxs) - len(top_idxs_filtered)}")
                        top_idxs = top_idxs_filtered
                
                top_idxs = top_idxs[:top_n]
                top_idxs = top_idxs[::-1]
                
                for k, idx in enumerate(top_idxs):
                    idx = str(idx)
                    question = ic_data[idx]["question"]
                    positive_ctx = ic_data[idx]["positive_ctxs"][0]["title"] + ". " + ic_data[idx]["positive_ctxs"][0]["text"]
                    if k == 0:
                        new_query += f"{question}\nLabel: {positive_ctx}\n"
                    else:
                        if self.doc_only:
                            new_query += f"Label: {positive_ctx}\n"
                        else:
                            new_query += f"Query: {question}\nLabel: {positive_ctx}\n"
            
            new_query += "Query: " + query
            
            # Create a new dictionary with all original fields
            updated_example = dict(example)
            # Update only the text field
            updated_example['text'] = new_query
            return updated_example

        # Create a copy of the dataset dictionary
        new_dataset_dict = dataset_dict.copy()
        
        new_dataset_dict['test'] = new_dataset_dict['test'].map(process_example)
        
        return new_dataset_dict
