import os
import json
import numpy as np
import torch
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union, Any
from loguru import logger
from src.config.models import CrossTranslateConfig
def load_embedding_with_memmap(
    embedding_path: Union[str, Path],
    mmap_mode: str = 'r'
) -> np.ndarray:
    embedding_path = Path(embedding_path)
    if not embedding_path.exists():
        raise FileNotFoundError(f"Embedding file not found: {embedding_path}")
    logger.info(f"Loading embedding with memmap from: {embedding_path}")
    embeddings = np.load(str(embedding_path), mmap_mode=mmap_mode)
    logger.info(f"Successfully loaded embedding with shape: {embeddings.shape}, dtype: {embeddings.dtype}")
    return embeddings
class MemmapEmbeddingDataset:
    def __init__(
        self,
        config: CrossTranslateConfig,
        dataset_name: str,
        d0_ratio: float = 1/3,
        split_strategy: str = "random",
        source_model: str = None,
        target_model: str = None,
    ):
        self.config = config
        self.dataset_name = dataset_name
        self.d0_ratio = d0_ratio
        self.split_strategy = split_strategy
        self.embedding_path = config.dataset.embedding_path
        self.cache_dir = config.dataset.cache_path
        if source_model is None:
            self.source_model = config.model.source_model
        else:
            self.source_model = source_model
        if target_model is None:
            self.target_model = config.model.target_model
        else:
            self.target_model = target_model
        self.cache_dataset_dir = os.path.join(
            self.cache_dir,
            "memmap_cache",
            f"{dataset_name}_{self.source_model}_{self.target_model}"
        )
        Path(self.cache_dataset_dir).mkdir(parents=True, exist_ok=True)
        self.metadata_path = os.path.join(self.cache_dataset_dir, "metadata.json")
        if self._check_cache_exists():
            logger.info(f"Loading from cache: {self.cache_dataset_dir}")
            self._load_from_cache()
        else:
            logger.info(f"Cache not found, generating new data...")
            self._load_embeddings()
            self._generate_and_save_metadata()
    def _load_embeddings(self) -> None:
        logger.info(f"Loading embeddings for dataset: {self.dataset_name}")
        self.source_embeddings = self._load_embedding('corpus', self.source_model)
        self.target_embeddings = self._load_embedding('corpus', self.target_model)
        self.source_query_embeddings = self._load_embedding('query', self.source_model)
        self.target_query_embeddings = self._load_embedding('query', self.target_model)
        if len(self.source_query_embeddings) != len(self.target_query_embeddings):
            min_len = min(len(self.source_query_embeddings), len(self.target_query_embeddings))
            logger.warning(
                f"Query size mismatch: source={len(self.source_query_embeddings)}, "
                f"target={len(self.target_query_embeddings)}, using minimum: {min_len}"
            )
            self.source_query_embeddings = self.source_query_embeddings[:min_len]
            self.target_query_embeddings = self.target_query_embeddings[:min_len]
        self.source_embedding_dim = self.source_embeddings.shape[1]
        self.target_embedding_dim = self.target_embeddings.shape[1]
        logger.info(
            f"Loaded embeddings: corpus={self.source_embeddings.shape}, "
            f"query={self.source_query_embeddings.shape}"
        )
    def _load_embedding(self, emb_type: str, model: str) -> np.ndarray:
        filename = f"{emb_type}_embeddings_{model}_{self.dataset_name}.npy"
        filepath = os.path.join(self.embedding_path, filename)
        return load_embedding_with_memmap(filepath)
    @classmethod
    def get_embedding_path(cls, emb_type: str, model: str, dataset_name: str, embedding_path: str) -> str:
        return os.path.join(embedding_path, f"{emb_type}_embeddings_{model}_{dataset_name}.npy")
    def _check_cache_exists(self) -> bool:
        if not os.path.exists(self.metadata_path):
            return False
        try:
            with open(self.metadata_path, 'r') as f:
                metadata = json.load(f)
            required_fields = ['d0_index', 'd1_index', 'd2_index', 'q2a', 
                             'corpus_size', 'query_size']
            for field in required_fields:
                if field not in metadata:
                    logger.warning(f"Cache missing field: {field}")
                    return False
            return True
        except Exception as e:
            logger.warning(f"Error reading cache metadata: {e}")
            return False
    def _load_from_cache(self) -> None:
        logger.info(f"Loading embeddings for dataset: {self.dataset_name}")
        self.source_embeddings = self._load_embedding('corpus', self.source_model)
        self.target_embeddings = self._load_embedding('corpus', self.target_model)
        self.source_query_embeddings = self._load_embedding('query', self.source_model)
        self.target_query_embeddings = self._load_embedding('query', self.target_model)
        if len(self.source_query_embeddings) != len(self.target_query_embeddings):
            min_len = min(len(self.source_query_embeddings), len(self.target_query_embeddings))
            logger.warning(
                f"Query size mismatch: source={len(self.source_query_embeddings)}, "
                f"target={len(self.target_query_embeddings)}, using minimum: {min_len}"
            )
            self.source_query_embeddings = self.source_query_embeddings[:min_len]
            self.target_query_embeddings = self.target_query_embeddings[:min_len]
        self.source_embedding_dim = self.source_embeddings.shape[1]
        self.target_embedding_dim = self.target_embeddings.shape[1]
        logger.info(
            f"Loaded embeddings: corpus={self.source_embeddings.shape}, "
            f"query={self.source_query_embeddings.shape}"
        )
        logger.info(f"Loading metadata from cache: {self.metadata_path}")
        with open(self.metadata_path, 'r') as f:
            metadata = json.load(f)
        self.d0_index = np.array(metadata['d0_index'], dtype=np.int64)
        self.d1_index = np.array(metadata['d1_index'], dtype=np.int64)
        self.d2_index = np.array(metadata['d2_index'], dtype=np.int64)
        self.q2a = {int(k): v for k, v in metadata['q2a'].items()}
        logger.info(
            f"Loaded from cache - d0: {len(self.d0_index)}, "
            f"d1: {len(self.d1_index)}, d2: {len(self.d2_index)}, "
            f"q2a: {len(self.q2a)} queries"
        )
    def _generate_and_save_metadata(self) -> None:
        logger.info("Generating metadata (q2a, d0, d1, d2)...")
        try:
            from vectormerge.dataset import load_dataset
            text_dataset = load_dataset(self.dataset_name, split="test")
            self.q2a = text_dataset.query_index2answer_index
            logger.info(f"Loaded q2a mapping: {len(self.q2a)} queries")
        except Exception as e:
            logger.warning(f"Could not load q2a mapping: {e}")
            self.q2a = {}
        from src.data_splitting.split_functions import split_data
        p_index_list = []
        for answer_indices in self.q2a.values():
            if isinstance(answer_indices, list):
                p_index_list.extend(answer_indices)
            else:
                p_index_list.append(answer_indices)
        logger.info(
            f"Generating data splits with d0_ratio={self.d0_ratio}, "
            f"strategy={self.split_strategy}"
        )
        self.d0_index, self.d1_index, self.d2_index = split_data(
            corpus_size=len(self.source_embeddings),
            p_index_list=p_index_list,
            d0_ratio=self.d0_ratio,
            corpus_emb_2=self.target_embeddings,
            query_emb_2=self.target_query_embeddings,
            strategy=self.split_strategy
        )
        logger.info(
            f"Data splits generated: d0={len(self.d0_index)}, "
            f"d1={len(self.d1_index)}, d2={len(self.d2_index)}"
        )
        metadata = {
            'dataset_name': self.dataset_name,
            'source_model': self.source_model,
            'target_model': self.target_model,
            'corpus_size': int(len(self.source_embeddings)),
            'query_size': int(len(self.source_query_embeddings)),
            'source_corpus_dim': int(self.source_embedding_dim),
            'target_corpus_dim': int(self.target_embedding_dim),
            'd0_index': self.d0_index.tolist(),
            'd1_index': self.d1_index.tolist(),
            'd2_index': self.d2_index.tolist(),
            'q2a': {str(k): v for k, v in self.q2a.items()},
            'd0_ratio': float(self.d0_ratio),
            'split_strategy': self.split_strategy,
            'cache_version': '1.0'
        }
        logger.info(f"Saving metadata to cache: {self.metadata_path}")
        with open(self.metadata_path, 'w') as f:
            json.dump(metadata, f, indent=2)
        logger.info(f"✓ Cache saved successfully to: {self.cache_dataset_dir}")
    def clear_cache(self) -> None:
        import shutil
        if os.path.exists(self.cache_dataset_dir):
            shutil.rmtree(self.cache_dataset_dir)
            logger.info(f"Cache cleared: {self.cache_dataset_dir}")
        else:
            logger.info("No cache to clear")
    def get_embedding_info(self) -> Dict[str, Any]:
        return {
            'dataset_name': self.dataset_name,
            'cache_dir': self.cache_dataset_dir,
            'source_embeddings_shape': self.source_embeddings.shape,
            'target_embeddings_shape': self.target_embeddings.shape,
            'source_query_embeddings_shape': self.source_query_embeddings.shape,
            'target_query_embeddings_shape': self.target_query_embeddings.shape,
            'source_dtype': self.source_embeddings.dtype,
            'target_dtype': self.target_embeddings.dtype,
            'd0_size': len(self.d0_index),
            'd1_size': len(self.d1_index),
            'd2_size': len(self.d2_index),
            'num_queries': len(self.q2a)
        }
class MultiMemmapDatasetLoader:
    def __init__(
        self,
        config: CrossTranslateConfig,
        dataset_names: List[str],
        batch_size: int = 10000,
        shuffle: bool = False,
        num_workers: int = 0,
        device: str = 'cuda',
        pin_memory: bool = True,
        d0_ratio: float = 1/3,
        split_strategy: str = "random"
    ):
        self.config = config
        self.dataset_names = dataset_names
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.device = device if torch.cuda.is_available() else 'cpu'
        self.pin_memory = pin_memory and self.device == 'cuda' and torch.cuda.is_available()
        logger.info(f"Loading {len(dataset_names)} datasets with memmap...")
        self.datasets: List[MemmapEmbeddingDataset] = []
        self.offsets: List[int] = [0]
        for name in dataset_names:
            dataset = MemmapEmbeddingDataset(
                config=config,
                dataset_name=name,
                d0_ratio=d0_ratio,
                split_strategy=split_strategy
            )
            self.datasets.append(dataset)
            self.offsets.append(self.offsets[-1] + len(dataset.source_embeddings))
        self.total_samples = self.offsets[-1]
        first_dataset = self.datasets[0]
        self.source_embedding_dim = first_dataset.source_embedding_dim
        self.target_embedding_dim = first_dataset.target_embedding_dim
        for dataset in self.datasets[1:]:
            if dataset.source_embedding_dim != self.source_embedding_dim:
                raise ValueError(
                    f"Source embedding dimension mismatch: "
                    f"{dataset.source_embedding_dim} vs {self.source_embedding_dim}"
                )
            if dataset.target_embedding_dim != self.target_embedding_dim:
                raise ValueError(
                    f"Target embedding dimension mismatch: "
                    f"{dataset.target_embedding_dim} vs {self.target_embedding_dim}"
                )
        logger.info(
            f"Initialized MultiMemmapDatasetLoader: "
            f"{len(self.datasets)} datasets, {self.total_samples} total samples, "
            f"dims: {self.source_embedding_dim} → {self.target_embedding_dim}"
        )
    def _find_dataset_idx(self, global_idx: int) -> Tuple[int, int]:
        for i, offset in enumerate(self.offsets[:-1]):
            if global_idx < self.offsets[i + 1]:
                return i, global_idx - offset
        raise IndexError(f"Global index {global_idx} out of range [0, {self.total_samples})")
    def load_batch(
        self,
        start_idx: int,
        end_idx: Optional[int] = None,
        return_target: bool = True
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        if end_idx is None:
            end_idx = min(start_idx + self.batch_size, self.total_samples)
        end_idx = min(end_idx, self.total_samples)
        start_ds_idx, start_local = self._find_dataset_idx(start_idx)
        end_ds_idx, end_local = self._find_dataset_idx(end_idx - 1)
        end_local += 1
        def load_range(use_target: bool) -> np.ndarray:
            if start_ds_idx == end_ds_idx:
                dataset = self.datasets[start_ds_idx]
                data = dataset.target_embeddings if use_target else dataset.source_embeddings
                return data[start_local:end_local]
            else:
                parts = []
                dataset = self.datasets[start_ds_idx]
                data = dataset.target_embeddings if use_target else dataset.source_embeddings
                parts.append(data[start_local:])
                for ds_idx in range(start_ds_idx + 1, end_ds_idx):
                    dataset = self.datasets[ds_idx]
                    data = dataset.target_embeddings if use_target else dataset.source_embeddings
                    parts.append(data[:])
                dataset = self.datasets[end_ds_idx]
                data = dataset.target_embeddings if use_target else dataset.source_embeddings
                parts.append(data[:end_local])
                return np.concatenate(parts, axis=0)
        src_data = load_range(False)
        if return_target:
            tgt_data = load_range(True)
            return src_data, tgt_data
        else:
            return src_data
    def __iter__(self):
        batch_starts = list(range(0, self.total_samples, self.batch_size))
        if self.shuffle:
            np.random.shuffle(batch_starts)
        for start_idx in batch_starts:
            end_idx = min(start_idx + self.batch_size, self.total_samples)
            src_batch, tgt_batch = self.load_batch(start_idx, end_idx, return_target=True)
            src_tensor = torch.from_numpy(src_batch)
            tgt_tensor = torch.from_numpy(tgt_batch)
            if self.pin_memory and torch.cuda.is_available() and self.device == 'cuda':
                src_tensor = src_tensor.pin_memory()
                tgt_tensor = tgt_tensor.pin_memory()
            src_tensor = src_tensor.to(self.device, non_blocking=True)
            tgt_tensor = tgt_tensor.to(self.device, non_blocking=True)
            yield src_tensor, tgt_tensor
    def get_dataset_info(self) -> Dict[str, Any]:
        return {
            'num_datasets': len(self.datasets),
            'total_samples': self.total_samples,
            'source_embedding_dim': self.source_embedding_dim,
            'target_embedding_dim': self.target_embedding_dim,
            'batch_size': self.batch_size,
            'datasets': [
                {
                    'name': dataset.dataset_name,
                    'samples': len(dataset.source_embeddings),
                    'offset': self.offsets[i],
                    'cache_dir': dataset.cache_dataset_dir,
                    'd0_size': len(dataset.d0_index),
                    'd1_size': len(dataset.d1_index),
                    'd2_size': len(dataset.d2_index)
                }
                for i, dataset in enumerate(self.datasets)
            ]
        }
    def get_dataset(self, dataset_name: str) -> MemmapEmbeddingDataset:
        for dataset in self.datasets:
            if dataset.dataset_name == dataset_name:
                return dataset
        raise ValueError(f"Dataset '{dataset_name}' not found")
