import torch
from typing import Dict, Any, List, Tuple, Optional, Callable
import matplotlib.pyplot as plt
import os

def apply_prefix_cache(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    sweep_config: Dict[str, Any],
    block_num: int,
    qkv_cache_path_template: str,
    qkv_cache_files: List[Tuple[str, int]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    설정값(sweep_config)에 따라 캐시된 Q, K, V 접두사를 불러와 텐서에 추가합니다.

    Args:
        q (torch.Tensor): 원본 Query 텐서
        k (torch.Tensor): 원본 Key 텐서
        v (torch.Tensor): 원본 Value 텐서
        sweep_config (Dict[str, Any]): 접두사 추가 여부, 경로, 옵션 등이 담긴 설정 딕셔너리
        block_num (int): 현재 처리 중인 블록 번호
        qkv_cache_path_template (str): 캐시 파일 경로 템플릿
        qkv_cache_files (List[Tuple[str, int]]): 불러올 캐시 파일 이름 목록

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 접두사가 추가된 q, k, v 텐서 튜플
    """
    # 현재 블록에 접두사를 추가해야 하는지 확인. 아니라면 원본 텐서를 그대로 반환합니다.
    if not (sweep_config.get('prefix_add') and block_num in sweep_config.get('prefix_add_block', [])):
        return q, k, v

    B = q.shape[0]      # 배치 사이즈
    device = q.device   # 디바이스 정보
    dtype = q.dtype     # 데이터 타입

    all_q_prefixes, all_k_prefixes, all_v_prefixes = [], [], []

    # 캐시 파일이 저장된 기본 경로를 설정합니다.
    current_qkv_path = qkv_cache_path_template.format(
        sweep_config['target_layer'], sweep_config['target_block']
    )

    # 지정된 캐시 파일들을 순회하며 접두사를 불러옵니다.
    for filename, _ in qkv_cache_files:
        path = current_qkv_path + filename.format(sweep_config['global_prefix_rank'], block_num)
        qkv_cache = torch.load(path, map_location=device)

        cache_option = sweep_config.get('cache_option')
        prefix_number = sweep_config.get('prefix_number', 1)

        # 'KV_Cache' 또는 'Token_Insert' 옵션일 때 K, V 접두사를 처리합니다.
        if cache_option in ['KV_Cache', 'Token_Insert']:
            if 'k' in qkv_cache and 'v' in qkv_cache:
                k_prefix = qkv_cache['k'].to(dtype).repeat(1, 1, prefix_number, 1)
                v_prefix = qkv_cache['v'].to(dtype).repeat(1, 1, prefix_number, 1)
                all_k_prefixes.append(k_prefix)
                all_v_prefixes.append(v_prefix)

        # 'Token_Insert' 옵션일 때 Q 접두사를 처리합니다.
        if cache_option == 'Token_Insert':
            if 'q' in qkv_cache:
                q_prefix = qkv_cache['q'].to(dtype).repeat(1, 1, prefix_number, 1)
                all_q_prefixes.append(q_prefix)

    # 불러온 모든 접두사를 합치고 원본 텐서의 앞부분에 연결합니다.
    if all_q_prefixes:
        q_total_prefix = torch.cat(all_q_prefixes, dim=2).expand(B, -1, -1, -1)
        q = torch.cat((q_total_prefix, q), dim=2)

    if all_k_prefixes:
        k_total_prefix = torch.cat(all_k_prefixes, dim=2).expand(B, -1, -1, -1)
        v_total_prefix = torch.cat(all_v_prefixes, dim=2).expand(B, -1, -1, -1)
        k = torch.cat((k_total_prefix, k), dim=2)
        v = torch.cat((v_total_prefix, v), dim=2)

    return q, k, v

def calculate_prediction_blocks(sweep_config: Dict[str, Any]) -> List[int]:
    """
    설정값에 따라 토큰 삭제 예측을 수행할 블록 번호 리스트를 계산합니다.

    Args:
        sweep_config (Dict[str, Any]): 'token_delete_method', 'token_delete_block', 
                                     'token_delete_previous_layer' 키를 포함하는 설정 딕셔너리.

    Returns:
        List[int]: 예측을 수행할 블록 번호 리스트.

    Raises:
        ValueError: 계산된 블록 번호가 음수일 경우 발생.
    """
    # 토큰 삭제 방법이 지정된 ['score', 'attn', 'value'] 중 하나가 아니면 빈 리스트 반환
    if sweep_config.get('token_delete_method') not in ['score', 'attn', 'value']:
        return []
    
    # 예측을 수행할 실제 블록 위치 계산
    # (예: 5번 블록의 결과를 보고 4번 블록에서 예측 -> 5 - 1 = 4)
    prediction_blocks = [
        b - sweep_config.get('token_delete_previous_layer', 0)
        for b in sweep_config.get('token_delete_block', [])
    ]
    
    # 계산된 블록 번호가 유효한지 검사
    for num in prediction_blocks:
        if num < 0:
            # 에러 발생 시 프로그램을 강제 종료하는 대신, 호출한 쪽에서 처리할 수 있도록 ValueError 발생
            raise ValueError(
                f"계산된 예측 블록 번호가 음수({num})입니다. "
                f"'token_delete_block'과 'token_delete_previous_layer' 설정을 확인해주세요."
            )
            
    return prediction_blocks

def predict_tokens_to_delete(
    attn: torch.Tensor,
    v: torch.Tensor,
    sweep_config: Dict[str, Any],
    block_num: int,
    prediction_blocks: List[int],
    dynamic_delete_index: List[int],
) -> Optional[torch.Tensor]:
    if sweep_config.get('token_delete') and block_num > min(prediction_blocks):
        return dynamic_delete_index
    elif not (sweep_config.get('token_delete') and block_num in prediction_blocks):
        return None

    prefix_len = sweep_config.get('prefix_number', 0)
    cache_option = sweep_config.get('cache_option')
    
    attn_k, v_k = attn, v
    if sweep_config.get('prefix_add') and prefix_len > 0:
        prefix_add_blocks = sweep_config.get('prefix_add_block', [])
        if (cache_option == 'Token_Insert' and block_num >= min(prefix_add_blocks)) or \
           (cache_option == 'KV_Cache' and block_num in prefix_add_blocks):
            
            attn_k = attn[:, :, prefix_len+1:, prefix_len+1:] if cache_option == 'Token_Insert' else attn[:, :, :, prefix_len+1:]
            v_k = v[:, :, prefix_len+1:, :]

    method = sweep_config['token_delete_method']
    if method == 'attn':
        final_scores = torch.sum(attn_k, dim=2)
    elif method == 'value':
        value_norms = torch.norm(v_k, p=1, dim=-1)
        final_scores = 1 / (value_norms + 1e-9)  
    else:
        return None
    k = sweep_config.get('token_delete_number', 0)
    if k == 0:
        return None
        
    _, predicted_delete_index = torch.topk(final_scores, k, dim=-1, largest=True)
    
    head_index = sweep_config.get('head_index_for_score', 0)
    predicted_delete_index = predicted_delete_index[:, head_index]

    return predicted_delete_index


def apply_token_deletion(
    attn: torch.Tensor,
    v: torch.Tensor,
    sweep_config: Dict[str, Any],
    block_num: int,
    index_correction: int,
    random_outlier_index1: Dict[str, list],
    random_outlier_index2: Dict[str, list], 
    random_outlier_index3: Dict[str, list],
    predicted_delete_index: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    제공된 코드의 로직을 수정 없이 그대로 사용하여 토큰 삭제를 수행합니다.

    Args:
        attn (torch.Tensor): 원본 어텐션 텐서.
        v (torch.Tensor): 원본 밸류 텐서.
        q_device (torch.device): 텐서 생성을 위한 디바이스 정보 (원본의 q.device).
        sweep_config (Dict[str, Any]): 설정 딕셔너리.
        block_num (int): 현재 블록 번호 (원본의 self.block_num).
        index_correction (int): 인덱스 보정값 (원본의 self.index_correction).
        static_indices_map (Dict[str, list]): 정적 삭제에 사용할 인덱스 맵.
        predicted_delete_index (Optional[torch.Tensor]): 예측된 삭제 인덱스.
        dynamic_delete_index (Optional[torch.Tensor]): 동적 삭제 인덱스.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, int]:
            - 필터링된 attention 텐서
            - 필터링된 value 텐서
            - 삭제 후 남은 토큰의 개수 (새로운 시퀀스 길이 N)
    """
    # 원본 코드의 시작 조건
    q_device = v.device
    if sweep_config['token_delete'] and block_num in sweep_config['token_delete_block']:
        B, H, N_q, N_k = attn.shape
        batch_keep_mask_q = torch.ones(B, N_q, dtype=torch.bool, device=q_device)
        batch_keep_mask_k = torch.ones(B, N_k, dtype=torch.bool, device=q_device)

        if sweep_config['token_delete_method'] in ['random1', 'random2', 'random3', 'frequent_index']:
            if sweep_config['token_delete_method'] == 'random1':
                delete_index = random_outlier_index1
            elif sweep_config['token_delete_method'] == 'random2':
                delete_index = random_outlier_index2
            elif sweep_config['token_delete_method'] == 'random3':
                delete_index = random_outlier_index3

            delete_index = delete_index[:sweep_config['token_delete_number']]
            if sweep_config['cache_option'] == 'KV_Cache':
                indices_to_delete_list_q = [idx for idx in delete_index]
                indices_to_delete_list_k = [idx + index_correction for idx in delete_index]
                indices_to_delete_q = torch.tensor(indices_to_delete_list_q, device=q_device, dtype=torch.long)
                indices_to_delete_k = torch.tensor(indices_to_delete_list_k, device=q_device, dtype=torch.long)
                if indices_to_delete_q.numel() > 0: batch_keep_mask_q[:, indices_to_delete_q] = False
                if indices_to_delete_k.numel() > 0: batch_keep_mask_k[:, indices_to_delete_k] = False
            elif sweep_config['cache_option'] == 'Token_Insert':
                indices_to_delete_list_q = [idx + index_correction for idx in delete_index]
                indices_to_delete_list_k = [idx + index_correction for idx in delete_index]
                indices_to_delete_q = torch.tensor(indices_to_delete_list_q, device=q_device, dtype=torch.long)
                indices_to_delete_k = torch.tensor(indices_to_delete_list_k, device=q_device, dtype=torch.long)
                if indices_to_delete_q.numel() > 0: batch_keep_mask_q[:, indices_to_delete_q] = False
                if indices_to_delete_k.numel() > 0: batch_keep_mask_k[:, indices_to_delete_k] = False

        elif sweep_config['token_delete_method'] in ['attn', 'value']:
            if sweep_config['token_delete_previous_layer'] == 0:
                if sweep_config['cache_option'] == 'KV_Cache':
                    per_sample_indices_q = (predicted_delete_index).squeeze(1)
                    per_sample_indices_k = (predicted_delete_index + index_correction).squeeze(1)
                elif sweep_config['cache_option'] == 'Token_Insert':
                    per_sample_indices_q = (predicted_delete_index).squeeze(1)
                    per_sample_indices_k = (predicted_delete_index + index_correction).squeeze(1)

            K_q = per_sample_indices_q.shape[-1] if per_sample_indices_q.dim() == 2 else 1
            K_k = per_sample_indices_k.shape[-1] if per_sample_indices_k.dim() == 2 else 1
            flat_token_idx_q = per_sample_indices_q.reshape(-1)
            flat_token_idx_k = per_sample_indices_k.reshape(-1)
            flat_batch_idx_q = torch.arange(B, device=q_device).repeat_interleave(K_q)
            flat_batch_idx_k = torch.arange(B, device=q_device).repeat_interleave(K_k)
            batch_keep_mask_q[flat_batch_idx_q, flat_token_idx_q] = False
            batch_keep_mask_k[flat_batch_idx_k, flat_token_idx_k] = False

        query_mask = batch_keep_mask_q.unsqueeze(1).unsqueeze(3)
        key_mask = batch_keep_mask_k.unsqueeze(1).unsqueeze(2)
        final_4d_mask = query_mask & key_mask
        num_kept_q = batch_keep_mask_q[0].sum()
        num_kept_k = batch_keep_mask_k[0].sum()
        attn = attn.masked_select(final_4d_mask).view(B, H, num_kept_q, num_kept_k)
        v_mask = key_mask.transpose(2, 3)
        v = v.masked_select(v_mask).view(B, H, num_kept_k, v.shape[-1])
        attn = attn / (torch.sum(attn, dim=-1, keepdim=True) + 1e-9)
        N = num_kept_q
        
        return attn, v, N

    # 토큰 삭제 조건이 아닐 경우, 원본 텐서와 시퀀스 길이를 그대로 반환
    else:
        return attn, v, attn.shape[2]


def add_residual_prefix(
    residual_1: torch.Tensor,
    sweep_config: Dict[str, Any],
    block_num: int,
    x: torch.Tensor,
    q_block_input_path_template: str,
    q_block_prefix_files: List[Tuple[str, int]]
) -> torch.Tensor:
    """
    원본 코드의 로직을 그대로 사용하여 residual 텐서에 접두사를 추가합니다.

    Args:
        residual_1 (torch.Tensor): 접두사를 추가할 원본 텐서.
        sweep_config (Dict[str, Any]): 설정 딕셔너리 (원본의 self.sweep_config).
        block_num (int): 현재 블록 번호 (원본의 self.block_num).
        x (torch.Tensor): .device 정보 확인을 위한 텐서.
        q_block_input_path_template (str): 접두사 파일 경로 템플릿.
        q_block_prefix_files (List[Tuple[str, int]]): 불러올 접두사 파일 목록.

    Returns:
        torch.Tensor: 접두사가 추가되었거나, 조건이 맞지 않으면 원본 그대로의 residual_1 텐서.
    """
    # <<<--- 시작: 보내주신 코드 (self.만 수정) ---<<<
    if sweep_config['prefix_add'] and sweep_config['cache_option'] == 'Token_Insert' and block_num in sweep_config['prefix_add_block']:
        B, N, C = residual_1.shape
        prefix_vectors = []
        current_block_path = q_block_input_path_template.format(sweep_config['target_layer'], sweep_config['target_block'])
        for filename, num_copies in q_block_prefix_files:
            path = current_block_path + filename.format(sweep_config['global_prefix_rank'], block_num)
            loaded_tensor = torch.load(path, map_location=x.device)
            repeated_tensor = loaded_tensor.unsqueeze(0).expand(sweep_config['prefix_number'], -1)
            prefix_vectors.append(repeated_tensor)
        prefix_tensor_single = torch.cat(prefix_vectors, dim=0)
        prefix_tensor_batch = prefix_tensor_single.unsqueeze(0).expand(B, -1, -1)
        residual_1 = torch.cat((prefix_tensor_batch, residual_1), dim=1)
    # <<<--- 종료: 보내주신 코드 ---<<<
    
    return residual_1


def delete_residual_tokens(
    residual_1: torch.Tensor,
    sweep_config: Dict[str, Any],
    block_num: int,
    index_correction: int,
    x: torch.Tensor,
    random_outlier_index1: List[int],
    random_outlier_index2: List[int],
    random_outlier_index3: List[int],
    predicted_delete_index: Optional[torch.Tensor],
    dynamic_delete_index: Optional[torch.Tensor],
    error_log: Callable
) -> torch.Tensor:
    """
    원본 코드의 로직을 그대로 사용하여 residual 텐서에서 토큰을 삭제합니다.

    Args:
        residual_1 (torch.Tensor): 토큰을 삭제할 원본 텐서.
        sweep_config (Dict[str, Any]): 설정 딕셔너리 (원본의 self.sweep_config).
        block_num (int): 현재 블록 번호 (원본의 self.block_num).
        index_correction (int): 인덱스 보정값 (원본의 self.index_correction).
        x (torch.Tensor): .device 정보 확인을 위한 텐서.
        RANDOM_OUTLIER_INDEX1 (List[int]): 정적 인덱스 리스트 1.
        RANDOM_OUTLIER_INDEX2 (List[int]): 정적 인덱스 리스트 2.
        RANDOM_OUTLIER_INDEX3 (List[int]): 정적 인덱스 리스트 3.
        FREQUENT_OUTLIER_INDEX (List[int]): 정적 인덱스 리스트 4.
        predicted_delete_index (Optional[torch.Tensor]): 예측된 삭제 인덱스.
        dynamic_delete_index (Optional[torch.Tensor]): 동적 삭제 인덱스.
        error_log (Callable): 에러 발생 시 호출할 로깅 함수.

    Returns:
        torch.Tensor: 토큰이 삭제되었거나, 조건이 맞지 않으면 원본 그대로의 residual_1 텐서.
    """
    if sweep_config['token_delete'] and block_num in sweep_config['token_delete_block']:
        if sweep_config['token_delete_method'] in ['random1', 'random2', 'random3', 'frequent_index']:
            if sweep_config['token_delete_method'] == 'random1':
                delete_index = random_outlier_index1
            elif sweep_config['token_delete_method'] == 'random2':
                delete_index = random_outlier_index2
            elif sweep_config['token_delete_method'] == 'random3':
                delete_index = random_outlier_index3
            delete_index = delete_index[:sweep_config['token_delete_number']]
            delete_index = [idx + index_correction for idx in delete_index]

        elif sweep_config['token_delete_method'] in ['attn', 'value']:
            if sweep_config['token_delete_previous_layer'] == 0:
                dynamic_delete_index = predicted_delete_index
            if dynamic_delete_index is None:
                print("dynamic_delete_index is None. Error.")
                error_log(sweep_config)
                exit()
            delete_index = dynamic_delete_index
            delete_index = delete_index + index_correction
        B, N, C = residual_1.shape
        batch_keep_mask = torch.ones(B, N, dtype=torch.bool, device=x.device)

        if sweep_config['token_delete_method'] in ['random1', 'random2', 'random3', 'frequent_index']:
            indices_to_delete = torch.tensor(delete_index, device=x.device, dtype=torch.long)
            batch_keep_mask[:, indices_to_delete] = False

        elif sweep_config['token_delete_method'] in ['attn', 'value']:
            per_sample_indices_to_delete = delete_index.to(device=x.device, dtype=torch.long)  # shape: [B, 8]
            batch_idx = torch.arange(B, device=x.device).unsqueeze(1)
            batch_keep_mask[batch_idx, per_sample_indices_to_delete] = False
        kept_elements = residual_1.masked_select(batch_keep_mask.unsqueeze(-1))
        num_kept_tokens = N - sweep_config['token_delete_number']
        x_mask = kept_elements.view(B, num_kept_tokens, C)
        residual_1 = x_mask
    
    return residual_1


import torch

def delete_residual_tokens2(residual_2: torch.Tensor, delete_index: torch.Tensor) -> torch.Tensor:
    """
    주어진 delete_index를 사용하여 residual 텐서에서 토큰을 삭제합니다.

    Args:
        residual_2 (torch.Tensor): 토큰을 삭제할 원본 텐서. Shape: (B, N, C)
        delete_index (torch.Tensor): 삭제할 토큰들의 인덱스. Shape: (B, k)
                                     (B: 배치 크기, k: 삭제할 토큰의 수)

    Returns:
        torch.Tensor: 해당 인덱스의 토큰이 삭제된 새로운 residual 텐서. Shape: (B, N-k, C)
    """
    # 1. residual_2 텐서의 배치(B), 토큰 수(N), 차원(C) 정보를 가져옵니다.
    B, N, C = residual_2.shape
    
    # 2. 삭제할 토큰의 수(k)를 delete_index의 shape에서 가져옵니다.
    k = delete_index.shape[1]

    # 3. 삭제하지 않을 토큰을 표시하기 위한 boolean 마스크를 생성합니다.
    # 우선 모든 토큰을 유지(True)하도록 (B, N) 크기의 마스크를 생성합니다.
    keep_mask = torch.ones(B, N, dtype=torch.bool, device=residual_2.device)

    # 4. 삭제할 토큰의 위치를 마스크에서 False로 변경합니다.
    # scatter_ 함수를 사용해 각 배치(B)마다 delete_index에 해당하는 위치를 False로 설정합니다.
    keep_mask.scatter_(1, delete_index, False)
    
    # 5. 마스크를 사용해 토큰을 실제로 삭제하고, 텐서의 shape을 재구성합니다.
    # keep_mask를 (B, N, 1)로 확장하여 residual_2 (B, N, C)에 적용합니다.
    # 그 후, 남은 토큰의 수(N - k)에 맞게 view를 통해 shape을 (B, N-k, C)로 변경합니다.
    num_kept_tokens = N - k
    residual_2_deleted = torch.masked_select(residual_2, keep_mask.unsqueeze(-1)).view(B, num_kept_tokens, C)

    return residual_2_deleted