from typing import Dict, Tuple, Optional
import fasttext
import os
import logging
import tqdm
from pathlib import Path
from dataclasses import dataclass
from beir.datasets.data_loader import GenericDataLoader
from beir import util as beir_util
from beir.retrieval import models
from beir import LoggingHandler
import numpy as np
logging.basicConfig(
    format='%(asctime)s - %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)
@dataclass
class DatasetConfig:
    name: str
    path: str
    split: str = "test"
    base_url: str = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets"
    @property
    def url(self) -> str:
        return f"{self.base_url}/{self.name}.zip"
class FastTextTrainer:
    def __init__(self, work_dir: str):
        self.work_dir = Path(work_dir)
        self.work_dir.mkdir(parents=True, exist_ok=True)
    def train_model(
        self, 
        corpus_path: str, 
        model_path: str,
        dim: int = 300,
        lr: float = 0.025,
        window_size: int = 5,
        epochs: int = 10,
        min_count: int = 1
    ) -> None:
        model = fasttext.train_unsupervised(
            corpus_path,
            model='skipgram',
            dim=dim,
            lr=lr,
            ws=window_size,
            epoch=epochs,
            minCount=min_count
        )
        model.save_model(model_path)
        logger.info(f"Model saved to {model_path}")
    def generate_embeddings(self, model_path: str, corpus: Dict, embeddings_path: str) -> None:
        model = fasttext.load_model(model_path)
        embeddings = []
        for doc_id, doc in tqdm.tqdm(corpus.items()):
            text = f"{doc.get('title', '')} {doc.get('text', '')}"
            embedding = model.get_word_vector(text.strip())
            embeddings.append(embedding)
        np.save(embeddings_path, np.array(embeddings))
        logger.info(f"Embeddings saved to {embeddings_path}")
class DatasetLoader:
    SUPPORTED_DATASETS = {
        "scifact": "train",
        "nfcorpus": "test",
        "cqadupstack": "test",
        "arguana": "test",
        "scidocs": "test",
        "fiqa": "test",
        "signal1m": "test",
        "fever": "test"
    }
    def __init__(self, dataset_path: str):
        self.dataset_path = Path(dataset_path)
    def load_dataset(self, dataset_name: str, split: str = "test") -> Tuple[Dict, Dict, Dict]:
        if dataset_name not in self.SUPPORTED_DATASETS:
            raise ValueError(f"Dataset {dataset_name} is not supported. "
                           f"Supported datasets: {list(self.SUPPORTED_DATASETS.keys())}")
        config = DatasetConfig(
            name=dataset_name,
            path=str(self.dataset_path),
            split=split
        )
        logger.info(f"Loading dataset: {dataset_name} from {config.path}")
        data_path = self._prepare_dataset_path(config)
        data_loader = GenericDataLoader(data_folder=data_path)
        return data_loader.load(split=config.split)
    def _prepare_dataset_path(self, config: DatasetConfig) -> str:
        out_dir = self.dataset_path
        if not out_dir.exists():
            raise ValueError(f"Dataset directory {out_dir} does not exist.")
        data_path = beir_util.download_and_unzip(config.url, str(out_dir))
        if config.name == "cqadupstack":
            return os.path.join(data_path, "english")
        return data_path
def clean_corpus(corpus: Dict[str, Dict[str, str]]) -> str:
    cleaned_texts = []
    for doc_id, doc in corpus.items():
        text = f"{doc.get('title', '')} {doc.get('text', '')}"
        cleaned_texts.append(text.strip())
    return "\n".join(cleaned_texts)
def main():
    TRAIN_DATASETS = ["fiqa", "scifact", "nfcorpus", "arguana"]
    TARGET_DATASETS = ["fever"]
    DATASET_PATH = "data/raw/beir/"
    WORK_PATH = "data/processed/"
    work_dir = Path(WORK_PATH)
    models_dir = work_dir / "models"
    vectors_dir = work_dir / "embeddings"
    dataset_loader = DatasetLoader(DATASET_PATH)
    trainer = FastTextTrainer(str(work_dir))
    for train_dataset in TRAIN_DATASETS:
        logger.info(f"Processing training dataset: {train_dataset}")
        model_path = models_dir / f"fasttext_{train_dataset}.model"
        corpus_path = work_dir / f"fasttext_{train_dataset}_corpus.txt"
        if not model_path.exists():
            if not corpus_path.exists():
                train_corpus, _, _ = dataset_loader.load_dataset(train_dataset, split="train")
                cleaned_corpus = clean_corpus(train_corpus)
                corpus_path.write_text(cleaned_corpus, encoding='utf-8')
                logger.info(f"Training corpus saved to {corpus_path}")
            trainer.train_model(str(corpus_path), str(model_path))
        for target_dataset in TARGET_DATASETS:
            logger.info(f"Generating embeddings for target dataset: {target_dataset}")
            target_corpus, _, _ = dataset_loader.load_dataset(target_dataset)
            embeddings_path = vectors_dir / f"fasttext_{train_dataset}_on_{target_dataset}_embeddings.npy"
            trainer.generate_embeddings(
                str(model_path),
                target_corpus,
                str(embeddings_path)
            )
        if corpus_path.exists():
            corpus_path.unlink()
if __name__ == "__main__":
    main()
