import torch
import os
import matplotlib.pyplot as plt


def save_l1_norm_plot(value_tensor, save_dir, filename_prefix, block_num, sample_idx):
    v_cpu = value_tensor.detach().cpu()

    # 각 토큰의 특성(head_dim) 차원에 대해 L1 Norm 계산
    # (B, H, N, D_h) -> (B, H, N)
    l1_norm = torch.linalg.norm(v_cpu, ord=1, dim=-1)

    # 모든 헤드에 대해 반복
    num_heads = v_cpu.shape[1]
    for head_idx in range(num_heads):
        # 1. 헤드별로 저장할 디렉토리 생성
        head_save_dir = os.path.join(save_dir, f"head_{head_idx}")
        os.makedirs(head_save_dir, exist_ok=True)

        # 2. 현재 헤드의 데이터 선택 (배치의 첫 번째 샘플 사용)
        l1_norm_sample = l1_norm[sample_idx, head_idx, :]

        # 3. 그래프 생성
        plt.figure(figsize=(12, 6))
        plt.plot(l1_norm_sample.numpy())
        # 제목에 현재 헤드 번호 추가
        plt.title(f'Block {block_num} - Value L1 Norm per Token (Batch 0, Head {head_idx})')
        plt.xlabel("Token Index")
        plt.ylabel("L1 Norm")
        plt.grid(True)

        # 4. 가장 작은 L1 norm 값 3개를 가진 토큰의 인덱스 찾기
        _, smallest_indices = torch.topk(l1_norm_sample, k=3, largest=False)
        indices_str = ', '.join(map(str, smallest_indices.tolist()))
        info_text = f"Top 3 Smallest L1 Norm Indices: {indices_str}"
        plt.figtext(0.5, 0.01, info_text, ha="center", fontsize=9, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

        # 5. 헤드별 디렉토리에 그래프 저장
        save_path = os.path.join(head_save_dir, f"{filename_prefix}_block_{block_num}.png")
        plt.savefig(save_path)
        plt.close() # 메모리 해제를 위해 figure를 닫아줍니다.

def save_l1_norm_plot(value_tensor, save_dir, filename_prefix, block_num, sample_idx):
    v_cpu = value_tensor.detach().cpu()

    l1_norm = torch.linalg.norm(v_cpu, ord=1, dim=-1)

    num_heads = v_cpu.shape[1]
    for head_idx in range(num_heads):

        head_save_dir = os.path.join(save_dir, f"head_{head_idx}")
        os.makedirs(head_save_dir, exist_ok=True)

        l1_norm_sample = l1_norm[sample_idx, head_idx, :]

        # 3. 그래프 생성
        plt.figure(figsize=(12, 6))
        plt.plot(l1_norm_sample.numpy())
        # 제목에 현재 헤드 번호 추가
        plt.title(f'Block {block_num} - Value L1 Norm per Token (Batch 0, Head {head_idx})')
        plt.xlabel("Token Index")
        plt.ylabel("L1 Norm")
        plt.grid(True)

        # 4. 가장 작은 L1 norm 값 3개를 가진 토큰의 인덱스 찾기
        _, smallest_indices = torch.topk(l1_norm_sample, k=3, largest=False)
        indices_str = ', '.join(map(str, smallest_indices.tolist()))
        info_text = f"Top 3 Smallest L1 Norm Indices: {indices_str}"
        plt.figtext(0.5, 0.01, info_text, ha="center", fontsize=9, bbox={"facecolor":"orange", "alpha":0.5, "pad":5})

        # 5. 헤드별 디렉토리에 그래프 저장
        save_path = os.path.join(head_save_dir, f"{filename_prefix}_block_{block_num}.png")
        plt.savefig(save_path)
        plt.close() 

#l2 norm으로 구하기 



def log_tensor_statistics(x: torch.Tensor, log_file_path: str, prefix: str, block_num: int, sample_idx: int):
    """
    텐서 x의 max 및 median 값을 계산하여 지정된 로그 파일에 append합니다.
    """
    x = x[sample_idx, :, :]
    x = x.abs()
    x_flat = x.reshape(-1).float()  # flatten & ensure float for safety
    max_val = x_flat.max().item()
    median_val = x_flat.median().item()

    if median_val != 0:
        share = max_val / median_val
        share_str = f"{share:.6f}"
        medianXscaleFactor = median_val * 256 / max_val
    else:
        share_str = "inf"
        medianXscaleFactor = 0.0  # 또는 float('inf') 등 선택 가능

    os.makedirs(os.path.dirname(log_file_path), exist_ok=True)

    with open(log_file_path, "a") as f:
        f.write(
            f"Block Num: {block_num} || Max: {max_val:.2f} | Median: {median_val:.6f} | "
            f"Max/Median: {share_str:} | median * scale factor: {medianXscaleFactor:.6f} \n"
        )

# Max, Median 통계와 동시에 상위 100개 값 ranking 
def log_tensor_statistics_with_rank(x: torch.Tensor, csv_file_path: str, block_num: int):
    """
    텐서 x에서 가장 큰 값 100개를 찾아 지정된 CSV 파일에 append합니다.
    CSV 파일에는 block_num, 순위(rank), 값(value)이 기록됩니다.
    """
    # --- 1. 데이터 준비 ---
    # 텐서는 (batch, tokens, dim) 형태라고 가정하고, 첫 번째 배치 아이템만 사용합니다.
    x = x[0, :, :]
    x = x.abs()
    x_flat = x.reshape(-1).float()  # 텐서를 1차원으로 만들고 float 타입으로 변환

    # 텐서에 최소 100개의 원소가 있는지 확인합니다.
    if x_flat.numel() < 100:
        print(f"Warning: Block {block_num} has fewer than 100 values. Skipping CSV write.")
        return

    # --- 2. 상위 100개 값 추출 ---
    # torch.topk를 사용하여 가장 큰 100개의 값을 효율적으로 찾습니다.
    # .tolist()로 변환하여 Python 리스트로 만듭니다.
    top_100_values = torch.topk(x_flat, 100).values.tolist()

    # --- 3. CSV 파일에 데이터 기록 ---
    # CSV 파일 저장을 위한 디렉토리 생성
    os.makedirs(os.path.dirname(csv_file_path), exist_ok=True)
    
    # 파일이 없는 경우 헤더를 추가하기 위해 미리 파일 존재 여부를 확인합니다.
    file_exists = os.path.exists(csv_file_path)

    # 'a' (append) 모드로 파일을 열어 기존 데이터를 유지하고 새 데이터를 추가합니다.
    # newline='' 옵션은 CSV 파일에서 불필요한 빈 줄이 생기는 것을 방지합니다.
    with open(csv_file_path, "a", newline='') as f:
        writer = csv.writer(f)

        # 파일이 새로 생성된 경우에만 헤더를 씁니다.
        if not file_exists:
            writer.writerow(["block_num", "rank", "value"])

        # 상위 100개 값을 한 줄씩 CSV에 기록합니다.
        # enumerate를 사용하여 1부터 시작하는 순위(rank)를 만듭니다.
        for i, value in enumerate(top_100_values):
            rank = i + 1
            writer.writerow([block_num, rank, value])

def get_unique_filepath(base_path: str) -> str:
    """
    동일한 경로의 파일이 존재할 경우, 숫자를 증가시켜 고유한 경로를 반환합니다.
    예: "file.png" → "file_1.png" → "file_2.png" ...
    """
    if not os.path.exists(base_path):
        return base_path

    base, ext = os.path.splitext(base_path)
    idx = 1
    while True:
        new_path = f"{base}_{idx}{ext}"
        if not os.path.exists(new_path):
            return new_path
        idx += 1

def plot_tensor_views(x: torch.Tensor, save_dir: str, prefix: str, block_num: int, sample_idx: int):
    """
    3차원 텐서의 token-wise 및 channel-wise max 값을 bar plot으로 시각화하여 저장합니다.
    
    Args:
        x (torch.Tensor): 입력 텐서, shape (1, seq_len, channel_dim).
        save_dir (str): 그림을 저장할 디렉토리.
        prefix (str): 파일명 앞에 붙일 선택적인 문자열 (예: "Layer3_").
    """
    token_dir = os.path.join(save_dir, "token_wise")
    os.makedirs(token_dir, exist_ok=True)
    channel_dir = os.path.join(save_dir, "channel_wise")
    os.makedirs(channel_dir, exist_ok=True)

    x = x[sample_idx,:,:]
    x = x.abs()

    x = x.squeeze(0).detach().cpu().float()  # shape: (seq_len, channel_dim)
    seq_len, channel_dim = x.shape

    # Token-wise max (각 token에서 max(channel))
    token_max_vals = x.max(dim=1).values  # shape: (seq_len,)
    plt.figure(figsize=(10, 6))
    plt.bar(range(seq_len), token_max_vals.numpy(), color='skyblue')
    plt.title("Token-wise Max per Token")
    plt.xlabel("Token Index")
    plt.ylabel("Max Value Across Channels")
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    save_path = get_unique_filepath(os.path.join(save_dir, f"token_wise/{prefix}token_wise_max_{block_num}.png"))
    plt.savefig(save_path)
    plt.close()

    # Channel-wise max (각 channel에서 max(token))
    channel_max_vals = x.max(dim=0).values  # shape: (channel_dim,)
    plt.figure(figsize=(10, 6))
    plt.bar(range(channel_dim), channel_max_vals.numpy(), color='salmon')
    plt.title("Channel-wise Max per Channel")
    plt.xlabel("Channel Index")
    plt.ylabel("Max Value Across Tokens")
    plt.grid(True, axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    save_path = get_unique_filepath(os.path.join(save_dir, f"channel_wise/{prefix}channel_wise_max_{block_num}.png"))
    plt.savefig(save_path)
    plt.close()


def log_top3_tokens(x, save_dir, tag, block_num):
    """
    x: [B, T, D] - input tensor
    save_dir: str - output directory
    tag: str - identifier for filename
    block_num: int - used in filename
    """
    os.makedirs(save_dir, exist_ok=True)

    if x.dim() != 3:
        raise ValueError(f"x must be 3D tensor [B, T, D], but got shape {x.shape}")
    
    # 분석 대상: 0번째 이미지 (batch의 첫 번째 샘플)
    x0 = x[0].abs()  # shape: [T, D]

    # Step 1: 각 token의 최대값 계산 → shape: [T]
    token_max_vals, _ = torch.max(x0, dim=1)

    # Step 2: top-3 token 값 및 index 추출
    top3_vals, top3_indices = torch.topk(token_max_vals, k=3, dim=0)  # shape: [3]

    # Step 3: 결과 저장
    result_lines = []
    for i in range(3):
        val = top3_vals[i].item()
        idx = top3_indices[i].item()
        result_lines.append(f"Top {i+1}: Value = {val:.6f}, Token Index = {idx}")

    filename = os.path.join(save_dir, f"{tag}_block{block_num}_sample0.txt")
    with open(filename, "w") as f:
        f.write("\n".join(result_lines))

def save_attention_heatmap(attn: torch.Tensor, save_root: str, prefix: str, block_num: int, sample_idx: int):
    """
    Attention heatmap을 시각화하여 저장합니다.
    - 각 헤드의 attention map을 이미지로 저장 (0번째 sample만 사용)
    - 블록 번호별로 폴더를 생성하여 저장
    
    Args:
        attn (torch.Tensor): shape (B, num_heads, N, N)
        save_root (str): 저장할 루트 디렉토리
        prefix (str): 파일명 접두어
        block_num (int): 블록 번호
    """
    block_dir = os.path.join(save_root, f"block_{block_num}")
    os.makedirs(block_dir, exist_ok=True)

    attn = attn.detach().cpu()
    B, num_heads, N, _ = attn.shape
    attn0 = attn[sample_idx]  # 첫 번째 sample만 사용 → shape: (num_heads, N, N)

    for head in range(num_heads):
        plt.figure(figsize=(6, 5))
        plt.imshow(attn0[head], cmap='viridis', interpolation='nearest')
        plt.colorbar()
        plt.title(f"{prefix} Attention Head {head} (Block {block_num})")
        plt.xlabel("Key Token Index")
        plt.ylabel("Query Token Index")
        plt.tight_layout()
        filename = os.path.join(block_dir, f"{prefix}_head{head}.png")
        plt.savefig(filename)
        plt.close()

def save_tensor_distribution(tensor, block_num, save_dir="/home/user/regcache/probe_result/probe_without_reg_image1/graph", name="fc2_input", log_scale=False):
    import os
    import numpy as np
    import matplotlib.pyplot as plt

    os.makedirs(save_dir, exist_ok=True)

    data = tensor.detach().cpu().numpy()
    flat_data = data.flatten()

    max_val = np.max(flat_data)
    min_val = np.min(flat_data)
    std_val = np.std(flat_data)

    abs_sorted = np.sort(np.abs(flat_data))
    max_abs = abs_sorted[-1] if len(abs_sorted) > 0 else 1.0
    scale = 32.0 / max_abs
    quantized = np.round(abs_sorted * scale)
    zero_mapped_ratio = np.mean(quantized == 0)

    print(f"[Block {block_num}] {name} - Max: {max_val:.6e}, Min: {min_val:.6e}, Std: {std_val:.6e}, Zero-Mapped Ratio: {zero_mapped_ratio:.4%}")

    plt.figure()
    plt.hist(flat_data, bins=100, color='green', alpha=0.6)
    if log_scale:
        plt.yscale('log')
    plt.title(f'{name} Full Distribution (Block {block_num})')
    plt.xlabel('Value')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{name}_block{block_num}_full.png'))
    plt.close()

    cropped_data = flat_data[np.abs(flat_data) < 5]

    plt.figure()
    plt.hist(cropped_data, bins=100, color='blue', alpha=0.7)
    if log_scale:
        plt.yscale('log')
    plt.title(f'{name} Cropped Distribution (Block {block_num})')
    plt.xlabel('Value (Cropped to [-5, 5])')
    plt.ylabel('Frequency')
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, f'{name}_block{block_num}_cropped.png'))
    plt.close()



def print_l2_norm_diff(x, quant_x, block_num):
    # x와 quant_x가 tensor라고 가정
    l2_norm_x = torch.norm(x, p=2).item()
    l2_norm_quant_x = torch.norm(quant_x, p=2).item()
    diff = abs(l2_norm_x - l2_norm_quant_x)

    print(f"[Block {block_num}] L2 Norm - Original: {l2_norm_x:.6e}, Quantized: {l2_norm_quant_x:.6e}, Diff: {diff:.2e}")


import csv


def save_top5_abs_tokens(x, file_path_prefix, block_num):
    """
    각 토큰의 채널 차원에서 L-inf norm (최대 절댓값)을 계산한 후,
    그 값이 가장 큰 top 5 토큰을 찾아 저장합니다.

    x: Tensor of shape [B, T, C]
    file_path_prefix: CSV 파일을 저장할 디렉토리 또는 파일명 접두사
    block_num: 몇 번째 블록인지 식별하기 위한 번호
    """
    x_abs = x.abs()  # [B, T, C]

    # 1. 각 토큰 내에서 채널 간 최대 절댓값을 찾습니다. (L-inf norm)
    # max_val_per_token의 shape: [B, T]
    max_val_per_token, _ = torch.max(x_abs, dim=1)

    # 2. 최대 절댓값을 기준으로 top-5 토큰을 찾습니다.
    # topk_values (최대값), topk_token_indices (토큰 인덱스)
    # 둘 다 shape: [B, 5]
    topk_values, topk_token_indices = torch.topk(max_val_per_token, k=5, dim=-1)

    # 블록 번호를 포함한 파일 경로를 생성합니다.
    file_path = os.path.join(file_path_prefix, f"block_{block_num}_token_max.csv")

    # 디렉토리가 존재하지 않으면 생성합니다.
    os.makedirs(os.path.dirname(file_path), exist_ok=True)

    with open(file_path, mode='a', newline='') as csvfile:
        writer = csv.writer(csvfile)
        # 각 배치 샘플에 대해 반복합니다.
        for b in range(x.size(0)):
            row = [f"block_{block_num}", f"sample_{b}"]
            # top 5 결과에 대해 반복합니다.
            for val, token_idx in zip(topk_values[b], topk_token_indices[b]):
                # row에 (토큰 인덱스, 해당 토큰의 최대 절댓값)을 추가합니다.
                row += [f"token_{token_idx.item()}", f"{val.item():.6f}"]
            writer.writerow(row)
