import pickle
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt

def save_saliency(attentions: list, model_type: str, save_path: Path):
    episode_state_t_token_attention = []
    for step, step_attentions in enumerate(attentions):
        if step % 50 == 0 or step == 19:
            first_attention = step_attentions[0]
            state_t_token_attention = first_attention[0, 0, -2]
            episode_state_t_token_attention.append(state_t_token_attention)
            if model_type == "dt":
                attribute_map = state_t_token_attention.reshape(-1, 3).transpose(1, 0)
            else:
                attribute_map = state_t_token_attention.reshape(-1, 1).transpose(1, 0)
            step_save_path = save_path / f"step_{step}.pdf"
            save_saliency_map(attribute_map, model_type, step_save_path)
    
    max_length = max(len(step_attention) for step_attention in episode_state_t_token_attention)
    episode_state_t_token_attention = [
        np.concatenate([np.zeros(max_length - len(step_attention)), step_attention])
        for step_attention in episode_state_t_token_attention
    ]
    
    average_attention = np.mean(np.stack(episode_state_t_token_attention), axis=0)
    if model_type == "dt":
        attribute_map = average_attention.reshape(-1, 3).transpose(1, 0)
    else:
        attribute_map = average_attention.reshape(-1, 1).transpose(1, 0)
    save_saliency_map(attribute_map, model_type, save_path / "average.pdf")
    

def save_saliency_map(attribute_map: np.ndarray, model_type: str, save_path: Path):
    if model_type == "dt":
        ylabels = ['R', 's', 'a']
    else:
        ylabels = ['R']
    fig, ax = plt.subplots(figsize=(6, 2))
    plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.2)
    cax = ax.imshow(attribute_map, cmap='Reds')
    ax.set_xticks(range(attribute_map.shape[1]))
    ax.set_xticklabels([f"t - {attribute_map.shape[1]-t-1}" if t != attribute_map.shape[1] - 1 else "t" for t in range(attribute_map.shape[1])], 
                       rotation=45, ha='right')
    ax.set_yticks(list(range(attribute_map.shape[0])))
    ax.set_yticklabels(ylabels[:attribute_map.shape[0]])
    fig.colorbar(cax, ax=ax, orientation='horizontal', pad=0.4, fraction=0.04, aspect=30)
    plt.savefig(str(save_path), dpi=150)
    with open(save_path.with_suffix('.pkl'), 'wb') as f:
        pickle.dump(attribute_map, f)
