import os
import warnings

warnings.filterwarnings('ignore', category=DeprecationWarning)

import os

os.environ['MKL_SERVICE_FORCE_INTEL'] = '1'
os.environ['MUJOCO_GL'] = 'egl'

from pathlib import Path

import hydra
import numpy as np
import torch
from dm_env import specs

import dmc
import utils
from logger import Logger
from replay_buffer import make_replay_loader
from video import VideoRecorder
import wandb
import omegaconf
import yaml
from agent import mdp
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import head_view
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import torch
import torch.nn as nn

from transformers import BertTokenizer, BertForQuestionAnswering, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients, LayerActivation
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer
import utils
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import matplotlib.backends.backend_pdf
import matplotlib.colors as mcolors


torch.cuda.empty_cache()         

cmap_segments = mcolors.LinearSegmentedColormap.from_list(
    'my_cmap', 
    [(0.0, '#FFFFFF'), (0.1, '#FFD700'), (0.3,"#FF8C00"),(0.5,'#FF0000'),(1.0,	'#8B0000')]
)

if torch.__version__ >= '1.7.0':
    norm_fn = torch.linalg.norm
else:
    norm_fn = torch.norm 

mdp_times=["2023.05.02/204940_mdp", "2023.05.08/195155_mdp","2023.05.05/174221_mdp","2023.05.09/212024_mdp","2023.05.09/212127_mdp","2023.05.10/235155_mdp","2023.05.10/235255_mdp","2023.05.12/000716_mdp","2023.05.12/000816_mdp"]
mask_types=["random" ,"AutoCurr2_mixed","Mixed_masking","random_fixed_seq_masking2-5","random_fixed_seq_masking2-10","random_fixed_seq_masking2-15","random_fixed_seq_masking2-20","MixedProg_4p","MixedInv_4p"]
domain ='jaco'
dataset_dir =Path("/data/your_entity_hdd/currmask_data/jaco/expert/reach_top_left/092834_ddpg_jaco_reach_top_left_0.5_states_7/buffer")
'''
mdp_times=["2023.05.11/095134_mdp","2023.05.11/095334_mdp","2023.05.11/095234_mdp","2023.05.15/003722_mdp","2023.05.15/003822_mdp","2023.05.15/003922_mdp","2023.05.15/004022_mdp","2023.05.14/235908_mdp","2023.05.15/000006_mdp"]
mask_types=["random" ,"AutoCurr2_mixed","Mixed_masking","random_fixed_seq_masking2-5","random_fixed_seq_masking2-10","random_fixed_seq_masking2-15","random_fixed_seq_masking2-20","MixedProg_4p","MixedInv_4p"]
domain ='walker'
dataset_dir = Path('/data/your_entity_hdd/currmask_data/walker/expert/run/125347_ddpg_walker_run_0.5_states_7/buffer')
'''
eval_mask_lens = [16+(128-32),128-8,128-2]
eval_mask_types  = ['action_masking','prompt_masking','goal_masking']
num_of_traj = 10

def visualize_token2token_scores(Scores_mat, decoder_tokens, x_label_name='Head', y_label_name='Layer', figname='none'):
    n_layers, token_len, token_len = Scores_mat.shape
    n_head = 1
    
    for layer in range(n_layers):
        scores_mat = Scores_mat[layer].reshape(n_head, token_len, token_len)
        
        for idx in range(scores_mat.shape[0]):  # num of heads
            scores = scores_mat[idx]
            
            fig = plt.figure(figsize=(10, 10)) 
            ax = fig.add_subplot(111)
            im = ax.imshow(scores, cmap=cmap_segments)
            rect = patches.Rectangle((5, 5), 4, 26, linewidth=2, edgecolor='k', facecolor='none')
            ax.add_patch(rect)

            ax.set_xlabel('{} {}'.format(x_label_name, idx+1))
            ax.set_xticks(range(len(decoder_tokens)))
            ax.set_xticklabels(decoder_tokens, rotation=90)
            ax.set_ylabel('{} {}'.format(y_label_name, layer+1))
            ax.set_yticks(range(len(decoder_tokens)))
            ax.set_yticklabels(decoder_tokens)
            fig.colorbar(im, fraction=0.046, pad=0.04)
            
            fig.subplots_adjust(wspace=0.4, hspace=0.4)  
            
            pdf_path = '/home/your_entity/currmask/new_currmask/currmask_public/final_maps/'+domain+'/{}_layer{}_head{}.pdf'.format(figname, layer+1, idx+1)
            
            with matplotlib.backends.backend_pdf.PdfPages(pdf_path) as pdf:
                pdf.savefig(fig)
            
            plt.close(fig)

for eval_idx in range(len(eval_mask_types)):
    eval_mask_type = eval_mask_types[eval_idx]
    eval_mask_len = eval_mask_lens[eval_idx]
    for idx in range(len(mdp_times)):
        device=torch.device('cuda')
        snapshot_dir = Path('/home/your_entity/currmask/new_currmask/currmask_public/output/final_mt/'+mdp_times[idx]+'/snapshot/'+domain+'/1/snapshot_300000.pt')
        config_dir = Path('/home/your_entity/currmask/new_currmask/currmask_public/output/final_mt/'+mdp_times[idx]+'/.hydra/config.yaml')
        
        replay_dir = dataset_dir.resolve() 
        print('loading existing model...')

        payload = torch.load(snapshot_dir)
        with open(config_dir, "r") as f:
            cfg = yaml.load(f, Loader=yaml.FullLoader)
        cfg = omegaconf.OmegaConf.create(cfg)
        env = dmc.make(cfg['task'], seed=cfg['seed'])
        print("task:",cfg['task'])
        #load pretrained agent
        agent = mdp.MaskedDPAgent(name=cfg['agent']['name'],
                            obs_shape=env.observation_spec().shape,
                            action_shape=env.action_spec().shape,
                            new_mask_ratio=cfg['mask_ratio'],
                            mask_type=cfg['mask_type'],
                            mask_len=cfg['mask_len'],
                            device=torch.device('cuda'),
                            lr=cfg['agent']['lr'],
                            batch_size=cfg['agent']['batch_size'],
                            use_tb=cfg['agent']['use_tb'],
                            mask_ratio=cfg['mask_ratio'],
                            transformer_cfg=cfg['agent']['transformer_cfg']
                            )

        agent.model.load_state_dict(payload['model'])

        
        train_loader = make_replay_loader(env, replay_dir, cfg.replay_buffer_size,
                                            cfg.batch_size,
                                            cfg.replay_buffer_num_workers,
                                            cfg.discount,
                                            'walker',
                                            cfg.agent.transformer_cfg.traj_length,
                                            relabel=False)
        train_iter = iter(train_loader)
        batch = next(train_iter)

        #start to evaluate
        states, actions, _, _, _, _ = utils.to_torch(batch, device)
        latent, mask, ids_restore = agent.model.forward_encoder(states, actions, mask_ratio=0.1,mask_type=eval_mask_type,mask_len=eval_mask_len,current_step=1,total_step=1)
        #print("mask;",mask)
        pred_s, pred_a = agent.model.forward_decoder(latent, ids_restore)  # [N, L, p*p*3]


        batch_size,n_head,dec_seq_len,dec_seq_len = agent.model.decoder_blocks[0].attn.att_map.shape
        final_decoder_attn_blk = torch.zeros([2,32,32])
        for traj in range(num_of_traj):
            decoder_attention = [blk.attn.att_map[traj,:,:,:].reshape(1,n_head,dec_seq_len,dec_seq_len).to('cpu').detach().numpy() for blk in agent.model.decoder_blocks]
            
            decoder_tokens = agent.model.decoder_input[0]


            final_decoder_tokens = ['tmp' for i in range(decoder_tokens.shape[0])]

            for i in range(decoder_tokens.shape[0]):
                #print("int(mask[0][i].item()):",int(mask[0][i].item()))
                if int(mask[0][i].item())==1:
                    final_decoder_tokens[i]='*' #[MASK]
                else:
                    if i%2==0:
                        final_decoder_tokens[i] = 's'+str(int((i+2)/2))
                    else:
                        final_decoder_tokens[i] = 'a'+str(int((i+1)/2))
            #print('final_decoder_tokens:',final_decoder_tokens)
            figname=str(mask_types[idx]+'_'+eval_mask_type+'_'+domain) 
            decoder_attention = np.array(decoder_attention).reshape(2,4,128,128)
            decoder_attn_blk = decoder_attention[:,:,0:32,0:32].reshape(2,4,32,32)
            final_decoder_attn_blk += norm_fn(torch.Tensor(decoder_attn_blk),dim=1).reshape(2,32,32)
            
        final_decoder_attn_blk /= num_of_traj
        attn_name = str('/home/your_entity/currmask/new_currmask/currmask_public/final_maps/'+domain+'/'+mask_types[idx]+'_'+eval_mask_type+'_'+domain+'.pth')
        torch.save(final_decoder_attn_blk[0:32], attn_name)
        #print(final_decoder_attn_blk[0:32].shape)
        #print(final_decoder_tokens[0:32])
        token_path=str('/home/your_entity/currmask/new_currmask/currmask_public/final_maps/'+eval_mask_type+'.txt')
        with open(token_path, 'w') as file:
            for item in final_decoder_tokens[0:32]:
                file.write(item + '\n')
        visualize_token2token_scores(final_decoder_attn_blk,decoder_tokens=final_decoder_tokens[0:32], x_label_name='Head',figname=figname)

