from dataclasses import dataclass
from typing import List, Optional, Protocol
import numpy as np
from loguru import logger
from src.utils import load_cache
from .data_types import EmbeddingPair, EmbeddingDataset
class CacheLoader(Protocol):
    def load(self, key: str) -> Optional[np.ndarray]: ...
@dataclass
class EmbeddingConfig:
    cache_dir: str
    model_name_1: str
    model_name_2: str
    train_dataset_list: List[str]
    test_dataset: str
class EmbeddingLoader:
    def __init__(self, config: EmbeddingConfig, cache_loader: Optional[CacheLoader] = None):
        self.config = config
        self._load_cache = cache_loader.load if cache_loader else self._default_load
    def _default_load(self, key: str) -> Optional[np.ndarray]:
        return load_cache(self.config.cache_dir, key)
    def _load_embedding_pair(self, dataset_name: str) -> EmbeddingPair:
        return EmbeddingPair(
            corpus_emb_1=self._load_cache(f"corpus_embeddings_{self.config.model_name_1}_{dataset_name}.npy"),
            corpus_emb_2=self._load_cache(f"corpus_embeddings_{self.config.model_name_2}_{dataset_name}.npy"),
            query_emb_1=self._load_cache(f"query_embeddings_{self.config.model_name_1}_{dataset_name}.npy"),
            query_emb_2=self._load_cache(f"query_embeddings_{self.config.model_name_2}_{dataset_name}.npy"),
        )
    def _merge_pairs(self, pairs: List[EmbeddingPair]) -> EmbeddingPair:
        return EmbeddingPair(
            corpus_emb_1=np.concatenate([p.corpus_emb_1 for p in pairs], axis=0),
            corpus_emb_2=np.concatenate([p.corpus_emb_2 for p in pairs], axis=0),
            query_emb_1=np.concatenate([p.query_emb_1 for p in pairs], axis=0),
            query_emb_2=np.concatenate([p.query_emb_2 for p in pairs], axis=0),
        )
    def load(self) -> tuple[EmbeddingDataset, EmbeddingDataset]:
        train_pairs = [self._load_embedding_pair(name) for name in self.config.train_dataset_list]
        train_pair = self._merge_pairs(train_pairs)
        test_pair = self._load_embedding_pair(self.config.test_dataset)
        return (
            EmbeddingDataset(pair=train_pair),
            EmbeddingDataset(pair=test_pair),
        )
def load_or_generate_embeddings(args) -> tuple[EmbeddingDataset, EmbeddingDataset]:
    config = EmbeddingConfig(
        cache_dir=args.embeddings_cache_dir,
        model_name_1=args.model_name_1,
        model_name_2=args.model_name_2,
        train_dataset_list=args.train_dataset_list,
        test_dataset=args.test_dataset,
    )
    return EmbeddingLoader(config).load()
