import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from ..models.qwen2 import Qwen2ModifiedForCausalLM
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
from llama_index.core import SimpleDirectoryReader
import os
import re
import numpy as np
import json
from datetime import datetime
import csv


def get_device():
    """Returns the appropriate device (CUDA or CPU)."""
    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def load_model_and_tokenizer(model_name: str, use_modified: bool = True, use_flash_attn: bool = False):
    """
    Loads a model and tokenizer. If use_modified is True, it loads the
    Qwen2ModifiedForCausalLM from the local turborag package.
    """
    device = get_device()
    attn_implementation = "flash_attention_2" if use_flash_attn else None
    
    model_class = Qwen2ModifiedForCausalLM if use_modified else AutoModelForCausalLM
    
    model = model_class.from_pretrained(
        model_name,
        torch_dtype="auto",
        attn_implementation=attn_implementation
    ).to(device)
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

def load_embedding_model(model_name: str = "BAAI/bge-small-en-v1.5"):
    """Loads a Hugging Face embedding model."""
    return HuggingFaceEmbedding(model_name=model_name)

def load_documents(path: str = 'documents'):
    """Loads documents from a directory."""
    return SimpleDirectoryReader(path).load_data()

def get_cache_directory_path(model_name: str, 
                             base_folder: str, 
                             is_instruct: bool, 
                             is_modified: bool, 
                             use_sink: bool, 
                             is_new: bool, 
                             use_drop: bool,
                             sparsity = None,
                             merge = None) -> str:
    """Generates a directory path for caching based on configuration."""
    match = re.search(r'\d+B', model_name)
    size = match.group(0) if match else "unknown_size"
    
    path = f"{base_folder}_{model_name.split('/')[-1]}_{size}"
    if is_instruct:
        path += '_inst'
    if is_modified:
        path += '_reorder'
    if use_sink:
        path += '_sink'
    if is_new:
        path += '_new'
    if use_drop:
        path += '_drop'
    if merge is not None and merge>=1:
        path += f'_merged_{merge}layers'
    if sparsity:
        path += f'_s_{sparsity}'
    if not os.path.exists(path):
        os.makedirs(path)
        
    return path

def create_result_folder(base_dir: str, method: str, dict_type: str = None, sparsity: int = None) -> str:
    """Create folder for results based on method and parameters"""
    if not os.path.exists(base_dir):
        os.makedirs(base_dir)
    
    if "ksvd" in method:
        folder_name = f"{method}_{dict_type}_{sparsity}"
    else:
        folder_name = method
    
    result_dir = os.path.join(base_dir, folder_name)
    if not os.path.exists(result_dir):
        os.makedirs(result_dir)
    
    return result_dir

def save_results(result_dir: str, method: str, num_docs: int, model_forward_time: float, reconstruction_time: float, to_gpu_time: float, to_ram_time: float, stacking_time: float, context_lengths: list, iteration: int):
    """Save results to a CSV file"""
    filename = f"{method}_{num_docs}docs.csv"
    filepath = os.path.join(result_dir, filename)
    
    # Calculate statistics
    sum_reconstruction = reconstruction_time if reconstruction_time else 0
    sum_to_gpu = to_gpu_time if to_gpu_time else 0
    sum_to_ram = to_ram_time if to_ram_time else 0
    
    # Calculate TTFT as model_forward_time + (sum_reconstruction + sum_to_gpu) + stacking_time
    TTFT = model_forward_time + sum_reconstruction + sum_to_gpu + stacking_time
    
    avg_context_length = sum(context_lengths) / len(context_lengths) if context_lengths else 0
    
    # Prepare row data with only aggregated metrics
    row = {
        'model_forward': f"{model_forward_time:.6f}",
        'reconstruction': f"{sum_reconstruction:.6f}",
        'to_gpu': f"{sum_to_gpu:.6f}",
        'to_ram': f"{sum_to_ram:.6f}",
        'stacking': f"{stacking_time:.6f}",
        'TTFT': f"{TTFT:.6f}",
        'context_length': f"{avg_context_length:.2f}"
    }
    
    # Write to CSV
    file_exists = os.path.exists(filepath)
    with open(filepath, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=row.keys())
        if not file_exists:
            writer.writeheader()
        writer.writerow(row)

def stack_past_key_values(past_key_values_list):
    """Stacks a list of past_key_values tuples."""
    num_layers = len(past_key_values_list[0])
    batch_past_key_values = []
    for layer in range(num_layers):
        keys = torch.cat([past_key_values[layer][0] for past_key_values in past_key_values_list], dim=2)
        values = torch.cat([past_key_values[layer][1] for past_key_values in past_key_values_list], dim=2)
        batch_past_key_values.append((keys, values))
    group_past_kv = [past_key_values[0][0].shape[2] for past_key_values in past_key_values_list]
    return tuple(batch_past_key_values), group_past_kv

def qa_to_prompt(prefix: str, chunk_list: list, query: str) -> str:
    """Formats the prompt with chunks and a query."""
    chunk_str = "\n".join(chunk_list)
    return f'''{prefix}{chunk_str}\n\n\nQuestion: {query}\n<|im_end|><|im_start|>assistant\n'''

def load_kvcache(cache_file_path: str):
    """Loads a KV cache from a file."""
    return torch.load(cache_file_path, weights_only=True)

def pack_tensor(tensor: torch.Tensor, bits: int = 13):
    """
    Pack tensor into bits and return as bytes.
    Args:
        tensor: torch.Tensor of dtype torch.int32
        bits: Number of bits to use (13 in our case)
    Returns:
        packed_bytes: bytes object
        orig_shape: Original tensor shape (tuple)
    """
    arr = tensor.cpu().numpy().astype(np.uint16).ravel()
    N = arr.size

    shifts = np.arange(bits-1, -1, -1, dtype=np.uint16)
    bits_mat = ((arr[:, None] >> shifts) & 1).astype(np.uint8)

    bits_flat = bits_mat.ravel()
    packed = np.packbits(bits_flat, bitorder='big')
    return packed.tobytes(), tensor.shape

def unpack_tensor(packed_bytes: bytes, orig_shape: tuple, bits: int = 13) -> torch.Tensor:
    """Unpacks bytes back into a tensor."""
    
    packed_arr = np.frombuffer(packed_bytes, dtype=np.uint8)
    bits_flat = np.unpackbits(packed_arr, bitorder='big')
    N = int(np.prod(orig_shape))
    bits_flat = bits_flat[: N * bits]

    bits_mat = bits_flat.reshape(N, bits)
    weights = (2**np.arange(bits-1, -1, -1, dtype=np.uint16))
    arr = bits_mat.dot(weights).astype(np.int32).reshape(orig_shape)

    return torch.from_numpy(arr)

def merge_multiple_reconstructions(reconstructions: list):
    """Merges multiple reconstructed KV caches into one."""
    if len(reconstructions) == 0:
        return (), []

    past_key_values_0, group_past_kv_0 = reconstructions[0]
    num_layers = len(past_key_values_0)

    accumulated = [
        (past_key_values_0[i][0].clone(), past_key_values_0[i][1].clone())
        for i in range(num_layers)
    ]
    combined_group = list(group_past_kv_0)

    for (past_key_values_i, group_i) in reconstructions[1:]:
        for layer_idx in range(num_layers):
            prev_keys, prev_values = accumulated[layer_idx]
            new_keys, new_values = past_key_values_i[layer_idx]
            cat_keys = torch.cat([prev_keys, new_keys], dim=2)
            cat_values = torch.cat([prev_values, new_values], dim=2)
            accumulated[layer_idx] = (cat_keys, cat_values)
        combined_group.extend(group_i)

    combined_past_key_values = tuple(accumulated)
    combined_group_past_kv = combined_group

    return combined_past_key_values, combined_group_past_kv 

def merge_multiple_layer_reconstructions(reconstructions: list, num_layers: int, sequence_length: int):
    """
    each element of reconstructions 
    (keys, values)
    keys: (Layer, Head, Seq_len, Dim)
    values: (Layers, Head, Seq_len, Dim)
    """
    if len(reconstructions) == 0:
        return (), []

    past_key_values_0, group_past_kv_0 = reconstructions[0]
    combined_group = list(group_past_kv_0)

    accumulated = [
        [[past_key_values_0[i][0].clone()], [past_key_values_0[i][1].clone()]]
        for i in range(num_layers)
    ]
    
    for kvs in reconstructions[1:]:
        keys, values = kvs
        num_accumulated_layer = keys.shape[0]

        for idx in range(num_accumulated_layer):
            accumulated[idx % num_layers][0].append(keys[idx]) # key
            accumulated[idx % num_layers][1].append(values[idx]) # value
        
    for idx in range(num_layers):
        accumulated[idx][0] = torch.cat(accumulated[idx][0], dim=-2).contiguous()
        accumulated[idx][1] = torch.cat(accumulated[idx][1], dim=-2).contiguous()
        # accumulated[idx] = (accumulated[idx][0], accumulated[idx][1])
    
    combined_past_key_values = tuple(accumulated)
    combined_group+=[sequence_length]*(num_accumulated_layer//num_layers)

    return combined_past_key_values, combined_group


def generate_dct_basis_1d(N):
    """
    1차원 DCT-II(Discrete Cosine Transform Type II) 기저 벡터를 생성합니다.

    Args:
    N (int): 신호의 길이 (및 생성할 기저 벡터의 수).

    Returns:
    numpy.ndarray: N x N 크기의 배열. 각 행이 하나의 DCT 기저 벡터입니다.
                    basis_matrix[k, n]은 k번째 기저 벡터의 n번째 샘플 값을 나타냅니다.
    """
    if not isinstance(N, int) or N <= 0:
        raise ValueError("N must be a positive integer.")

    # 기저 벡터를 저장할 N x N 배열 초기화
    basis_matrix = np.zeros((N, N))

    # n 값 배열 (샘플 인덱스): 0, 1, ..., N-1
    n = np.arange(N)

    # k 값 배열 (주파수 인덱스): 0, 1, ..., N-1
    k = np.arange(N)

    # 벡터화를 위해 k를 열 벡터로, n을 행 벡터로 만듭니다.
    k = k[:, np.newaxis] # Shape (N, 1)

    # DCT-II 공식 계산 (정규화 전)
    # basis_matrix[k, n] = cos(pi * k * (2n + 1) / (2N))
    basis_matrix = np.cos(np.pi * k * (2 * n + 1) / (2 * N))

    # 정규화 (Orthogonal basis 만들기 위함)
    # k=0 일 때: sqrt(1/N)
    # k>0 일 때: sqrt(2/N)
    factors = np.sqrt(2.0 / N) * np.ones(N)
    factors[0] = np.sqrt(1.0 / N)

    # 각 행(기저 벡터)에 해당하는 정규화 인수를 곱합니다.
    # factors[:, np.newaxis]를 사용하여 브로드캐스팅이 필요 없어집니다.
    # (N,) shape의 factors를 (N, 1) shape로 만들어 행별 곱셈 수행
    basis_matrix = factors[:, np.newaxis] * basis_matrix

    return basis_matrix