from typing import Dict, Tuple, Optional, List
import fasttext
import os
import logging
import tqdm
import argparse
from pathlib import Path
from dataclasses import dataclass
from beir.datasets.data_loader import GenericDataLoader
from beir import util as beir_util
from beir.retrieval import models
from beir import LoggingHandler
import numpy as np
from embedding_generator import get_embedding_generator
logging.basicConfig(
    format='%(asctime)s - %(message)s',
    dateummt='%Y-%m-%d %H:%M:%S',
    level=logging.INFO,
    handlers=[LoggingHandler()]
)
logger = logging.getLogger(__name__)
@dataclass
class DatasetConfig:
    name: str
    path: str
    split: str = "test"
    base_url: str = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets"
    @property
    def url(self) -> str:
        return f"{self.base_url}/{self.name}.zip"
class FastTextTrainer:
    def __init__(self, work_dir: str):
        self.work_dir = Path(work_dir)
        self.work_dir.mkdir(parents=True, exist_ok=True)
    def train_model(
        self, 
        corpus_path: str, 
        model_path: str,
        dim: int = 300,
        lr: float = 0.025,
        window_size: int = 5,
        epochs: int = 10,
        min_count: int = 1
    ) -> None:
        model = fasttext.train_unsupervised(
            corpus_path,
            model='skipgram',
            dim=dim,
            lr=lr,
            ws=window_size,
            epoch=epochs,
            minCount=min_count
        )
        model.save_model(model_path)
        logger.info(f"Model saved to {model_path}")
    def generate_embeddings(self, model_path: str, corpus: Dict, embeddings_path: str) -> None:
        model = fasttext.load_model(model_path)
        embeddings = []
        for doc_id, doc in tqdm.tqdm(corpus.items()):
            text = f"{doc.get('title', '')} {doc.get('text', '')}"
            embedding = model.get_word_vector(text.strip())
            embeddings.append(embedding)
        np.save(embeddings_path, np.array(embeddings))
        logger.info(f"Embeddings saved to {embeddings_path}")
class DatasetLoader:
    SUPPORTED_DATASETS = {
        "scifact": "train",
        "nfcorpus": "test",
        "cqadupstack": "test",
        "arguana": "test",
        "scidocs": "test",
        "fiqa": "test",
        "signal1m": "test",
        "fever": "test"
    }
    def __init__(self, dataset_path: str):
        self.dataset_path = Path(dataset_path)
    def load_dataset(self, dataset_name: str, split: str = "test") -> Tuple[Dict, Dict, Dict]:
        if dataset_name not in self.SUPPORTED_DATASETS:
            raise ValueError(f"Dataset {dataset_name} is not supported. "
                           f"Supported datasets: {list(self.SUPPORTED_DATASETS.keys())}")
        config = DatasetConfig(
            name=dataset_name,
            path=str(self.dataset_path),
            split=split
        )
        logger.info(f"Loading dataset: {dataset_name} from {config.path}")
        data_path = self._prepare_dataset_path(config)
        data_loader = GenericDataLoader(data_folder=data_path)
        return data_loader.load(split=config.split)
    def _prepare_dataset_path(self, config: DatasetConfig) -> str:
        out_dir = self.dataset_path
        if not out_dir.exists():
            raise ValueError(f"Dataset directory {out_dir} does not exist.")
        data_path = beir_util.download_and_unzip(config.url, str(out_dir))
        if config.name == "cqadupstack":
            return os.path.join(data_path, "english")
        return data_path
def clean_corpus(corpus: Dict[str, Dict[str, str]]) -> str:
    cleaned_texts = []
    for doc_id, doc in corpus.items():
        text = f"{doc.get('title', '')} {doc.get('text', '')}"
        cleaned_texts.append(text.strip())
    return "\n".join(cleaned_texts)
def parse_args():
    parser = argparse.ArgumentParser(description='Generate embeddings for datasets')
    parser.add_argument('--model', type=str, default="mistral", help='Name of the embedding model to use (e.g., fasttext, sbert)')
    parser.add_argument('--datasets', type=str, nargs='+', default="scifact",
                      help='List of dataset names to process')
    parser.add_argument('--dataset-path', type=str, default='data/raw/beir/',
                      help='Path to the raw datasets directory')
    parser.add_argument('--output-dir', type=str, default='data/processed/embeddings/',
                      help='Directory to save the generated embeddings')
    parser.add_argument('--model-path', type=str,
                      help='Path to the pre-trained model (if required)')
    parser.add_argument('--batch-size', type=int, default=32,
                      help='Batch size for embedding generation')
    parser.add_argument('--device', type=str, default='cuda',
                      help='Device to use for computation (cuda/cpu)')
    parser.add_argument('--dim', type=int, default=300,
                      help='Dimension of embeddings (for FastText)')
    return parser.parse_args()
def main():
    args = parse_args()
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    dataset_loader = DatasetLoader(args.dataset_path)
    embedding_generator = get_embedding_generator(
        model_name=args.model,
        dataset_name=args.datasets,
        cache_dir=output_dir,
        device=args.device,
        batch_size=args.batch_size,
    )
    if isinstance(args.datasets, str):
        args.datasets = [args.datasets]
    for dataset_name in args.datasets:
        logger.info(f"Processing dataset: {dataset_name}")
        corpus, _, _ = dataset_loader.load_dataset(dataset_name, split="test")
        text_list = []
        for doc_id, doc in tqdm.tqdm(corpus.items(), desc="Preparing texts"):
            text = f"{doc.get('title', '')} {doc.get('text', '')}"
            text_list.append(text)
        cache_key = f"corpus_embeddings_{args.model}_{dataset_name}.npy"
        embedding_generator.generate_embeddings(text_list, 
                                                cache_key=cache_key)
        logger.info(f"Successfully generated embeddings for {dataset_name}")
if __name__ == "__main__":
    main()
