"""HDF5 loader for precomputed hidden states with backward compatibility."""

import h5py
import torch
import pickle
import numpy as np
from typing import Dict, Any, Optional, List
from functools import lru_cache
from pathlib import Path

class HDF5PrecomputedData:
    """
    HDF5 loader that provides the same interface as pickle-loaded precomputed data.

    This class mimics the structure of pickle-loaded data to maintain backward compatibility:
    {
        'eval_dataset': dataset,
        'inputs_embeds': list_of_tensors,
        'args': args_object,
        'tokenizer_config': dict
    }
    """

    def __init__(self, hdf5_path: str, cache_size: int = 1000):
        """
        Args:
            hdf5_path: Path to HDF5 file
            cache_size: Number of embeddings to keep in memory cache
        """
        self.hdf5_path = hdf5_path
        self.cache_size = cache_size
        self._file_handle = None
        self._metadata = None
        self._eval_dataset = None
        self._args = None
        self._tokenizer_config = None

        # Initialize metadata
        self._load_metadata()

    def _get_file_handle(self):
        """Get file handle, opening if necessary."""
        if self._file_handle is None:
            self._file_handle = h5py.File(self.hdf5_path, 'r')
        return self._file_handle

    def _load_metadata(self):
        """Load metadata from HDF5 file."""
        with h5py.File(self.hdf5_path, 'r') as f:
            self._metadata = {}

            # Load embeddings metadata
            if 'inputs_embeds' in f:
                emb_group = f['inputs_embeds']
                self._metadata['num_examples'] = emb_group.attrs.get('num_examples', 0)
                self._metadata['sample_shape'] = emb_group.attrs.get('sample_shape', None)
                self._metadata['sample_dtype'] = emb_group.attrs.get('sample_dtype', None)

            # Load other metadata
            for key in f.attrs.keys():
                self._metadata[key] = f.attrs[key]

    @lru_cache(maxsize=1000)  # Fixed cache size for now
    def _get_embedding(self, idx: int) -> torch.Tensor:
        """
        Load embedding by index with caching.

        Args:
            idx: Index of embedding to load

        Returns:
            Tensor with the embedding in original dtype
        """
        f = self._get_file_handle()

        # Load from HDF5
        embedding_data = f['inputs_embeds'][str(idx)][:]
        embedding_tensor = torch.from_numpy(embedding_data)

        # Auto-convert back to original dtype
        original_dtype_str = self._metadata.get('sample_dtype', '')
        if 'bfloat16' in original_dtype_str:
            embedding_tensor = embedding_tensor.to(dtype=torch.bfloat16)
        elif 'float16' in original_dtype_str:
            embedding_tensor = embedding_tensor.to(dtype=torch.float16)

        return embedding_tensor

    @property
    def eval_dataset(self):
        """Load and cache the dataset object."""
        if self._eval_dataset is None:
            f = self._get_file_handle()
            if 'eval_dataset_pickle' in f:
                dataset_array = f['eval_dataset_pickle'][:]
                dataset_bytes = dataset_array.tobytes()
                self._eval_dataset = pickle.loads(dataset_bytes)
        return self._eval_dataset

    @property
    def args(self):
        """Load and cache the args object."""
        if self._args is None:
            f = self._get_file_handle()
            if 'args_pickle' in f:
                args_array = f['args_pickle'][:]
                args_bytes = args_array.tobytes()
                self._args = pickle.loads(args_bytes)
        return self._args

    @property
    def tokenizer_config(self):
        """Load and cache tokenizer config."""
        if self._tokenizer_config is None:
            f = self._get_file_handle()
            if 'tokenizer_config' in f:
                config = {}
                for key in f['tokenizer_config'].attrs.keys():
                    config[key] = f['tokenizer_config'].attrs[key]
                self._tokenizer_config = config
            else:
                self._tokenizer_config = {}
        return self._tokenizer_config

    @property
    def inputs_embeds(self):
        """
        Provide list-like interface to embeddings for backward compatibility.

        Returns a proxy object that behaves like a list but loads on-demand.
        """
        return HDF5EmbeddingsList(self)

    def __getitem__(self, key: str):
        """Provide dict-like access for backward compatibility."""
        if key == 'eval_dataset':
            return self.eval_dataset
        elif key == 'inputs_embeds':
            return self.inputs_embeds
        elif key == 'args':
            return self.args
        elif key == 'tokenizer_config':
            return self.tokenizer_config
        else:
            raise KeyError(f"Key '{key}' not found")

    def __setitem__(self, key: str, value):
        """Provide dict-like assignment for backward compatibility."""
        if key == 'eval_dataset':
            self._eval_dataset = value
        elif key == 'inputs_embeds':
            # For inputs_embeds, we need to update the underlying list proxy
            if hasattr(value, '__len__') and hasattr(value, '__getitem__'):
                # Create a new list proxy that wraps the assigned value
                self._inputs_embeds = HDF5EmbeddingsList(self)
                self._inputs_embeds._override_data = value
            else:
                raise ValueError("inputs_embeds must be list-like")
        elif key == 'args':
            self._args = value
        elif key == 'tokenizer_config':
            self._tokenizer_config = value
        else:
            raise KeyError(f"Key '{key}' not found")

    def __contains__(self, key: str) -> bool:
        """Check if key exists."""
        return key in ['eval_dataset', 'inputs_embeds', 'args', 'tokenizer_config']

    def keys(self):
        """Return available keys."""
        return ['eval_dataset', 'inputs_embeds', 'args', 'tokenizer_config']

    def close(self):
        """Close file handle."""
        if self._file_handle is not None:
            self._file_handle.close()
            self._file_handle = None

    def __del__(self):
        """Cleanup file handle."""
        self.close()


class HDF5EmbeddingsList:
    """
    List-like proxy for HDF5 embeddings that loads on-demand.

    Provides the same interface as a list of tensors for backward compatibility.
    """

    def __init__(self, hdf5_data: HDF5PrecomputedData):
        self.hdf5_data = hdf5_data
        self._override_data = None  # For when data gets reassigned

    def __len__(self) -> int:
        """Return number of embeddings."""
        if self._override_data is not None:
            return len(self._override_data)
        return self.hdf5_data._metadata.get('num_examples', 0)

    def __getitem__(self, key):
        """Enhanced getitem that supports slicing and indexing."""
        if self._override_data is not None:
            return self._override_data[key]

        if isinstance(key, slice):
            # Handle slicing
            start, stop, step = key.indices(len(self))
            return [self.hdf5_data._get_embedding(i) for i in range(start, stop, step)]
        else:
            # Handle single index
            if key < 0:
                key = len(self) + key  # Handle negative indexing

            if key < 0 or key >= len(self):
                raise IndexError(f"Index {key} out of range for {len(self)} embeddings")

            return self.hdf5_data._get_embedding(key)

    def __iter__(self):
        """Iterate over all embeddings."""
        if self._override_data is not None:
            return iter(self._override_data)

        for i in range(len(self)):
            yield self[i]

    def __setitem__(self, idx: int, value):
        """Set embedding by index (only works with override data)."""
        if self._override_data is None:
            raise ValueError("Cannot modify HDF5 data directly. Assign new data first.")
        self._override_data[idx] = value


def load_precomputed_data(file_path: str) -> Dict[str, Any]:
    """
    Load precomputed data from either .pt or .h5 file based on extension.

    Args:
        file_path: Path to precomputed data file

    Returns:
        Dict-like object with same interface regardless of file format
    """
    file_path = Path(file_path)

    if file_path.suffix.lower() == '.h5':
        print(f"Loading HDF5 precomputed data from: {file_path}")
        return HDF5PrecomputedData(str(file_path))

    elif file_path.suffix.lower() == '.pt':
        print(f"Loading PyTorch pickle precomputed data from: {file_path}")
        return torch.load(str(file_path), map_location='cpu', weights_only=False)

    else:
        raise ValueError(f"Unsupported file format: {file_path.suffix}. Use .pt or .h5")


# Convenience functions for backward compatibility
def is_hdf5_format(file_path: str) -> bool:
    """Check if file is HDF5 format based on extension."""
    return Path(file_path).suffix.lower() == '.h5'


def get_embedding_count(precomputed_data) -> int:
    """Get number of embeddings regardless of format."""
    if hasattr(precomputed_data, '_metadata'):
        # HDF5 format
        return precomputed_data._metadata.get('num_examples', 0)
    else:
        # Pickle format
        return len(precomputed_data['inputs_embeds'])