import torch
import torch.nn as nn
from tqdm import tqdm
from utils.funcs import *

class Occulusion_base_Time_Weight:
    def __init__(self,model,input_image,img_label,task="C",device="cuda"):
        super(Occulusion_base_Time_Weight, self).__init__()
        self.model = model
        self.input_image = input_image
        self.img_label = img_label
        self.time_weight_pyramid = []
        self.task = task
        eps = 1e-7

        if self.task == "C":

            with torch.no_grad():
                original_response = nn.Softmax(dim=1)(model(input_image.to(device)))[0][img_label].item()

        elif self.task == "R":
            original_response = model(input_image.to(device)).detach().cpu().item()

        window_size = 0
        self.len_step = input_image.shape[2]
        for step in tqdm(range(self.len_step), desc="Generating pyramid..."):
            window_size += 1
            tmp_time_weight = torch.zeros(self.len_step - (window_size - 1))

            if window_size == 1:
                for idx in range(len(tmp_time_weight)):
                    occulusion_operator = torch.ones(self.len_step, 1)
                    occulusion_operator[idx] = 0

                    tmp_time_weight[idx] = original_response - evaluate_time_weight(model, img_label,
                                                                                    occulusion_operator,
                                                                                    input_image,task=self.task,device=device).detach().item()

            else:
                for idx in range(len(tmp_time_weight)):
                    occulusion_operator = torch.ones(self.len_step, 1)
                    occulusion_operator[idx:idx + window_size] = 0
                    # occulusion_operator[idx:idx + window_size] = 0.5
                    tmp_time_weight[idx] = original_response - evaluate_time_weight(model, img_label,
                                                                                    occulusion_operator,
                                                                                    input_image,task=self.task,device=device).detach().item()

            if task=="R":
                tmp_time_weight=tmp_time_weight/(tmp_time_weight.sum()+eps)

            self.time_weight_pyramid.append(tmp_time_weight/(tmp_time_weight.sum()+eps))

            for i in range(len(self.time_weight_pyramid)):
                for j in range(len(self.time_weight_pyramid[i])):
                    if self.time_weight_pyramid[i][j] < 0:
                        self.time_weight_pyramid[i][j] = 0


    def Get_time_weight(self, max_step=3):
        eps=1e-7
        tmp_time_weight = torch.zeros(16)
        count = torch.zeros(16)
        for step in range(max_step + 1):

            if step == 0:
                tmp_time_weight = self.time_weight_pyramid[step].clone()
                count = count + 1
            else:
                window_size = step + 1
                for idx in range(len(self.time_weight_pyramid[step])):

                    tmp_time_weight[idx:idx + window_size] += self.time_weight_pyramid[step][idx]
                count += count_window_passes(16, window_size)

        Final_time_weight = tmp_time_weight / count


        return Final_time_weight/(torch.sum(Final_time_weight)+eps)

    def Get_time_weight_threshold(self,ratio=0.9,cmd="sharpness"):
        eps = 1e-7
        tmp_small_pyramid = []
        index=[]
        sel_met=torch.zeros(len(self.time_weight_pyramid))
        if cmd=="sharpness":
            for i in range(len(self.time_weight_pyramid)):
                sel_met[i]=calculate_sharpness(self.time_weight_pyramid[i].tolist())
        elif cmd=="std":
            for i in range(len(self.time_weight_pyramid)):
                sel_met[i]=np.std(np.array(self.time_weight_pyramid[i].tolist()))

        threshold=torch.max(sel_met)*ratio

        for i in range(len(self.time_weight_pyramid)):
            if calculate_sharpness(self.time_weight_pyramid[i]) >= threshold:
                tmp_small_pyramid.append(self.time_weight_pyramid[i].tolist())
                index.append(i)

        tmp_time_weight = torch.zeros(self.len_step)
        count = torch.zeros(self.len_step)

        for step in range(len(tmp_small_pyramid)):
            window_size = index[step] + 1
            for idx in range(len(tmp_small_pyramid[step])):
                tmp_time_weight[idx:idx + window_size] += tmp_small_pyramid[step][idx]
            count += count_window_passes(self.len_step, window_size)

        Final_time_weight = tmp_time_weight / count

        return Final_time_weight / (torch.sum(Final_time_weight) + eps),index


        eps=1e-7
        tmp_time_weight = torch.zeros(16)
        count = torch.zeros(16)
        for step in range(max_step + 1):

            if step == 0:
                tmp_time_weight = self.time_weight_pyramid[step].clone()
                count = count + 1
            else:
                window_size = step + 1
                for idx in range(len(self.time_weight_pyramid[step])):

                    tmp_time_weight[idx:idx + window_size+1] += self.time_weight_pyramid[step][idx]
                count += count_window_passes(16, window_size)

        Final_time_weight = tmp_time_weight / count


        return Final_time_weight/(torch.sum(Final_time_weight)+eps)

    def Get_custom_time_weight(self, arr):
        eps=1e-7
        count = torch.zeros(16)
        tmp_time_weight = torch.zeros(16)
        count = count + 1

        for step in arr:

            window_size = step + 1
            for idx in range(len(self.time_weight_pyramid[step])):

                tmp_time_weight[idx:idx + window_size+1] += self.time_weight_pyramid[step][idx]
            count += count_window_passes(16, window_size)

        Final_time_weight = tmp_time_weight / count


        return Final_time_weight/(torch.sum(Final_time_weight)+eps)