import argparse
import os
from .config import config

def _initialize_config():
    """Parses command-line arguments to load the correct config file."""
    parser = argparse.ArgumentParser(description="Run TurboRAG experiments with a specified configuration.")
    parser.add_argument(
        '--config_file', 
        type=str, 
        default='configs/config.yaml',
        help='Path to the configuration YAML file.'
    )
    args, unknown = parser.parse_known_args()
    config.load(args.config_file, unknown)

_initialize_config()


import logging
import re
import glob
import sys
import torch
import json
import time
import numpy as np
from tqdm import tqdm
from tabulate import tabulate
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2ForCausalLM
import os
from datetime import datetime
import csv
from .kSVD import kSVD
from .config import config
from .utils.utils import (
    create_result_folder, 
    save_results, 
    stack_past_key_values, 
    qa_to_prompt, 
    load_kvcache, 
    unpack_tensor, 
    merge_multiple_reconstructions,
    merge_multiple_layer_reconstructions,
    generate_dct_basis_1d,
    load_model_and_tokenizer
)
from .utils.metrics import QueryMetrics
from .cache_handler import CacheHandler, BatchedKSVDCacheHandler, LayerBatchedKSVDCacheHandler

# Llama Index Related
from llama_index.core import Settings, load_index_from_storage, StorageContext, QueryBundle
from llama_index.embeddings.huggingface import HuggingFaceEmbedding

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

# Set up device globally
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load model and tokenizer globally
model, tokenizer = load_model_and_tokenizer(
    config.MODEL_NAME, 
    use_modified=True, 
    use_flash_attn=config.USE_FLASH_ATTN
)

model_dtype = next(model.parameters()).dtype

# Set up embedding model and index
Settings.embed_model = HuggingFaceEmbedding(model_name=config.EMBEDDING_MODEL_NAME)

# Load the appropriate index based on the configuration
index_path = os.path.join(config.BASE_KV_DIR, config.INDEX_PERSIST_DIR)
print(f"Loading index from: {index_path}")

storage_context = StorageContext.from_defaults(persist_dir=index_path)
index = load_index_from_storage(storage_context)
retriever = index.as_retriever(similarity_top_k=config.SIMILARITY_TOP_K)

inputs_prefix = tokenizer([config.PREFIX], return_tensors="pt",padding=True)
outputs_prefix = model(
    inputs_prefix['input_ids'].to(device), 
    attention_mask = inputs_prefix['attention_mask'].to(device), 
    use_cache=True
)
prefix_kvcache = outputs_prefix.past_key_values
prefix_seq_len   = prefix_kvcache[0][0].shape[2]

if config.USE_SEL_ATTN:
    model.config.use_sel_attn = True
    model.config.aggr_mode = ["head", "query", "layer"]

# Conditionally load k-SVD resources only if needed
key_dictionary = None
value_dictionary = None
if config.USE_KSVD:
    # Only using the '_dl' version of the dictionary
    print("Loading k-SVD dictionary...")
    dict_data = torch.load(config.DICT_FILE_PATH, weights_only=True).to(device)
    key_dictionary = dict_data[0]
    value_dictionary = dict_data[1]

    dct_basis = None
    if config.ADD_DCT:
        feature_size = key_dictionary.shape[0]
        dct_basis = torch.tensor(generate_dct_basis_1d(feature_size), dtype=torch.float32, device=device)
        dct_basis = dct_basis.T

def query_with_kvcache(query_text, use_chunk_cache=True, cache_type=None):
    metrics = QueryMetrics()
    query_bundle = QueryBundle(query_str=query_text)
    retrieved_nodes = retriever.retrieve(query_bundle)

    ksvd_compressor = None
    # k-SVD compressor is needed for both batched and non-batched k-SVD
    if "ksvd" in cache_type:
        if key_dictionary is None or value_dictionary is None:
            raise RuntimeError("k-SVD dictionary was not loaded. Ensure USE_KSVD is True and DICT_FILE_PATH is set.")
        dictionary = torch.cat([key_dictionary.unsqueeze(0), value_dictionary.unsqueeze(0)], dim=0).to(torch.float16).to(device)
        ksvd_compressor = kSVD(dictionary, model.config, model_dtype)

    # Initialize the appropriate cache handler
    handler_class = CacheHandler.get_handler(cache_type)
    handler = None
    if handler_class:
        if cache_type == "batched_ksvd":
            # Batched handler needs a batch size
            handler = handler_class(device, ksvd_compressor, batch_size=config.BATCH_SIZE)
        elif cache_type == "layer_batched_ksvd":
            handler = handler_class(device, ksvd_compressor, batch_size=config.BATCH_SIZE * config.NUM_LAYERS, seq_len=config.SEQ_LEN)
        elif cache_type == "ksvd":
            # Standard k-SVD handler does not take batch_size
            handler = handler_class(device, ksvd_compressor)
        elif cache_type == "kvcache":
            # The base handler for kvcache takes the ksvd_compressor but will ignore it.
            handler = handler_class(device, ksvd_compressor=None)

    # Determine initial cache list based on handler type
    if isinstance(handler, BatchedKSVDCacheHandler):
        # Batched handler expects a different cache format: (cache, group_info)
        kvcache_list, chunk_list = [(prefix_kvcache, [prefix_seq_len])], []
    elif isinstance(handler, LayerBatchedKSVDCacheHandler):
        kvcache_list, chunk_list = [(prefix_kvcache, [prefix_seq_len])], []
    else:
        kvcache_list, chunk_list = [prefix_kvcache], []

    for node_with_score in retrieved_nodes:
        node = node_with_score.node
        metrics.add('context_lengths', len(tokenizer.encode(node.text)))
        chunk_list.append(node.text)

        if use_chunk_cache and handler:
            cache = handler.retrieve(node, metrics)
            if cache:
                kvcache_list.append(cache)
            torch.cuda.empty_cache()
            
    # Flush any remaining items in the batch handler
    if isinstance(handler, BatchedKSVDCacheHandler) or isinstance(handler, LayerBatchedKSVDCacheHandler):
        remaining_cache = handler.flush(metrics)
        if remaining_cache:
            kvcache_list.append(remaining_cache)
    torch.cuda.empty_cache()

    prompt = qa_to_prompt(config.PREFIX, chunk_list, query_text)
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(model.device)
    
    metrics.start('stacking')
    # The stacking/merging logic depends on the cache type
    if cache_type == "batched_ksvd":
        # past_kvcache, past_group = stack_past_key_values(kvcache_list) # for layer-wise batching
        past_kvcache, past_group = merge_multiple_reconstructions(kvcache_list)
    elif cache_type == "layer_batched_ksvd":
        past_kvcache, past_group = merge_multiple_layer_reconstructions(kvcache_list, config.NUM_LAYERS, config.SEQ_LEN)
    elif use_chunk_cache:
        past_kvcache, past_group = stack_past_key_values(kvcache_list)
    else:
        past_kvcache, past_group = (None, None)
    metrics.stop('stacking')
    
    model.config.group_size = past_group

    eos_token_ids = [151645,151643]
    metrics.start('model_forward')
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=1,
            past_key_values=past_kvcache,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False,
            eos_token_id=eos_token_ids,
        )
    metrics.stop('model_forward')
    
    answer = "".join([tokenizer.decode(x) for x in outputs[0][input_ids.shape[1]:]])
    print(f'Answer\n{answer}')

    torch.cuda.empty_cache()
    
    return metrics

def run_experiment(questions, num_iterations=1):
    """
    Runs a single experiment based on the settings in the loaded config file.
    """
    # Create the results folder based on the experiment type
    method_name = config.CACHE_TYPE if config.CACHE_TYPE else "no_cache"
    is_ksvd = config.CACHE_TYPE in ["ksvd", "batched_ksvd", "layer_batched_ksvd"]
    
    if is_ksvd:
        result_dir = create_result_folder(config.BASE_DIR, method_name, config.DICT_TYPE, config.SPARSITY)
    else:
        result_dir = create_result_folder(config.BASE_DIR, method_name)

    print(f"Running Experiment: {method_name.replace('_', ' ').title()}")
    print(f"Results will be saved to: {result_dir}")

    # Run experiment for multiple iterations
    for iteration in range(num_iterations):
        print(f"\n====================\nIteration {iteration + 1}/{num_iterations}\n====================")
        
        total_metrics = QueryMetrics()
        all_context_lengths = []

        for query in tqdm(questions, desc=f"Running {method_name}"):
            query_metrics = query_with_kvcache(
                query, 
                use_chunk_cache=config.USE_CHUNK_CACHE, 
                cache_type=config.CACHE_TYPE
            )
            # Aggregate metrics
            for key, value in query_metrics.totals.items():
                total_metrics.add(key, value)
            all_context_lengths.append(query_metrics.get('context_lengths'))

        # Calculate and save average results for the iteration
        num_questions = len(questions)
        avg_model_forward = total_metrics.get('model_forward') / num_questions
        avg_reconstruction = total_metrics.get('reconstruction') / num_questions
        avg_to_gpu = total_metrics.get('to_gpu') / num_questions
        avg_to_ram = total_metrics.get('to_ram') / num_questions
        avg_stacking = total_metrics.get('stacking') / num_questions

        save_results(
            result_dir, 
            method_name,
            config.SIMILARITY_TOP_K, 
            avg_model_forward, 
            avg_reconstruction, 
            avg_to_gpu, 
            avg_to_ram, 
            avg_stacking,
            all_context_lengths, 
            iteration
        )
        
        print(f"\nResults for Iteration {iteration + 1}:")
        print(f"  - Average Model Forward: {avg_model_forward:.6f} seconds")
        print(f"  - Average Reconstruction: {avg_reconstruction:.6f} seconds")
        print(f"  - Average GPU Transfer: {avg_to_gpu:.6f} seconds")
        print(f"  - Average RAM Transfer: {avg_to_ram:.6f} seconds")
        print(f"  - Average Stacking Time: {avg_stacking:.6f} seconds")


if __name__ == "__main__":
    questions = []
    with open(config.QUERY_FILE) as file:
        for item in file:
            data = json.loads(item)
            questions.append(data["query"])
    questions = questions[:config.NUM_QUESTIONS]
    
    # The number of iterations is now a config parameter
    run_experiment(questions, num_iterations=config.NUM_ITERATIONS)


# python kv_cache_store.py --compression_type int16 --output_dir compressed_kvcache_drop_v_f32 --persist_dir compressed_doc_emb_drop_v_f32
