# -*- coding: UTF-8 -*-


import numpy as np
import torch
import torch.nn.functional as F
import math


class LimeNet():
    
    def __init__(self, num_augu, wm_length, image_size, num_channels, device, lam=0.0):
        if num_augu is not None:
            self.num_augu = num_augu
        else:
            self.num_augu = wm_length
        self.wm_length = wm_length
        self.image_size = image_size
        self.num_channels = num_channels
        self.block_size = int(image_size / math.sqrt(wm_length))
        self.lam = lam
        self.device = device
        # randomly generate mask

        self.masks = []
        self.flatten_mask = []
        if num_augu is None:
            for i in range(self.wm_length):
                flatten_mask = np.ones(self.wm_length)
                flatten_mask[i] = 0
                self.flatten_mask.append(flatten_mask)
                mask_matrix = np.ones((num_channels, image_size, image_size))
                row_num = int(image_size / self.block_size)
                row_count = int(i // row_num)
                col_count = int(i % row_num)
                mask_matrix[:, row_count * self.block_size: (row_count + 1) * self.block_size,
                            col_count * self.block_size: (col_count + 1) * self.block_size] = 0
                self.masks.append(mask_matrix)
        else:
            self.flatten_mask = np.random.choice([0, 1], self.wm_length, p=[0.5, 0.5]).reshape(1, self.wm_length)
            while self.flatten_mask.shape[0] < self.num_augu:
                flatten_mask = np.random.choice([0, 1], self.wm_length, p=[0.5, 0.5])
                self.flatten_mask = np.r_[self.flatten_mask, [flatten_mask]]
            for i in range(self.num_augu):
                flatten_mask = self.flatten_mask[i]
                mask_matrix = np.ones((num_channels, image_size, image_size))
                row_num = int(image_size / self.block_size)
                row_count = 0
                col_count = 0
                for j in range(self.wm_length):
                    if flatten_mask[j] == 0:
                        mask_matrix[:, row_count * self.block_size: (row_count + 1) * self.block_size,
                            col_count * self.block_size: (col_count + 1) * self.block_size] = 0
                    row_count += 1
                    if row_count == row_num:
                        row_count = 0
                        col_count += 1
                self.masks.append(mask_matrix)
        self.masks = torch.from_numpy(np.array(self.masks)).to(device)
        self.flatten_mask = torch.from_numpy(np.array(self.flatten_mask)).to(device)
        flatten_mask = self.flatten_mask.type(torch.float32)
        self.weight_matrix = torch.mm(flatten_mask.T, flatten_mask)
        self.weight_matrix = self.weight_matrix + torch.eye(self.wm_length).to(device) * self.lam
        self.weight_matrix = self.weight_matrix.inverse()
        self.weight_matrix = torch.mm(self.weight_matrix, flatten_mask.T)
        self.weight_matrix = self.weight_matrix.to(self.device)

    
    def explain(self, model, image):
        masked_images = self.masks * image
        masked_images = masked_images.float().to(self.device)
        
        preds = torch.softmax(model(masked_images), dim=1)
        # print(preds.shape)
        # calculate the normalized entropy
        predictions = (np.log(preds.shape[1]) - torch.sum(-1 * preds * torch.log(preds), axis=1).unsqueeze(-1)) / np.log(preds.shape[1])
        weight = torch.mm(self.weight_matrix, predictions)
        return weight
    
    def explain_image(self, model, image):
        weight = self.explain(model, image)
        weight = weight.reshape((int(math.sqrt(self.wm_length)), int(math.sqrt(self.wm_length)), 1)).cpu().detach().numpy()
        weight[weight > 0] = 255
        weight[weight <= 0] = 0
        return weight
        

class ModelDiffNet():
    def __init__(self, wm_length, device):
        self.wm_length = wm_length
        self.device = device

    def explain(self, model, images, ori_images):
        model.eval()
        model = model.to(self.device)
        x1 = images.to(self.device)
        x2 = ori_images.to(self.device)
        out1 = F.softmax(model(x1), dim=1)
        with torch.no_grad():
            out2 = F.softmax(model(x2), dim=1)
        # out1 = model(x1)
        # out2 = model(x2)
        # diff = js_div(out1, out2, get_softmax=False, reduction='none').sum(dim=1) * 2 - 1
        diff = F.cosine_similarity(out1, out2, dim=1) - np.cos(75)
        # H1 = entropy(out1).unsqueeze(-1)
        # H2 = entropy(out2).unsqueeze(-1)
        # diff = H1 - H2
        # print(diff.shape)
        # exit(0)
        return diff
    
    def explain_image(self, model, images, ori_images):
        weight = self.explain(model, images, ori_images)
        weight = weight.reshape((int(math.sqrt(self.wm_length)), int(math.sqrt(self.wm_length)), 1)).cpu().detach().numpy()
        weight[weight > 0] = 255
        weight[weight <= 0] = 0
        return weight
        
def js_div(p_output, q_output, get_softmax=True, reduction='mean'):
    """
    Function that measures JS divergence between target and output logits:
    """
    if get_softmax:
        p_output = F.softmax(p_output, dim=1)
        q_output = F.softmax(q_output, dim=1)
    mean_output = (p_output + q_output ) / 2
    # print(log_mean_output.shape)
    return (F.kl_div(p_output.log(), mean_output, reduction=reduction) 
            + F.kl_div(q_output.log(), mean_output, reduction=reduction)) / 2


def entropy(preds):
    entropy = torch.sum(-1 * preds * torch.log(preds), axis=1)
    return entropy

def normlized_entropy(preds):
    entropy = torch.sum(-1 * preds * torch.log(preds), axis=1).unsqueeze(-1)
    normlized_H = (np.log(preds.shape[1]) - entropy) / np.log(preds.shape[1])
    return normlized_H


def dec2bin(num, length):
    mid = []
    while True:
        if num == 0:
            break
        num, rem = divmod(num, 2)
        if int(rem) == 0:
            mid.append(0)
        else:
            mid.append(1)
        # mid.append(int(rem))
    while len(mid) < length:
        mid.insert(0, 0)
    return mid

