from abc import ABC, abstractmethod
import numpy as np
from typing import List, Tuple
import argparse
from loguru import logger
from src.data_splitting.split_functions import split_data, hierarchical_kmeans_sampling
class DataSplitter(ABC):
    def __init__(self, args: argparse.Namespace):
        self.args = args
    @abstractmethod
    def split(self, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, 
              query_emb_2: np.ndarray, p_index_list: List[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        pass
class RandomDataSplitter(DataSplitter):
    def split(self, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, 
              query_emb_2: np.ndarray, p_index_list: List[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        return split_data(
            len(corpus_emb_1), p_index_list, self.args.d0_ratio, 
            corpus_emb_2, query_emb_2, self.args.split_strategy)
class HierarchicalDataSplitter(DataSplitter):
    def split(self, corpus_emb_1: np.ndarray, corpus_emb_2: np.ndarray, 
              query_emb_2: np.ndarray, p_index_list: List[int]) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        return hierarchical_kmeans_sampling(
            corpus_emb_1, p_index_list, d0_ratio=self.args.d0_ratio, 
            layer=3, branch_num=3, corpus_emb_2=corpus_emb_2, query_emb_2=query_emb_2)
