""" Explanation results

reference:
[CVPR 2024] Comparing the Decision-Making Mechanisms by Transformers and CNNs via Explanation Methods
https://mingqij.github.io/projects/cdmmtc/
[github] Advanced AI explainability for PyTorch
https://github.com/jacobgil/pytorch-grad-cam

Copyright (c) 2025 Anonymous Authors
"""
import os
import contextlib
import cv2 # pip install opencv-python (check requirements.txt)
import time
from copy import deepcopy
import pygraphviz as pgv # pip install pygraphviz (check requirements.txt)
import numpy as np
from functools import partial
from pytorch_grad_cam import (
    GradCAM, FEM, HiResCAM, ScoreCAM, GradCAMPlusPlus,
    AblationCAM, XGradCAM, EigenCAM, EigenGradCAM,
    LayerCAM, FullGrad, GradCAMElementWise, KPCA_CAM, ShapleyCAM,
    FinerCAM
)
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

import torch

from .sag import get_topn_categories_probabilities_pairs, save_perturbation_heatmap, create_minsufexp_gif
from .sag import Get_blurred_img, Integrated_Mask, Deletion_Insertion_Comb_withOverlay, showimage
from .sag import beamSearch_topKSuccessors_roots
from .sag import maximal_overlapThresh_set
from .sag import get_patch_boolean, gridimage, get_conjuncts_set, build_tree


cam_registry = {
    "gradcam": GradCAM,
    "hirescam": HiResCAM,
    # "scorecam": ScoreCAM,
    "gradcam++": GradCAMPlusPlus,
    # "ablationcam": AblationCAM,
    "xgradcam": XGradCAM,
    "eigencam": EigenCAM,
    "eigengradcam": EigenGradCAM,
    "layercam": LayerCAM,
    # "fullgrad": FullGrad,
    "fem": FEM,
    "gradcamelementwise": GradCAMElementWise,
    'kpcacam': KPCA_CAM,
    # 'shapleycam': ShapleyCAM,
    # 'finercam': FinerCAM
}


def get_explanation_by_perturbation(perturbation_base_model_explanation,
                                    input_img,
                                    img_label,
                                    model,
                                    amp_autocast,
                                    reduce_factor,
                                    ups,
                                    prob_thresh,
                                    numCategories,
                                    node_prob_thresh,
                                    beam_width,
                                    max_num_roots,
                                    overlap_thresh,
                                    numSuccessors,
                                    num_roots_sag,
                                    maxRootSize,
                                    output_path,
                                    batch_idx):
    
    if not perturbation_base_model_explanation:
        return 

    if output_path[-1] != '/':
        output_path = output_path + '/'

    if input_img.device.type == 'cuda':
        use_cuda = 1
    else:
        use_cuda = 0

    images_no_roots_found = 0

    # get low probability blurred image
    img, blurred_img = Get_blurred_img(
        input_img,
        img_label,
        model,
        resize_shape=(224, 224),
        Gaussian_param=[51, 50],
        Median_param=11,
        blur_type='Black',
        use_cuda=use_cuda
    )

    # get top "numCategories" predicted categories with their probabilities
    top_cp = get_topn_categories_probabilities_pairs(img, model, numCategories, amp_autocast, reduce_factor, use_cuda=use_cuda)

    for category, probability in top_cp:
        # get perturbation mask
        start = time.time()
        mask, upsampled_mask = Integrated_Mask(
            ups,
            img,
            blurred_img,
            model,
            category,
            amp_autocast,
            reduce_factor, 
            max_iterations=2,
            integ_iter=20,
            tv_beta=2,
            l1_coeff=0.01 * 100,
            tv_coeff=0.2 * 100,
            size_init=28,
            use_cuda=use_cuda
        )

        # get all DISTINCT roots found via beam search
        roots_mp = beamSearch_topKSuccessors_roots(mask, beam_width, numSuccessors, img, blurred_img, model, category,
                                                   prob_thresh, probability, max_num_roots, maxRootSize,
                                                   amp_autocast, reduce_factor, use_cuda=use_cuda)

        numRoots = len(roots_mp)
        print('numRoots_all = ', numRoots)
        # get maximal set of non-overlapping roots
        maximal_Overlap_mp = []
        if numRoots > 0:
            maximal_Overlap_mp = maximal_overlapThresh_set(roots_mp, overlap_thresh)
        else:
            images_no_roots_found += 1
        numRoots_Overlap = len(maximal_Overlap_mp)
        print('numRoots_Overlap = ', numRoots_Overlap)
        if numRoots_Overlap == 0:
            continue

        # prune number of roots to be shown in the sag
        if numRoots_Overlap > num_roots_sag:
            maximal_Overlap_mp = maximal_Overlap_mp[:num_roots_sag]
            numRoots_Overlap = num_roots_sag

        # deletion insertion on filtered set of masks - just  to generate result figures
        dnf = ""
        imgprefix=str(batch_idx)
        category_name=img_label.item()
        for mask, ins_prob, rel_prob in maximal_Overlap_mp:
            output_file_videoimgs = imgprefix + '_'
            delloss_top2, insloss_top2, minsufexpmask_upsampled, showimg_buffer = Deletion_Insertion_Comb_withOverlay(
                                                                                               maxRootSize,
                                                                                               mask,
                                                                                               model,
                                                                                               output_file_videoimgs,
                                                                                               img,
                                                                                               blurred_img,
                                                                                               category=category,
                                                                                               line_i=str(category_name),
                                                                                               use_cuda=use_cuda,
                                                                                               blur_mask=0,
                                                                                               outputfig=1)
            time_taken = int(time.time()-start)
            output_path_img = output_path + imgprefix + "_timetaken_" + str(time_taken) + "_category_" + str(category_name) + "_probthresh_" + str(prob_thresh) + "/"
            output_file_perturbation_heatmaps = output_path_img + 'perturbation_'
            output_path_count = output_path_img + '_insprob_' + str(ins_prob) + '_relprob_' + str(rel_prob) + "/"
            outvideo_path = output_path_count + 'VIDEO/'

            # create MDNF expression
            patch_boolean_list = get_patch_boolean(mask)
            conjunction = ""
            for b in patch_boolean_list:
                conjunction += ' & P'+str(b)
            conjunction = conjunction[3:]
            dnf += ' | '+conjunction

            # save obtained sample
            if not os.path.isdir(outvideo_path):
                os.makedirs(outvideo_path)

            # save perturbation heatmaps
            save_perturbation_heatmap(output_file_perturbation_heatmaps, upsampled_mask, img * 255, blurred_img, blur_mask=0)

            # unpack result images
            for item in showimg_buffer:
                deletion_img, insertion_img, del_curve, insert_curve, out_pathx, xtick, line_i = item
                out_pathx = outvideo_path + out_pathx
                showimage(deletion_img, insertion_img, del_curve, insert_curve, out_pathx, xtick, line_i)

            # save perturbation minsufexp heatmaps
            output_file_perturbationminsufexp_heatmaps = output_path_count + imgprefix + '_perturbationminsufexp_'
            img_ori = save_perturbation_heatmap(output_file_perturbationminsufexp_heatmaps, minsufexpmask_upsampled, img * 255, blurred_img, blur_mask=0)
            insertion_img = cv2.cvtColor(insertion_img, cv2.COLOR_RGB2BGR)
            cv2.imwrite(output_path_count + imgprefix + 'InsertionImg.png', insertion_img * 255)

            # write root conjunctions
            conjunction_file = open(output_path_count + imgprefix + 'conjunction.txt', 'w+')
            conjunction_file.write(conjunction)
            conjunction_file.close()

        # write MDNF expression
        dnf_file = open(output_path_img + 'dnf.txt', 'w+')
        dnf = dnf[3:]
        dnf_file.write(dnf)
        dnf_file.close()

        # save SAG roots as a GIF
        create_minsufexp_gif(output_path_img)

        ## build patch deletion tree ##

        # load original image
        # img_ori = cv2.imread(output_path_img + 'perturbation_original.png')
        # img_ori = cv2.cvtColor(img_ori, cv2.COLOR_RGB2BGR)
        # create patchImages folder if not exists
        current_patchImages_path = output_path_img + 'SAG_PatchImages_'+str(numRoots_Overlap)+'roots'
        if not os.path.isdir(current_patchImages_path):
            os.makedirs(current_patchImages_path)

        # create and save grid image
        img_ori = (img.cpu().squeeze(0).numpy().transpose(1, 2, 0) + 1) / 2
        gridimage(img_ori, output_path_img + 'gridimage.png')

        # get set of conjunctions from DNF expression
        conjuncts = get_conjuncts_set(dnf)

        # prune conjunctions to required number of roots in SAG
        if len(conjuncts) > numRoots_Overlap:
            conjuncts = conjuncts[:numRoots_Overlap]
        
        # build tree
        img_ori = img.clone()
        sag_tree = build_tree(conjuncts, ups, img_ori, blurred_img, model, category, current_patchImages_path, node_prob_thresh, probability, amp_autocast, reduce_factor)
        # book-keeping and save generated result files
        f = output_path_img + 'SAG_'+str(numRoots_Overlap)+'roots.dot'
        sag_tree.write(f)
        # TODO : adjust node size
        with open(os.devnull, 'w') as devnull:
            with contextlib.redirect_stderr(devnull):
                img = pgv.AGraph(f)
                img.layout(prog='dot')
                f2 = output_path_img + 'SAG_dag_'+str(numRoots_Overlap)+'roots.png'
                img.draw(f2)
                img.close()
                f3 = output_path_img + 'SAG_final_'+str(numRoots_Overlap)+'roots.png'
                img_tree = cv2.imread(f2)
                h,w,c = img_tree.shape
                hp = h
                wp = hp
                off = 50
                h1 = hp - (off*2)
                w1 = h1
                dim = (h1,w1)
                tmp_img_ori = deepcopy(img_ori)
                tmp_img_ori = tmp_img_ori.squeeze(0).cpu().numpy().transpose(1, 2, 0)
                # tmp_img_ori = cv2.resize(tmp_img_ori.transpose(1, 2, 0), dim, interpolation=cv2.INTER_AREA)
                tmp_img_ori = cv2.cvtColor(tmp_img_ori, cv2.COLOR_RGB2BGR)
                img_ori_padded = np.ones((hp,wp,c)) * 0
                try:
                    img_ori_padded[off:off+tmp_img_ori.shape[0], off:off+tmp_img_ori.shape[1], :] = tmp_img_ori
                except: 
                    print(f"failed img_ori_padded.shape {img_ori_padded.shape} tmp_img_ori off {off} tmp_img_ori.shape {tmp_img_ori.shape}")
                # img_ori_padded[off:off+h1, off:off+w1, :] = tmp_img_ori

                # concatenate image and explanation tree
                img_sag_final = np.concatenate((img_ori_padded, img_tree), axis=1)
                # save generated sag image
                cv2.imwrite(f3, img_sag_final)


def reshape_transform(tensor, pre_token=0, ratio_h=256, ratio_w=256):
    ratio_hw = ratio_h * ratio_w 
    total_spatial_tokens = tensor.shape[1] - pre_token
    ratio = (total_spatial_tokens / ratio_hw) ** 0.5
    height = int(ratio_h * ratio)
    width = int(ratio_w * ratio)
    result = tensor[:, pre_token:, :].reshape(tensor.size(0), height, width, tensor.size(2))
    result = result.permute(0, 3, 1, 2)
    return result


def get_explanation_by_gradient(gradient_base_model_explanation,
                                input_tensor,
                                targets,
                                model,
                                target_layers,
                                amp_autocast,
                                mean,
                                std,
                                output_path,
                                batch_idx,
                                aug_smooth=False,
                                eigen_smooth=False,
                                return_instead_save=False):
    
    if gradient_base_model_explanation == "":
         return
                                    
    # gradient_base_model_explanation -> register base load
    CAM = cam_registry.get(gradient_base_model_explanation)
    if CAM is None:
        raise ValueError(f"Unknown CAM method: {gradient_base_model_explanation}")

    # Construct the CAM object once, and then re-use it on many images.
    pre_token = 0
    if hasattr(model, 'reg_token') and model.reg_token is not None:
        pre_token += model.reg_token.shape[1]
    if hasattr(model, 'class_token') and model.class_token is not None:
        pre_token += 1
    reshape_fn = partial(
        reshape_transform,
        pre_token=pre_token,
        ratio_h=input_tensor.shape[2],
        ratio_w=input_tensor.shape[3]
    )
    cam_output_path = os.path.join(output_path, f'cam')
    os.makedirs(cam_output_path, exist_ok=True)
    with amp_autocast():
        with CAM(model=model, target_layers=target_layers, reshape_transform=reshape_fn) as cam:
            # You can also pass aug_smooth=True and eigen_smooth=True, to apply smoothing.
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets,
                            aug_smooth=aug_smooth,
                            eigen_smooth=eigen_smooth)
            if return_instead_save:
                return grayscale_cam
            # In this example grayscale_cam has only one image in the batch:
            mean = torch.tensor(mean, device=input_tensor.device).view(1, -1, 1, 1)
            std = torch.tensor(std, device=input_tensor.device).view(1, -1, 1, 1)
            rgb_img = input_tensor * std + mean
            rgb_img = rgb_img.permute(0, 2, 3, 1).cpu().numpy()
            for iteration_idx, (cur_rgb_img, cur_grayscale_cam) in enumerate(zip(rgb_img, grayscale_cam)):
                cam_image = show_cam_on_image(cur_rgb_img, cur_grayscale_cam, use_rgb=True, image_weight=0.7)
                cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR)
                cam_output_path = os.path.join(output_path, f'cam/batch{batch_idx}-iteration{iteration_idx}-{gradient_base_model_explanation}.jpg')
                cv2.imwrite(cam_output_path, cam_image)

                rgb_image = show_cam_on_image(cur_rgb_img, cur_grayscale_cam, use_rgb=True, image_weight=1.0)
                img_output_path = os.path.join(output_path, f'cam/batch{batch_idx}-iteration{iteration_idx}.jpg')
                cv2.imwrite(img_output_path, rgb_image)
