import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import torch.optim as optim
import sys
sys.path.append(".")
sys.path.append("..")
from TIEM.TIS_calculator import Occulusion_base_Time_Weight
from torchray.utils import imsmooth, imsc
from torchray.attribution.common import resize_saliency
from tqdm import tqdm
from utils.funcs import *

BLUR_PERTURBATION = "blur"
FADE_PERTURBATION = "fade"

PRESERVE_VARIANT = "preserve"
DELETE_VARIANT = "delete"
DUAL_VARIANT = "dual"

def simple_log_reward(activation, target, variant):
    N = target.shape[0]
    bs = activation.shape[0]
    b_repeat = int( bs // N )
    device = activation.device

    col_idx = target.repeat(b_repeat) # batch_size
    row_idx = torch.arange(activation.shape[0], dtype=torch.long, device=device)   # batch_size
    prob = activation[row_idx, col_idx] # batch_size

    if variant == DELETE_VARIANT:
        reward = -torch.log(1-prob)
    elif variant == PRESERVE_VARIANT:
        reward = -torch.log(prob)
    elif variant == DUAL_VARIANT:
        reward = (-torch.log(1-prob[N:])) + (-torch.log(prob[:N]))
    else:
        assert False
    return reward

class MaskGenerator:
    def __init__(self, shape, step, sigma, batch_size=1, clamp=True, pooling_method='softmax'):
        self.shape = shape
        self.step = step
        self.sigma = sigma
        self.coldness = 20
        self.batch_size = batch_size
        self.clamp = clamp
        self.pooling_method = pooling_method

        assert int(step) == step

        self.kernel = lambda z: torch.exp(-2 * ((z - .5).clamp(min=0)**2))

        self.margin = self.sigma
        self.padding = 1 + math.ceil((self.margin + sigma) / step)
        self.radius = 1 + math.ceil(sigma / step)
        self.shape_in = [math.ceil(z / step) for z in self.shape]
        self.shape_mid = [
            z + 2 * self.padding - (2 * self.radius + 1) + 1
            for z in self.shape_in
        ]
        self.shape_up = [self.step * z for z in self.shape_mid]
        self.shape_out = [z - step + 1 for z in self.shape_up]

        step_inv = [
            torch.tensor(zm, dtype=torch.float32) /
            torch.tensor(zo, dtype=torch.float32)
            for zm, zo in zip(self.shape_mid, self.shape_up)
        ]


        self.weight = torch.zeros((
            1,
            (2 * self.radius + 1)**2,
            self.shape_out[0],
            self.shape_out[1]
        ))

        for ky in range(2 * self.radius + 1):
            for kx in range(2 * self.radius + 1):
                uy, ux = torch.meshgrid(
                    torch.arange(self.shape_out[0], dtype=torch.float32),
                    torch.arange(self.shape_out[1], dtype=torch.float32)
                )
                iy = torch.floor(step_inv[0] * uy) + ky - self.padding
                ix = torch.floor(step_inv[1] * ux) + kx - self.padding

                delta = torch.sqrt(
                    (uy - (self.margin + self.step * iy))**2 +
                    (ux - (self.margin + self.step * ix))**2
                )

                k = ky * (2 * self.radius + 1) + kx

                self.weight[0, k] = self.kernel(delta / sigma)

    def generate(self, mask_in):
        mask = F.unfold(mask_in,
                        (2 * self.radius + 1,) * 2,
                        padding=(self.padding,) * 2)
        mask = mask.reshape(
            mask_in.shape[0], -1, self.shape_mid[0], self.shape_mid[1])
        mask = F.interpolate(mask, size=self.shape_up, mode='nearest')
        mask = F.pad(mask, (0, -self.step + 1, 0, -self.step + 1))
        mask = self.weight * mask

        if self.pooling_method == 'sigmoid':
            if self.coldness == float('+Inf'):
                mask = (mask.sum(dim=1, keepdim=True) - 5 > 0).float()
            else:
                mask = torch.sigmoid(
                    self.coldness * mask.sum(dim=1, keepdim=True) - 3
                )
        elif self.pooling_method == 'softmax':
            if self.coldness == float('+Inf'):  # max normalization
                mask = mask.max(dim=1, keepdim=True)[0]
            else:   # smax normalization
                mask = (
                    mask * F.softmax(self.coldness * mask, dim=1)
                ).sum(dim=1, keepdim=True)
        elif self.pooling_method == 'sum':
            mask = mask.sum(dim=1, keepdim=True)
        else:
            assert False, f"Unknown pooling method {self.pooling_method}"

        m = round(self.margin)
        if self.clamp:
            mask = mask.clamp(min=0, max=1)
        cropped = mask[:, :, m:m + self.shape[0], m:m + self.shape[1]]
        return cropped, mask

    def to(self, dev):
        self.weight = self.weight.to(dev)
        return self

class Perturbation:
    def __init__(self, input, num_levels=8, max_blur=20, type=BLUR_PERTURBATION):
        self.type = type
        self.num_levels = num_levels
        self.pyramid = []
        assert num_levels >= 2
        assert max_blur > 0
        with torch.no_grad():
            for sigma in torch.linspace(0, 1, self.num_levels):
                if type == BLUR_PERTURBATION:
                    y = imsmooth(input, sigma=(1 - sigma) * max_blur)
                elif type == FADE_PERTURBATION:
                    y = input * sigma
                else:
                    assert False
                self.pyramid.append(y)
            self.pyramid = torch.stack(self.pyramid, dim=1)

    def apply(self, mask):
        n = mask.shape[0]
        w = mask.reshape(n, 1, *mask.shape[1:])
        w = w * (self.num_levels - 1)
        k = w.floor()
        w = w - k
        k = k.long()

        y = self.pyramid
        k = k.expand(n, 1, *y.shape[2:])
        y0 = torch.gather(y, 1, k)
        y1 = torch.gather(y, 1, torch.clamp(k + 1, max=self.num_levels - 1))

        perturb_x = ((1 - w) * y0 + w * y1)
        return perturb_x

    def to(self, dev):

        self.pyramid.to(dev)
        return self

    def __str__(self):
        return (
            f"Perturbation:\n"
            f"- type: {self.type}\n"
            f"- num_levels: {self.num_levels}\n"
            f"- pyramid shape: {list(self.pyramid.shape)}"
        )




def TIEM(model, input, target,
                        areas=0.1, perturb_type=FADE_PERTURBATION,
                        max_iter=2000, num_levels=8, step=7, sigma=11, variant=PRESERVE_VARIANT,
                        print_iter=None, reward_func="simple_log", resize=False,
                        resize_mode='bilinear', smooth=0,task="C",sigma_max=5,regul_weight=300,learning_rate=5e-2,reward_weight=100,alpha=0.8):

    occ = Occulusion_base_Time_Weight(model, input, target, task=task)
    TIS_pyramid=occ.time_weight_pyramid
    TIS, idx = occ.Get_time_weight_threshold(ratio=alpha)
    Time_weight=rescale_time_imp(TIS, areas, total_frame=16)

    if isinstance(areas, float):
        areas = [areas]
    momentum = 0.9
    learning_rate = learning_rate
    regul_weight = regul_weight

    reward_weight = reward_weight
    device = input.device

    iter_period = 2000

    batch_size = input.shape[0]
    num_frame = input.shape[2]  #16
    num_areas = len(areas)

    for p in model.parameters():
        p.requires_grad_(False)



    ori_y = model(input)
    ori_prob, ori_pred_label, _ = process_activations(ori_y, target, softmaxed=True)


    pmt_inp = input.transpose(1,2).contiguous()
    pmt_inp = pmt_inp.view(batch_size*num_frame, *pmt_inp.shape[2:])

    perturbation = Perturbation(pmt_inp, num_levels=num_levels, 
                                    type=perturb_type).to(device)


    shape = perturbation.pyramid.shape[3:]
    mask_generator = MaskGenerator(shape, step, sigma, pooling_method='softmax').to(device)
    h, w = mask_generator.shape_in
    pmasks = torch.ones(batch_size*num_frame, 1, h, w).to(device)
    # pmasks = torch.rand(batch_size*num_frame, 1, h, w).to(device)

    # noise_scale = 1e-3  # Tweak the scale of noise for control
    # noise = torch.randn_like(pmasks) * noise_scale
    # pmasks.data = pmasks.data.clamp(0, 1) + noise

    max_area = np.prod(mask_generator.shape_out)
    max_volume = np.prod(mask_generator.shape_out) * num_frame
    aref = torch.ones(batch_size, num_frame, max_area).to(device)
    for i in range(num_frame):
        aref[:, i, :int(max_area * (1 - (Time_weight[i])))] = 0

    optimizer = optim.SGD([pmasks],
                          lr=learning_rate,
                          momentum=momentum,
                          dampening=momentum)
    hist = torch.zeros((batch_size, 2, 0))

    for t in tqdm(range(max_iter)):
        pmasks.requires_grad_(True)
        masks, padded_masks = mask_generator.generate(pmasks)

        if variant == DELETE_VARIANT:
            perturb_x = perturbation.apply(1 - masks)
        elif variant == PRESERVE_VARIANT:
            perturb_x = perturbation.apply(masks)
        elif variant == DUAL_VARIANT:
            perturb_x = torch.cat((
                perturbation.apply(masks),
                perturbation.apply(1 - masks),
            ), dim = 1)
        else:
            assert False


        perturb_x = perturb_x.view(batch_size, num_frame, *perturb_x.shape[1:])
        perturb_x = perturb_x.permute(2, 0, 3, 1, 4, 5).contiguous()
        perturb_x = perturb_x.view(perturb_x.shape[0]*perturb_x.shape[1], *perturb_x.shape[2:])

        masks = masks.view(batch_size, num_frame, *masks.shape[1:]).transpose(1,2)
        padded_masks = padded_masks.view(batch_size, num_frame, \
                                *padded_masks.shape[1:]).transpose(1,2)


        y = model(perturb_x)
        y = F.softmax(y, dim=1)


        prob, pred_label, pred_label_prob = process_activations(y, target, softmaxed=True)


        if reward_func == "simple_log":
            reward = simple_log_reward(y, target, variant=variant)
        else:
            raise Exception("Do not support other reward function now.")

        reward = reward.view(batch_size, -1).mean(dim=1) * reward_weight
        mask_sorted = padded_masks.squeeze(1).reshape(batch_size, num_frame, -1).sort(dim=2)[0]
        regul = ((mask_sorted - aref)**2).mean(dim=2).mean(dim=1) * regul_weight


        energy = (reward + regul).sum() 

        optimizer.zero_grad()
        energy.backward()
        optimizer.step()

        pmasks.data = pmasks.data.clamp(0, 1)

        hist_item = torch.cat(
                    (
                        reward.detach().cpu().view(-1, 1, 1),
                        regul.detach().cpu().view(-1, 1, 1)
                    ), dim=1)

        hist = torch.cat((hist,hist_item), dim=2)

        # if (print_iter != None) and (t % print_iter == 0):
        #     print("[{:04d}/{:04d}]".format(t + 1, max_iter), end="\n")
        #     for i in range(batch_size):
        #         if variant == "dual":
        #             print(" [area:{:.2f} loss:{:.2f} reg:{:.2f} presv:{:.2f}/{} del:{:.2f}/{}]".format(
        #                 areas[0], hist[i, 0, -1], hist[i, 1, -1],
        #                 prob[i], pred_label[i],
        #                 prob[i+batch_size], pred_label[i+batch_size]), end="")
        #         else:
        #             print(" [area:{:.2f} loss:{:.2f} reg:{:.2f} {}:{:.2f}/{}]".format(
        #                 areas[0], hist[i, 0, -1], hist[i, 1, -1],
        #                 variant, prob[i], pred_label[i]), end="")
        #         print()

    masks = masks.detach()
    list_mask = []
    for frame_idx in range(num_frame):
        mask = masks[:,:,frame_idx,:,:]
        mask = resize_saliency(pmt_inp, mask, resize, mode=resize_mode)

        if smooth > 0:
            mask = imsmooth(mask, sigma=smooth * min(mask.shape[2:]), padding_mode='constant')
        list_mask.append(mask)
    masks = torch.stack(list_mask, dim=2)

    return masks,TIS_pyramid,TIS