import matplotlib.pyplot as plt  # for display purposes

from scipy import signal
import numpy as np
# import stumpy
from sklearn.linear_model import Ridge
# from fastdtw import fastdtw
import random
import torch
from utils.models import resnet34, BiLSTMModel, TransformerModel
from scipy.integrate import simpson

from captum.attr import IntegratedGradients, Saliency
from hausdorff import hausdorff_distance
# import pyhomogeneity as hg
MIN_NUM_SEGMENTS = 2
MIN_VALUE = -1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def explain2attributionmap(visual_instance, segment_explain):
    attribution_map = np.zeros((visual_instance.shape[0]))
    segment_map = segment_explain[1]
    scores_map = segment_explain[0]
    for k in segment_map:
        segment_length = segment_map[k][1] - segment_map[k][0]
        for h in range(segment_map[k][0], segment_map[k][1]):
            attribution_map[h] = max(attribution_map[h], abs(scores_map[k]))
    return attribution_map

def segment_union(segment_1, segment_2):
    start_1, end_1 = segment_1
    start_2, end_2 = segment_2
    return [min(start_1, start_2), max(end_1, end_2)]


def Intial_Segment(t, window_size, step_size, index_0, index_1):
    # length = len(t)/(num_intervals)
    # cp_indexes = []
    # for i in range(num_intervals):
    #     cp_indexes.append(int(i*length))
    # cp_indexes = cp_indexes + [len(t)]
    # interval_dict = {}
    # for i in range(len(cp_indexes)-1):
    #     interval_dict[i] = [cp_indexes[i], cp_indexes[i+1]]
    length = index_1 - index_0
    starts = []
    for i in range(index_0, index_1, step_size):
        starts.append(i)
    ends = []
    for index, start in enumerate(starts):
        if start + window_size >= index_1:
            ends.append(index_1)
        else:
            ends.append(start + window_size)
    interval_dict = {}
    for i in range(len(starts)):
        interval_dict[i] = [starts[i], ends[i]]
    return interval_dict

def get_segment_score(exp_matrix, index_0, index_1):
    score = torch.sum(exp_matrix[index_0:index_1]).cpu().item()
    # explantion = {}
    # average_exp = {}
    # for key in pattern_dict:
    #     start, end = pattern_dict[key][0], pattern_dict[key][1]
    #     score = torch.sum(exp_matrix[start:end])
    #     mean = torch.mean(exp_matrix[start:end])
    #     # score = torch.abs(score) since update two consesutive the A+B become |A| + |B| ! in line 184
    #     # mean = torch.abs(mean)
    #     explantion[key] = score.cpu().item()
    #     average_exp[key] = mean.cpu().item()
    # print(explantion)
    # print(average_exp)
    # exit()
    # explantion = {key: abs(val) for key, val in explantion.items()}
    # average_exp = {key: abs(val) for key, val in average_exp.items()}
    return score

def get_exp_scores(exp_matrix, pattern_dict):
    explantion = {}
    # average_exp = {}
    for key in pattern_dict:
        start, end = pattern_dict[key][0], pattern_dict[key][1]
        score = torch.sum(exp_matrix[start:end])
        # mean = torch.mean(exp_matrix[start:end])
        # score = torch.abs(score) since update two consesutive the A+B become |A| + |B| ! in line 184
        # mean = torch.abs(mean)
        explantion[key] = score.cpu().item()
        # average_exp[key] = mean.cpu().item()
    # print(explantion)
    # print(average_exp)
    # exit()
    # explantion = {key: abs(val) for key, val in explantion.items()}
    # average_exp = {key: abs(val) for key, val in average_exp.items()}
    return explantion

def reverse_segment(ts, index0, index1):
    perturbed_ts = ts.copy()
    # Shuffle the subset in-place
    # subset = perturbed_ts[index0:index1]
    # np.random.shuffle(subset)
    perturbed_ts[index0:index1] = np.flip(ts[index0:index1])
    # perturbed_ts[index0:index1] = subset
    # print(subset)
    # exit()
    return perturbed_ts

def perturb_segment(ts, indices, pertubation):
    perturbed_ts = ts.copy()
    perturbed_ts[indices] = pertubation[indices]
    return perturbed_ts

def perturb_segment_with_mask(ts, mask, pertubation):
    perturbed_ts = ts.copy()
    perturbed_ts = mask*ts + (1-mask)*pertubation
    return perturbed_ts

import operator
# def instance_faithfulness(ts, label, model, explain, pattern_dict):
#     top_index = max(explain.items(), key=operator.itemgetter(1))[0]
#     index_0 = pattern_dict[top_index][0]
#     index_1 = pattern_dict[top_index][1]
#     reversed_sample = reverse_segment(ts, index_0, index_1)
#     # ts = ts.reshape([1, ts.shape[0], 1])
#     ts = ts[np.newaxis, :][:, np.newaxis, :]
#     # reversed_sample = reversed_sample.reshape([1, reversed_sample.shape[0], 1])
#     reversed_sample = reversed_sample[np.newaxis, :][:, np.newaxis, :]
#     predictions = model(torch.Tensor(ts).to(device))
#     predictions = torch.nn.Softmax(dim=1)(predictions)
#     reversed_predictions = model(torch.Tensor(np.asarray(reversed_sample)).to(device))
#     reversed_predictions = torch.nn.Softmax(dim=1)(reversed_predictions)
#     change_score = torch.abs(reversed_predictions[0][label] - predictions[0][label])
#     return change_score.cpu().item(),top_index

def twostage_attr_instance_faithfulness(ts, label, model, explain, pattern_dict, perturbation):
    # num_top_features = 20
    # top_index = max(explain.items(), key=operator.itemgetter(1))[0]
    # index_0 = pattern_dict[top_index][0]
    # index_1 = pattern_dict[top_index][1]
    attribution = explain2attributionmap(ts, (explain, pattern_dict))
    # print(attribution, explain, pattern_dict, num_top_features)
    # exit()
    feature_order = np.argsort(attribution)
    num_top_features = int(len(attribution)*0.1)
    top_features_10 = feature_order[(0-num_top_features):]
    num_top_features = int(len(attribution)*0.05)
    top_features_5 = feature_order[(0-num_top_features):]
    pertubed_sample_10 = perturb_segment(ts, top_features_10, perturbation)
    pertubed_sample_5 = perturb_segment(ts, top_features_5, perturbation)
    # ts = ts.reshape([1, ts.shape[0], 1])
    ts = ts[np.newaxis, :][:, np.newaxis, :]
    # reversed_sample = reversed_sample.reshape([1, reversed_sample.shape[0], 1])
    pertubed_sample_10 = pertubed_sample_10[np.newaxis, :][:, np.newaxis, :]
    pertubed_sample_10 = torch.Tensor(np.asarray(pertubed_sample_10))
    pertubed_sample_5 = pertubed_sample_5[np.newaxis, :][:, np.newaxis, :]
    pertubed_sample_5 = torch.Tensor(np.asarray(pertubed_sample_5))
    pertubed_samples = torch.concat([pertubed_sample_10, pertubed_sample_5], dim=0)
    original_ts = torch.Tensor(ts)
    samples = torch.concat([original_ts, pertubed_samples], dim=0)
    
    # predictions = model(torch.Tensor(ts).to(device))
    # predictions = torch.nn.Softmax(dim=1)(predictions)
    predictions = model(samples.to(device))
    predictions = torch.nn.Softmax(dim=1)(predictions)
    # print(predictions, label)
    # exit()
    change_score = predictions[0][label] - predictions[1][label] + predictions[0][label] - predictions[2][label]
    return change_score.cpu().item(), top_features_10

def multithresholds_attr_instance_faithfulness(ts, label, model, explain, pattern_dict, perturbation):
    attribution = explain2attributionmap(ts, (explain, pattern_dict))
    feature_order = np.argsort(attribution)
    stages = [0.05, 0.1, 0.2]
    # stages = [0.05, 0.1, 0.2, 0.3]
    perturbed_samples = []
    samples = []
    for stage in stages:
        num_top_features = int(len(feature_order)*stage)
        top_features = feature_order[(0-num_top_features):]
        # print("stage ...", stage, ts.shape, top_features, len(feature_order))
        perturbed_sample = perturb_segment(ts, top_features, perturbation)
        # print(stage, perturbed_sample.flatten(), ts.flatten())
        # print("stage after perturbed ...", stage, ts.shape)
        perturbed_sample = perturbed_sample[np.newaxis, :][:, np.newaxis, :]
        perturbed_sample = torch.Tensor(np.asarray(perturbed_sample))
        perturbed_samples.append(perturbed_sample)
        original_ts = torch.Tensor(ts[np.newaxis, :][:, np.newaxis, :])
        samples.append(original_ts)
    # exit()
    
        
    perturbed_samples = torch.cat(perturbed_samples, dim=0)
    samples = torch.cat(samples, dim=0)
    # print(perturbed_samples.shape)
    # exit()
    predictions = model(samples.to(device))
    predictions = torch.nn.Softmax(dim=1)(predictions)
    perturb_predictions = model(perturbed_samples.to(device))
    perturb_predictions = torch.nn.Softmax(dim=1)(perturb_predictions)
    # print(predictions, label)
    # print(perturb_predictions, label)
    # exit()
    change_score = torch.mean(predictions[:, label] - perturb_predictions[:, label])
    # print(change_score.shape)
    # exit()
    return change_score.cpu().item(), 0

def twostage_attrmap_instance_faithfulness(ts, label, model, attr_map, perturbation):
    num_top_features = int(len(attr_map)*0.1)
    # num_top_features = 20
    # top_index = max(explain.items(), key=operator.itemgetter(1))[0]
    # index_0 = pattern_dict[top_index][0]
    # index_1 = pattern_dict[top_index][1]
    # attribution = explain2attributionmap(ts, (explain, pattern_dict))
    # print(attribution, explain, pattern_dict, num_top_features)
    # exit()
    feature_order = np.argsort(attr_map)
    top_features = feature_order[(0-num_top_features):]
    reversed_sample = perturb_segment(ts, top_features, perturbation)
    # ts = ts.reshape([1, ts.shape[0], 1])
    ts = ts[np.newaxis, :][:, np.newaxis, :]
    # reversed_sample = reversed_sample.reshape([1, reversed_sample.shape[0], 1])
    reversed_sample = reversed_sample[np.newaxis, :][:, np.newaxis, :]
    predictions = model(torch.Tensor(ts).to(device))
    predictions = torch.nn.Softmax(dim=1)(predictions)
    reversed_predictions = model(torch.Tensor(np.asarray(reversed_sample)).to(device))
    reversed_predictions = torch.nn.Softmax(dim=1)(reversed_predictions)
    change_score =  predictions[0][label] - reversed_predictions[0][label]
    return change_score.cpu().item(),top_features

def multithresholds_attrmap_instance_faithfulness(ts, label, model, attr_map, perturbation):
    # attribution = explain2attributionmap(ts, (explain, pattern_dict))
    feature_order = np.argsort(attr_map)
    stages = [0.05, 0.1, 0.2]
    # stages = [0.05, 0.1, 0.2, 0.3]
    perturbed_samples = []
    samples = []
    for stage in stages:
        num_top_features = int(len(feature_order)*stage)
        top_features = feature_order[(0-num_top_features):]
        # print("stage ...", stage, ts.shape, top_features, len(feature_order))
        perturbed_sample = perturb_segment(ts, top_features, perturbation)
        # print(stage, perturbed_sample.flatten(), ts.flatten())
        # print("stage after perturbed ...", stage, ts.shape)
        perturbed_sample = perturbed_sample[np.newaxis, :][:, np.newaxis, :]
        perturbed_sample = torch.Tensor(np.asarray(perturbed_sample))
        perturbed_samples.append(perturbed_sample)
        original_ts = torch.Tensor(ts[np.newaxis, :][:, np.newaxis, :])
        samples.append(original_ts)
    # exit()
    
        
    perturbed_samples = torch.cat(perturbed_samples, dim=0)
    samples = torch.cat(samples, dim=0)
    # print(perturbed_samples.shape)
    # exit()
    predictions = model(samples.to(device))
    predictions = torch.nn.Softmax(dim=1)(predictions)
    perturb_predictions = model(perturbed_samples.to(device))
    perturb_predictions = torch.nn.Softmax(dim=1)(perturb_predictions)
    # print(predictions, label)
    # print(perturb_predictions, label)
    # exit()
    change_score = torch.mean(predictions[:, label] - perturb_predictions[:, label])
    # print(change_score.shape)
    # exit()
    return change_score.cpu().item(), 0

def AUCStop_attrmap_instance(ts, label, model, attr_map, perturbation):
    # attribution = explain2attributionmap(ts, (explain, pattern_dict))
    feature_order = np.argsort(attr_map)
    ratios = [0.05, 0.10, 0.15, 0.20]
    # ratios = [0.05, 0.10, 0.15]
    # stages = [0.05, 0.1, 0.2, 0.3]
    perturbed_samples = []
    samples = []
    n_features_ratios = []
    # print(attr_map)
    for ratio in ratios:
        threshold = np.quantile(attr_map, 1-ratio)
        # print("threshold", threshold)
        mask = (attr_map>=threshold).astype(int)
        mask = 1- mask
        n_features_ratios.append(float(np.sum(mask==0)/mask.shape[0]))
        perturbed_sample = perturb_segment_with_mask(ts, mask, perturbation)
        # print(perturbed_sample.shape, ts.shape)
        # exit()
        # print(stage, perturbed_sample.flatten(), ts.flatten())
        # print("stage after perturbed ...", stage, ts.shape)
        perturbed_sample = perturbed_sample[np.newaxis, :][:, np.newaxis, :]
        perturbed_sample = torch.Tensor(np.asarray(perturbed_sample))
        perturbed_samples.append(perturbed_sample)
        original_ts = torch.Tensor(ts[np.newaxis, :][:, np.newaxis, :])
        samples.append(original_ts)
    # exit()
    
        
    perturbed_samples = torch.cat(perturbed_samples, dim=0)
    samples = torch.cat(samples, dim=0)
    # print(perturbed_samples.shape, perturbation.shape)
    # exit()
    with torch.no_grad():
        predictions = model(samples.to(device))
        predictions = torch.nn.Softmax(dim=1)(predictions)
        perturb_predictions = model(perturbed_samples.to(device))
        perturb_predictions = torch.nn.Softmax(dim=1)(perturb_predictions)
    # print(predictions, label)
    # print(perturb_predictions, label)
    # exit()
    original_pred = predictions[0, label].cpu().item()
    change_scores = ((predictions[:, label] - perturb_predictions[:, label])/original_pred).cpu().numpy()
    AUCStop_score = simpson(change_scores, np.array(n_features_ratios)) 
    # print(change_scores, n_features_ratios, AUCStop_score, perturb_predictions[:, label].cpu().numpy(), original_pred)
    # exit()
    return AUCStop_score, 0


def discrimination_gain_cal(index, example, pattern_dict, sorted_key_list):
    key_1 = index
    # sorted_key_list = sorted(list(pattern_dict.keys()))
    index_key_1 = sorted_key_list.index(key_1)
    if index_key_1 - 1 >= 0:
        key_0 = sorted_key_list[index_key_1 - 1]
    else:
        key_0 = MIN_VALUE
    if index_key_1 + 1 in range(len(sorted_key_list)):
        key_2 = sorted_key_list[index_key_1 + 1]
    else:
        key_2 = MIN_VALUE
    if index_key_1 + 2 in range(len(sorted_key_list)):
        key_3 = sorted_key_list[index_key_1 + 2]
    else:
        key_3 = MIN_VALUE
    # key_2, key_0, key_3 = key_1 + 1, key_1 -1, key_1 + 2
    subsequence_1 = example[pattern_dict[key_1][0]:pattern_dict[key_1][1]]    
    subsequence_2 = example[pattern_dict[key_2][0]:pattern_dict[key_2][1]]
    subsequence_12_index = segment_union(pattern_dict[key_1], pattern_dict[key_2])
    subsequence_12 = example[subsequence_12_index[0]: subsequence_12_index[1]]
    if key_0 in pattern_dict:
        subsequence_0 = example[pattern_dict[key_0][0]:pattern_dict[key_0][1]]
        d_0_1 = hausdorff_distance(np.expand_dims(subsequence_0, axis=1), np.expand_dims(subsequence_1, axis=1), distance='euclidean')
        d_0_12 = hausdorff_distance(np.expand_dims(subsequence_0, axis=1), np.expand_dims(subsequence_12, axis=1), distance='euclidean')
    else:
        d_0_1 = 0
        d_0_12 = 0
    if key_3 in pattern_dict:
        subsequence_3 = example[pattern_dict[key_3][0]:pattern_dict[key_3][1]]
        d_2_3 = hausdorff_distance(np.expand_dims(subsequence_2, axis=1), np.expand_dims(subsequence_3, axis=1), distance='euclidean')
        d_12_3 = hausdorff_distance(np.expand_dims(subsequence_12, axis=1), np.expand_dims(subsequence_3, axis=1), distance='euclidean')
    else:
        d_2_3 = 0
        d_12_3 = 0
    d_1_2 = hausdorff_distance(np.expand_dims(subsequence_1, axis=1), np.expand_dims(subsequence_2, axis=1), distance='euclidean')
    discrimination_gain = 3*(d_0_12 + d_12_3) - 2*(d_0_1 + d_1_2 + d_2_3)
    # if key_1 == 4:
        # print(4, pattern_dict, discrimination_gain, d_2_3, key_3)
    return discrimination_gain

def Pytorch_IGSegment_Fadix_Overlap(example, model, label, window_size=None, step_size=None, f=None, part_lengths=None, pertub_serie=None, wavelet_level=0):
    model.eval()
    torch_example = example.clone()
    example = example.flatten().cpu().numpy().astype(np.float64)
    
    # Itegrated Gradient
    IG = IntegratedGradients(model)

    baseline = torch.Tensor(pertub_serie).unsqueeze(0).unsqueeze(0).to(device)
    if isinstance(model.model, BiLSTMModel):
        model.train()
    attr = IG.attribute(torch_example, baselines= baseline, target = label, n_steps=50, method='gausslegendre')
    model.eval()
    # attr = torch.abs(attr)
    attr = attr.flatten()
    attr_numpy = attr.cpu().numpy().astype(np.float64).reshape(attr.shape[0])

    explain_cAcD_sets = []
    
    example =  example.reshape(example.shape[0])
    for order in range(2):
        if order == 0:
            start_index, end_index = 0, part_lengths[0]
        else:
            start_index, end_index = len(example)- part_lengths[-1], len(example)
        exp_patterns = Intial_Segment(example, window_size=window_size, step_size=step_size, index_0=start_index, index_1=end_index)
        explains = get_exp_scores(attr, pattern_dict=exp_patterns)
        structure_list = {}
        explain_list = {}
        discrimination_gains = {}
        sorted_key_list = sorted(list(exp_patterns.keys()))
        # average_scores_dict = {}
        for index in exp_patterns:
            if index >= len(exp_patterns)-1:
                break
            if index not in exp_patterns:
                continue
            new_exp_patterns = exp_patterns.copy()
            index_of_index = sorted_key_list.index(index)
            new_exp_patterns[index] = segment_union(new_exp_patterns[index], new_exp_patterns[sorted_key_list[index_of_index+1]])
            del new_exp_patterns[sorted_key_list[index_of_index+1]]
            new_explains= get_exp_scores(attr, pattern_dict=new_exp_patterns)
            
            # Score_gains[index] = score_gain
            structure_list[index] = new_exp_patterns
            explain_list[index] = new_explains

            # discrimination_gain = discrimination_gain_cal(index, example, pattern_dict=exp_patterns, sorted_key_list=sorted_key_list)
            # discrimination_gains[index] = discrimination_gain
        
        explain_cAcD_sets.append((exp_patterns, explains))
        # exit()
    
    cA_explain = explain_cAcD_sets[0]
    cD_explain = explain_cAcD_sets[1]
    cA_segments, cA_scores = cA_explain[0], cA_explain[1]
    cA_segments = {f"cA_{key}": value for key, value in cA_segments.items()}
    cA_scores = {f"cA_{key}": value for key, value in cA_scores.items()}
    cD_segments, cD_scores = cD_explain[0], cD_explain[1]
    cD_segments = {f"cD_{key}": value for key, value in cD_segments.items()}
    cD_scores = {f"cD_{key}": value for key, value in cD_scores.items()}
    merged_segments = cA_segments | cD_segments
    merged_scores = cA_scores | cD_scores
    # print(exp_patterns)
    # print(explains)
    # exit()
    # cp_indexes = []
    # for key, value in exp_patterns.items():
    #     cp_indexes.append(value[0])
    # print(cp_indexes, explains)
    # exit()
    return (merged_scores, merged_segments), torch.abs(attr)


def backgroundIdentification(original_signal,f=40):
    f, t, Zxx = signal.stft(original_signal.reshape(original_signal.shape[0]),1,nperseg=f)
    frequency_composition_abs = np.abs(Zxx)
    measures = []
    for freq,freq_composition in zip(f,frequency_composition_abs):
        measures.append(np.mean(freq_composition)/np.std(freq_composition))
    max_value = max(measures)
    selected_frequency = measures.index(max_value)
    weights = 1-(measures/sum(measures))
    dummymatrix = np.zeros((len(f),len(t)))
    dummymatrix[selected_frequency,:] = 1  
    #Option to admit information from other frequency bands
    """dummymatrix = np.ones((len(f),len(t)))
    for i in range(0,len(weights)):
        dummymatrix[i,:] = dummymatrix[i,:] * weights[i]"""
    
    background_frequency = Zxx * dummymatrix
    _, xrec = signal.istft(background_frequency, 1)
    xrec = xrec[:original_signal.shape[0]]
    xrec = xrec.reshape(original_signal.shape)
    return xrec



