import matplotlib.pyplot as plt
import seaborn as sns
import os
import numpy as np
import umap.umap_ as umap


def patchtst_map1(attns, model, batch_idx, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars):
    """
    Visualize and save attention maps for all layers, heads, and channels.
    """
    if len(attns[0]) == 4:
        dict_attn = {0 : "local", 1 : "local_to_global", 2 : "global", 3 : "global_to_local"}
    elif len(attns[0]) == 3:
        dict_attn = {0 : "local", 1 : "global", 2 : "local2"}
    elif len(attns[0]) == 2:
        dict_attn = {0 : "local", 1 : "global"}
    elif len(attns[0]) == 1:
        dict_attn = {0 : "local"}
    BN,heads,Q,K = attns[0][0].shape
    if BN == 0:
        Batch = BN
    else:
        Batch = BN//n_vars
    # print(len(attns)) # 3 - e_layers
    # print(len(attns[0])) # 4
    # # print(attns[1])
    # # print(attns[0][0].shape) # [7, 8, 4, 4]
    # print("attns[0][0]: ", attns[0][0].shape)
    # # print(attns[0][1].shape) # [7, 8, 1, 4]
    # print("attns[0][1]: ", attns[0][1].shape)
    # # print(attns[0][2].shape) # [7, 8, 1, 1]
    # if len(attns[0]) == 4:
    #     print("attns[0][2]: ", attns[0][2].shape)
    #     print("attns[0][3]: ", attns[0][3].shape)
    # print(attns[0][3].shape) # [7, 8, 4, 1]
    # print("attns[0][3]: ", attns[0][3].shape)
    
    # for layer_idx, attn in enumerate(attns):
    #     local_attn = []
    #     local_to_global_attn = []
    #     global_attn = []
    #     global_to_local_attn = []
    #     for token_idx, attn_layer in enumerate(attn):
    #         # 각 attention layer의 shape에 따라 다르게 처리
            
    #         if len(attn_layer.shape) == 4:
    #             n_vars, heads, P1, P2 = attn_layer.shape
    #             batch_size = 1
                
    #             # attention map을 batch_size x n_vars x heads x P1 x P2 형태로 reshape
    #             attn_layer = attn_layer.view(batch_size, n_vars, heads, P1, P2)
    #             if token_idx == 0:
    #                 local_attn.append(attn_layer.detach().cpu())
    #             elif token_idx == 1:
    #                 local_to_global_attn.append(attn_layer.detach().cpu())
    #             elif token_idx == 2:
    #                 global_attn.append(attn_layer.detach().cpu())
    #             elif token_idx == 3:
    #                 global_to_local_attn.append(attn_layer.detach().cpu())
    #             for var_idx in range(n_vars):
    #                 for head_idx in range(heads):
    #                     attn_map = attn_layer[batch_idx, var_idx, head_idx].detach().cpu().numpy()# Construct filename
    #                     filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_C{var_idx+1}_attn_Layer{layer_idx}_H{head_idx}_{dict_attn[token_idx]}.png"
    #                     filepath = os.path.join(save_dir, filename)
                        
    #                     # Visualize and save the attention map
    #                     plt.figure(figsize=(6, 5))
    #                     sns.heatmap(attn_map, cmap="viridis")
    #                     plt.title(f"Layer {layer_idx} | Head {head_idx} | Channel {var_idx+1} | {dict_attn[token_idx]}")
    #                     plt.xlabel("Key")
    #                     plt.ylabel("Query")
    #                     plt.tight_layout()
    #                     plt.savefig(filepath)
    #                     plt.close()
    #         else:
    #             print(f"Unexpected attention shape: {attn_layer.shape}")
        
    #     for all_attn in [local_attn, local_to_global_attn, global_attn, global_to_local_attn]:
    #         all_attn = np.array(all_attn)
    #         avg_attn = np.mean(all_attn, axis=(0,1,2,3))
    #         print(avg_attn.shape)
    #         plt.figure(figsize=(6, 5))
    #         sns.heatmap(avg_attn, cmap="viridis")
    #         plt.title(f"Average Attention Map {dict_attn[token_idx]}")
    #         plt.xlabel("Key")
    #         plt.ylabel("Query")
    #         plt.tight_layout()
    #         filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_avg_attn_{dict_attn[token_idx]}.png"
    #         filepath = os.path.join(save_dir, filename)
    #         plt.savefig(filepath)
    #         plt.close()
    for layer_idx, attn in enumerate(attns):
        local_attn = []
        local_to_global_attn = []
        global_attn = []
        global_to_local_attn = []
        for token_idx, attn_layer in enumerate(attn):
            BN, heads, Q, K = attn_layer.shape
            Batch = BN//n_vars
            if Batch == 0:
                Batch = BN
            else:
                attn_layer = attn_layer.view(Batch, n_vars, heads, Q, K)
                attn_layer = attn_layer[0,:,:,:,:]
            # attn_layer = attn_layer.mean(axis=0)
            
            


            for var_idx in range(n_vars):
                
                for head_idx in range(heads):
                    attn_map = attn_layer[var_idx,head_idx,:,:].detach().cpu().numpy()# Construct filename
                    # print(attn_map.shape)
                    filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_C{var_idx+1}_attn_Layer{layer_idx}_H{head_idx}_{dict_attn[token_idx]}.png"
                    new_folder = os.path.join(save_dir, dict_attn[token_idx])
                    if not os.path.exists(new_folder):
                        os.makedirs(new_folder)
                    filepath = os.path.join(new_folder, filename)
                    
                    # Visualize and save the attention map
                    plt.figure(figsize=(6, 5))
                    sns.heatmap(attn_map, cmap="viridis")
                    plt.title(f"Layer {layer_idx} | Head {head_idx} | Channel {var_idx+1} | {dict_attn[token_idx]}")
                    plt.xlabel("Key")
                    plt.ylabel("Query")
                    plt.tight_layout()
                    plt.savefig(filepath)
                    plt.close()

                new_folder = os.path.join(save_dir, dict_attn[token_idx])
                if not os.path.exists(new_folder):
                    os.makedirs(new_folder)
                attn_map = attn_layer[var_idx,:,:,:].mean(axis=0).detach().cpu().numpy()
                if token_idx == 0:
                    local_attn.append(attn_map)
                elif token_idx == 1:
                    local_to_global_attn.append(attn_map)
                elif token_idx == 2:
                    global_attn.append(attn_map)
                elif token_idx == 3:
                    global_to_local_attn.append(attn_map)
                plt.figure(figsize=(6, 5))
                sns.heatmap(attn_map, cmap="viridis")
                plt.title(f"Average Attention Map | Channel {var_idx+1} | Layer {layer_idx} | {dict_attn[token_idx]}")
                plt.xlabel("Key")
                plt.ylabel("Query")
                plt.tight_layout()
                filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_C{var_idx+1}_attn_Layer{layer_idx}_avg_attn_{dict_attn[token_idx]}.png"
                filepath = os.path.join(new_folder, filename)
                plt.savefig(filepath)
                plt.close()

        for attn_idx, attn_type in enumerate([local_attn, local_to_global_attn, global_attn, global_to_local_attn]):
            if not attn_type:
                continue
            new_folder = os.path.join(save_dir, dict_attn[attn_idx])
            if not os.path.exists(new_folder):
                os.makedirs(new_folder)
            attn_map = np.mean(attn_type, axis=0)
            plt.figure(figsize=(6, 5))
            sns.heatmap(attn_map, cmap="viridis")
            plt.title(f"Channel Average Attention Map | Layer {layer_idx} | {dict_attn[attn_idx]}")
            plt.xlabel("Key")
            plt.ylabel("Query")
            plt.tight_layout()
            filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_attn_Layer{layer_idx}_Channel_avg_attn_{dict_attn[attn_idx]}.png"
            filepath = os.path.join(new_folder, filename)
            plt.savefig(filepath)
            plt.close()


def patchtst_map(attns, model, batch_idx, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars):
    """
    Visualize and save attention maps for all layers, heads, and channels.
    """
    if len(attns[0]) == 4:
        dict_attn = {0 : "local", 1 : "local_to_global", 2 : "global", 3 : "global_to_local"}
    elif len(attns[0]) == 3:
        dict_attn = {0 : "local", 1 : "global", 2 : "local2"}
    elif len(attns[0]) == 2:
        dict_attn = {0 : "local", 1 : "global"}
    elif len(attns[0]) == 1:
        dict_attn = {0 : "local"}
    init_nvars = n_vars
    BN,heads,Q,K = attns[0][0].shape
    if BN == init_nvars:
        Batch = 1
    else:
        Batch = BN//init_nvars
    
    for layer_idx, attn in enumerate(attns):
        local_attn = []
        local_to_global_attn = []
        global_attn = []
        global_to_local_attn = []
        for token_idx, attn_layer in enumerate(attn):
            BN, heads, Q, K = attn_layer.shape
            if Batch != 1 and attn_layer.shape[0] != Batch:
                attn_layer = attn_layer.view(Batch, init_nvars, heads, Q, K)
                attn_layer = attn_layer[-1,:,:,:,:]
            elif Batch != 1 and attn_layer.shape[0] == Batch:
                attn_layer = attn_layer[-1,:,:,:].unsqueeze(0)
            
            n_vars, heads, Q, K = attn_layer.shape
            # attn_layer = attn_layer.mean(axis=0)

            for var_idx in range(n_vars):
                
                for head_idx in range(heads):
                    attn_map = attn_layer[var_idx,head_idx,:,:].detach().cpu().numpy()# Construct filename
                    # print(attn_map.shape)
                    filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_C{var_idx+1}_attn_Layer{layer_idx}_H{head_idx}_{dict_attn[token_idx]}.png"
                    new_folder = os.path.join(save_dir, dict_attn[token_idx])
                    if not os.path.exists(new_folder):
                        os.makedirs(new_folder)
                    filepath = os.path.join(new_folder, filename)
                    
                    # Visualize and save the attention map
                    plt.figure(figsize=(6, 5))
                    sns.heatmap(attn_map, cmap="viridis")
                    plt.title(f"Layer {layer_idx} | Head {head_idx} | Channel {var_idx+1} | {dict_attn[token_idx]}")
                    plt.xlabel("Key")
                    plt.ylabel("Query")
                    plt.tight_layout()
                    plt.savefig(filepath)
                    plt.close()

                new_folder = os.path.join(save_dir, dict_attn[token_idx])
                if not os.path.exists(new_folder):
                    os.makedirs(new_folder)
                attn_map = attn_layer[var_idx,:,:,:].mean(axis=0).detach().cpu().numpy()
                if token_idx == 0:
                    local_attn.append(attn_map)
                elif token_idx == 1:
                    local_to_global_attn.append(attn_map)
                elif token_idx == 2:
                    global_attn.append(attn_map)
                elif token_idx == 3:
                    global_to_local_attn.append(attn_map)
                plt.figure(figsize=(6, 5))
                sns.heatmap(attn_map, cmap="viridis")
                plt.title(f"Average Attention Map | Channel {var_idx+1} | Layer {layer_idx} | {dict_attn[token_idx]}")
                plt.xlabel("Key")
                plt.ylabel("Query")
                plt.tight_layout()
                filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_C{var_idx+1}_attn_Layer{layer_idx}_avg_attn_{dict_attn[token_idx]}.png"
                filepath = os.path.join(new_folder, filename)
                plt.savefig(filepath)
                plt.close()

        for attn_idx, attn_type in enumerate([local_attn, local_to_global_attn, global_attn, global_to_local_attn]):
            if not attn_type:
                continue
            new_folder = os.path.join(save_dir, dict_attn[attn_idx])
            if not os.path.exists(new_folder):
                os.makedirs(new_folder)
            attn_map = np.mean(attn_type, axis=0)
            plt.figure(figsize=(6, 5))
            sns.heatmap(attn_map, cmap="viridis")
            plt.title(f"Channel Average Attention Map | Layer {layer_idx} | {dict_attn[attn_idx]}")
            plt.xlabel("Key")
            plt.ylabel("Query")
            plt.tight_layout()
            filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_attn_Layer{layer_idx}_Channel_avg_attn_{dict_attn[attn_idx]}.png"
            filepath = os.path.join(new_folder, filename)
            plt.savefig(filepath)
            plt.close()



def new_attention_map1(attns, model, batch_idx, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars, num_global_tokens):
    """
    Visualize full attention maps (Local + Global together) for Channel-Wise Patch model.
    """
    os.makedirs(save_dir, exist_ok=True)

    for layer_idx, layer_attns in enumerate(attns):
        for attn_idx, attn in enumerate(layer_attns):
            if attn is None:
                continue

            attn_tensor = attn  # [n_heads, query_len, key_len] or [query_len, key_len]

            if attn_tensor.dim() == 3:
                attn_avg = attn_tensor.mean(0).detach().cpu().numpy()
            elif attn_tensor.dim() == 2:
                attn_avg = attn_tensor.detach().cpu().numpy()
            else:
                raise ValueError(f"Unexpected attention shape: {attn_tensor.shape}")

            query_len, key_len = attn_avg.shape

            glb_start_idx = query_len - num_global_tokens

            plt.figure(figsize=(8, 6))
            plt.imshow(attn_avg, aspect='auto', cmap='viridis')
            plt.colorbar()
            plt.title(f"Attention Map (Layer {layer_idx} - Channel {attn_idx})")
            plt.xlabel("Key Position")
            plt.ylabel("Query Position")

            # Mark Local-Global boundary
            plt.axvline(x=glb_start_idx - 0.5, color='red', linestyle='--')
            plt.axhline(y=glb_start_idx - 0.5, color='red', linestyle='--')

            filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_attn_Layer{layer_idx}_Channel{attn_idx}.png"
            filepath = os.path.join(save_dir, filename)
            plt.savefig(filepath, bbox_inches='tight')
            plt.close()

            print(f"Saved attention map for Layer {layer_idx}, Channel {attn_idx} to {filepath}")

def global_token_attention_map(attns, model, batch_idx, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars, num_global_tokens):
    """
    Visualize attention maps focusing only on queries from global tokens.
    """
    os.makedirs(save_dir, exist_ok=True)
    init_nvars = n_vars
    # BN,heads,Q,K = attns[0][0].shape
    # if BN == init_nvars:
    #     Batch = 1
    # else:
    #     Batch = BN//init_nvars
    # print("attns[0][0].shape: ", attns[0][0].shape)
    for layer_idx, layer_attns in enumerate(attns):
        # for attn_idx, attn in enumerate(layer_attns):
        attn = layer_attns[-1]
            
        
        
        if attn is None:
            continue

        attn_batch = attn[batch_idx]  # [n_heads, query_len, key_len] or [query_len, key_len]

        if attn_batch.dim() == 3:
            attn_avg = attn_batch.mean(0).detach().cpu().numpy()
        elif attn_batch.dim() == 2:
            attn_avg = attn_batch.detach().cpu().numpy()
        else:
            raise ValueError(f"Unexpected attention shape: {attn_batch.shape}")

        query_len, key_len = attn_avg.shape

        glb_start_idx = query_len - num_global_tokens

        global_queries = attn_avg[glb_start_idx:, :]  # [num_global_tokens, key_len]

        plt.figure(figsize=(8, 2))
        plt.imshow(global_queries, aspect='auto', cmap='plasma')
        plt.colorbar()
        plt.title(f"Global Token Query Attention (Layer {layer_idx})")
        plt.xlabel("Key Position")
        plt.ylabel("Global Query Position")

        filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_attn_Layer{layer_idx}_GlobalOnlyAttention.png"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath, bbox_inches='tight')
        plt.close()

        print(f"Saved GLOBAL-ONLY attention map to {filepath}")



def new_attention_map(attns, model, batch_idx, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars, num_global_tokens):
    """
    attns: attention scores list (len = n_layers), each tensor (B, H, L, S)
    """
    os.makedirs(save_dir, exist_ok=True)
    

    for layer_idx, layer_attns in enumerate(attns):
        # for attn_idx, attn in enumerate(layer_attns):
        # attn = layer_attns[-1]
            
        # attn = attns[layer_idx]  # (B, H, L, S)
        # B, H, L_q, L_k = attn.shape
        
        # # mean over batch and heads
        # attn_avg = attn.mean(dim=0).mean(dim=0)  # (L, S)
        attn = layer_attns
        # print("attn.shape: ", attn.shape)
        # print("len(attn): ", len(attn))
        # print("attn.shape: ", attn.shape)
        B, H, L_q, L_k = attn.shape
        attn = attn[-1, :, :, :]  # (B, H, L, S)
        
        # print("attn.shape: ", attn.shape)

        # mean over batch and heads
        attn_avg = attn.mean(dim=0)  # (L, S)

        # Plot
        plt.figure(figsize=(8, 8))
        plt.imshow(attn_avg.detach().cpu().numpy(), cmap="viridis", interpolation="nearest")
        plt.colorbar()

        # Patch 및 Global token 구분
        patches_per_var = (L_q // n_vars) - num_global_tokens

        # 변수별로 선 긋기
        for var_idx in range(1, n_vars):
            # Local/Global 전체 token 중에서 이 variable이 어디 끝나는지 계산
            end_idx = var_idx * (patches_per_var + num_global_tokens)

            # 가로/세로 선
            plt.axhline(y=end_idx-0.5, color="red", linestyle="--", linewidth=0.8)
            plt.axvline(x=end_idx-0.5, color="red", linestyle="--", linewidth=0.8)

        # Local/Global 경계선 그리기
        for var_idx in range(n_vars):
            local_end = var_idx * (patches_per_var + num_global_tokens) + patches_per_var
            if local_end < L_q:
                plt.axhline(y=local_end-0.5, color="cyan", linestyle=":", linewidth=1.2)
                plt.axvline(x=local_end-0.5, color="cyan", linestyle=":", linewidth=1.2)
        plt.title(f"Attention Map (Layer {layer_idx}, pred_len: {pred_len}, num_global_tokens: {num_global_tokens})")
        plt.tight_layout()
        filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_attn_Layer{layer_idx}_n{n_vars}vars_{num_global_tokens}global.png"
        filepath = os.path.join(save_dir, filename)
        plt.savefig(filepath)
        plt.close()

        print(f"Attention map saved at: {filepath}")

    
def umap_visual(enc_out, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars, num_global_tokens):
    """
    enc_out: (B, n_vars, d_model, num_token)
    """
    enc_out = enc_out.detach().cpu().numpy()
    B, n_vars, d_model, num_token = enc_out.shape
    num_local_token = num_token - num_global_tokens
    local_tokens = enc_out[:, :, :, :-num_global_tokens].transpose(0, 1, 3, 2).reshape(-1, d_model)
    global_tokens = enc_out[:, :, :, -num_global_tokens:].reshape(-1, d_model)
    all_tokens = np.concatenate([local_tokens, global_tokens], axis=0)
    local_labels = []
    for b in range(B):
        for v in range(n_vars):
            for _ in range(num_local_token):
                local_labels.append(f'local_{v}')
    global_labels = []
    for b in range(B):
        for v in range(n_vars):
            for _ in range(num_global_tokens):
                global_labels.append(f'global_{v}')

    labels = np.array(local_labels + global_labels)

    # 4. UMAP
    for n_neighbors in [10, 20, 30, 40, 50]:
        for min_dist in [0.1, 0.2, 0.3, 0.4, 0.5]:
            reducer = umap.UMAP(n_neighbors=n_neighbors, min_dist=min_dist, metric='euclidean', random_state=42)
            embedding = reducer.fit_transform(all_tokens)
            channel_ids = sorted(set(int(lbl.split('_')[-1]) for lbl in labels))
            palette = sns.color_palette('hsv', n_colors=len(channel_ids))


    # 5. Plot
            plt.figure(figsize=(20, 10))
            for v in channel_ids:
                color = palette[v]
                local_idx = labels == f'local_{v}'
                global_idx = labels == f'global_{v}'

                plt.scatter(embedding[local_idx, 0], embedding[local_idx, 1],
                            label=f'local_{v}', marker='o', s=120, color=color, alpha=0.6)
                plt.scatter(embedding[global_idx, 0], embedding[global_idx, 1],
                            label=f'global_{v}', marker='^', s=140, color=color, edgecolors='black')

            plt.title(f'UMAP of Local (o) and Global (^) Tokens by Channel / n_neighbors: {n_neighbors}, min_dist: {min_dist}')
            plt.legend(loc='upper left', fontsize='small')
            filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_umap_n{n_neighbors}_d{min_dist}.png"
            filepath = os.path.join(save_dir, filename)
            plt.savefig(filepath)
            plt.close()

            print(f"UMAP saved at: {filepath}")

def mask_visual(mask, save_dir, slurm_id, model_name, seq_len, pred_len, n_vars, num_global_tokens):
    """
    mask: (B, H, L, S)
    """
    B, H, L_q, L_k = mask.shape
    mask = mask.detach().cpu().numpy()
    mask = mask[0, :, :, :]
    mask = mask.mean(axis=0)
    
    plt.figure(figsize=(8, 8))
    # extent로 픽셀 중심 위치를 좌표축에 정확히 맞춤
    plt.imshow(mask, cmap="Reds", interpolation='none', origin='upper', extent=[0, L_k, L_q, 0])
    plt.colorbar()
    plt.xlabel("Key")
    plt.ylabel("Query")
    plt.title(f"Attention Mask (seq_len: {seq_len}, pred_len: {pred_len}, n_vars: {n_vars}, num_global_tokens: {num_global_tokens})")

    patches_per_var = (L_q // n_vars) - num_global_tokens

    # 변수별 경계선 (파란 점선)
    for var_idx in range(1, n_vars):
        end_idx = var_idx * (patches_per_var + num_global_tokens)
        plt.axhline(y=end_idx, color="blue", linestyle="--", linewidth=0.8)
        plt.axvline(x=end_idx, color="blue", linestyle="--", linewidth=0.8)

    # Local/Global 경계선 (흰 실선)
    for var_idx in range(n_vars):
        local_end = var_idx * (patches_per_var + num_global_tokens) + patches_per_var
        if local_end < L_q:
            plt.axhline(y=local_end, color="white", linestyle=":", linewidth=1.2)
            plt.axvline(x=local_end, color="white", linestyle=":", linewidth=1.2)

    plt.tight_layout()
    filename = f"{slurm_id}_{model_name}_L{seq_len}_P{pred_len}_{n_vars}vars_{num_global_tokens}global_mask.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath)
    plt.close()
    print(f"Mask saved at: {filepath}")
