import os
import numpy as np
import torch
from typing import Tuple, List, Optional, Dict, Any
from loguru import logger
from pathlib import Path
class DataProcessor:
    def __init__(
        self,
        model_name_1: str,
        model_name_2: str,
        dataset_name: str,
        cache_dir: str = "./cached_output/embeddings/",
        device: str = "cuda",
        batch_size: int = 8,
        d0_ratio: float = 1/3,
        reference_creation_method: str = "random",
        split_strategy: str = "random"
    ):
        self.model_name_1 = model_name_1
        self.model_name_2 = model_name_2
        self.dataset_name = dataset_name
        self.cache_dir = cache_dir
        self.device = device
        self.batch_size = batch_size
        self.d0_ratio = d0_ratio
        self.reference_creation_method = reference_creation_method
        self.split_strategy = split_strategy
        self.dataset = None
        self.corpus_emb_1 = None
        self.corpus_emb_2 = None
        self.query_emb_1 = None
        self.query_emb_2 = None
        self.p_index_list = None
        self.q_index_list = None
        self.n_index_list = None
        self.d0 = None
        self.d1 = None
        self.d2 = None
    def load_dataset(self, evaluator) -> Any:
        logger.info(f"Loading dataset: {self.dataset_name}")
        self.dataset = evaluator.load_data()
        logger.info("✅ Dataset loaded successfully")
        return self.dataset
    def _generate_cache_keys(self) -> Tuple[str, str, str, str]:
        corpus_emb_1_key = f"corpus_embeddings_{self.model_name_1}_{self.dataset_name}.npy"
        corpus_emb_2_key = f"corpus_embeddings_{self.model_name_2}_{self.dataset_name}.npy"
        query_emb_1_key = f"query_embeddings_{self.model_name_1}_{self.dataset_name}.npy"
        query_emb_2_key = f"query_embeddings_{self.model_name_2}_{self.dataset_name}.npy"
        return corpus_emb_1_key, corpus_emb_2_key, query_emb_1_key, query_emb_2_key
    def load_or_generate_embeddings(self, get_embedding_generator_func, load_cache_func) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
        logger.info("Loading/generating embeddings...")
        corpus_emb_1_key, corpus_emb_2_key, query_emb_1_key, query_emb_2_key = self._generate_cache_keys()
        self.corpus_emb_1 = load_cache_func(self.cache_dir, corpus_emb_1_key)
        self.corpus_emb_2 = load_cache_func(self.cache_dir, corpus_emb_2_key)
        self.query_emb_1 = load_cache_func(self.cache_dir, query_emb_1_key)
        self.query_emb_2 = load_cache_func(self.cache_dir, query_emb_2_key)
        if self.corpus_emb_1 is None or self.query_emb_1 is None:
            logger.info(f"Generating embeddings for {self.model_name_1}")
            embedding_generator_1 = get_embedding_generator_func(
                self.model_name_1, self.dataset_name, self.cache_dir, self.device, self.batch_size
            )
            self.corpus_emb_1 = embedding_generator_1.generate_corpus_embeddings(self.dataset.corpus, corpus_emb_1_key)
            self.query_emb_1 = embedding_generator_1.generate_query_embeddings(self.dataset.queries, query_emb_1_key)
        if self.corpus_emb_2 is None or self.query_emb_2 is None:
            logger.info(f"Generating embeddings for {self.model_name_2}")
            embedding_generator_2 = get_embedding_generator_func(
                self.model_name_2, self.dataset_name, self.cache_dir, self.device, self.batch_size
            )
            self.corpus_emb_2 = embedding_generator_2.generate_corpus_embeddings(self.dataset.corpus, corpus_emb_2_key)
            self.query_emb_2 = embedding_generator_2.generate_query_embeddings(self.dataset.queries, query_emb_2_key)
        logger.info("✅ Embeddings loaded/generated successfully")
        return self.corpus_emb_1, self.corpus_emb_2, self.query_emb_1, self.query_emb_2
    def extract_positive_negative_indices(self, find_positive_negative_indices_func) -> Tuple[List, Tuple[List, List, List], List]:
        logger.info("Extracting positive/negative indices...")
        q_p_n_index_list, (q_index_list, p_index_list, n_index_list), mask_list = find_positive_negative_indices_func(
            self.dataset, self.dataset_name, self.corpus_emb_1, self.corpus_emb_2, self.query_emb_1, self.query_emb_2
        )
        self.q_index_list = q_index_list
        self.p_index_list = p_index_list
        self.n_index_list = n_index_list
        logger.info(f"✅ Found {len(p_index_list)} positive samples")
        return q_p_n_index_list, (q_index_list, p_index_list, n_index_list), mask_list
    def clean_data(self, mask_list: List[int]) -> None:
        if len(mask_list) > 0:
            logger.warning(f"Removing {len(mask_list)} queries from embeddings")
            select_index = [i for i in range(len(self.query_emb_1)) if i not in mask_list]
            self.query_emb_1 = self.query_emb_1[select_index]
            self.query_emb_2 = self.query_emb_2[select_index]
            logger.info("✅ Data cleaning completed")
    def align_dimensions(self) -> int:
        logger.info("Aligning embedding dimensions...")
        max_dim = max(
            self.corpus_emb_1.shape[1], self.corpus_emb_2.shape[1],
            self.query_emb_1.shape[1], self.query_emb_2.shape[1]
        )
        if self.corpus_emb_1.shape[1] < max_dim:
            self.corpus_emb_1 = np.pad(self.corpus_emb_1, ((0, 0), (0, max_dim - self.corpus_emb_1.shape[1])))
        if self.corpus_emb_2.shape[1] < max_dim:
            self.corpus_emb_2 = np.pad(self.corpus_emb_2, ((0, 0), (0, max_dim - self.corpus_emb_2.shape[1])))
        if self.query_emb_1.shape[1] < max_dim:
            self.query_emb_1 = np.pad(self.query_emb_1, ((0, 0), (0, max_dim - self.query_emb_1.shape[1])))
        if self.query_emb_2.shape[1] < max_dim:
            self.query_emb_2 = np.pad(self.query_emb_2, ((0, 0), (0, max_dim - self.query_emb_2.shape[1])))
        logger.info(f"✅ Dimensions aligned to {max_dim}")
        logger.info(f"Shapes: corpus_1={self.corpus_emb_1.shape}, corpus_2={self.corpus_emb_2.shape}, "
                   f"query_1={self.query_emb_1.shape}, query_2={self.query_emb_2.shape}")
        return max_dim
    def split_data(self, hierarchical_kmeans_sampling_func=None, split_data_func=None) -> Tuple[List[int], List[int], List[int]]:
        logger.info(f"Splitting data using method: {self.reference_creation_method}")
        if self.reference_creation_method == "ours" and hierarchical_kmeans_sampling_func is not None:
            self.d0, self.d1, self.d2 = hierarchical_kmeans_sampling_func(
                self.corpus_emb_1, self.p_index_list,
                d0_ratio=self.d0_ratio, layer=3, branch_num=3,
                corpus_emb_2=self.corpus_emb_2, query_emb_2=self.query_emb_2
            )
        elif self.reference_creation_method == "random" and split_data_func is not None:
            self.d0, self.d1, self.d2 = split_data_func(
                len(self.corpus_emb_1), self.p_index_list, self.d0_ratio,
                self.corpus_emb_2, self.query_emb_2, self.split_strategy
            )
        else:
            raise ValueError(f"Invalid reference creation method: {self.reference_creation_method}")
        logger.info(f"✅ Data split completed - D0: {len(self.d0)}, D1: {len(self.d1)}, D2: {len(self.d2)}")
        return self.d0, self.d1, self.d2
    def get_processing_info(self) -> Dict[str, Any]:
        return {
            "model_name_1": self.model_name_1,
            "model_name_2": self.model_name_2,
            "dataset_name": self.dataset_name,
            "d0_ratio": self.d0_ratio,
            "d0_size": len(self.d0) if self.d0 is not None else None,
            "d1_size": len(self.d1) if self.d1 is not None else None,
            "d2_size": len(self.d2) if self.d2 is not None else None,
            "corpus_emb_1_shape": self.corpus_emb_1.shape if self.corpus_emb_1 is not None else None,
            "corpus_emb_2_shape": self.corpus_emb_2.shape if self.corpus_emb_2 is not None else None,
            "query_emb_1_shape": self.query_emb_1.shape if self.query_emb_1 is not None else None,
            "query_emb_2_shape": self.query_emb_2.shape if self.query_emb_2 is not None else None,
            "positive_samples": len(self.p_index_list) if self.p_index_list is not None else None,
        }
    def process_all(
        self,
        evaluator,
        get_embedding_generator_func,
        load_cache_func,
        find_positive_negative_indices_func,
        hierarchical_kmeans_sampling_func=None,
        split_data_func=None
    ) -> Dict[str, Any]:
        logger.info("🚀 Starting complete data processing pipeline...")
        self.load_dataset(evaluator)
        self.load_or_generate_embeddings(get_embedding_generator_func, load_cache_func)
        q_p_n_index_list, indices_tuple, mask_list = self.extract_positive_negative_indices(
            find_positive_negative_indices_func
        )
        self.clean_data(mask_list)
        max_dim = self.align_dimensions()
        self.split_data(hierarchical_kmeans_sampling_func, split_data_func)
        logger.info("✅ Complete data processing pipeline finished successfully!")
        return {
            "dataset": self.dataset,
            "corpus_emb_1": self.corpus_emb_1,
            "corpus_emb_2": self.corpus_emb_2,
            "query_emb_1": self.query_emb_1,
            "query_emb_2": self.query_emb_2,
            "q_p_n_index_list": q_p_n_index_list,
            "q_index_list": self.q_index_list,
            "p_index_list": self.p_index_list,
            "n_index_list": self.n_index_list,
            "d0": self.d0,
            "d1": self.d1,
            "d2": self.d2,
            "max_dim": max_dim,
            "processing_info": self.get_processing_info()
        }
    def __repr__(self) -> str:
        return (
            f"DataProcessor(\n"
            f"  model_1='{self.model_name_1}',\n"
            f"  model_2='{self.model_name_2}',\n"
            f"  dataset='{self.dataset_name}',\n"
            f"  d0_ratio={self.d0_ratio},\n"
            f"  method='{self.reference_creation_method}'\n"
            f")"
        )
