import matplotlib.pyplot as plt  # for display purposes

from scipy import signal
import numpy as np
# import stumpy
from sklearn.linear_model import Ridge
# from fastdtw import fastdtw
import random
import torch
from utils.models import resnet34, BiLSTMModel, TransformerModel
from NoMerge_Wave_Fadix_Overlap_v6 import multithresholds_attr_instance_faithfulness

from captum.attr import IntegratedGradients, Saliency
from hausdorff import hausdorff_distance
# import pyhomogeneity as hg
MIN_NUM_SEGMENTS = 2
MIN_VALUE = -1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def explain2attributionmap(visual_instance, segment_explain):
    attribution_map = np.zeros((visual_instance.shape[0]))
    segment_map = segment_explain[1]
    scores_map = segment_explain[0]
    for k in segment_map:
        segment_length = segment_map[k][1] - segment_map[k][0]
        for h in range(segment_map[k][0], segment_map[k][1]):
            attribution_map[h] = max(attribution_map[h], abs(scores_map[k]))
    return attribution_map

def segment_union(segment_1, segment_2):
    start_1, end_1 = segment_1
    start_2, end_2 = segment_2
    return [min(start_1, start_2), max(end_1, end_2)]


def Intial_Segment(t, window_size, step_size, index_0, index_1):
    # length = len(t)/(num_intervals)
    # cp_indexes = []
    # for i in range(num_intervals):
    #     cp_indexes.append(int(i*length))
    # cp_indexes = cp_indexes + [len(t)]
    # interval_dict = {}
    # for i in range(len(cp_indexes)-1):
    #     interval_dict[i] = [cp_indexes[i], cp_indexes[i+1]]
    length = index_1 - index_0
    starts = []
    for i in range(index_0, index_1, step_size):
        starts.append(i)
    ends = []
    for index, start in enumerate(starts):
        if start + window_size >= index_1:
            ends.append(index_1)
        else:
            ends.append(start + window_size)
    interval_dict = {}
    for i in range(len(starts)):
        interval_dict[i] = [starts[i], ends[i]]
    return interval_dict

def get_segment_score(exp_matrix, index_0, index_1):
    score = torch.sum(exp_matrix[index_0:index_1])
    # explantion = {}
    # average_exp = {}
    # for key in pattern_dict:
    #     start, end = pattern_dict[key][0], pattern_dict[key][1]
    #     score = torch.sum(exp_matrix[start:end])
    #     mean = torch.mean(exp_matrix[start:end])
    #     # score = torch.abs(score) since update two consesutive the A+B become |A| + |B| ! in line 184
    #     # mean = torch.abs(mean)
    #     explantion[key] = score.cpu().item()
    #     average_exp[key] = mean.cpu().item()
    # print(explantion)
    # print(average_exp)
    # exit()
    # explantion = {key: abs(val) for key, val in explantion.items()}
    # average_exp = {key: abs(val) for key, val in average_exp.items()}
    return score

def get_exp_scores(exp_matrix, pattern_dict):
    explantion = {}
    average_exp = {}
    for key in pattern_dict:
        start, end = pattern_dict[key][0], pattern_dict[key][1]
        score = torch.sum(exp_matrix[start:end])
        mean = torch.mean(exp_matrix[start:end])
        # score = torch.abs(score) since update two consesutive the A+B become |A| + |B| ! in line 184
        # mean = torch.abs(mean)
        explantion[key] = score.cpu().item()
        average_exp[key] = mean.cpu().item()
    # print(explantion)
    # print(average_exp)
    # exit()
    # explantion = {key: abs(val) for key, val in explantion.items()}
    # average_exp = {key: abs(val) for key, val in average_exp.items()}
    return explantion, average_exp

def reverse_segment(ts, index0, index1):
    perturbed_ts = ts.copy()
    # Shuffle the subset in-place
    # subset = perturbed_ts[index0:index1]
    # np.random.shuffle(subset)
    perturbed_ts[index0:index1] = np.flip(ts[index0:index1])
    # perturbed_ts[index0:index1] = subset
    # print(subset)
    # exit()
    return perturbed_ts

def perturb_segment(ts, indices, pertubation):
    perturbed_ts = ts.copy()
    perturbed_ts[indices] = pertubation[indices]
    return perturbed_ts

import operator
def instance_faithfulness(ts, label, model, explain, pattern_dict):
    top_index = max(explain.items(), key=operator.itemgetter(1))[0]
    index_0 = pattern_dict[top_index][0]
    index_1 = pattern_dict[top_index][1]
    reversed_sample = reverse_segment(ts, index_0, index_1)
    # ts = ts.reshape([1, ts.shape[0], 1])
    ts = ts[np.newaxis, :][:, np.newaxis, :]
    # reversed_sample = reversed_sample.reshape([1, reversed_sample.shape[0], 1])
    reversed_sample = reversed_sample[np.newaxis, :][:, np.newaxis, :]
    predictions = model(torch.Tensor(ts).to(device))
    predictions = torch.nn.Softmax(dim=1)(predictions)
    reversed_predictions = model(torch.Tensor(np.asarray(reversed_sample)).to(device))
    reversed_predictions = torch.nn.Softmax(dim=1)(reversed_predictions)
    change_score = torch.abs(reversed_predictions[0][label] - predictions[0][label])
    return change_score.cpu().item(),top_index


def discrimination_gain_cal(index, example, pattern_dict, sorted_key_list):
    key_1 = index
    # sorted_key_list = sorted(list(pattern_dict.keys()))
    index_key_1 = sorted_key_list.index(key_1)
    if index_key_1 - 1 >= 0:
        key_0 = sorted_key_list[index_key_1 - 1]
    else:
        key_0 = MIN_VALUE
    if index_key_1 + 1 in range(len(sorted_key_list)):
        key_2 = sorted_key_list[index_key_1 + 1]
    else:
        key_2 = MIN_VALUE
    if index_key_1 + 2 in range(len(sorted_key_list)):
        key_3 = sorted_key_list[index_key_1 + 2]
    else:
        key_3 = MIN_VALUE
    # key_2, key_0, key_3 = key_1 + 1, key_1 -1, key_1 + 2
    subsequence_1 = example[pattern_dict[key_1][0]:pattern_dict[key_1][1]]    
    subsequence_2 = example[pattern_dict[key_2][0]:pattern_dict[key_2][1]]
    subsequence_12_index = segment_union(pattern_dict[key_1], pattern_dict[key_2])
    subsequence_12 = example[subsequence_12_index[0]: subsequence_12_index[1]]
    if key_0 in pattern_dict:
        subsequence_0 = example[pattern_dict[key_0][0]:pattern_dict[key_0][1]]
        d_0_1 = hausdorff_distance(np.expand_dims(subsequence_0, axis=1), np.expand_dims(subsequence_1, axis=1), distance='euclidean')
        d_0_12 = hausdorff_distance(np.expand_dims(subsequence_0, axis=1), np.expand_dims(subsequence_12, axis=1), distance='euclidean')
    else:
        d_0_1 = 0
        d_0_12 = 0
    if key_3 in pattern_dict:
        subsequence_3 = example[pattern_dict[key_3][0]:pattern_dict[key_3][1]]
        d_2_3 = hausdorff_distance(np.expand_dims(subsequence_2, axis=1), np.expand_dims(subsequence_3, axis=1), distance='euclidean')
        d_12_3 = hausdorff_distance(np.expand_dims(subsequence_12, axis=1), np.expand_dims(subsequence_3, axis=1), distance='euclidean')
    else:
        d_2_3 = 0
        d_12_3 = 0
    d_1_2 = hausdorff_distance(np.expand_dims(subsequence_1, axis=1), np.expand_dims(subsequence_2, axis=1), distance='euclidean')
    discrimination_gain = 3*(d_0_12 + d_12_3) - 2*(d_0_1 + d_1_2 + d_2_3)
    # if key_1 == 4:
        # print(4, pattern_dict, discrimination_gain, d_2_3, key_3)
    return discrimination_gain

def top_segments_overlap_cA(scores, segments, pos_set):
    top_keys = sorted(scores, key=scores.get, reverse=True)[:2]
    pos_set_1 = set()
    for top_key in top_keys:
        segment = segments[top_key]
        # print(segment) 
        segment[0] = segment[0]
        segment[1] = segment[1]
        pos_set_1.update(range(segment[0], segment[1]))
    intersection = len(pos_set_1 & pos_set)  # Size of intersection
    union = len(pos_set_1 | pos_set)         # Size of union
    jaccard_coefficient = intersection / union
    return jaccard_coefficient


# def Pytorch_IGSegment_Fadix_Overlap_WithLeader(example, model, label, model_type='class', distance='dtw', n=100, window_size=None, step_size=None, f=None, part_lengths=None, position_set=None, pertubation=None):
#     model.eval()
#     torch_example = example.clone()
#     example = example.flatten().cpu().numpy().astype(np.float64)
    
#     # Itegrated Gradient
#     IG = IntegratedGradients(model)
#     # print(torch_example.shape)
#     # torch_example = torch_example.unsqueeze(dim=0).to(device)
#     # print(torch_example.shape)
#     # exit()
#     baseline = torch.Tensor(pertubation).unsqueeze(0).unsqueeze(0).to(device)
#     if isinstance(model.model, BiLSTMModel):
#         model.train()
#     attr = IG.attribute(torch_example, baselines= baseline, target = label, n_steps=50, method='gausslegendre')
#     model.eval()
#     # attr = torch.abs(attr)
#     attr = attr.flatten()
#     attr_numpy = attr.cpu().numpy().astype(np.float64).reshape(attr.shape[0])

#     explain_cAcD_sets = []
    
#     example =  example.reshape(example.shape[0])
#     for order in range(2):
#         if order == 0: # process cA
#             start_index, end_index = 0, part_lengths[0]
#         else:
#             start_index, end_index = len(example)- part_lengths[-1], len(example)
#         exp_patterns = Intial_Segment(example, window_size=window_size, step_size=step_size, index_0=start_index, index_1=end_index)
#         # print(exp_patterns)
#         # exit()
#         explains, _ = get_exp_scores(attr, pattern_dict=exp_patterns)
#         # print(start_index, end_index, exp_patterns, explains)
#         # continue
#         faithfullness_score, _ = attr_instance_faithfulness(example, label, model, {key: abs(val) for key, val in explains.items()}, exp_patterns, pertubation=pertubation)
#         if order == 0:
#             overlap_score = top_segments_overlap_cA({key: abs(val) for key, val in explains.items()}, exp_patterns, position_set)
#         Overlap_gains = {}
#         Score_gains = {}
#         structure_list = {}
#         explain_list = {}
#         discrimination_gains = {}
#         sorted_key_list = sorted(list(exp_patterns.keys()))
#         # average_scores_dict = {}
#         ######################### Initial all explain dictionar (scores and segments) first!!! ########################
#         for index in exp_patterns:
#             if index >= len(exp_patterns)-1:
#                 break
#             if index not in exp_patterns:
#                 continue
#             new_exp_patterns = exp_patterns.copy()
#             index_of_index = sorted_key_list.index(index)
#             new_exp_patterns[index] = segment_union(new_exp_patterns[index], new_exp_patterns[sorted_key_list[index_of_index+1]])
#             del new_exp_patterns[sorted_key_list[index_of_index+1]]
#             # print(new_exp_patterns)
#             new_explains, _ = get_exp_scores(attr, pattern_dict=new_exp_patterns)
#             # print(new_average_scores)
#             new_change_score, _ = attr_instance_faithfulness(example, label, model, {key: abs(val) for key, val in new_explains.items()}, new_exp_patterns, pertubation)
#             score_gain = new_change_score - faithfullness_score

#             if order == 0:
#                 new_overlap_score = top_segments_overlap_cA({key: abs(val) for key, val in new_explains.items()}, new_exp_patterns, position_set)
#                 overlap_gain = new_overlap_score - overlap_score
#                 Overlap_gains[index] = overlap_gain
            
#             Score_gains[index] = score_gain
#             structure_list[index] = new_exp_patterns
#             explain_list[index] = new_explains
#             # average_scores_dict[index] = new_average_scores

#             discrimination_gain = discrimination_gain_cal(index, example, pattern_dict=exp_patterns, sorted_key_list=sorted_key_list)
#             discrimination_gains[index] = discrimination_gain
#         # print(structure_list)
#         # exit()
        
#         d_top_index = -1
#         # repeat = 0
#         for i in range(100):
#             # print("iteration ...", i)
#             if len(exp_patterns) == 2:
#                 break
#             max_key = max(list(exp_patterns.keys()))
#             sorted_key_list = sorted(list(exp_patterns.keys())) # update sorted_key_list
#             # Score_gains = {}
#             for index in exp_patterns:
#                 if index >= max_key:
#                     break
                
#                 new_exp_patterns = exp_patterns.copy()
#                 index_of_index = sorted_key_list.index(index)
#                 new_exp_patterns[index] = segment_union(new_exp_patterns[index],  new_exp_patterns[sorted_key_list[index_of_index+1]])

#                 del new_exp_patterns[sorted_key_list[index_of_index+1]]
                
#                 new_explains = explains.copy()
#                 new_explains[index] = get_segment_score(attr, new_exp_patterns[index][0], new_exp_patterns[index][1])
#                 # new_explains[index] = abs(new_explains[index] + new_explains[sorted_key_list[index_of_index+1]])
#                 del new_explains[sorted_key_list[index_of_index+1]]
                            
#                 structure_list[index] = new_exp_patterns
#                 explain_list[index] = new_explains
#                 # average_scores_dict[index] = new_average_scores
#                 discrimination_gain = discrimination_gain_cal(index=index, example=example, pattern_dict=exp_patterns, sorted_key_list=sorted_key_list)
#                 discrimination_gains[index] = discrimination_gain

#                 abs_new_explains = {key: abs(val) for key, val in new_explains.items()}
#                 # new_top_index = max(abs_new_explains.items(), key=operator.itemgetter(1))[0]

#                 # if new_top_index != top_index or (new_top_index==index and top_index==index):
#                 new_change_score, _ = attr_instance_faithfulness(example, label, model, {key: abs(val) for key, val in new_explains.items()}, new_exp_patterns, pertubation)
#                 Score_gains[index] = new_change_score - faithfullness_score

#                 if order == 0:
#                     new_overlap_score = top_segments_overlap_cA({key: abs(val) for key, val in new_explains.items()}, new_exp_patterns, position_set)
#                     # print(new_overlap_score, overlap_score)
#                     overlap_gain = new_overlap_score - overlap_score
#                     Overlap_gains[index] = overlap_gain

#             max_value = MIN_VALUE
#             prev_d_top = d_top_index
#             d_top_index = -1
#             # print(Overlap_gains)
#             # exit()
#             if order == 0: 
#                 for index, value in discrimination_gains.items():
#                     if value + 1000*Overlap_gains[index] > max_value and value>=0 and Score_gains[index]>0:
#                     # if value > max_value and value>=0:
#                         max_value = value
#                         d_top_index = index
#             else:
#                 for index, value in discrimination_gains.items():
#                     if value > max_value and value>=0 and Score_gains[index]>0:
#                     # if value > max_value and value>=0:
#                         max_value = value
#                         d_top_index = index

#             if d_top_index == -1:
#                 break

#             d_top_index_index = sorted_key_list.index(d_top_index)
#             exp_patterns = structure_list[d_top_index]
#             explains = explain_list[d_top_index]
            
#             if d_top_index_index < len(sorted_key_list)-1:
#                 if sorted_key_list[d_top_index_index+1] in structure_list:
#                     del structure_list[sorted_key_list[d_top_index_index+1]]
#                 if sorted_key_list[d_top_index_index+1] in explain_list:
#                     del explain_list[sorted_key_list[d_top_index_index+1]]
#             else:
#                 del structure_list[sorted_key_list[d_top_index_index]]
#                 del explain_list[sorted_key_list[d_top_index_index]]
#                 del discrimination_gains[d_top_index]

#             # Update matrix
#             if d_top_index_index < len(sorted_key_list)-1:
#                 if sorted_key_list[d_top_index_index+1] in discrimination_gains:
#                     del discrimination_gains[sorted_key_list[d_top_index_index+1]]
            
#             sorted_key_list = sorted(list(exp_patterns.keys())) # update sorted_key_list
#             d_top_index_index = sorted_key_list.index(d_top_index)

#             faithfullness_score, _ = attr_instance_faithfulness(example, label, model, {key: abs(val) for key, val in explains.items()}, exp_patterns, pertubation)
#             if order == 0:
#                 overlap_score = top_segments_overlap_cA({key: abs(val) for key, val in explains.items()}, exp_patterns, position_set)
#         # print(exp_patterns, explains)
#         explain_cAcD_sets.append((exp_patterns, explains))
#     cA_explain = explain_cAcD_sets[0]
#     cD_explain = explain_cAcD_sets[1]
#     cA_segments, cA_scores = cA_explain[0], cA_explain[1]
#     cA_segments = {f"cA_{key}": value for key, value in cA_segments.items()}
#     cA_scores = {f"cA_{key}": value for key, value in cA_scores.items()}
#     cD_segments, cD_scores = cD_explain[0], cD_explain[1]
#     cD_segments = {f"cD_{key}": value for key, value in cD_segments.items()}
#     cD_scores = {f"cD_{key}": value for key, value in cD_scores.items()}
#     merged_segments = cA_segments | cD_segments
#     merged_scores = cA_scores | cD_scores
#     # print(exp_patterns)
#     # print(explains)
#     # exit()
#     # cp_indexes = []
#     # for key, value in exp_patterns.items():
#     #     cp_indexes.append(value[0])
#     # print(cp_indexes, explains)
#     # exit()
#     return merged_scores, merged_segments


def Pytorch_IGSegment_Fadix_Overlap_WithLeader_2(example, model, label, window_size=None, step_size=None, f=None, part_lengths=None, position_set=None, pertub_serie=None, wavelet_level=0):
    model.eval()
    torch_example = example.clone()
    example = example.flatten().cpu().numpy().astype(np.float64)
    
    # Itegrated Gradient
    IG = IntegratedGradients(model)
    baseline = torch.Tensor(pertub_serie).unsqueeze(0).unsqueeze(0).to(device)
    if len(position_set)>0:
        # print(torch_example.shape)
        # torch_example = torch_example.unsqueeze(dim=0).to(device)
        mask = torch.zeros(baseline.shape).to(device)
        # for pos in position_set:
        # print(mask.shape, position_set)
        mask[:, :, list(position_set)] = 1
        # print(mask)
        # print(position_set)
        torch_example_1 = mask*torch_example + (1-mask)*baseline
        # print(torch_example)
        # print(torch_example_1)
        # exit()

        if isinstance(model.model, BiLSTMModel):
            model.model.train()
        # try:
        attr_1 = IG.attribute(torch_example_1, baselines= baseline, target = label, n_steps=50, method='gausslegendre')
        # except Exception as e:
        #     print(f"Error: {type(e).__name__} - {e}")
        #     print("This always runs, error or not.", torch_example_1.shape, baseline.shape, wavelet_level)
        #     exit()
        attr_2 = IG.attribute(torch_example, baselines= torch_example_1, target = label, n_steps=50, method='gausslegendre')
        attr = attr_1 + attr_2
        # print(attr_1, attr_2)
        # print(torch_example, torch_example_1, mask)
        # exit()
    else:
        if isinstance(model.model, BiLSTMModel):
            model.train()
        attr = IG.attribute(torch_example, baselines= baseline, target = label, n_steps=50, method='gausslegendre')
    if isinstance(model.model, BiLSTMModel):
            model.eval()
    model.eval()
    # attr = torch.abs(attr)
    attr = attr.flatten()
    attr_numpy = attr.cpu().numpy().astype(np.float64).reshape(attr.shape[0])

    explain_cAcD_sets = []
    
    example =  example.reshape(example.shape[0])
    for order in range(2):
        if order == 0: # process cA
            start_index, end_index = 0, part_lengths[0]
        else:
            start_index, end_index = len(example)- part_lengths[-1], len(example)
        exp_patterns = Intial_Segment(example, window_size=window_size, step_size=step_size, index_0=start_index, index_1=end_index)
        # print(exp_patterns)
        # exit()
        explains, _ = get_exp_scores(attr, pattern_dict=exp_patterns)
        # print(start_index, end_index, exp_patterns, explains)
        # continue
        # faithfullness_score, _ = multithresholds_attr_instance_faithfulness(example, label, model, {key: abs(val) for key, val in explains.items()}, exp_patterns, pertub_serie)
        # if order == 0:
            # overlap_score = top_segments_overlap_cA({key: abs(val) for key, val in explains.items()}, exp_patterns, position_set)
        # Overlap_gains = {}
        # Score_gains = {}
        structure_list = {}
        explain_list = {}
        discrimination_gains = {}
        sorted_key_list = sorted(list(exp_patterns.keys()))
        # average_scores_dict = {}
        ######################### Initial all explain dictionar (scores and segments) first!!! ########################
        for index in exp_patterns:
            if index >= len(exp_patterns)-1:
                break
            if index not in exp_patterns:
                continue
            new_exp_patterns = exp_patterns.copy()
            index_of_index = sorted_key_list.index(index)
            new_exp_patterns[index] = segment_union(new_exp_patterns[index], new_exp_patterns[sorted_key_list[index_of_index+1]])
            del new_exp_patterns[sorted_key_list[index_of_index+1]]
            # print(new_exp_patterns)
            new_explains, _ = get_exp_scores(attr, pattern_dict=new_exp_patterns)
            # print(new_average_scores)
            # new_change_score, _ = multithresholds_attr_instance_faithfulness(example, label, model, {key: abs(val) for key, val in new_explains.items()}, new_exp_patterns, pertub_serie)
            # score_gain = new_change_score - faithfullness_score

            # if order == 0:
            #     new_overlap_score = top_segments_overlap_cA({key: abs(val) for key, val in new_explains.items()}, new_exp_patterns, position_set)
            #     overlap_gain = new_overlap_score - overlap_score
            #     Overlap_gains[index] = overlap_gain
            # 
            # Score_gains[index] = score_gain
            structure_list[index] = new_exp_patterns
            explain_list[index] = new_explains
            # average_scores_dict[index] = new_average_scores

            # discrimination_gain = discrimination_gain_cal(index, example, pattern_dict=exp_patterns, sorted_key_list=sorted_key_list)
            # discrimination_gains[index] = discrimination_gain
        explain_cAcD_sets.append((exp_patterns, explains))
    
    cA_explain = explain_cAcD_sets[0]
    cD_explain = explain_cAcD_sets[1]
    cA_segments, cA_scores = cA_explain[0], cA_explain[1]
    cA_segments = {f"cA_{key}": value for key, value in cA_segments.items()}
    cA_scores = {f"cA_{key}": value for key, value in cA_scores.items()}
    cD_segments, cD_scores = cD_explain[0], cD_explain[1]
    cD_segments = {f"cD_{key}": value for key, value in cD_segments.items()}
    cD_scores = {f"cD_{key}": value for key, value in cD_scores.items()}
    merged_segments = cA_segments | cD_segments
    merged_scores = cA_scores | cD_scores
    # print(exp_patterns)
    # print(explains)
    # exit()
    # cp_indexes = []
    # for key, value in exp_patterns.items():
    #     cp_indexes.append(value[0])
    # print(cp_indexes, explains)
    # exit()
    return (merged_scores, merged_segments), attr


def backgroundIdentification(original_signal,f=40):
    f, t, Zxx = signal.stft(original_signal.reshape(original_signal.shape[0]),1,nperseg=f)
    frequency_composition_abs = np.abs(Zxx)
    measures = []
    for freq,freq_composition in zip(f,frequency_composition_abs):
        measures.append(np.mean(freq_composition)/np.std(freq_composition))
    max_value = max(measures)
    selected_frequency = measures.index(max_value)
    weights = 1-(measures/sum(measures))
    dummymatrix = np.zeros((len(f),len(t)))
    dummymatrix[selected_frequency,:] = 1  
    #Option to admit information from other frequency bands
    """dummymatrix = np.ones((len(f),len(t)))
    for i in range(0,len(weights)):
        dummymatrix[i,:] = dummymatrix[i,:] * weights[i]"""
    
    background_frequency = Zxx * dummymatrix
    _, xrec = signal.istft(background_frequency, 1)
    xrec = xrec[:original_signal.shape[0]]
    xrec = xrec.reshape(original_signal.shape)
    return xrec



