from typing import Dict, List, Tuple
import numpy as np
import argparse
from loguru import logger
import wandb
from src.data_splitting.splitters import DataSplitter, RandomDataSplitter, HierarchicalDataSplitter
from src.clustering import Clusterer, KMeansClusterer, HypergraphClusterer, AvgLinkageClusterer
from src.mapper import PAVectorMapper, LinearMapper, VectorMapper
class VectorTranslationHandler:
    def __init__(self, args: argparse.Namespace, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, query_emb_1: np.ndarray, query_emb_2: np.ndarray, p_index_list: List[int]):
        self.args = args
        self.data_splitter = self._create_data_splitter()
        self.corpus_emb_1 = corpus_emb_1
        self.corpus_emb_2 = corpus_emb_2
        self.query_emb_1 = query_emb_1
        self.query_emb_2 = query_emb_2
        self.p_index_list = p_index_list
    def _create_data_splitter(self) -> DataSplitter:
        if self.args.reference_creation_method == "ours":
            return HierarchicalDataSplitter(self.args)
        else:
            return RandomDataSplitter(self.args)
    def run_data_splitting_phase(self, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, 
                               query_emb_2: np.ndarray, p_index_list: List[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        logger.info("🔄 Starting data splitting phase...")
        d0, d1, d2 = self.data_splitter.split(corpus_emb_1, corpus_emb_2, query_emb_2, p_index_list)
        logger.info(f"d0_ratio: {self.args.d0_ratio}, d0_size: {len(d0)}, d1_size: {len(d1)}, d2_size: {len(d2)}")
        wandb.log({
            "d0_ratio": self.args.d0_ratio,
            "d0_size": len(d0),
            "d1_size": len(d1),
            "d2_size": len(d2)
        })
        return d0, d1, d2
