import torch
from typing import Tuple, List
from .omp import omp_v0  # Assuming omp_v0 is defined in omp.py


class kSVD:
    def __init__(self, dictionary: torch.Tensor, model_config, dtype):
        self.dictionary = dictionary
        self.layer_num = model_config.num_hidden_layers
        self.n_head = model_config.num_key_value_heads
        self.h_dim = model_config.hidden_size // model_config.num_attention_heads
        self.dtype  = dtype
        self.DTD = None

        # --- Dictionary Compatibility Check & Setup ---
        model_feature_dim = self.h_dim * self.n_head
        dict_feature_dim = self.dictionary.shape[1]

        if dict_feature_dim % model_feature_dim != 0:
            raise RuntimeError(
                f"Dictionary's feature size ({dict_feature_dim}) is not a multiple of "
                f"the model's feature size ({model_feature_dim})."
            )
        
        self.num_layer_to_merge = dict_feature_dim // model_feature_dim

        if self.layer_num % self.num_layer_to_merge != 0:
            raise RuntimeError(
                f"Number of layers ({self.layer_num}) is not divisible by the "
                f"dictionary's merge factor ({self.num_layer_to_merge})."
            )

    def prepare_for_compression(self):
        """
        Pre-computes the DTD matrix for faster compression.
        This should be called once before running a batch of compressions.
        """
        if self.DTD is None:
            # DTD is dictionary.T @ dictionary, which is used in OMP.
            self.DTD = torch.bmm(self.dictionary.permute(0, 2, 1), self.dictionary)

    def clear_compression_resources(self):
        """
        Clears pre-computed resources to free memory.
        """
        self.DTD = None
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    def compress(self, kvcache: torch.Tensor, sparsity: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Compress the KV cache into indices and values using the provided dictionary.
        Assumes prepare_for_compression has been called.
        Args:
            kvcache: KV cache tensor
            sparsity: Sparsity level for compression
        Returns:
            Tuple of (indices, values) representing the compressed data
        """
        if self.DTD is None:
            raise RuntimeError(
                "DTD not pre-computed. Call prepare_for_compression() before compressing."
            )

        # --- Tensor Preparation ---
        layer_num = len(kvcache)
        if layer_num % self.num_layer_to_merge != 0:
            raise RuntimeError(
                f"Number of layers ({layer_num}) is not divisible by the "
                f"dictionary's merge factor ({self.num_layer_to_merge})."
            )

        # Reshape key and value tensors based on the merge factor
        def prepare_tensor(kv_part_index):
            # Extract all layers for either key (0) or value (1)
            Y_part = torch.cat([kvcache[i][kv_part_index] for i in range(layer_num)], dim=0)
            n_layer_total, n_head, seq_len, h_dim_local = Y_part.size()
            
            # Reshape to group layers for merging
            # (seq_len, layer_groups, layers_to_merge, feature_dim)
            Y_part = Y_part.permute(2, 0, 1, 3).contiguous()
            Y_part = Y_part.view(seq_len, -1, self.num_layer_to_merge, n_head * h_dim_local)

            # Concatenate the layers to form the final signal matrix
            # (total_signals, merged_feature_dim)
            return Y_part.view(-1, self.num_layer_to_merge * n_head * h_dim_local)

        Y_key = prepare_tensor(0)
        Y_value = prepare_tensor(1)
        
        Y = torch.cat([Y_key.unsqueeze(1), Y_value.unsqueeze(1)], dim=1).float()

        indices, values, _, _, _, _ = omp_v0(self.dictionary, Y.transpose(0, 1), self.DTD, sparsity)
        values = values.squeeze(-1).cpu()
        return indices, values

    def reconstruct_kSVD_with_omp_v0(self, indices: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
        """
        Reconstruct the KV cache using the provided indices and values.
        Args:
            indices: Indices tensor
            values: Values tensor
        Returns:
            Reconstructed KV cache
        """
        device = indices.device
        kv, n_signals, sparsity = indices.shape
        
        # Adjust sequence length calculation based on merged layers
        if (n_signals * self.num_layer_to_merge) % self.layer_num != 0:
            raise RuntimeError("Number of signals and layers are not compatible.")
        seq_len = (n_signals * self.num_layer_to_merge) // self.layer_num
        
        # --- Reconstruction ---
        dict_size = self.dictionary.shape[2]
        sparse_matrix = torch.zeros((kv, n_signals, dict_size), device=device, dtype=self.dtype)
        sparse_matrix.scatter_(-1, indices.to(torch.int64), values.to(self.dtype))

        recon = torch.einsum('lmn,bln->blm', self.dictionary.to(self.dtype), sparse_matrix.transpose(0, 1))
        
        # --- Reshape back to KV Cache format (Inverse of Compression) ---
        # recon shape after einsum: (n_signals, kv, dict_feature_dim)
        
        # 1. Inverse of the final .view() in compress
        num_layer_groups = self.layer_num // self.num_layer_to_merge
        # recon = recon.reshape(
        #     seq_len,
        #     num_layer_groups,
        #     self.num_layer_to_merge,
        #     kv, # kv is split into key and value later
        #     self.n_head,
        #     self.h_dim
        # )
        recon = recon.reshape(
            seq_len,
            num_layer_groups,
            kv, # kv is split into key and value later
            self.num_layer_to_merge,
            self.n_head,
            self.h_dim
        )
        
        # 2. Inverse of the permute(2, 0, 1, 3) in compress
        # Current shape: (seq_len, num_groups, num_to_merge, kv, n_head, h_dim)
        # Current shape fixed: (seq_len, num_groups, kv, num_to_merge, n_head, h_dim)
        # Target shape: (num_groups, num_to_merge, kv, n_head, seq_len, h_dim)
        # recon = recon.permute(1, 2, 3, 4, 0, 5)
        recon = recon.permute(1, 3, 2, 4, 0, 5)

        # 3. Final reshape and split
        # Shape becomes (layer_num, kv, n_head, seq_len, h_dim)
        recon = recon.reshape(self.layer_num, kv, self.n_head, seq_len, self.h_dim)
        
        recon = recon.to(self.dtype)
        
        kvcache = []
        for i in range(self.layer_num):
            kvcache.append([
                recon[i, 0].unsqueeze(0), # Key
                recon[i, 1].unsqueeze(0)  # Value
            ])

        return tuple(kvcache)

    def reconstruct_kSVD_with_omp_v0_layer_wise_batch(self, indices: torch.Tensor, values: torch.Tensor, batch_layers:int) -> torch.Tensor:
        """
        Reconstruct the KV cache using the provided indices and values.
        Args:
            indices: Indices tensor
            values: Values tensor
        Returns:
            Reconstructed KV cache
        """
        device = indices.device
        kv, n_signals, sparsity = indices.shape
        
        # Adjust sequence length calculation based on merged layers
        if (n_signals) % batch_layers != 0:
            raise RuntimeError(f"Number of signals {n_signals} and layers {batch_layers} are not compatible.")
        seq_len = (n_signals) // batch_layers
        

        # --- Reconstruction ---
        dict_size = self.dictionary.shape[2]
        sparse_matrix = torch.zeros((kv, n_signals, dict_size), device=device, dtype=self.dtype)
        sparse_matrix.scatter_(-1, indices.to(torch.int64), values.to(self.dtype))

        recon = torch.einsum('lmn,bln->blm', self.dictionary.to(self.dtype), sparse_matrix.transpose(0, 1))
        
        # --- Reshape back to KV Cache format (Inverse of Compression) ---
        # recon shape after einsum: (n_signals, kv, dict_feature_dim)
        # and n_signals is seq_len x number of batched layers    

        # 1. Inverse of the final .view() in compress
        recon = recon.reshape(
            seq_len,
            batch_layers,
            kv, # kv is split into key and value later
            self.num_layer_to_merge,
            self.n_head,
            self.h_dim
        )
        
        # 2. Inverse of the permute(2, 0, 1, 3) in compress
        # Current shape fixed: (seq_len, num_groups, kv, num_to_merge, n_head, h_dim)
        # Target shape: (num_groups, num_to_merge, kv, n_head, seq_len, h_dim)
        recon = recon.permute(1, 3, 2, 4, 0, 5)

        # 3. Final reshape and split
        # Shape becomes (layer_num, kv, n_head, seq_len, h_dim)
        recon = recon.reshape(batch_layers * self.num_layer_to_merge, kv, self.n_head, seq_len, self.h_dim)
        
        recon = recon.to(self.dtype) # -> torch.Size([560, 2, 4, 514, 128])
        
        # --- Convert to model-compatible format ---
        # recon shape: (L, kv, H, S, D), where L is the number of layers in the chunk
        # kv=2 (key/value), H=n_head, S=seq_len, D=h_dim
        
        # We need to format this into a list of tuples for the model's forward pass.
        # The final format should be a tuple of length L, where each element is:
        # (keys, values), and keys/values have shape (1, H, S, D) since batch_size is 1 here.
        
        keys = recon[:, 0]   # (L, H, S, D)
        values = recon[:, 1] # (L, H, S, D) 

        # Add a batch dimension of 1.
        keys = keys.unsqueeze(1)   # (L, 1, H, S, D)
        values = values.unsqueeze(1) # (L, 1, H, S, D)
        
        return (keys, values)

    def test_compression_reconstruction(self, kvcache: Tuple[Tuple[torch.Tensor, torch.Tensor]], sparsity: int) -> float:
        """
        Performs a full compress-reconstruct cycle and returns the relative reconstruction error.
        This is a self-contained sanity check method.
        Args:
            kvcache: The original KV cache tuple.
            sparsity: The sparsity level for compression.
        Returns:
            The relative reconstruction error as a float.
        """
        # Ensure DTD is ready for compression
        self.prepare_for_compression()
        
        # Ensure the input kvcache is on the same device as the dictionary for the test
        device = self.dictionary.device
        kvcache_on_device = tuple((k.to(device), v.to(device)) for k, v in kvcache)

        # 1. Compress
        indices, values = self.compress(kvcache_on_device, sparsity)

        # 2. Reconstruct
        reconstructed_kvcache = self.reconstruct_kSVD_with_omp_v0(indices.to(device), values.to(device))

        # 3. Compare
        # The original cache is already on the correct device for comparison
        original_tensors = [t for layer_kv in kvcache_on_device for t in layer_kv]
        reconstructed_tensors = [t for layer_kv in reconstructed_kvcache for t in layer_kv]

        original_flat = torch.cat([t.flatten() for t in original_tensors])
        reconstructed_flat = torch.cat([t.flatten() for t in reconstructed_tensors])

        # Use float32 for error calculation to avoid precision issues
        error_norm = torch.norm(original_flat.float() - reconstructed_flat.float())
        original_norm = torch.norm(original_flat.float())
        
        # This sanity check method shouldn't permanently alter the state, so clear resources.
        self.clear_compression_resources()

        if original_norm.item() == 0:
            return 0.0
        
        relative_error = error_norm / original_norm
        return relative_error.item()

    def reconstruct_and_stack_kSVD_with_omp_v0_batched(
        self, 
        batch_indices: torch.Tensor, 
        batch_values: torch.Tensor,
        _for_test_return_unstacked: bool = False
    ) -> Tuple[torch.Tensor, List[int]]:
        """
        Reconstruct the KV cache in a batched manner using the provided indices and values.
        Args:
            batch_indices: Batched indices tensor
            batch_values: Batched values tensor
            _for_test_return_unstacked: Boolean flag to return unstacked reconstructed tensor
        Returns:
            Tuple of (batch_past_key_values, group_past_kv)
        """
        device = batch_indices.device
        batch_size, kv, n_signals, sparsity = batch_indices.shape
        
        if (n_signals * self.num_layer_to_merge) % self.layer_num != 0:
            raise RuntimeError("Number of signals and layers are not compatible.")
        seq_len = (n_signals * self.num_layer_to_merge) // self.layer_num

        # --- Batched Reconstruction ---
        dict_size = self.dictionary.shape[2]
        sparse_matrix = torch.zeros((batch_size, kv, n_signals, dict_size), device=device, dtype=self.dtype)
        sparse_matrix.scatter_(-1, batch_indices.to(torch.int64), batch_values.to(self.dtype))

        recon_raw = torch.einsum('bksa,kfa->bksf', sparse_matrix, self.dictionary.to(self.dtype))

        # --- Reshape back to Stacked KV Cache format (Inverse of Compression) ---
        # recon_raw shape: (batch, kv, n_signals, dict_feature_dim)

        # 1. Inverse of the final .view() in compress
        num_layer_groups = self.layer_num // self.num_layer_to_merge
        recon = recon_raw.reshape(
            batch_size,
            kv,
            seq_len,
            num_layer_groups,
            self.num_layer_to_merge,
            self.n_head,
            self.h_dim
        )
        
        # 2. Inverse of the permute(...) in compress
        # Current: (batch, kv, seq, num_groups, num_to_merge, n_head, h_dim)
        # Target: (batch, kv, num_groups, num_to_merge, n_head, seq, h_dim)
        recon = recon.permute(0, 1, 3, 4, 5, 2, 6)

        # 3. Final reshape to get (batch, kv, layer_num, n_head, seq_len, h_dim)
        recon = recon.reshape(
            batch_size,
            kv,
            self.layer_num,
            self.n_head,
            seq_len,
            self.h_dim
        )
        
        if _for_test_return_unstacked:
            return recon

        B, two, L, H, S, D = recon.shape
        keys_all = recon[:, 0].permute(1, 2, 0, 3, 4).reshape(L, H, batch_size * S, D).unsqueeze(1)
        values_all = recon[:, 1].permute(1, 2, 0, 3, 4).reshape(L, H, batch_size * S, D).unsqueeze(1)
        batch_past_key_values = [(keys_all[i], values_all[i]) for i in range(L)]
        group_past_kv = [seq_len] * batch_size

        return tuple(batch_past_key_values), group_past_kv

    def test_batched_compression_reconstruction(self, kvcache_batch: List[Tuple[Tuple[torch.Tensor, torch.Tensor]]], sparsity: int) -> float:
        """
        Performs a full compress-reconstruct cycle for a batch of KV caches and returns the average error.
        """
        self.prepare_for_compression()

        # 1. Compress each item in the batch
        batch_indices = []
        batch_values = []
        device = self.dictionary.device
        for kvcache in kvcache_batch:
            kvcache_on_device = tuple((k.to(device), v.to(device)) for k, v in kvcache)
            indices, values = self.compress(kvcache_on_device, sparsity)
            batch_indices.append(indices)
            batch_values.append(values)
        
        stacked_indices = torch.stack(batch_indices, dim=0).to(device)
        stacked_values = torch.stack(batch_values, dim=0).to(device)

        # 2. Reconstruct the batch and get the un-stacked result
        recon_batch_tensor = self.reconstruct_and_stack_kSVD_with_omp_v0_batched(
            stacked_indices, stacked_values, _for_test_return_unstacked=True
        )

        total_error = 0.0
        for i, original_kvcache in enumerate(kvcache_batch):
            # 3. Format the reconstructed tensor for item `i` back into a standard kvcache tuple
            recon_item_tensor = recon_batch_tensor[i].permute(1, 0, 2, 3, 4) # (L, kv, H, S, D)
            
            reconstructed_kvcache = []
            for j in range(self.layer_num):
                reconstructed_kvcache.append((
                    recon_item_tensor[j, 0].unsqueeze(0), # Key
                    recon_item_tensor[j, 1].unsqueeze(0)  # Value
                ))
            reconstructed_kvcache = tuple(reconstructed_kvcache)
            
            # 4. Compare with the original
            original_on_device = tuple((k.to(device), v.to(device)) for k, v in original_kvcache)
            
            # Correctly flatten the nested tuple structure for both original and reconstructed caches
            original_flat = torch.cat([t.flatten() for layer_kv in original_on_device for t in layer_kv])
            reconstructed_flat = torch.cat([t.flatten() for layer_kv in reconstructed_kvcache for t in layer_kv])

            error_norm = torch.norm(original_flat.float() - reconstructed_flat.float())
            original_norm = torch.norm(original_flat.float())
            
            if original_norm.item() > 0:
                total_error += (error_norm / original_norm).item()

        return total_error / len(kvcache_batch) if kvcache_batch else 0.0


if __name__ == "__main__":
    from .config import config
    from .utils.utils import load_model_and_tokenizer
    
    # --- Test Configuration ---
    # Set the paths to your model and dictionary for a realistic test
    MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"
    DICT_PATH = "/path/to/your/layer_merged_dictionary.pt"
    SPARSITY = 64
    # --- End Test Configuration ---

    print("--- Starting Realistic Compression/Reconstruction Sanity Check ---")

    # 1. Load Real Model for its configuration
    print(f"Loading model: {MODEL_NAME} to get config...")
    model, tokenizer = load_model_and_tokenizer(MODEL_NAME, use_modified=False)
    model_config = model.config
    # We no longer need the full model, so delete it to save memory
    del model
    del tokenizer
    
    # 2. Load Real Dictionary
    print(f"Loading dictionary from: {DICT_PATH}")
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    dictionary = torch.load(DICT_PATH, weights_only=True).to(device)

    # 3. Initialize kSVD class with real components
    print("Initializing kSVD class...")
    ksvd = kSVD(dictionary, model_config, torch.float32)
    print(f"  - kSVD initialized with a layer merge factor of: {ksvd.num_layer_to_merge}")
    
    # 4. Create a dummy KV cache that matches the real model's dimensions
    print("Creating a dummy KV cache with compatible dimensions...")
    kvcache = tuple(
        (torch.randn(1, model_config.num_key_value_heads, 10, model_config.hidden_size // model_config.num_attention_heads),
         torch.randn(1, model_config.num_key_value_heads, 10, model_config.hidden_size // model_config.num_attention_heads))
        for _ in range(model_config.num_hidden_layers)
    )
    
    # 5. Run the end-to-end test
    print(f"Running test_compression_reconstruction with sparsity: {SPARSITY}...")
    reconstruction_error = ksvd.test_compression_reconstruction(kvcache, SPARSITY)
    
    print(f"\n--- Sanity Check Complete ---")
    print(f"  Relative Reconstruction Error: {reconstruction_error:.6f}") 