import struct
import string
import seaborn
import numpy as np
import matplotlib.pyplot as plt

seaborn.set()
COFF_LENGTH = 24
SECT_ENTRY_LEN = 40


def _pick_color(max_v, value):
	color = str.format("{:02x}", 255 - int(value / max_v * 255))
	rgb = "#ff{}ff".format(color)
	return rgb


def plot_bins(itgs: np.ndarray, n_bins: int = 256, cut_index: int = None, force_plot: bool = True):
	"""Plot contribution in histogram
	
	Parameters
	----------
	itgs : numpy array
		The array containing the results of Integrated Gradients.
	n_bins : int, optional, default 256
		How many bins fot the histogram, default 256
	cut_index : int, optional, default None
		limit how many location should be used, default None
	force_plot : bool, optional, default True
		Should show plot on screen? Default True
	"""
	attr = itgs[1, :].tolist()[0]
	n_bytes = len(attr) // n_bins
	positives = [i if i >= 0 else 0 for i in attr]
	negatives = [i if i < 0 else 0 for i in attr]
	x = range(n_bins)
	if cut_index:
		positives = positives[:cut_index]
		negatives = negatives[:cut_index]
		n_bins = cut_index // n_bytes
		x = x[:n_bins]
	y_p = [np.sum(positives[i * n_bytes: (i + 1) * n_bytes]) for i in range(n_bins)]
	y_n = [np.sum(negatives[i * n_bytes: (i + 1) * n_bytes]) for i in range(n_bins)]
	plt.bar(x, y_p, color="b")
	plt.bar(x, y_n, color="r")
	plt.xlabel("Segment of {} bytes".format(n_bytes))
	plt.ylabel("Contribution")
	plt.title("Sum of contributions, divided in {} bins".format(n_bins))
	if force_plot:
		plt.show()


def plot_range_histogram(itgs: np.ndarray, start_idx: int = 0, how_many: int = None, force_plot: bool = True):
	"""Plot chunk of bytes with their relevance, according to the integrated gradient technique

	Parameters
	----------

	itgs : numpy array
		array containing the result of Integrated gradient
	start_idx : int
		starting index for visualizing the results, default 0
	how_many : int
		how many bytes represent inside the plot (default : None, means use all)
	force_plot : bool, optional, default True
		Should show plot on screen? Default True
	"""
	attr = itgs[1, :].tolist()[0]
	if start_idx > len(attr):
		raise ValueError()
	stop_index = len(attr) if how_many is None else how_many + start_idx
	if stop_index < start_idx:
		raise ValueError()
	if stop_index > len(attr):
		raise ValueError()
	attr = attr[start_idx:stop_index]
	positives = [i if i >= 0 else 0 for i in attr]
	negatives = [i if i < 0 else 0 for i in attr]
	x = range(how_many)
	plt.bar(x, positives, color="r")
	plt.bar(x, negatives, color="b")
	plt.title("Offset histogram, from {} to {} ".format(start_idx, stop_index))
	if force_plot:
		plt.show()


def plot_header_contribution_histogram(bytestring_program: bytearray, itgs: np.ndarray, percentage: bool = True,
									   force_plot: bool = True):
	"""Plot integrated gradient results, divided by sections

	Parameters
	----------
	bytestring_program: bytearray
		the program as bytearray
	itgs : numpy array
		array containing the result of Integrated gradient
	percentage :bool
		display percentage instead of absolute values (default: True)
	force_plot : bool
		Should show the results? (default:True)
	"""
	pe_position = struct.unpack("<I", bytestring_program[0x3C: 0x3C + 4])[0]
	n_sects = struct.unpack(
		"<I", bytestring_program[pe_position + 6: pe_position + 8] + b"\x00\x00"
	)[0]
	opt_h_len = struct.unpack(
		"<I", bytestring_program[pe_position + 20: pe_position + 22] + b"\x00\x00"
	)[0]
	sect_offset = opt_h_len + pe_position + COFF_LENGTH
	sect_end = sect_offset + SECT_ENTRY_LEN * n_sects
	mean_dos = np.mean(itgs[0, 0:pe_position].tondarray())
	mean_header_coff = np.sum(itgs[0, pe_position: pe_position + 24].tondarray())
	mean_header_optional = np.sum(
		itgs[0, pe_position + COFF_LENGTH: pe_position + COFF_LENGTH + opt_h_len].tondarray()
	)
	mean_section_table = np.sum(itgs[0, sect_offset:sect_end].tondarray())
	names = ["DOS Header", "COFF Header", "Optional Header", "Section Headers"]
	to_plot = [mean_dos, mean_header_coff, mean_header_optional, mean_section_table]
	sect_name_length = 8
	for i in range(n_sects):
		offset_index = sect_offset + i * SECT_ENTRY_LEN + sect_name_length + 12
		size_index = sect_offset + i * SECT_ENTRY_LEN + sect_name_length + 8
		offset = struct.unpack("<I", bytestring_program[offset_index: offset_index + 4])[0]
		size = struct.unpack("<I", bytestring_program[size_index: size_index + 4])[0]
		name = str(
			bytestring_program[
			sect_offset + i * SECT_ENTRY_LEN: sect_offset + i * SECT_ENTRY_LEN + 8
			]
				.decode("utf-8")
				.rstrip("\x00")
		)
		mean = np.sum(itgs[0, offset: offset + size].tondarray())
		to_plot.append(mean)
		names.append(name)
	if percentage:
		to_plot = to_plot / np.linalg.norm(to_plot, ord=2)
	plt.figure()
	x = range(len(names))
	positives = [i if i >= 0 else 0 for i in to_plot]
	negatives = [i if i < 0 else 0 for i in to_plot]
	plt.bar(x, positives, width=0.2, color="r")
	plt.bar(x, negatives, width=0.2, color="b")
	plt.yticks(fontsize=22)
	plt.xticks(x, names, fontsize=22, rotation=45, ha="right")
	xs = np.linspace(-1, 7, 4)
	plt.plot(xs, np.zeros(len(xs)), "k")
	ax = plt.gca()
	ax.set_facecolor((1.0, 1.0, 1.0))
	plt.title(
		"Sum of each contribution,divided into headers and sections\n", fontsize=25
	)
	if force_plot:
		plt.show()


def plot_code_segment(
		pe_file: list,
		start: int,
		stop: int,
		itgs: np.ndarray,
		title: str,
		show_positives: bool = True,
		show_negatives: bool = False,
		force_plot: bool = True,
		width: int = 16,
		percentage: bool = True,
):
	"""Plot contribution of chunks of bytes.

	Parameters
	----------
	pe_file : list
		list of bytes
	start : int
		starting index for segment to plot
	stop : int
		stop index for segment to plot
	itgs : numpy array
		array containing result of integrated gradients
	title : str
		plot title
	show_positives : bool
		show positives contributes (default:True)
	show_negatives : bool
		show negative contributes (default:False}
	force_plot : bool
		show plot (default:True)
	width : int
		how many byte per row of the heatmap (default:16)
	percentage : bool
		display percentage instead of absolute values (default: True)
	"""
	grad_section = itgs[1, start:stop].tolist()[0]
	text = [
		hex(i) if chr(i) not in string.ascii_letters else chr(i)
		for i in pe_file[start:stop]
	]
	cols = width
	tot_len = len(text)
	row = tot_len // cols
	if tot_len % cols:
		rem = tot_len % cols
		text.extend(["" for _ in range(cols - rem)])
		grad_section.extend([0 for _ in range(cols - rem)])
		row = row + 1
	text = np.array(text)
	grad_section = np.array(grad_section) / (
		np.linalg.norm(grad_section) if percentage else 1
	)

	if not show_positives:
		grad_section[grad_section > 0] = 0
	if not show_negatives:
		grad_section[grad_section < 0] = 0

	grad_section = grad_section.reshape((row, cols))
	text = text.reshape((row, cols))
	plt.figure()
	hmap = seaborn.heatmap(
		grad_section,
		annot=text,
		fmt="",
		cmap="seismic",
		center=0,
		annot_kws={"size": 18},
	)
	hmap.collections[0].colorbar.ax.tick_params(labelsize=20)
	hmap.set_title(title, fontsize=25)
	if force_plot:
		plt.show()
