import torch
import matplotlib.pyplot as plt
from PIL import Image
import seaborn as sns
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import math

def visualize_mask(mask,path):
    # %matplotlib inline
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    data1 = np.round(mask.detach().cpu().squeeze().numpy(), decimals=2)
    color_map = sns.diverging_palette(10, 133, as_cmap=True)

    for i, ax in enumerate(axes.flat):
        sns.heatmap(data=data1[i], cmap=color_map, yticklabels=False, xticklabels=False, vmin=0, vmax=1, ax=ax,
                    cbar=False)
        ax.set_title(f"Frame={i}")

    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    fig.colorbar(axes[0, 0].collections[0], cax=cbar_ax)

    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.savefig(path, dpi=300)
    plt.show()


def visualize_preservation(mask,img,size):
    result_image = np.copy(img.squeeze().permute(1,2,3,0))
    data1=mask.detach().cpu().squeeze().numpy()
    for i in range(16):
        result_image[i]=result_image[i]*data1[i].reshape(size,size,1)
    fig, axes = plt.subplots(4, 4, figsize=(10, 10))
    for i, ax in enumerate(axes.flatten()):
        ax.imshow(Image.fromarray((result_image[i]*255).astype(np.uint8)))
        ax.axis('off')
        ax.set_title(f"Image {i}")
    plt.tight_layout()
    plt.show()


# import re
# def extract_mask_value(file_path):
#     with open(file_path, 'r') as file:
#         content = file.read()
#
#     mask_values = re.findall(r'Mask Value = tensor\((\[.*?\])\)', content)
#
#     mask_values = [eval(value) for value in mask_values]
#     mask_values = np.array(mask_values)
#
#     return mask_values

def evaluate_time_weight(model,label,time_weight,input_image,task="C",device="cuda"):
    eps=1e-7
    weighted_img = time_weight.reshape([1, 1, time_weight.shape[0], 1, 1]).to(device) * input_image

    pred=model(weighted_img.to(device))
    if task=="C":
        softmax=torch.nn.Softmax(dim=1)
        soft_pred=softmax(pred)
        return soft_pred[0][label]

    elif task=="R":
        return pred

def count_window_passes(array_length, window_size):
    passes = []
    for i in range(array_length):
        passes.append(min(i + 1, array_length - i, window_size))
    return torch.Tensor(passes)


def rescale_time_imp(tt, area,total_frame=16):

    current_tt_sum = sum(tt)

    highlight_frames = total_frame * area

    highlight_value = highlight_frames / current_tt_sum

    ratio = highlight_value / current_tt_sum

    scaled_tt = [value * ratio for value in tt]

    return scaled_tt


def calculate_sharpness(data):

    gradients = np.diff(data)
    sharpness = np.sum(np.abs(gradients))

    return sharpness

def process_activations(activations, targets, softmaxed=True):
    assert activations.shape[0] % targets.shape[0] == 0, \
        f"Check the batch size of activations and targets!"

    b_repeat = activations.shape[0] // targets.shape[0]
    if b_repeat > 1:
        targets = targets.repeat(b_repeat)

    if not softmaxed:
        soft_act = torch.nn.functional.softmax(activations, dim=1)
    else:
        soft_act = activations

    row_idx = torch.arange(
        soft_act.shape[0], dtype=torch.long, device=activations.device)
    probs = soft_act[row_idx, targets]  # batch_size
    pred_label_probs, pred_labels = torch.max(soft_act, dim=1)

    return probs, pred_labels, pred_label_probs





