import random
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from pathlib import Path
import os


class BufferCustom():
    """
    Buffer for storing KV vectors for training the autoencoder.
    """
    def __init__(self, cfg):
        self.cfg = cfg
        self.device = cfg["device"]
        self.batch_size = self.cfg["batch_size"]

        print('Loading key and value data...')
        key_data_path, value_data_path = self._construct_data_paths()
        self.key_data = torch.load(
            key_data_path,
            map_location=torch.device('cpu'),
            weights_only=True
        )
        print('value_data is loading')
        self.value_data = torch.load(
            value_data_path,
            map_location=torch.device('cpu'),
            weights_only=True
        )

        self.num_samples = self.key_data.shape[0]
        self.feature_dim = self.key_data.shape[1]
        
        self.train_test_ratio = 0.9
        self.num_train_samples = int(self.num_samples * self.train_test_ratio)
        self.num_test_samples = int(self.num_samples * (1-self.train_test_ratio))
        self.refresh()
        self.test_refresh()
        self.ood_tokens = []

    def _construct_data_paths(self):
        """
        Construct data paths for key and value files based on configuration.
        
        Returns:
            tuple: (key_data_path, value_data_path)
        """
        model_name = self.cfg['model_name_or_path']
        
        # Extract data capacity from model name
        data_index = model_name.find('wiki')
        if data_index == -1:
            raise ValueError(f"Invalid model_name_or_path, 'wiki' not found: {model_name}")
        
        start_index = data_index + 4  # Skip "wiki"
        data_capa = model_name[start_index:]
        
        # Get base directory from config or use default
        data_base_dir_str = self.cfg.get('data_base_dir') or '/data/llm/tmp'
        data_base_dir = Path(data_base_dir_str)
        
        # Construct filename components
        concat_suffix = self.cfg.get("concat", 1)
        model_prefix = "Qwen2.5-7B-Instruct-1M"
        
        # Try different filename patterns
        filename_patterns = [
            f"{model_prefix}_2nd_wiki_samples_1m_key_{data_capa}_all_{concat_suffix}layer.pt",
            f"{model_prefix}_wiki_samples_1m_key_{data_capa}_all_{concat_suffix}layer.pt"
        ]
        
        key_data_path = None
        for pattern in filename_patterns:
            candidate_path = data_base_dir / pattern
            if candidate_path.exists():
                key_data_path = candidate_path
                break
        
        if key_data_path is None:
            # Default to first pattern if none exist (will fail later with clear error)
            key_data_path = data_base_dir / filename_patterns[0]
        
        # Handle concat features subdirectory
        if self.cfg.get('concat', 1) > 1:
            key_data_path = Path(str(key_data_path).replace('/wiki_samples/', '/wiki_samples/concat_features/'))
        
        # Construct value data path
        value_data_path = Path(str(key_data_path).replace('_key_', '_value_'))
        
        return str(key_data_path), str(value_data_path)

    def get_length(self):
        # return self.num_samples
        return self.num_train_samples
        
    def get_test_length(self):
        return self.num_test_samples

    def get_feature_dim(self):
        return self.feature_dim

    @torch.no_grad()
    def refresh(self):
        self.pointer = 0
        self.shuffled_indices = torch.randperm(self.num_train_samples) # PyTorch 텐서로 인덱스 관리

    @torch.no_grad()
    def test_refresh(self):
        self.test_pointer = 0
        self.shuffled_test_indices = torch.randperm(self.num_test_samples) + self.num_train_samples # PyTorch 텐서로 인덱스 관리


    def next(self):
        """Get next training batch."""
        if self.pointer + self.batch_size > self.num_train_samples:
            self.refresh()

        batch_indices = self.shuffled_indices[self.pointer:self.pointer + self.batch_size]
        batch_data = self._get_batch_data(batch_indices)
        self.pointer += self.batch_size
        return batch_data

    def test_next(self):
        """Get next test batch."""
        if self.test_pointer + self.batch_size > self.num_test_samples:
            self.test_refresh()

        batch_indices = self.shuffled_test_indices[self.test_pointer:self.test_pointer + self.batch_size]
        batch_data = self._get_batch_data(batch_indices)
        self.test_pointer += self.batch_size
        return batch_data

    def _get_batch_data(self, batch_indices):
        """
        Get batch data for given indices.
        
        Args:
            batch_indices: Tensor of indices to fetch
            
        Returns:
            torch.Tensor: Concatenated key-value batch data
        """
        # Extract key and value slices from mmap data
        key_batch_slice = self.key_data[batch_indices]
        value_batch_slice = self.value_data[batch_indices]

        # Clone to create new memory copies (required for GPU transfer from mmap)
        key_batch_torch = key_batch_slice.clone().detach()
        value_batch_torch = value_batch_slice.clone().detach()

        # Add sequence dimension and concatenate
        key_batch_torch = key_batch_torch.unsqueeze(1)
        value_batch_torch = value_batch_torch.unsqueeze(1)
        out = torch.cat([key_batch_torch, value_batch_torch], dim=1)
        
        return out.float().to(self.device)

