import random
import torch
from utils.models import resnet34, BiLSTMModel, TransformerModel
from utils.data_loader import load_noise_data
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import argparse
import os
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import json
from NoMerge_Wave_Fadix_Overlap_v6 import Pytorch_IGSegment_Fadix_Overlap, AUCStop_attrmap_instance
from NoMerge_Wave_Fadix_Overlap_WithLeader_v6 import Pytorch_IGSegment_Fadix_Overlap_WithLeader_2
from metrics_overlap_dwt import pytorch_faithfulness_attribution_dwt, pytorch_AUCStop_attribution_dwt
import matplotlib.pyplot as plt  # for display purposes
from matplotlib.patches import Rectangle
import numpy as np
from pytorch_wavelets import DWT1DForward, DWT1DInverse
from torch.utils.data import DataLoader, Subset
import time

random.seed(42)
torch.set_num_threads(32)
torch.manual_seed(911)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Wavelet_Model(torch.nn.Module):
    def __init__(self, model, idwt1d, part_lengths, data_size):
        super(Wavelet_Model, self).__init__()
        # Compute the total number of input features (channels * size)
        self.model = model
        self.idwt1d = idwt1d
        self.part_lengths = part_lengths
        self.data_size = data_size

    def forward(self, x, captum_input=True):
        yl = x[:, :, :self.part_lengths[0]]
        yh = []
        current_pos = self.part_lengths[0]
        for j in range(1, len(self.part_lengths)):
            # print(current_pos)
            yh.append(x[:, :, current_pos:current_pos+self.part_lengths[j]])
            current_pos += self.part_lengths[j]
        reconstructed_X = self.idwt1d((yl, yh))[:, :, :self.data_size]
        # if reconstructed_X.shape[2]!=self.model.initial_linear.in_features:
            # print("Wrong size", reconstructed_X.shape, self.data_size, self.model.initial_linear.in_features)
            # exit(1)
        output = self.model(reconstructed_X)
        return output

def visualize_cp(args, time_serie_data, explain, file_path, part_lengths):
    scores, segments = explain[0], explain[1]
    time_serie_data = time_serie_data.reshape(time_serie_data.shape[0])
    ig, axs = plt.subplots(1, figsize=(6, 2))
    axs.plot(time_serie_data, alpha=0.7, linewidth=2.0, color='black')
    height = max(time_serie_data) - min(time_serie_data)
    count = 0
    current_pos = part_lengths[0]
    plt.axvline(x=current_pos, color='blue', linestyle='--', label='Vertical Line')
    for j in range(1, len(part_lengths)):
        current_pos += part_lengths[j]
        plt.axvline(x=current_pos, color='blue', linestyle='--', label='Vertical Line')
    abs_scores = {key: abs(value) for key, value in scores.items()}
    top_1_keys = sorted(abs_scores, key=abs_scores.get, reverse=True)[:5]
    for key in segments:
        if abs(scores[key]) > -0.001: 
            start, end = segments[key][0], segments[key][1]
            if count%2==0:
                alpha = 0.7
                color = 'green'
            else:
                alpha = 0.3
                color = 'yellow'
            # alpha = abs(scores[key])/1
            if key in top_1_keys:
                rect = Rectangle((start, min(time_serie_data)), end - start, height, facecolor=color, alpha=alpha)
                rx, ry = rect.get_xy()
                cx = rx + rect.get_width()/2.0
                cy = ry + rect.get_height()/2.0
                axs.annotate(f'{scores[key]:.2f}', (cx, cy), color='blue', weight='bold', 
                            fontsize=10, ha='center', va='center')
                axs.add_patch(rect)
            count+=0
    plt.title(args.dataset+'/'+args.architecture)
    plt.tight_layout()
    # plt.subplots_adjust(left=0.06, right=0.98, top=0.82, bottom=0.15)
    plt.savefig(file_path, dpi=600)
    plt.close()


def evaluate_model(model, dataloader, num_classes):
    """
    Evaluate accuracy and F1 score for each class, and compute overall F1 score.
    
    Args:
        model (torch.nn.Module): Trained model.
        dataloader (torch.utils.data.DataLoader): DataLoader for evaluation data.
        num_classes (int): Number of classes in the dataset.
    
    Returns:
        dict: Per-class accuracy and F1-score, and overall F1-score.
    """
    model.eval()  # Set the model to evaluation mode
    all_labels = []
    all_predictions = []
    with torch.no_grad():
        for _, (data, labels) in enumerate(dataloader):
            data, labels = data.to(device), labels.to(device)
            output = model(data)

            _, predicted = torch.max(output, 1)
            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
    
    precision = precision_score(all_labels, all_predictions, average=None)
    recall = recall_score(all_labels, all_predictions, average=None)
    f1 = f1_score(all_labels, all_predictions, average=None)
    f1_macro = f1_score(all_labels, all_predictions, average='macro')
    accuracy = accuracy_score(all_labels, all_predictions)
    
    return f1, f1_macro

if __name__ == "__main__":

    parser = argparse.ArgumentParser() 

    parser.add_argument('--dataset', type=str, default='', help="Dataset to train on")
    parser.add_argument('--architecture', type=str, default="", choices=['resnet', 'transformer', 'bilstm'])

    parser.add_argument('--batch_size', type=int, default=16)
    parser.add_argument('--lr', type=float, default=2e-4)
    parser.add_argument('--n_epochs', type=int, default=200)
    parser.add_argument('--savedir', type=str, default="classification_models")
    parser.add_argument('--inplanes', type=int, default=64)
    parser.add_argument('--num_classes', type=int, default=4)    
    
    #transformer and bi-lstm
    parser.add_argument('--use_transformer', type=bool, default=False)
    parser.add_argument('--num_layers', type=int, default=2)
    parser.add_argument('--d_model', type=int, default=64)
    parser.add_argument('--nhead', type=int, default=8)
    parser.add_argument('--dim_feedforward', type=int, default=256)
    parser.add_argument('--dropout', type=int, default=0.2)
    parser.add_argument('--timesteps', type=int, default=1024) #128 140 1639

    # Explanation
    parser.add_argument("-WS", "--window_size", type=int, help="num of change points")
    parser.add_argument("-Step", "--step_size", type=int, help="num of change points")
    
    #################################### Setup parameters #########################################
    args = parser.parse_args()
    print("Model::::::::::::", args.architecture)
    print("Data::::::::::::", args.dataset)
    print("Windowsize, step size", args.window_size, args.step_size)

    model_root_dir = args.savedir
    model_dir = os.path.join(model_root_dir, args.dataset)
    model_path = os.path.join(model_dir, args.architecture+'_'+'.pth')

    if args.architecture == "resnet":
        if args.use_transformer:
            exit()
        net = resnet34(args, num_classes = args.num_classes).to(device)
    elif args.architecture == 'bilstm':
        net = BiLSTMModel(args, num_classes = args.num_classes).to(device) 
    elif args.architecture == 'transformer':
        net = TransformerModel(args, num_classes = args.num_classes).to(device) 

    # Load the state dictionary
    state_dict = torch.load(model_path)
    # Load state into the model
    net.load_state_dict(state_dict, strict=True)
    net.eval()

    _, _, test_dataset, mean_dict = load_noise_data(data_name=args.dataset)
    mean_dict = {int(k): v for k, v in mean_dict.items()}
    test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)
    scores, f1_macro = evaluate_model(model=net, dataloader=test_loader, num_classes=args.num_classes)
    print("Testing Scores::::::", scores)
    print("Data Size ##############", test_dataset[0][0].shape)
    log_interval = 5
    num_top_segments = int((test_dataset[0][0].shape[-1]/args.step_size)/5) # 1/5 of all segments
    # if args.dataset== 'Yoga':
        # num_top_segments = 3
    print("Num top segments #########", num_top_segments)

   
    # test_dataset = Subset(test_dataset, range(2)) # Set subset for quick test

    start_time = time.time()

    ###################################### Dictionary dataset level ###################################################
    # faithfulness_scores = {}
    # IG_faithfulness_scores = {}
    total_IG_attr_map_dict_1 = {} # For all samples after phase 1, IG attribution map dictionary, key: wavelet_level, value: list of attibution map generated by IG
    total_IG_attr_map_dict = {} # For all samples after phase 2, IG attribution map dictionary, key: wavelet_level, value: list of attibution map generated by IG
    total_wave_explain_dict_1 = {} # After phase 1, Dictionary following to key:wavelet_level; value: list of explain for wavlet_levels, example: {'cA_0': 4.008}
    total_wave_explain_dict = {} # After phase 2, Dictionary following to key:wavelet_level; value: list of explain for wavlet_levels, example: {'cA_0': 4.008}
    total_attr_map_dict_1 = {} # After phase 1, dictionary store all attribution map generated from segment scores key:waveluet_level, value: list of attribution map (length: number of samples) for each wavelet level key.
    total_attr_map_dict = {} # After phase 2, dictionary store all attribution map generated from segment scores key:waveluet_level, value: list of attribution map (length: number of samples) for each wavelet level key.
    part_lengths_dict = {} # Length for each waevlet level, format: {waevlet_level: list_of_length}. List_of_length: [cA_2, cD_0, cD_1, cD_2]
    # num_samples = len(test_dataset)
    highest_level = 5 # highest wavelet level H, it means we explain on level from 0 to H
    for wavelet_level in range(highest_level+1): # Initialize list of all above dictionary
        total_IG_attr_map_dict_1[wavelet_level] = []
        total_IG_attr_map_dict[wavelet_level] = []
        total_wave_explain_dict_1[wavelet_level] = []
        total_wave_explain_dict[wavelet_level] = []
        total_attr_map_dict_1[wavelet_level] = []
        total_attr_map_dict[wavelet_level] = []        
    data_size = 0 # Size of data at level 0
    for data_index in range(len(test_dataset)):
        if data_index % log_interval == 0:
            print("Processing ...", data_index, "&"*30) 
        instance_faithfulness_score_dict = {} # faithfulness score of current explanation for this instance
        ig_instance_faithfulness_score_dict = {} # faithfulness score of current IG attribution map for this instance
        explain_dict = {} # Dictionary of explanation with key is wavelel level
        attr_map_dict = {} # Dictionary of attribution map with key is wavelel level
        ig_attr_map_dict = {} 
        test_instance, test_label = test_dataset[data_index][0], test_dataset[data_index][1].cpu().item()
        test_instance = test_instance.unsqueeze(dim=0).to(device)
        for wavelet_level in range(0, highest_level+1):
            if data_index % log_interval == 0:
                print("wavelet_level ....", wavelet_level)
            pertub_serie = np.array(mean_dict[wavelet_level])
            dwt1d = DWT1DForward(wave='haar', J=wavelet_level).to(device)
            idwt1d = DWT1DInverse(wave='haar').to(device)
            if data_index == 0:
                data_size = test_instance.shape[-1]
                yl, yh = dwt1d(test_instance)
                part_lengths = []
                part_lengths.append(yl.shape[-1])
                for tensor in yh:
                    part_lengths.append(tensor.shape[-1])
                part_lengths_dict[wavelet_level] = part_lengths
            else:
                part_lengths = part_lengths_dict[wavelet_level]
            wave_model = Wavelet_Model(model=net, idwt1d=idwt1d, part_lengths=part_lengths, data_size=data_size)
            wave_model = wave_model.to(device)
            wave_model.eval()

            visual_dir = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture,  'wavelet_'+str(wavelet_level))
            os.makedirs(os.path.join(visual_dir), exist_ok=True)
            
            testl, testh = dwt1d(test_instance)
            wave_input = [testl] + testh
            wave_test_instance = torch.cat(wave_input, dim=-1)
            # wave_test_instance = wave_test_instance[:, :, : data_size]
            ig_segment_explain, IG_attr = Pytorch_IGSegment_Fadix_Overlap(wave_test_instance, wave_model, test_label,  window_size=max(int(args.window_size/2**wavelet_level), 1), step_size=max(int(args.step_size/2**wavelet_level),1), part_lengths=part_lengths, pertub_serie=pertub_serie)
            explain_dict[wavelet_level] = ig_segment_explain
            visual_instance = wave_test_instance.flatten().cpu().numpy()
            IG_attr = torch.abs(IG_attr).cpu().numpy()


            attribution_map = np.zeros((visual_instance.shape[0]))
            segment_map = ig_segment_explain[1]
            scores_map = ig_segment_explain[0]
            for k in segment_map:
                segment_length = segment_map[k][1] - segment_map[k][0]
                for h in range(segment_map[k][0], segment_map[k][1]):
                    attribution_map[h] = max(attribution_map[h], abs(scores_map[k]))
            attr_map_dict[wavelet_level] = attribution_map
            ig_attr_map_dict[wavelet_level] = IG_attr
            # attribution_maps.append(attribution_map)
            instance_faithfulness_score_dict[wavelet_level], _ = AUCStop_attrmap_instance(ts=wave_test_instance.cpu().flatten().numpy(), label=test_label, model=wave_model, attr_map=attr_map_dict[wavelet_level], perturbation=pertub_serie)
            ig_instance_faithfulness_score_dict[wavelet_level], _ = AUCStop_attrmap_instance(ts=wave_test_instance.cpu().flatten().numpy(), label=test_label, model=wave_model, attr_map=ig_attr_map_dict[wavelet_level], perturbation=pertub_serie)
            

            # if data_index % log_interval == 0:
            #     print(data_index, "Processing ...", len(ig_segment_explain[0]))
            visual_file = os.path.join(visual_dir, args.architecture + '_' + str(data_index)+'_'+str(test_label)+'_'+args.dataset+'.png')
            if data_index % (log_interval*10) == 0:
                visualize_cp(args, visual_instance, ig_segment_explain, visual_file, part_lengths)
        max_wavelet_level = max(instance_faithfulness_score_dict, key=instance_faithfulness_score_dict.get) # Wavelet level with maximize multithresholds faithfulness score
        # if data_index % log_interval ==0:
            # print("instance_faithfulness_score_dict", instance_faithfulness_score_dict)
            # print("ig_instance_faithfulness_score_dict", ig_instance_faithfulness_score_dict)
        # print(part_lengths_dict)
        # exit()
        for wavelet_level in range(highest_level+1):
            total_wave_explain_dict_1[wavelet_level].append(explain_dict[wavelet_level])
            total_attr_map_dict_1[wavelet_level].append(attr_map_dict[wavelet_level])
            total_IG_attr_map_dict_1[wavelet_level].append(ig_attr_map_dict[wavelet_level])
        
        ############################################# Round 2 Down ############################################
        leader = max_wavelet_level
        if data_index % log_interval == 0:
            print("£"*20, "Round 2 Down", leader, "£"*20)
        instance_faithfulness_score_dict_2 = dict() # Round 2: faithfulness score of current explanation for this instance
        ig_instance_faithfulness_score_dict_2 = dict() # Round 2: faithfulness score of current IG attribution map for this instance
        explain_dict_2 = dict()
        attr_map_dict_2 = dict()
        ig_attr_map_dict_2 = dict()
        leader_explain = explain_dict[leader]
        leader_part_lengths = part_lengths_dict[leader]
        for wavelet_level in range(leader, -1, -1):
            if data_index % log_interval == 0:
                print("wavelet_level ....", wavelet_level)
            visual_dir = os.path.join('noise_nomerge_result_neurips_2025_v6_round_2', args.dataset, args.architecture,  'wavelet_'+str(wavelet_level))
            pertub_serie = np.array(mean_dict[wavelet_level])
            dwt1d = DWT1DForward(wave='haar', J=wavelet_level).to(device)
            idwt1d = DWT1DInverse(wave='haar').to(device)
            part_lengths = part_lengths_dict[wavelet_level]
            wave_model = Wavelet_Model(model=net, idwt1d=idwt1d, part_lengths=part_lengths, data_size=data_size)
            wave_model = wave_model.to(device)
            wave_model.eval()
            testl, testh = dwt1d(test_instance)
            wave_input = [testl] + testh
            wave_test_instance = torch.cat(wave_input, dim=-1)
            # wave_test_instance = wave_test_instance[:, :, :data_size]

            leader_scores, leader_segments = leader_explain[0], leader_explain[1]
            top_keys = sorted({key: abs(val) for key, val in leader_scores.items()}, key=leader_scores.get, reverse=True)[:num_top_segments]
            cD_shift = test_instance.shape[-1] - leader_part_lengths[-1]
            current_cD_shift = test_instance.shape[-1] - part_lengths[-1]
            # print(leader_part_lengths, cD_shift)
            position_set = set()
            for top_key in top_keys:
                segment = leader_segments[top_key].copy()
                # print(leader_segments, segment)
                if 'cA' in top_key:
                    segment[0] = min(test_instance.shape[-1]-1, segment[0]*2**(leader-wavelet_level))
                    segment[1] = min(test_instance.shape[-1]-1, segment[1]*2**(leader-wavelet_level))
                elif 'cD' in top_key:
                    # segment[0] = (segment[0] - cD_shift)*2**1
                    # segment[1] = (segment[1] - cD_shift)*2**1
                    if leader!=wavelet_level:
                        segment[0] = min(test_instance.shape[-1]-1, (segment[0] - cD_shift)*2**(leader-wavelet_level))
                        segment[1] = min(test_instance.shape[-1]-1, (segment[1] - cD_shift)*2**(leader-wavelet_level))
                position_set.update(range(segment[0], segment[1]))
                if leader!=wavelet_level and wavelet_level!=0:
                    # print(segment[0]+current_cD_shift, segment[1]+current_cD_shift, test_instance.shape[-1], current_cD_shift, segment[0], segment[1])
                    # exit()
                    cd_pos_0 = min(test_instance.shape[-1]-1, segment[0]+current_cD_shift)
                    cd_pos_1 = min(test_instance.shape[-1]-1, segment[1]+current_cD_shift)
                    position_set.update(range(cd_pos_0, cd_pos_1))
                    # print(position_set, segment[0], segment[1], current_cD_shift, cD_shift)
                
            # print(wavelet_level, leader, top_keys, position_set)
            # exit()
            # try:
            ig_segment_explain, IG_attr = Pytorch_IGSegment_Fadix_Overlap_WithLeader_2(wave_test_instance, wave_model, test_label, window_size=max(int(args.window_size/2**wavelet_level), 1), step_size=max(int(args.step_size/2**wavelet_level),1), part_lengths=part_lengths, position_set=position_set, pertub_serie=pertub_serie)
            # except:
                # print(position_set, segment[0], segment[1], current_cD_shift, cD_shift)
                # exit()
            IG_attr = torch.abs(IG_attr).cpu().numpy()
            visual_instance = wave_test_instance.flatten().cpu().numpy()

            attribution_map = np.zeros((visual_instance.shape[0]))
            segment_map = ig_segment_explain[1]
            scores_map = ig_segment_explain[0]
            for k in segment_map:
                segment_length = segment_map[k][1] - segment_map[k][0]
                for h in range(segment_map[k][0], segment_map[k][1]):
                    attribution_map[h] = max(attribution_map[h], abs(scores_map[k]))
            # attr_map_dict_2[wavelet_level].append(attribution_map)
            # attribution_maps.append(attribution_map)

            # if data_index % log_interval == 0:
            #     print(data_index, "Processing ...", len(ig_segment_explain[0]))
            # visual_file = os.path.join(visual_dir, args.architecture + '_' + str(data_index)+'_'+str(test_label)+'_'+args.dataset+'.png')
            # if data_index % log_interval == 0:
            #     visualize_cp(args, visual_instance, ig_segment_explain, visual_file, part_lengths)
            instance_faithfulness_score, _ = AUCStop_attrmap_instance(ts=wave_test_instance.cpu().flatten().numpy(), label=test_label, model=wave_model, attr_map=attribution_map, perturbation=pertub_serie)
            ig_instance_faithfulness_score, _ = AUCStop_attrmap_instance(ts=wave_test_instance.cpu().flatten().numpy(), label=test_label, model=wave_model, attr_map=IG_attr, perturbation=pertub_serie)
            # print(instance_faithfulness_score, ig_instance_faithfulness_score)
            
            if instance_faithfulness_score > instance_faithfulness_score_dict[wavelet_level]:
                instance_faithfulness_score_dict_2[wavelet_level] = instance_faithfulness_score
                ig_instance_faithfulness_score_dict_2[wavelet_level] = ig_instance_faithfulness_score
                explain_dict_2[wavelet_level] = ig_segment_explain
                attr_map_dict_2[wavelet_level] = attribution_map
                ig_attr_map_dict_2[wavelet_level] = IG_attr
            else:
                instance_faithfulness_score_dict_2[wavelet_level] = instance_faithfulness_score_dict[wavelet_level]
                ig_instance_faithfulness_score_dict_2[wavelet_level] = ig_instance_faithfulness_score_dict[wavelet_level]
                explain_dict_2[wavelet_level] = explain_dict[wavelet_level]
                attr_map_dict_2[wavelet_level] = attr_map_dict[wavelet_level]
                ig_attr_map_dict_2[wavelet_level] = ig_attr_map_dict[wavelet_level]
        


        
        ############################################# Round 2 Upper ############################################
        leader = max_wavelet_level
        if data_index % log_interval == 0:
            print("£"*20, "Round 2 Upper ", leader, "£"*20)
        leader_explain = explain_dict[leader]
        leader_part_lengths = part_lengths_dict[leader]
        for wavelet_level in range(leader+1, highest_level+1):
            if data_index % log_interval == 0:
                print("wavelet_level ....", wavelet_level)
            # visual_dir = os.path.join('nomerge_result_neurips_2025_v4_round_2', args.dataset, args.architecture,  'wavelet_'+str(wavelet_level))
            pertub_serie = np.array(mean_dict[wavelet_level])
            dwt1d = DWT1DForward(wave='haar', J=wavelet_level).to(device)
            idwt1d = DWT1DInverse(wave='haar').to(device)
            part_lengths = part_lengths_dict[wavelet_level]
            wave_model = Wavelet_Model(model=net, idwt1d=idwt1d, part_lengths=part_lengths, data_size=data_size)
            wave_model = wave_model.to(device)
            wave_model.eval()
            testl, testh = dwt1d(test_instance)
            wave_input = [testl] + testh
            wave_test_instance = torch.cat(wave_input, dim=-1)
            # wave_test_instance = wave_test_instance[:, :, : data_size]

            leader_scores, leader_segments = leader_explain[0], leader_explain[1]
            top_keys = sorted({key: abs(val) for key, val in leader_scores.items()}, key=leader_scores.get, reverse=True)[:num_top_segments]
            cD_shift = test_instance.shape[-1] - leader_part_lengths[-1]
            current_cD_shift = test_instance.shape[-1] - part_lengths[-1]
            # print(leader_part_lengths, test_instance.shape[-1], cD_shift)
            position_set = set()
            for top_key in top_keys:
                segment = leader_segments[top_key].copy()
                # print(leader_segments, segment)
                if 'cA' in top_key:
                    segment_0 = min(test_instance.shape[-1]-1, int(segment[0]/2**(wavelet_level-leader)))
                    segment_1 = min(test_instance.shape[-1]-1, int(segment[1]/2**(wavelet_level-leader)))
                    position_set.update(range(segment_0, segment_1+1))
                    position_set.update(range(segment_0+current_cD_shift, segment_1+current_cD_shift))
                elif 'cD' in top_key:
                    if leader!=wavelet_level:
                        segment_0 = min(test_instance.shape[-1]-1, int((segment[0] - cD_shift)/2**(wavelet_level-leader)))
                        segment_1 = min(test_instance.shape[-1]-1, int((segment[1] - cD_shift)/2**(wavelet_level-leader)))
                    position_set.update(range(segment_0, segment_1+1))
            # print(wavelet_level, leader, leader_segments, top_keys, position_set)
            # exit()
            # continue

            ig_segment_explain, IG_attr = Pytorch_IGSegment_Fadix_Overlap_WithLeader_2(wave_test_instance, wave_model, test_label, window_size=max(int(args.window_size/2**wavelet_level), 1), step_size=max(int(args.step_size/2**wavelet_level),1), part_lengths=part_lengths, position_set=position_set, pertub_serie=pertub_serie, wavelet_level=wavelet_level)
            # ig_segment_explain, IG_attr = Pytorch_IGSegment_Fadix_Overlap_WithLeader_2(wave_test_instance, wave_model, test_label, window_size=max(int(args.window_size/2**wavelet_level), 1), step_size=max(int(args.step_size/2**wavelet_level),1), part_lengths=part_lengths, position_set=position_set, pertub_serie=pertub_serie)
            visual_instance = wave_test_instance.flatten().cpu().numpy()
            IG_attr = torch.abs(IG_attr).cpu().numpy()

            attribution_map = np.zeros((visual_instance.shape[0]))
            segment_map = ig_segment_explain[1]
            scores_map = ig_segment_explain[0]
            for k in segment_map:
                segment_length = segment_map[k][1] - segment_map[k][0]
                for h in range(segment_map[k][0], segment_map[k][1]):
                    attribution_map[h] = max(attribution_map[h], abs(scores_map[k]))
            # attribution_maps.append(attribution_map)

            # if data_index % log_interval == 0:
            #     print(data_index, "Processing ...", len(ig_segment_explain[0]))
            # visual_file = os.path.join(visual_dir, args.architecture + '_' + str(data_index)+'_'+str(test_label)+'_'+args.dataset+'.png')
            # if data_index % log_interval == 0:
            #     visualize_cp(args, visual_instance, ig_segment_explain, visual_file, part_lengths)
            instance_faithfulness_score, _ = AUCStop_attrmap_instance(ts=wave_test_instance.cpu().flatten().numpy(), label=test_label, model=wave_model, attr_map=attribution_map, perturbation=pertub_serie)
            ig_instance_faithfulness_score, _ = AUCStop_attrmap_instance(ts=wave_test_instance.cpu().flatten().numpy(), label=test_label, model=wave_model, attr_map=IG_attr, perturbation=pertub_serie)
            if instance_faithfulness_score > instance_faithfulness_score_dict[wavelet_level]:
                instance_faithfulness_score_dict_2[wavelet_level] = instance_faithfulness_score
                ig_instance_faithfulness_score_dict_2[wavelet_level] = ig_instance_faithfulness_score
                explain_dict_2[wavelet_level] = ig_segment_explain
                attr_map_dict_2[wavelet_level] = attribution_map
                ig_attr_map_dict_2[wavelet_level] = IG_attr
            else:
                instance_faithfulness_score_dict_2[wavelet_level] = instance_faithfulness_score_dict[wavelet_level]
                ig_instance_faithfulness_score_dict_2[wavelet_level] = ig_instance_faithfulness_score_dict[wavelet_level]
                explain_dict_2[wavelet_level] = explain_dict[wavelet_level]
                attr_map_dict_2[wavelet_level] = attr_map_dict[wavelet_level]
                ig_attr_map_dict_2[wavelet_level] = ig_attr_map_dict[wavelet_level]
        ######################### Print Results of Faithfulness ################################################
        instance_faithfulness_score_dict_2 = {key: instance_faithfulness_score_dict_2[key] for key in sorted(instance_faithfulness_score_dict_2)}
        ig_instance_faithfulness_score_dict_2 = {key: ig_instance_faithfulness_score_dict_2[key] for key in sorted(ig_instance_faithfulness_score_dict_2)}
        # if data_index % log_interval == 0:
        #     print("instance_faithfulness_score_dict", instance_faithfulness_score_dict)
        #     print("instance_faithfulness_score_dict_2", instance_faithfulness_score_dict_2)
        #     print("ig_instance_faithfulness_score_dict", ig_instance_faithfulness_score_dict)
        #     print("ig_instance_faithfulness_score_dict_2", ig_instance_faithfulness_score_dict_2)
        for wavelet_level in range(highest_level+1):
            total_wave_explain_dict[wavelet_level].append(explain_dict_2[wavelet_level])
            total_attr_map_dict[wavelet_level].append(attr_map_dict_2[wavelet_level])
            total_IG_attr_map_dict[wavelet_level].append(ig_attr_map_dict_2[wavelet_level])
        # exit()  
    
    delete_faithfulness_1 = dict() # Faithfulness score of overall dataset after Phase 1
    ig_delete_faithfulness_1 = dict() # Faithfulness score of overall dataset with IG after Phase 1
    # del_aucs_1 = dict()
    # in_aucs_1 = dict()
    # del_aucs = dict()
    # in_aucs = dict()
    positions_dict_1 = dict() # Top max (20) most imprtatant features position after Phase 1.
    ig_positions_dict_1 = dict() # Top max (20) most imprtatant IG features position after Phase 1.
    positions_dict = dict() # Top max (20) most imprtatant features position after Phase 2.
    ig_positions_dict = dict() # Top max (20) most imprtatant IG features position after Phase 2.
    for wavelet_level in range(highest_level+1):
        part_lengths = part_lengths_dict[wavelet_level]
        pertub_serie = np.array(mean_dict[wavelet_level])
        wave_model = Wavelet_Model(model=net, idwt1d=idwt1d, part_lengths=part_lengths, data_size=data_size)
        wave_model = wave_model.to(device)
        wave_model.eval()
        delete_faithfulness_1[wavelet_level], positions_dict_1[wavelet_level] = pytorch_faithfulness_attribution_dwt(total_attr_map_dict_1[wavelet_level], test_dataset, model=wave_model, wavelet_level=wavelet_level, pertubation=pertub_serie)
        ig_delete_faithfulness_1[wavelet_level], ig_positions_dict_1[wavelet_level] = pytorch_faithfulness_attribution_dwt(total_IG_attr_map_dict_1[wavelet_level], test_dataset, model=wave_model, wavelet_level=wavelet_level, pertubation=pertub_serie)     
    delete_faithfulness_2 = dict()
    ig_delete_faithfulness_2 = dict()
    AUCStop_score_dict_2 = dict() # AUStop score
    F1S_score_dict_2 = dict() # F1S score
    for wavelet_level in range(highest_level+1):
        part_lengths = part_lengths_dict[wavelet_level]
        pertub_serie = np.array(mean_dict[wavelet_level])
        wave_model = Wavelet_Model(model=net, idwt1d=idwt1d, part_lengths=part_lengths, data_size=data_size)
        wave_model = wave_model.to(device)
        wave_model.eval()
        delete_faithfulness_2[wavelet_level], positions_dict[wavelet_level] = pytorch_faithfulness_attribution_dwt(total_attr_map_dict[wavelet_level], test_dataset, model=wave_model, wavelet_level=wavelet_level, pertubation=pertub_serie)
        ig_delete_faithfulness_2[wavelet_level], ig_positions_dict[wavelet_level] = pytorch_faithfulness_attribution_dwt(total_IG_attr_map_dict[wavelet_level], test_dataset, model=wave_model, wavelet_level=wavelet_level, pertubation=pertub_serie)
        AUCStop_score_dict_2[wavelet_level], F1S_score_dict_2[wavelet_level] = pytorch_AUCStop_attribution_dwt(total_attr_map_dict[wavelet_level], test_dataset, model=wave_model, wavelet_level=wavelet_level, pertubation=pertub_serie)
    # print(AUCStop_score_dict_2, F1S_score_dict_2)
    # exit()
        
     

    os.makedirs(os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture,  'results'), exist_ok=True)
    delete_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture,  'results', str(len(test_dataset)) +'_'+str(args.window_size)+'_'+str(args.step_size)+'_delete_faithfulness.json')

    delete_round1_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture,  'results', str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_delete_faithfulness_round1.json')

    ig_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture, str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_IG_faithfulness.json')
    
    ig_round1_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture, str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_IG_faithfulness_round1.json')
    
    aucstop_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture, str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_AUCS_top.json')
    
    f1s_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture, str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_F1S_top.json')
    
    
    

    with open(delete_round1_json_file_path, "w") as json_file:
        json.dump(delete_faithfulness_1, json_file, indent=4)  # 'indent=4' makes the JSON file readable
    
    with open(delete_json_file_path, "w") as json_file:
        json.dump(delete_faithfulness_2, json_file, indent=4)  # 'indent=4' makes the JSON file readable
        
    with open(ig_round1_json_file_path, "w") as json_file:
        json.dump(ig_delete_faithfulness_1, json_file, indent=4)  # 'indent=4' makes the JSON file readable
    
    with open(ig_json_file_path, "w") as json_file:
        json.dump(ig_delete_faithfulness_2, json_file, indent=4)  # 'indent=4' makes the JSON file readable
    
    with open(aucstop_json_file_path, "w") as json_file:
        json.dump(AUCStop_score_dict_2, json_file, indent=4)  # 'indent=4' makes the JSON file readable
    
    with open(f1s_json_file_path, "w") as json_file:
        json.dump(F1S_score_dict_2, json_file, indent=4)  # 'indent=4' makes the JSON file readable



    exp_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture, str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_Xplanation.json')
    with open(exp_json_file_path, "w") as json_file:
        json.dump(positions_dict, json_file, indent=4)  # 'indent=4' makes the JSON file readable
        
    ig_exp_json_file_path = os.path.join('noise_nomerge_result_neurips_2025_v6', args.dataset, args.architecture, str(len(test_dataset))+'_'+str(args.window_size)+'_'+str(args.step_size)+'_IGXplanation.json')
    with open(ig_exp_json_file_path, "w") as json_file:
        json.dump(ig_positions_dict, json_file, indent=4)  # 'indent=4' makes the JSON file readable

    end_time = time.time()
    elapsed_time = end_time - start_time
    print(ig_delete_faithfulness_2)
    print("Elapsed Time:", elapsed_time, "seconds")
