import json
import os
from src.config.models import CrossTranslateConfig
from typing import Any, Dict, List, Tuple, Optional, Union
from pathlib import Path
import numpy as np
from loguru import logger
from rich.progress import Progress
import torch
class ChunkedDatasetPreprocessor:
    def __init__(
        self,
        chunk_size_mb: float = 400.0,
        dtype: np.dtype = np.float32,
        num_workers: int = 4,
        config: CrossTranslateConfig = None
    ):
        self.chunk_size_mb = chunk_size_mb
        self.dtype = dtype
        self.num_workers = num_workers
        self.element_size_bytes = np.dtype(dtype).itemsize
        self.config = config
        self.source_embedding_model_name = config.model.source_model
        self.target_embedding_model_name = config.model.target_model
    def _get_metadata_path(self, chunk_directory_name: str, dataset_name: str) -> str:
        return os.path.join(chunk_directory_name, f"{dataset_name}_metadata.json")
    def check_chunks_exist(self, dataset_names: List[str]) -> bool:
        for dataset_name in dataset_names:
            chunk_directory_name = self._get_chunk_directory_name(dataset_name)
            if not os.path.exists(chunk_directory_name):
                return False
            metadata_path = self._get_metadata_path(chunk_directory_name, dataset_name)
            if not os.path.exists(metadata_path):
                return False
        return True
    def _get_chunk_directory_name(self, dataset_name: str) -> str:
        return os.path.join(self.config.dataset.cache_path, dataset_name)
    def load_metadata(self, dataset_names: List[str]) -> List[Dict[str, Any]]:
        metadata_list = []
        for dataset_name in dataset_names:
            chunk_directory_name = self._get_chunk_directory_name(dataset_name)
            metadata_path = self._get_metadata_path(chunk_directory_name, dataset_name)
            if not os.path.exists(metadata_path):
                raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            metadata_list.append(metadata)
        return metadata_list
    def calculate_chunk_size(
        self,
        total_samples: int,
        embedding_dim: int
    ) -> int:
        bytes_per_sample = embedding_dim * self.element_size_bytes
        target_chunk_bytes = self.chunk_size_mb * 1024 * 1024
        samples_per_chunk = int(target_chunk_bytes / bytes_per_sample)
        samples_per_chunk = max(1, samples_per_chunk)
        logger.info(
            f"Calculated chunk size: {samples_per_chunk} samples/chunk "
            f"(~{samples_per_chunk * bytes_per_sample / 1024 / 1024:.2f} MB)"
        )
        return samples_per_chunk
    def _load_memmap_with_shape(self, file_path: str) -> Tuple[np.ndarray, int, int, np.dtype]:
        original_shape = None
        actual_dtype = self.dtype
        if file_path.endswith('.npy'):
            try:
                with open(file_path, 'rb') as f:
                    np.lib.format.read_magic(f)
                    shape, _, dtype_from_file = np.lib.format.read_array_header_1_0(f)
                    original_shape = shape
                    if dtype_from_file is not None:
                        actual_dtype = dtype_from_file
                        if actual_dtype != self.dtype:
                            logger.warning(f"File dtype ({actual_dtype}) differs from configured dtype ({self.dtype}), using file dtype")
                    header_offset = f.tell() 
            except Exception as e:
                logger.debug(f"Could not read .npy header: {e}")
        if original_shape is not None:
            memmap = np.memmap(file_path, dtype=actual_dtype, mode='r', shape=original_shape, offset=header_offset)
            total_samples = original_shape[0]
            embedding_dim = 1 if len(original_shape) == 1 else original_shape[1]
        else:
            memmap = np.memmap(file_path, dtype=actual_dtype, mode='r', offset=header_offset)
            total_samples = len(memmap) if len(memmap.shape) == 1 else memmap.shape[0]
            embedding_dim = 1 if len(memmap.shape) == 1 else memmap.shape[1]
        return memmap, total_samples, embedding_dim, actual_dtype
    def _save_chunk(
        self,
        source_memmap: np.ndarray,
        start_idx: int,
        end_idx: int,
        output_path: str,
        embedding_dim: int
    ) -> Tuple[int, ...]:
        chunk_data = source_memmap[start_idx:end_idx] if len(source_memmap.shape) == 1 else source_memmap[start_idx:end_idx, :]
        np.save(output_path, chunk_data)
        return chunk_data.shape
    def _build_metadata(
        self,
        dataset_name: str,
        original_shape: Tuple,
        total_samples: int,
        embedding_dim: int,
        samples_per_chunk: int,
        num_chunks: int,
        chunks_metadata: List[Dict[str, Any]]
    ) -> Dict[str, Any]:
        return {
            "dataset_name": dataset_name,
            "original_shape": original_shape,
            "total_samples": total_samples,
            "embedding_dim": embedding_dim,
            "dtype": str(self.dtype),
            "samples_per_chunk": samples_per_chunk,
            "num_chunks": num_chunks,
            "chunks": chunks_metadata
        }
    def _metadata_path(self, dataset_name: str) -> str:
        return os.path.join(self._get_chunk_directory_name(dataset_name), f"{dataset_name}_metadata.json")
    def split_datasets(self, dataset_names: List[str]) -> None:
        for dataset_name in dataset_names:
            self.split_paired_embedding_files(
                source_input_path=os.path.join(self.config.dataset.embedding_path, f"corpus_embeddings_{self.config.model.source_model}_{dataset_name}.npy"),
                target_input_path=os.path.join(self.config.dataset.embedding_path, f"corpus_embeddings_{self.config.model.target_model}_{dataset_name}.npy"),
                output_dir=self._get_chunk_directory_name(dataset_name),
                dataset_name=dataset_name
            )
    def split_paired_embedding_files(
        self,
        source_input_path: str,
        target_input_path: str,
        output_dir: str,
        dataset_name: str
    ) -> Dict[str, Any]:
        logger.info(f"Starting to split paired files: {source_input_path} and {target_input_path}")
        output_dir_path = Path(output_dir)
        output_dir_path.mkdir(parents=True, exist_ok=True)
        source_memmap, source_total_samples, source_embedding_dim, source_dtype = self._load_memmap_with_shape(source_input_path)
        target_memmap, target_total_samples, target_embedding_dim, target_dtype = self._load_memmap_with_shape(target_input_path)
        if source_total_samples != target_total_samples:
            raise ValueError(
                f"Sample count mismatch: source={source_total_samples}, target={target_total_samples}"
            )
        total_samples = source_total_samples
        samples_per_chunk = self.calculate_chunk_size(total_samples, source_embedding_dim)
        num_chunks = (total_samples + samples_per_chunk - 1) // samples_per_chunk
        logger.info(
            f"Total samples: {total_samples}, Source dim: {source_embedding_dim}, "
            f"Target dim: {target_embedding_dim}, Chunks: {num_chunks}"
        )
        source_chunks_metadata: List[Dict[str, Any]] = []
        target_chunks_metadata: List[Dict[str, Any]] = []
        with Progress() as progress:
            task = progress.add_task(f"[green]Splitting {dataset_name} (src+tgt)...", total=num_chunks)
            for chunk_idx in range(num_chunks):
                start_idx = chunk_idx * samples_per_chunk
                end_idx = min(start_idx + samples_per_chunk, total_samples)
                source_chunk_path = output_dir_path / f"corpus_embeddings_{self.source_embedding_model_name}_{dataset_name}_chunk_{chunk_idx:04d}.npy"
                target_chunk_path = output_dir_path / f"corpus_embeddings_{self.target_embedding_model_name}_{dataset_name}_chunk_{chunk_idx:04d}.npy"
                source_shape = self._save_chunk(source_memmap, start_idx, end_idx, str(source_chunk_path), source_embedding_dim)
                target_shape = self._save_chunk(target_memmap, start_idx, end_idx, str(target_chunk_path), target_embedding_dim)
                for chunks_list, shape, path in [
                    (source_chunks_metadata, source_shape, source_chunk_path),
                    (target_chunks_metadata, target_shape, target_chunk_path)
                ]:
                    chunks_list.append({
                        "chunk_idx": chunk_idx,
                        "file_path": str(path),
                        "start_idx": start_idx,
                        "end_idx": end_idx,
                        "shape": shape,
                        "num_samples": shape[0]
                    })
                progress.update(task, advance=1)
        combined_metadata = {
            "dataset_name": dataset_name,
            "total_samples": total_samples,
            "samples_per_chunk": samples_per_chunk,
            "num_chunks": num_chunks,
            "source": {
                "original_shape": source_memmap.shape,
                "embedding_dim": source_embedding_dim,
                "dtype": str(source_dtype),
                "chunks": source_chunks_metadata
            },
            "target": {
                "original_shape": target_memmap.shape,
                "embedding_dim": target_embedding_dim,
                "dtype": str(target_dtype),
                "chunks": target_chunks_metadata
            },
        }
        metadata_path = self._get_metadata_path(output_dir, dataset_name)
        with open(metadata_path, 'w') as f:
            json.dump(combined_metadata, f, indent=2)
        logger.info(f"Successfully split paired files into {num_chunks} chunks. Metadata: {metadata_path}")
        return combined_metadata
class MultiChunkedDatasetLoader:
    def __init__(
        self,
        config: CrossTranslateConfig,
        dataset_names: List[str],
        batch_size: int = 10000,
        shuffle: bool = False,
        num_workers: int = 4,
        use_memmap: bool = False,
        device: str = 'cuda',
        pin_memory: bool = True
    ):
        self.config = config
        self.dataset_names = dataset_names
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.use_memmap = use_memmap
        self.device = device
        self.pin_memory = pin_memory and device == 'cuda'
        self.shuffle = shuffle
        self.preprocessor = ChunkedDatasetPreprocessor(
            config=config,
            chunk_size_mb=400.0,
            num_workers=num_workers
        )
        metadata_list = []
        if self.preprocessor.check_chunks_exist(dataset_names):
            metadata_list = self.preprocessor.load_metadata(dataset_names)
        else:
            self.preprocessor.split_datasets(dataset_names)
            metadata_list = self.preprocessor.load_metadata(dataset_names)
        self._parse_metadata(metadata_list)
    def _parse_metadata(self, metadata_list: List[Dict[str, Any]]) -> None:
        self.dataset_loaders: List[Dict[str, Any]] = []
        self.target_dataset_loaders: List[Dict[str, Any]] = []
        self.offsets: List[int] = [0]
        self.source_embedding_dim: Optional[int] = None
        self.target_embedding_dim: Optional[int] = None
        self.has_target = False
        for meta in metadata_list:
            current_offset = self.offsets[-1]
            if 'source' not in meta or 'target' not in meta:
                raise ValueError(f"Invalid metadata format for dataset '{meta.get('dataset_name', 'unknown')}'")
            self.has_target = True
            source_loader = self._create_loader_dict(
                dataset_name=meta['dataset_name'],
                total_samples=meta['total_samples'],
                data_meta=meta['source'],
                offset=current_offset
            )
            self.dataset_loaders.append(source_loader)
            if self.source_embedding_dim is None:
                self.source_embedding_dim = source_loader['embedding_dim']
            elif self.source_embedding_dim != source_loader['embedding_dim']:
                raise ValueError(
                    f"Source embedding dimension mismatch: {self.source_embedding_dim} vs {source_loader['embedding_dim']}"
                )
            target_loader = self._create_loader_dict(
                dataset_name=meta['dataset_name'],
                total_samples=meta['total_samples'],
                data_meta=meta['target'],
                offset=current_offset
            )
            self.target_dataset_loaders.append(target_loader)
            if self.target_embedding_dim is None:
                self.target_embedding_dim = target_loader['embedding_dim']
            elif self.target_embedding_dim != target_loader['embedding_dim']:
                raise ValueError(
                    f"Target embedding dimension mismatch: {self.target_embedding_dim} vs {target_loader['embedding_dim']}"
                )
            self.offsets.append(self.offsets[-1] + meta['total_samples'])
        self.total_samples = self.offsets[-1]
        logger.info(
            f"Initialized MultiChunkedDatasetLoader: {len(self.dataset_loaders)} datasets, "
            f"total_samples={self.total_samples}, src_dim={self.source_embedding_dim}, "
            f"tgt_dim={self.target_embedding_dim}, has_target={self.has_target}"
        )
    def _create_loader_dict(
        self, 
        dataset_name: str, 
        total_samples: int, 
        data_meta: Dict[str, Any], 
        offset: int
    ) -> Dict[str, Any]:
        embedding_dim = data_meta['embedding_dim']
        chunks = data_meta['chunks']
        dtype_str = data_meta['dtype']
        if dtype_str.startswith("<class 'numpy."):
            dtype_str = dtype_str.replace("<class 'numpy.", "").replace("'>", "")
        elif dtype_str.startswith("<class '"):
            dtype_str = dtype_str.replace("<class '", "").replace("'>", "")
        return {
            'dataset_name': dataset_name,
            'total_samples': total_samples,
            'embedding_dim': embedding_dim,
            'offset': offset,
            'chunks': chunks,
            'dtype': np.dtype(dtype_str)
        }
    def _find_dataset_idx(self, global_idx: int) -> Tuple[int, int]:
        for i, offset in enumerate(self.offsets[:-1]):
            if global_idx < self.offsets[i + 1]:
                return i, global_idx - offset
        raise IndexError(f"Global index {global_idx} out of range [0, {self.total_samples})")
    def _load_from_dataset(
        self, 
        dataset_idx: int, 
        start_idx: int, 
        end_idx: int,
        use_target: bool = False
    ) -> np.ndarray:
        loader = self.target_dataset_loaders[dataset_idx] if use_target else self.dataset_loaders[dataset_idx]
        parts = []
        for chunk in loader['chunks']:
            if chunk['end_idx'] <= start_idx or chunk['start_idx'] >= end_idx:
                continue
            load_start = max(chunk['start_idx'], start_idx) - chunk['start_idx']
            load_end = min(chunk['end_idx'], end_idx) - chunk['start_idx']
            chunk_data = np.load(chunk['file_path'], mmap_mode='r' if self.use_memmap else None)
            parts.append(chunk_data[load_start:load_end].copy())
        return np.concatenate(parts, axis=0) if parts else np.array([])
    def load_batch(
        self,
        start_idx: int,
        end_idx: Optional[int] = None,
        return_target: bool = False
    ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        if end_idx is None:
            end_idx = start_idx + self.batch_size
        end_idx = min(end_idx, self.total_samples)
        start_ds_idx, start_local = self._find_dataset_idx(start_idx)
        end_ds_idx, _ = self._find_dataset_idx(end_idx - 1)
        end_local = end_idx - self.offsets[end_ds_idx]
        def load_range(use_target: bool) -> np.ndarray:
            if start_ds_idx == end_ds_idx:
                return self._load_from_dataset(start_ds_idx, start_local, end_local, use_target)
            loader_list = self.target_dataset_loaders if use_target else self.dataset_loaders
            parts = [self._load_from_dataset(start_ds_idx, start_local, loader_list[start_ds_idx]['total_samples'], use_target)]
            for ds_idx in range(start_ds_idx + 1, end_ds_idx):
                parts.append(self._load_from_dataset(ds_idx, 0, loader_list[ds_idx]['total_samples'], use_target))
            parts.append(self._load_from_dataset(end_ds_idx, 0, end_local, use_target))
            return np.concatenate(parts, axis=0)
        src_data = load_range(False)
        return (src_data, load_range(True)) if return_target else src_data
    def __iter__(self):
        batch_starts = list(range(0, self.total_samples, self.batch_size))
        if self.shuffle:
            np.random.shuffle(batch_starts)
        for start_idx in batch_starts:
            end_idx = min(start_idx + self.batch_size, self.total_samples)
            src_batch, tgt_batch = self.load_batch(start_idx, end_idx, return_target=True)
            src_tensor = torch.from_numpy(src_batch)
            tgt_tensor = torch.from_numpy(tgt_batch)
            if self.pin_memory:
                src_tensor = src_tensor.pin_memory()
                tgt_tensor = tgt_tensor.pin_memory()
            src_tensor = src_tensor.to(self.device, non_blocking=True)
            tgt_tensor = tgt_tensor.to(self.device, non_blocking=True)
            yield src_tensor, tgt_tensor
    def iter_batches(self, shuffle: bool = False, indices: Optional[np.ndarray] = None, return_target: bool = False):
        if indices is None:
            indices = np.arange(self.total_samples)
        if shuffle:
            indices = np.random.permutation(indices)
        for i in range(0, len(indices), self.batch_size):
            batch_indices = indices[i:i + self.batch_size]
            batch_data = self._load_discrete_indices(batch_indices, return_target)
            yield batch_data, batch_indices
    def _load_discrete_indices(self, indices: np.ndarray, return_target: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        def load_indices(use_target: bool) -> np.ndarray:
            loader_list = self.target_dataset_loaders if use_target else self.dataset_loaders
            embedding_dim = self.target_embedding_dim if use_target else self.source_embedding_dim
            result = np.zeros((len(indices), embedding_dim), dtype=loader_list[0]['dtype'])
            for i, idx in enumerate(indices):
                ds_idx, local_idx = self._find_dataset_idx(int(idx))
                for chunk in loader_list[ds_idx]['chunks']:
                    if chunk['start_idx'] <= local_idx < chunk['end_idx']:
                        chunk_data = np.load(chunk['file_path'], mmap_mode='r' if self.use_memmap else None)
                        result[i] = chunk_data[local_idx - chunk['start_idx']]
                        break
            return result
        src_result = load_indices(False)
        return (src_result, load_indices(True)) if return_target else src_result
    def load_all(self, return_target: bool = False) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
        logger.warning("Loading all data - this may cause OOM for large datasets!")
        def load_all_data(use_target: bool) -> np.ndarray:
            loader_list = self.target_dataset_loaders if use_target else self.dataset_loaders
            parts = [np.load(chunk['file_path'], mmap_mode='r' if self.use_memmap else None).copy() 
                     for loader in loader_list for chunk in loader['chunks']]
            return np.concatenate(parts, axis=0)
        src_data = load_all_data(False)
        return (src_data, load_all_data(True)) if return_target else src_data
    def get_dataset_info(self) -> Dict[str, Any]:
        return {
            "num_datasets": len(self.dataset_loaders),
            "total_samples": self.total_samples,
            "datasets": [
                {"name": loader['dataset_name'], "samples": loader['total_samples'], "offset": loader['offset']}
                for loader in self.dataset_loaders
            ]
        }
