# import libraries
import torch
from transformers import AutoTokenizer
from typing import List, Optional
from dataclasses import dataclass
from tqdm import tqdm
import os
import argparse
import numpy as np
from enum import Enum
from .utils.types import CompressionType

# LlamaIndex related
from llama_index.core import (
    VectorStoreIndex,
    SimpleDirectoryReader,
    ServiceContext,
    PromptHelper,
    Document,
    StorageContext,
    load_index_from_storage,
)
from llama_index.core.node_parser import SimpleNodeParser
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.text_splitter import TokenTextSplitter
from llama_index.core.indices.query.schema import QueryBundle
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core.vector_stores import SimpleVectorStore

from .omp import omp_v0
from .kSVD import kSVD
from .utils.utils import (
    get_device, 
    load_model_and_tokenizer, 
    load_embedding_model, 
    load_documents, 
    get_cache_directory_path,
    pack_tensor, 
    unpack_tensor
)
from .config import config

def _initialize_config():
    """Parses command-line arguments to load the correct config file."""
    parser = argparse.ArgumentParser(description="Run TurboRAG experiments with a specified configuration.")
    parser.add_argument(
        '--config_file', 
        type=str, 
        default='configs/config.yaml',
        help='Path to the configuration YAML file.'
    )
    args, unknown = parser.parse_known_args()
    config.load(args.config_file, unknown)
    return config
config = _initialize_config()

# For creation, 'batched_ksvd' is the same as 'ksvd'
creation_cache_type = config.CACHE_TYPE
if creation_cache_type == 'batched_ksvd' or creation_cache_type == 'layer_batched_ksvd':
    creation_cache_type = 'ksvd'

# Load model and tokenizer
device = get_device()
model, tokenizer = load_model_and_tokenizer(config.MODEL_NAME, use_modified=True)
embed_model = load_embedding_model(config.EMBEDDING_MODEL_NAME)

# Load k-SVD dictionary if needed
ksvd_compressor = None
if creation_cache_type == 'ksvd':
    dictionary = torch.load(config.DICT_FILE_PATH, weights_only=True).to(device)
    model_dtype = next(model.parameters()).dtype
    ksvd_compressor = kSVD(dictionary, model.config, model_dtype)
    num_layer_to_merge = ksvd_compressor.num_layer_to_merge
else:
    num_layer_to_merge = None

try:
    sparsity = config.SPARSITY
except:
    sparsity = None
# Setup directories
output_path = get_cache_directory_path(
    config.MODEL_NAME,
    base_folder=os.path.join(config.BASE_KV_DIR, f"cache_{creation_cache_type}"),
    is_instruct='instruct' in config.MODEL_NAME.lower(),
    is_modified=True,
    use_sink=False,
    is_new=True,
    use_drop=True,
    sparsity=sparsity,
    merge=num_layer_to_merge 
)
persist_dir = get_cache_directory_path(
    config.MODEL_NAME,
    base_folder=os.path.join(config.BASE_KV_DIR, f"doc_emb_{creation_cache_type}"),
    is_instruct='instruct' in config.MODEL_NAME.lower(),
    is_modified=True,
    use_sink=False,
    is_new=True,
    use_drop=True,
    sparsity=sparsity,
    merge=num_layer_to_merge
)

splitter = TokenTextSplitter(
    tokenizer=tokenizer.encode,
    chunk_size=config.SEQ_LEN + 10,
    chunk_overlap=10
)

def process_chunk(chunk_text, chunk_id):
    # First tokenize to check length
    chunk_tokens = tokenizer.encode(chunk_text)
    target_chunk_length = config.SEQ_LEN-2
    
    if len(chunk_tokens) > target_chunk_length:
        chunk_text = tokenizer.decode(chunk_tokens[:target_chunk_length])
    
    # We always include special tokens in the node text for consistency
    chunk_text_with_tokens = "<|doc_start|>" + chunk_text + "<|doc_end|>"
    metadata = {}

    # Only generate and process KV cache if we intend to use it in experiments
    if config.USE_CHUNK_CACHE:
        inputs = tokenizer(
            chunk_text_with_tokens,
            return_tensors="pt",
            # padding="max_length",  # padding하면 raw text와 길이가 안맞음
            max_length=target_chunk_length + 2,
            truncation=True
        ).to(model.device)
        
        with torch.no_grad():
            outputs = model(**inputs, use_cache=True)
        
        kvcache = outputs.past_key_values

        if creation_cache_type == 'ksvd':
            indices_file_path = f'{output_path}/compressed_kvcache_chunk_{chunk_id}_indices.pt'
            values_file_path = f'{output_path}/compressed_kvcache_chunk_{chunk_id}_values.pt'

            if not os.path.exists(indices_file_path) or not os.path.exists(values_file_path):
                indices, values = ksvd_compressor.compress(kvcache, config.SPARSITY)
                torch.save(indices.to(torch.int16).cpu(), indices_file_path)
                torch.save(values.to(torch.float16).cpu(), values_file_path)
            
            metadata.update({
                "indices_file_path": indices_file_path,
                "values_file_path": values_file_path,
                "compression_type": 'ksvd',
            })
            
        else: # This will now correctly handle the 'kvcache' use-case ('none' compression)
            kvcache_cpu = [(k.cpu(), v.cpu()) for (k, v) in kvcache]
            kvcache_file_path = f'{output_path}/kvcache_chunk_{chunk_id}.pt'
            torch.save(kvcache_cpu, kvcache_file_path)
            
            metadata.update({
                "kvcache_file_path": kvcache_file_path
            })

    torch.cuda.empty_cache()
    node = TextNode(
        text=chunk_text_with_tokens,
        id_=f"chunk_{chunk_id}",
        metadata=metadata
    )
    
    return node

class KVCachedNodeParser(SimpleNodeParser):
    def get_nodes_from_documents(
        self,
        documents: List[Document],
        **kwargs,
    ) -> List[BaseNode]:
        nodes = []
        for doc_id, document in tqdm(enumerate(documents)):
            doc_text = document.get_content()
            chunk_texts = splitter.split_text(doc_text)
            
            for chunk_id, chunk_text in enumerate(chunk_texts):
                if len(tokenizer.encode(chunk_text)) < 512:
                    continue
                node = process_chunk(chunk_text, f"{doc_id}_{chunk_id}")
                nodes.append(node)
        return nodes


if __name__ == "__main__":
    # A more robust check for a complete index is to check for a key file.
    index_is_complete = os.path.exists(os.path.join(persist_dir, "docstore.json"))

    # The `output_path` (for KV cache files) is only relevant if `USE_CHUNK_CACHE` is True.
    cache_is_complete = True # Assume true if we're not using chunk cache
    if config.USE_CHUNK_CACHE:
        # For cached runs, check that the cache directory also exists.
        # A more robust check could be to ensure it's not empty, but this is a good start.
        cache_is_complete = os.path.exists(output_path)

    if index_is_complete and cache_is_complete:
        print(f"Vector index found at '{persist_dir}'.")
        if config.USE_CHUNK_CACHE:
            print(f"KV cache found at '{output_path}'.")
        print("Skipping creation. Use the --force flag to override.")
        exit(0)

    # load dataset
    vector_store = SimpleVectorStore()
    node_parser = KVCachedNodeParser()
    documents = load_documents('documents')

    # Prepare for compression only if using kSVD
    if creation_cache_type == 'ksvd' and ksvd_compressor:
        print("Pre-computing DTD matrix for k-SVD compression...")
        ksvd_compressor.prepare_for_compression()

    try:
        # Process documents to create nodes. This will conditionally create KV cache.
        if config.USE_CHUNK_CACHE:
            print("Processing documents and creating KV cache...")
        else:
            print("Processing documents to create vector index (no KV cache)...")
        nodes = node_parser.get_nodes_from_documents(documents)

        print("Creating and persisting vector index...")
        index = VectorStoreIndex(
            nodes=nodes,
            embed_model=embed_model,
            vector_store=vector_store,
        )
        index.storage_context.persist(persist_dir=persist_dir)
        print("Done.")
    finally:
        # Clear resources after compression is complete
        if creation_cache_type == 'ksvd' and ksvd_compressor:
            print("Clearing k-SVD compression resources...")
            ksvd_compressor.clear_compression_resources()