import argparse
import os
import torch
import numpy as np
from tqdm import tqdm

from .config import config
from .kSVD import kSVD
from .utils.utils import (
    load_model_and_tokenizer,
    load_documents,
)

# LlamaIndex specific for chunking
from llama_index.core.text_splitter import TokenTextSplitter

def _initialize_config():
    """Parses command-line arguments to load the correct config file."""
    parser = argparse.ArgumentParser(description="Run k-SVD reconstruction error test.")
    parser.add_argument(
        '--config_file', 
        type=str, 
        required=True,
        help='Path to the configuration YAML file (e.g., configs/config.yaml).'
    )
    args, _ = parser.parse_known_args()
    config.load(args.config_file)

def main():
    _initialize_config()

    if not config.USE_KSVD:
        print("k-SVD is not enabled in the provided config file. Exiting.")
        return

    # --- 1. Load Model, Tokenizer, and Dictionary ---
    print("Loading resources...")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, tokenizer = load_model_and_tokenizer(config.MODEL_NAME, use_modified=False)
    
    print(f"Loading dictionary from: {config.DICT_FILE_PATH}")
    dictionary = torch.load(config.DICT_FILE_PATH, weights_only=True).to(device)

    # --- 2. Initialize kSVD Class ---
    print("Initializing k-SVD...")
    ksvd = kSVD(dictionary, model.config, model.dtype)
    print(f"  - kSVD initialized with a layer merge factor of: {ksvd.num_layer_to_merge}")

    # --- 3. Load Documents and Prepare Chunks ---
    print("Loading and chunking documents...")
    documents = load_documents('documents')
    splitter = TokenTextSplitter(
        tokenizer=tokenizer.encode,
        chunk_size=512,
        chunk_overlap=10
    )
    
    all_chunks = []
    for doc in documents:
        all_chunks.extend(splitter.split_text(doc.get_content()))

    # Filter chunks to match the criteria in create_cache.py and limit the count
    chunks_to_test = []
    for chunk in all_chunks:
        if len(chunks_to_test) >= config.NUM_QUESTIONS:
            break
        if len(tokenizer.encode(chunk)) >= 512:
            chunks_to_test.append(chunk)

    print(f"Found {len(chunks_to_test)} chunks meeting the length criteria to test.")

    # --- 4. Generate KV Cache and Test Reconstruction ---
    errors = []
    
    if config.CACHE_TYPE == "batched_ksvd":
        print("\n--- Testing in BATCHED mode ---")
        batch_size = config.BATCH_SIZE
        kvcache_batch = []
        
        for chunk_text in tqdm(chunks_to_test, desc="Processing Chunks for Batching"):
            kvcache_cpu, _ = generate_kv_cache_from_chunk(chunk_text, tokenizer, model, device)
            kvcache_batch.append(kvcache_cpu)

            if len(kvcache_batch) == batch_size:
                error = ksvd.test_batched_compression_reconstruction(kvcache_batch, config.SPARSITY)
                errors.append(error)
                kvcache_batch = [] # Reset for next batch
        
        # Process any remaining items that didn't form a full batch
        if kvcache_batch:
            error = ksvd.test_batched_compression_reconstruction(kvcache_batch, config.SPARSITY)
            errors.append(error)

    else: # For standard 'ksvd' or other non-batched types
        print("\n--- Testing in NON-BATCHED (single item) mode ---")
        for chunk_text in tqdm(chunks_to_test, desc="Testing Reconstruction Error"):
            kvcache_cpu, _ = generate_kv_cache_from_chunk(chunk_text, tokenizer, model, device)
            error = ksvd.test_compression_reconstruction(kvcache_cpu, config.SPARSITY)
            errors.append(error)

    # --- 5. Report Results ---
    average_error = np.mean(errors) if errors else 0
    
    print("\n--- Test Complete ---")
    print(f"Average Relative Reconstruction Error: {average_error:.6f}")
    print(f"Tested over {len(errors)} batches/items.")

def generate_kv_cache_from_chunk(chunk_text, tokenizer, model, device):
    """Helper function to generate a KV cache from a text chunk."""
    chunk_text_with_tokens = "<|doc_start|>" + chunk_text + "<|doc_end|>"
    
    inputs = tokenizer(
        chunk_text_with_tokens,
        return_tensors="pt",
        padding="max_length",
        max_length=512 + 2,
        truncation=True
    ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs, use_cache=True)
    
    kvcache_gpu = outputs.past_key_values
    kvcache_cpu = tuple((k.cpu(), v.cpu()) for k, v in kvcache_gpu)
    
    return kvcache_cpu, inputs.input_ids.shape[1]

if __name__ == "__main__":
    main() 
    # to test with this code, run ./scripts/run_reconstruction_test.sh config/{config file you want to test with}