###############
#   Package   #
###############
import os
import time
import math
import argparse
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import List, Dict, Tuple, Optional
from torch import Tensor

###########################
#   Package from myself   #
###########################
from datasets.ehr_dataset import *
from utils.util import *

#################
#   Functions   #
#################
def GetDevice(device_id: int = 0) -> torch.device:
    torch.cuda.set_device(device_id)
    use_gpu = torch.cuda.is_available()
    device = torch.device('cuda' if use_gpu else 'cpu')
    return device

#############
#   Class   #
#############
class Visualizer():
    def __init__(self,
                 model: torch.nn.Module,
                 testing_dataloader: torch.utils.data.DataLoader,
                 pid_label_df: pd.DataFrame,
                 resume_model_path: str = None,
                 plot_save_path: str = '',
                 device: torch.device = torch.device("cpu"),
                 *args,
                 **kwargs,
                ):
        # check if the resumed model (.pth) file exists.
        assert os.path.isfile(resume_model_path), 'resumed model file error.'
        assert os.path.isdir(plot_save_path), 'plot saving dictionary does not exist.'

        # define the variables of the tester
        self.model = model
        self.model.load_state_dict(torch.load(resume_model_path))
        self.testing_dataloader = testing_dataloader
        self.pid_label_df = pid_label_df
        
        self.plot_save_path = plot_save_path

        self.device = device

        NUM_COLS = ["AFP", "ALB", "ALP", "ALT", "AST", "BUN", "CRE", "D_BIL", "GGT", "GlucoseAC", "HB",
                    "HBVDNA", "HCVRNA", "HbA1c", "Lym", "Na", "PLT", "PT", "PT_INR", "Seg", "T_BIL", "TP", "WBC"
                   ]
        NUM_COLS += ["HEIGHT", "WEIGHT", "fatty_liver", "parenchymal_liver_disease", "age", "hosp_days", "seg_entry_cnt"]
        # categorical value
        CAT_COLS = ["Anti_HBc", "Anti_HBe", "Anti_HBs", "Anti_HCV", "HBeAg", "HBsAg"]
        CAT_COLS += ["sex", "sono"]
        self.feature = NUM_COLS + CAT_COLS

        self.feature_152 = [i + '_1' for i in self.feature]
        self.feature_152 += [i + '_2' for i in self.feature]
        self.feature_152 += [i + '_3' for i in self.feature]
        self.feature_152 += [i + '_4' for i in self.feature]

    @staticmethod
    def RolloutAttention(attn_maps: Tensor, device: torch.device = torch.device("cpu")) -> Tensor:
        assert (attn_maps.dim() == 4), ValueError("the number of the dimension of the input attention maps should be 4.")
        assert (attn_maps.size(-2) == attn_maps.size(-1)), ValueError("the number of the last two dimension should be equal.")
        with torch.no_grad():
            rolled_maps = torch.FloatTensor().to(device)
            identity_matrix = torch.eye(attn_maps.size(-1)).float().to(device)
            attn_maps = attn_maps.to(device)

            for patient in range(attn_maps.shape[1]):
                sub_rolled_maps = torch.FloatTensor().to(device)
                #sub_rolled_maps = torch.cat((sub_rolled_maps, attn_maps[0, patient, :, :].unsqueeze(0)))
                A = torch.eye(attn_maps.size(-1)).float().to(device) #0.5 * attn_maps[0, patient, :, :] + 0.5 * identity_matrix
    
                for layer in range(0, attn_maps.shape[0]):
                    A = torch.matmul((0.5 * attn_maps[layer, patient, :, :] + 0.5 * identity_matrix), A)
                    sub_rolled_maps = torch.cat((sub_rolled_maps, A.unsqueeze(0)), dim=0)
            
                rolled_maps = torch.cat((rolled_maps, sub_rolled_maps.unsqueeze(0)), dim=0)

            rolled_maps = rolled_maps.cpu()
            torch.cuda.empty_cache()
        
        return rolled_maps

    @staticmethod
    def MaskedRolloutAttention(attn_maps: Tensor, masks: Tensor, device: torch.device = torch.device("cpu")) -> Tensor:
        def _masked_raw_attn(attn_map: Tensor, mask: Tensor) -> Tensor:
            """
                Apply this function on the first layer.

                Input:
                    attn_map: (seq_len, seq_len)
                    mask: (summarization times, number of features)
            """
            flatten_mask = torch.flatten(mask)
            identity = torch.eye(attn_map.size(-1)).to(device)
            masked_identity = flatten_mask.unsqueeze(0) * identity
            masked_raw_attn = torch.renorm((masked_identity + attn_map), p=1, dim=0, maxnorm=1)
            return masked_raw_attn

        assert (attn_maps.dim() == 4), ValueError("the number of the dimension of the input attention maps should be 4.")
        assert (attn_maps.size(-2) == attn_maps.size(-1)), ValueError("the number of the last two dimension should be equal.")
        with torch.no_grad():
            rolled_maps = torch.FloatTensor().to(device)
            identity_matrix = torch.eye(attn_maps.size(-1)).float().to(device)
            attn_maps = attn_maps.to(device)
            masks = masks.to(device)

            for patient in range(attn_maps.shape[1]):
                sub_rolled_maps = torch.FloatTensor().to(device)
                A = _masked_raw_attn(attn_maps[0, patient, :, :], masks[patient])
                sub_rolled_maps = torch.cat((sub_rolled_maps, attn_maps[0, patient, :, :].unsqueeze(0)))

                for layer in range(1, attn_maps.shape[0]):
                    A = torch.matmul((0.5 * attn_maps[layer, patient, :, :] + 0.5 * identity_matrix), A)
                    sub_rolled_maps = torch.cat((sub_rolled_maps, A.unsqueeze(0)), dim=0)

                rolled_maps = torch.cat((rolled_maps, sub_rolled_maps.unsqueeze(0)), dim=0)

            rolled_maps = rolled_maps.cpu()
            torch.cuda.empty_cache()

        return rolled_maps
        
    def _attn_map_plot(self, attn_maps: Tensor, predict_prob: list, masks: Tensor, masked_rollout: bool = False) -> None:
        #plt.rcParams["font.family"] = "Times New Roman"
        def find_second_minimal(array: np.ndarray) -> float:
            sorted_array = np.sort(array, kind="quicksort")
            idx = 0
            while idx < len(sorted_array):
                if sorted_array[idx] == 0:
                    idx += 1
                else: return sorted_array[idx]

        def ranking(array: np.ndarray) -> np.ndarray:
            if len(array.shape) == 1:
                sorted_idx = array.argsort()
                sorted_idx = sorted_idx[::-1] # reverse
                rank = np.zeros_like(sorted_idx)
                for i in range(len(sorted_idx)):
                    rank[sorted_idx[i]] = i+1
                return np.expand_dims(rank, axis=0)
            else:
                for j in range(array.shape[0]):
                    if j == 0:
                        rank = ranking(array[j])
                    else:
                        rank = np.concatenate((rank, ranking(array[j])), axis=0)
                return rank

        rolled_maps = []
        for i in range(int(attn_maps.shape[1] / 1000) + 1):
            if masked_rollout:
                rolled_maps.append(self.MaskedRolloutAttention(attn_maps[:, i * 1000 :(i + 1) * 1000, :, :], masks[i * 1000: (i + 1) * 1000, :, :], self.device))
            else:
                rolled_maps.append(self.RolloutAttention(attn_maps[:, i * 1000 :(i + 1) * 1000, :, :], self.device))
        attn_maps = torch.cat(rolled_maps, dim=0) 

        pid_list = self.pid_label_df['pid'].tolist()
        label = self.pid_label_df['group'].tolist()
        for idx in range(len(pid_list)):
            fig_path = os.path.join(self.plot_save_path, pid_list[idx]+"_{:d}".format(label[idx])+"_{:s}".format(str(predict_prob[idx])[2:8]))
            os.makedirs(fig_path, exist_ok=True)
            sub_attn_map = attn_maps[idx, :, :, :].cpu().numpy()
            tick_pos = np.arange(38 * 4) + 0.5

            '''
            for layer in range(sub_attn_map.shape[0]):
                fig = plt.figure(figsize=(14, 14), dpi=480)
                sns.heatmap(sub_attn_map[layer])
                plt.title('{:s}, layer {:d}'.format(pid_list[idx], layer))
                plt.xticks(ticks=tick_pos, labels=self.feature_152, fontsize=4)
                plt.yticks(ticks=tick_pos, labels=self.feature_152, fontsize=4)
                fig.savefig(os.path.join(fig_path, 'layer_{:d}'.format(layer) + '.png'))
                plt.close()
            '''


            ave_attn_map = np.mean(sub_attn_map[-1, :, :], axis=0)
            global_rank = ranking(ave_attn_map).reshape((4, 38))
            ave_attn_map = ave_attn_map.reshape((4, 38))
            x_tick_pos = np.arange(38) + 0.5
            y_tick_pos = np.arange(4) + 0.5

            rank = np.argsort(np.mean(ranking(ave_attn_map), axis=0))
            ave_attn_map = ave_attn_map[:, rank]
            feature = np.array(self.feature)
            fig = plt.figure(figsize=(8, 6), dpi=480)
            hp = sns.heatmap(ave_attn_map, cmap="gray", center=0.01,  annot=global_rank[:, rank], square=True, robust=True, annot_kws={"fontsize":4}, fmt="g", cbar_kws={"shrink":0.12})
            cbar = hp.collections[0].colorbar
            cbar.ax.tick_params(labelsize=4)
            plt.xlabel("features", fontsize=5, fontweight="bold")
            plt.ylabel("the order of summarization", fontsize=4, fontweight="bold")
            plt.title("Attention of all features, Revised Rollout Attention", fontsize=7, fontweight="bold")
            plt.xticks(x_tick_pos, labels=feature[rank], fontsize=4, rotation=35,  ha='right', rotation_mode="anchor")
            plt.yticks(y_tick_pos, labels=["1", "2", "3", "4"], fontsize=4)
            fig.savefig(os.path.join(fig_path, "all_attn.png"))
            plt.close()

    def _inference_epoch(self) -> Tuple[Tensor]:
        with torch.no_grad():
            self.model.eval()
            self.model.to(self.device)
            outputs = torch.FloatTensor().to(self.device)
            targets = torch.FloatTensor().to(self.device)
            attn_maps = torch.FloatTensor()
            masks = torch.LongTensor()

            for batch_idx, (x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask, target, day_delta) in enumerate(self.testing_dataloader):
                # put all variables to the appropriate device
                mask = torch.cat([x_num_mask, x_cat_mask], dim=2)
                x_num_idx = x_num_idx.to(self.device)
                x_num = x_num.to(self.device)
                x_num_mask = x_num_mask.to(self.device)
                x_cat_idx = x_cat_idx.to(self.device)
                x_cat = x_cat.to(self.device)
                x_cat_mask = x_cat_mask.to(self.device)
                target = target.to(self.device)

                # feed the data to the model and get the output
                output, attn_map = self.model(x_num_idx, x_num, x_num_mask, x_cat_idx, x_cat, x_cat_mask)

                # collect the target and output
                masks = torch.cat((masks, mask))
                outputs = torch.cat((outputs, output))
                targets = torch.cat((targets, target))
                attn_maps = torch.cat((attn_maps, attn_map.cpu()), dim=1)

        return (outputs, targets, attn_maps, masks)

    def plot(self, masked_rollout: bool = False) -> None:
        # get the output and target
        testing_outputs, testing_targets, testing_attn_maps, testing_masks = self._inference_epoch()
        testing_outputs = [testing_outputs.squeeze().cpu().numpy().tolist()] if type(testing_outputs.squeeze().cpu().numpy().tolist()) is float else testing_outputs.squeeze().cpu().numpy().tolist()
        self._attn_map_plot(testing_attn_maps,
                            testing_outputs,
                            testing_masks,
                            masked_rollout,
                           )


if __name__ == '__main__':
    # variables from command
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config', type=str, help='the path of config(.json) file.', required=True)
    parser.add_argument('-r', '--resume', type=str, help='checkpoint path of the model.', required=True)
    parser.add_argument('-p', '--path', type=str, help='the path to save attnetion map', required=True)
    parser.add_argument('-t', '--rollout_type', type=bool, help='the type of rollout attention (masked or not)', default=True)
    command_variable = parser.parse_args()

    # device
    device = GetDevice()

    # get config
    assert os.path.isfile(command_variable.config), print('config file is not exist.')
    config = GetConfigDict(command_variable.config)

    # check plot_save path
    os.makedirs(command_variable.path, exist_ok=True)

    # get dataset
    dataset_module = importlib.import_module(name=config['dataset']['module'])
    dataset_list = getattr(dataset_module, config['dataset']['type'])(data_dir=config['dataset']['data_path'], base_statistic_info_kwargs=config['dataset']['BasicStatisticInfo'], **config['dataset']['test']['GetDataset'])
    testing_dataset = dataset_list[0][0]

    # create dataloader
    testing_dataloader = DataLoader(testing_dataset, **config['dataloader']['test'])

    # create model
    model_module = importlib.import_module(name=config['model']['module'])
    model = getattr(model_module, config['model']['type'])(**config['model']['kwargs'])

    # get pid_label_df
    pid_label_df = pd.read_csv(os.path.join(config['dataset']['data_path'], 'test', 'y.csv'))

    # create tester
    tester = Visualizer(model = model,
                        testing_dataloader = testing_dataloader,
                        pid_label_df = pid_label_df,
                        resume_model_path = command_variable.resume,
                        plot_save_path = command_variable.path,
                        device = device
                        )
    tester.plot(command_variable.rollout_type)
