import numpy as np
import networkx as nx
import torch


import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


# Calculate the shapley value for a range of nodes to one output token
# Encoder only models
def calculate_rollout_encoder_only(
    attention_dict,
    input_token_list,
    layers,
    head,
    output_token=-1,
    show=False,
    plot=False,
):
    print("Head " + str(head))
    enc_dict = attention_dict["enc_attn"][head]
    matrix = np.array(enc_dict[0].detach().numpy())
    for x in range(1, layers):
        matrix = np.matmul(matrix, np.array(enc_dict[x].detach().numpy()))
    print(matrix)
    result_dict = {}
    result_list = []
    result_list = matrix[0]
    # print(result_list)
    # result_list = [float(i) / result_sum for i in result_list]
    all_tokens = input_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    return result_dict


# Calculate the shapley value for a range of nodes to one output token
# Decoder only models
def calculate_rollout_decoder_only(
    attention_dict,
    input_token_list,
    output_token_list,
    layers,
    head,
    output_token=-1,
    show=False,
    plot=False,
):
    print("Head " + str(head))
    dec_dict = attention_dict["dec_attn"][head]
    matrix = np.array(dec_dict[0].detach().numpy())
    print(matrix)
    for x in range(1, layers):
        print(np.array(dec_dict[x].detach().numpy()))
        matrix = np.matmul(matrix, np.array(dec_dict[x].detach().numpy()))
    print(matrix)
    result_dict = {}
    result_list = []
    result_list = matrix[0]
    num_input_tokens = len(input_token_list)
    # print(result_list)
    # result_list = [float(i) / result_sum for i in result_list]
    all_tokens = input_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    for x in range(0, num_input_tokens + output_token):
        result_list[x] = result_list[x] * (1 / (2 + num_input_tokens + output_token - x))
    # print(result_list)
    # result_list = [float(i) / result_sum for i in result_list]
    all_tokens = input_token_list + output_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    return result_dict


# Calculate the shapley value for a range of nodes to one output token
# Encoder-Decoder models
def calculate_rollout(
    attention_dict,
    layers_enc,
    layers_dec,
    input_token_list,
    output_token_list,
    head,
    output_token=-1,
    show=False,
    plot=False,
):
    # build flownetwork
    g = nx.DiGraph()
    # read attention
    num_input_tokens = len(input_token_list)
    num_output_tokens = len(output_token_list)
    # i = 0
    test = 0
    s = (-1, int(num_input_tokens / 2))
    t = (layers_enc + layers_dec + 2, int(num_input_tokens + num_output_tokens / 2))
    g.add_node(s)
    g.add_node(t)
    sum = 0
    enc_dict = attention_dict["enc_attn"][head]
    for x in range(0, layers_enc):  # x coordinate
        for y in range(0, num_input_tokens):  # y coordinate
            test = 0
            for z in range(0, num_input_tokens):
                # counter for y coordinate in successor layer
                if z == y:  # residual connection
                    g.add_edge((x, y), (x + 1, z), capacity=float(enc_dict[x][z][y]) + 1)
                    sum = sum + float(enc_dict[x][z][y])
                else:
                    g.add_edge((x, y), (x + 1, z), capacity=float(enc_dict[x][z][y]))
                    sum = sum + float(enc_dict[x][z][y])

    # connect decoder to terminal node
    # num_output_tokens = x + 1
    if layers_dec == 0:  # calculate flow without decoder
        num_input_tokens = len(input_token_list)
        for y in range(0, num_input_tokens):  # connect to end of flow network
            g.add_edge((layers_enc, y), t, capacity=1)
    else:
        if output_token == -1:
            for y in range(0, num_output_tokens):  # connect to end of flow network
                g.add_edge((layers_enc + layers_dec + 1, y + num_input_tokens), t, capacity=1)
        else:
            g.add_edge(
                (layers_enc + layers_dec + 1, output_token + num_input_tokens),
                t,
                capacity=1,
            )

    # decoder self attention
    dec_dict = attention_dict["dec_attn"][head]
    for x in range(0, layers_dec):
        # x coordinate
        for y in range(0, num_output_tokens):  # y coordinate
            test = 0
            for z in range(y, num_output_tokens):
                # counter for y coordinate in successor, but only prev.
                if z == y:
                    test = test + float(dec_dict[x][z][y])
                    g.add_edge(
                        (x + layers_enc + 1, y + num_input_tokens),
                        (x + layers_enc + 2, z + num_input_tokens),
                        capacity=float(dec_dict[x][z][y]) + 1,
                    )
                else:
                    test = test + float(dec_dict[x][z][y])
                g.add_edge(
                    (x + layers_enc + 1, y + num_input_tokens),
                    (x + layers_enc + 2, z + num_input_tokens),
                    capacity=float(dec_dict[x][z][y]),
                )
    # cross attention
    cross_dict = attention_dict["enc_dec_attn"][head]
    for x in range(0, layers_dec):  # x coordinate
        for y in range(0, num_output_tokens):  # y coordinate
            test = 0
            for z in range(0, num_input_tokens):
                g.add_edge(
                    (layers_enc, z),
                    (x + layers_enc + 2, y + num_input_tokens),
                    capacity=float(cross_dict[x][y][z]),
                )
    pos = {(x, y): (x + 3, -y - 2) for x, y in g.nodes()}
    if show:
        nx.draw(g, pos, with_labels=True, node_size=600)
        nx.draw_networkx_edge_labels(
            g, pos, edge_labels=nx.get_edge_attributes(g, "capacity"), font_color="red"
        )
        plt.rcParams["figure.figsize"] = [
            (layers_enc + layers_dec + 5) * 2,
            num_input_tokens + num_output_tokens + 5,
        ]
        plt.show()
    result_dict = {}
    result_list = []
    for x in range(0, num_input_tokens):
        g.add_edge(s, (0, x))
        all_paths = [path for path in nx.all_simple_paths(g, s, t)]
        sum = 0
        for p in all_paths:
            for _ in range(len(p)):
                pairs = zip(p, p[1:])
                product = 1
                for pair in pairs:
                    edge = g.get_edge_data(pair[0], pair[1])
                    product = product * edge["capacity"]
                sum = sum + product
        # print(flow_value)
        result_list.append(product)
        g.remove_edge(s, (0, x))
    for x in range(0, num_output_tokens):
        all_paths = [path for path in nx.all_simple_paths(g, (layers_enc + 1, x + num_input_tokens), t)]
        sum = 0
        for p in all_paths:
            for _ in range(len(p)):
                pairs = zip(p, p[1:])
                product = 1
                for pair in pairs:
                    edge = g.get_edge_data(pair[0], pair[1])
                    product = product * edge["capacity"]
                sum = sum + product
        # print(flow_value)
        result_list.append(product)
    # normalize encoder connections
    # print("Resultlist")
    # print(result_list)
    for x in range(0, num_input_tokens):
        result_list[x] = result_list[x] * (1 / (num_input_tokens + 1))
    # normalize auto-regression
    for x in range(0, output_token + 1):
        result_list[x + num_input_tokens] = result_list[x + num_input_tokens] * (
            1 / (1 + output_token - x)
        )
    result_sum = 0
    for r in result_list:
        result_sum += r
    all_tokens = input_token_list + output_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    if show:
        pos = {(x, y): (x, -y) for x, y in g.nodes()}
        data = pd.DataFrame({"attention rollout": result_list}, input_token_list)
        sns.heatmap(data, annot=True, cmap="YlGnBu")
        plt.show()
    return result_dict


# compute one attention matrix for mulitple heads by summing up the values of all heads
def sum_up_heads(attention_dict):
    new_attention_dict = {}
    for dict in attention_dict:
        attention_layer_dict = [0]*len(attention_dict[dict][0])
        for layer in range(0, len(attention_dict[dict][0])):  # num layers
            player_dict = [0]*len(attention_dict[dict][0][layer])
            for player in range(0, len(attention_dict[dict][0][layer])):  # num players
                attended_player_dict = [0]*len(attention_dict[dict][0][layer][player])
                for player_attended in range(0, len(attention_dict[dict][0][layer][player])):
                    multi_head_attention = 0
                    for head in range(0, len(attention_dict[dict])):
                        # multi_head_attention = 1  # for normalization checking
                        multi_head_attention = multi_head_attention + float(
                            attention_dict[dict][head][layer][player][player_attended]
                        )
                    attended_player_dict[player_attended] = multi_head_attention / len(attention_dict[dict])
                player_dict[player] = attended_player_dict
            attention_layer_dict[layer] = player_dict
        new_attention_dict[dict] = torch.from_numpy(np.array(attention_layer_dict))
    return new_attention_dict


# calculastes attention flow for every head sepeately and for the sum of all heads for every input/output token pair
def calculate_rollout_sum(
    attention_dict_list,
    input_tokens,
    pred_tokens,
    layers_enc=0,
    layers_dec=0,
    output_token=-1,
    plot=True,
    decoder_only=False,
    encoder_only=False,
):
    shapley_list = []
    num_heads = 0
    if encoder_only:
        num_heads = len(attention_dict_list["enc_attn"])
        all_tokens = input_tokens
    else:
        num_heads = len(attention_dict_list["dec_attn"])
        all_tokens = input_tokens + pred_tokens
    for head in range(0, num_heads):
        if decoder_only:
            shapley_list.append(
                calculate_rollout_decoder_only(
                    attention_dict_list,
                    input_tokens,
                    pred_tokens,
                    layers_dec,
                    head,
                    output_token=output_token,
                    show=False,
                    plot=False,
                )
            )
            result_list = []
            i = 0
            for n in all_tokens:
                result_list.append(shapley_list[head][(n, i)])
                i = i + 1
            if plot:
                # print(shapley_list)
                # print(shapley_list)
                # print("Head " + str(head))
                del result_list[len(input_tokens) + output_token + 1 :]
                del all_tokens[len(input_tokens) + output_token + 1 :]
                data = pd.DataFrame({"attention flow": result_list}, all_tokens)
                sns.set(font="Times New Roman", font_scale=1.3)
                plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=600)
                sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
                plt.show()
        elif encoder_only:
            shapley_list.append(
                calculate_rollout_encoder_only(
                    attention_dict_list,
                    input_tokens,
                    layers_enc,
                    head,
                    output_token=output_token,
                    show=False,
                    plot=False,
                )
            )
            result_list = []
            i = 0
            for n in all_tokens:
                result_list.append(shapley_list[head][(n, i)])
                i = i + 1
            if plot:
                # print(shapley_list)
                # print(shapley_list)
                # print("Head " + str(head))
                del result_list[len(input_tokens) :]
                del all_tokens[len(input_tokens) :]
                data = pd.DataFrame({"attention rollout": result_list}, all_tokens)
                sns.set(font="Times New Roman", font_scale=1.3)
                plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=600)
                sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
                plt.show()

        else:
            shapley_list.append(
                calculate_rollout(
                    attention_dict_list,
                    layers_enc,
                    layers_dec,
                    input_tokens,
                    pred_tokens,
                    head,
                    output_token=output_token,
                    show=False,
                    plot=False,
                )
            )
            result_list = []
            i = 0
            for n in all_tokens:
                result_list.append(shapley_list[head][(n, i)])
                i = i + 1
            result_list_1 = result_list[: len(input_tokens)]
            all_tokens_1 = all_tokens[: len(input_tokens)]
            result_list_2 = result_list[
                len(input_tokens) : len(input_tokens) + output_token + 1
            ]
            all_tokens_2 = all_tokens[len(input_tokens) : len(input_tokens) + output_token + 1]
            if plot:
                print("Head " + str(head))
                data_1 = pd.DataFrame({"attention rollout encoder": result_list_1}, all_tokens_1)
                data_2 = pd.DataFrame({"attention rollout decoder": result_list_2}, all_tokens_2)
                sns.set(font="Times New Roman", font_scale=1.3)
                plt.subplots(figsize=(3, len(input_tokens) / 2), dpi=600)
                sns.heatmap(data_1, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
                sns.heatmap(data_2, annot=True, linewidth=0.2, cmap="YlGnBu", cbar=True)
                plt.show()
                plt.subplots(figsize=(3, (output_token + 1) / 2), dpi=600)
                sns.heatmap(data_2, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
                plt.show()
    if output_token == -1:
        output_token = len(pred_tokens)
    shapley_list = []
    # print("Sum of all heads")

    attention_dict_list = sum_up_heads(attention_dict_list)
    if decoder_only:
        shapley_list.append(
            calculate_rollout_decoder_only(
                attention_dict_list,
                input_tokens,
                pred_tokens,
                layers_dec,
                0,
                output_token=output_token,
                show=True,
                plot=False,
            )
        )
        result_list = []
        i = 0
        for n in all_tokens:
            result_list.append(shapley_list[0][(n, i)])
            i = i + 1
        if plot:
            del result_list[len(input_tokens) + output_token + 1 :]
            del all_tokens[len(input_tokens) + output_token + 1 :]
            data = pd.DataFrame({"attention flow": result_list}, all_tokens)
            sns.set(font="Times New Roman", font_scale=1.3)
            plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=600)
            sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
            plt.show()
    elif encoder_only:
        shapley_list.append(
            calculate_rollout_encoder_only(
                attention_dict_list,
                input_tokens,
                layers_enc,
                0,
                output_token=output_token,
                show=False,
                plot=False,
            )
        )
        result_list = []
        i = 0
        for n in all_tokens:
            result_list.append(shapley_list[0][(n, i)])
            i = i + 1
        if plot:
            # print(shapley_list)
            # print(shapley_list)
            del result_list[len(input_tokens) :]
            del all_tokens[len(input_tokens) :]
            data = pd.DataFrame({"attention flow": result_list}, all_tokens)
            plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=600)
            sns.set(font="Times New Roman", font_scale=1.3)
            sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
            plt.show()

    else:
        shapley_list.append(
            calculate_rollout(
                attention_dict_list,
                layers_enc,
                layers_dec,
                input_tokens,
                pred_tokens,
                0,
                output_token=output_token,
                show=False,
                plot=False,
            )
        )
        result_list = []
        i = 0
        for n in all_tokens:
            result_list.append(shapley_list[0][(n, i)])
            i = i + 1
        result_list_1 = result_list[: len(input_tokens)]
        all_tokens_1 = all_tokens[: len(input_tokens)]
        result_list_2 = result_list[len(input_tokens) : len(input_tokens) + output_token + 1]
        all_tokens_2 = all_tokens[len(input_tokens) : len(input_tokens) + output_token + 1]
        if plot:
            data_1 = pd.DataFrame({"attention rollout encoder": result_list_1}, all_tokens_1)
            data_2 = pd.DataFrame({"attention rollout decoder": result_list_2}, all_tokens_2)
            plt.subplots(figsize=(3, len(input_tokens) / 2), dpi=600)
            sns.set(font="Times New Roman", font_scale=1.3)
            sns.heatmap(
                data_1,
                annot=True,
                linewidth=0.2,
                cmap="crest",
                cbar=False,
                fmt=".9f",
            )
            plt.show()
            plt.subplots(figsize=(3, (output_token + 1) / 2), dpi=600)
            sns.heatmap(data_2, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".9f")
            plt.show()
    return result_list
