import numpy as np
import networkx as nx
import networkx.algorithms.flow as flow

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd


# Calculate the shapley value for a range of nodes to one output token
# Encoder only models
def calculate_shapley_encoder_only(
    attention_dict,
    input_token_list,
    layers,
    head,
    output_token=-1,
    show=False,
    plot=True,
):
    # build flownetwork
    g = nx.DiGraph()
    # read attention
    num_input_tokens = len(input_token_list)
    s = (-1, int(num_input_tokens / 2))
    t = (layers + 2, int(num_input_tokens / 2))
    g.add_node(s)
    g.add_node(t)
    all_tokens = input_token_list
    # encoder flow
    enc_dict = attention_dict["enc_attn"][head]
    for x in range(0, layers):  # x coordinate
        for y in range(0, num_input_tokens):  # y coordinate

            for z in range(0, num_input_tokens):
                # counter for y coordinate in successor layer
                if z == y:  # residual connection
                    g.add_edge((x, y), (x + 1, z), capacity=float(enc_dict[x][z][y]))
                else:
                    g.add_edge((x, y), (x + 1, z), capacity=float(enc_dict[x][z][y]))
    # connect  token in output list to terminal node
    if output_token != -1:
        g.add_edge(
            (layers, output_token),
            t,
            capacity=np.inf,
        )
    else:
        for i in range(num_input_tokens):
            g.add_edge(
                (layers, i),
                t,
                capacity=np.inf,
            )

    result_dict = {}
    result_list = []
    for x in range(0, num_input_tokens):
        flow_value, _ = nx.maximum_flow(g, (0, x), t, flow_func=flow.edmonds_karp)
        # print(flow_value)
        result_list.append(flow_value)
    result_sum = 0
    #pos = {(x, y): (x + 3, -y - 2) for x, y in g.nodes()}
    # if show:
    # nx.draw(g, pos, with_labels=True, node_size=200)
    # nx.draw_networkx_edge_labels(
    #    g, pos, edge_labels=nx.get_edge_attributes(g, "capacity"), font_color="red"
    # )
    # plt.rcParams["figure.figsize"] = [
    #    (layers + 5) * 2,
    #    num_input_tokens + 5,
    # ]
    # plt.show()
    for r in result_list:
        result_sum += r
    # print(result_list)
    # result_list = [float(i) / result_sum for i in result_list]
    all_tokens = input_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    return result_dict


# Calculate the shapley value for a range of nodes to one output token
# Decoder only models
def calculate_shapley_decoder_only(
    attention_dict,
    input_token_list,
    output_token_list,
    layers,
    head,
    output_token=-1,
    show=False,
    plot=True,
):
    # build flownetwork
    g = nx.DiGraph()
    # read attention
    num_input_tokens = len(input_token_list)
    num_output_tokens = len(output_token_list)
    # i = 0
    test = 0
    s = (-1, int(num_input_tokens + num_output_tokens / 2))
    t = (layers + 2, int(num_input_tokens + num_output_tokens / 2))
    g.add_node(s)
    g.add_node(t)
    all_tokens = input_token_list + output_token_list
    dec_dict = attention_dict["dec_attn"][head]
    for x in range(0, layers):
        # x coordinate
        for y in range(0, len(all_tokens)):  # y coordinate
            test = 0
            for z in range(y, len(all_tokens)):
                # counter for y coordinate in successor, but only prev.
                if z == y:
                    test = test + float(dec_dict[x][z][y])
                    g.add_edge(
                        (x, y),
                        (x + 1, z),
                        capacity=float(dec_dict[x][z][y]),  # residual
                    )
                else:
                    test = test + float(dec_dict[x][z][y])
                    g.add_edge(
                        (x, y),
                        (x + 1, z),
                        capacity=float(dec_dict[x][z][y]),
                    )
    # connect last token in output list to terminal node
    g.add_edge(
        (layers, output_token + num_input_tokens),
        t,
        capacity=np.inf,
    )
    result_dict = {}
    result_list = []
    for x in range(0, num_input_tokens):
        flow_value, _ = nx.maximum_flow(g, (0, x), t, flow_func=flow.edmonds_karp)
        # print(flow_value)
        result_list.append(flow_value)
    for x in range(num_input_tokens, num_input_tokens + num_output_tokens):
        flow_value, _ = nx.maximum_flow(g, (0, x), t, flow_func=flow.edmonds_karp)
        # print(flow_value)
        result_list.append(flow_value)
    # normalize auto-regression
    for x in range(0, num_input_tokens + output_token):
        result_list[x] = result_list[x] * (1 / (2 + num_input_tokens + output_token - x))
    result_sum = 0
    #pos = {(x, y): (x + 3, -y - 2) for x, y in g.nodes()}
    # if show:
    # nx.draw(g, pos, with_labels=True, node_size=200)
    # nx.draw_networkx_edge_labels(
    #    g, pos, edge_labels=nx.get_edge_attributes(g, "capacity"), font_color="red"
    # )
    # plt.rcParams["figure.figsize"] = [
    #    (layers + 5) * 2,
    #    num_input_tokens + num_output_tokens + 5,
    # ]
    # plt.show()
    for r in result_list:
        result_sum += r
    # print(result_list)
    # result_list = [float(i) / result_sum for i in result_list]
    all_tokens = input_token_list + output_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    return result_dict


# Calculate the shapley value for a range of nodes to one output token
# Encoder-Decoder models
def calculate_shapley(
    attention_dict,
    layers_enc,
    layers_dec,
    input_token_list,
    output_token_list,
    head,
    output_token=-1,
    show=False,
    plot=True,
):
    # build flownetwork
    g = nx.DiGraph()
    # read attention
    num_input_tokens = len(input_token_list)
    num_output_tokens = len(output_token_list)
    # i = 0
    test = 0
    s = (-1, int(num_input_tokens / 2))
    t = (layers_enc + layers_dec + 2, int(num_input_tokens + num_output_tokens / 2))
    g.add_node(s)
    g.add_node(t)
    sum = 0
    enc_dict = attention_dict["enc_attn"][head]
    for x in range(0, layers_enc):  # x coordinate
        for y in range(0, num_input_tokens):  # y coordinate
            test = 0
            for z in range(0, num_input_tokens):
                # counter for y coordinate in successor layer
                if z == y:  # residual connection
                    g.add_edge((x, y), (x + 1, z), capacity=float(enc_dict[x][z][y]) + 1)
                    sum = sum + float(enc_dict[x][z][y])
                else:
                    g.add_edge((x, y), (x + 1, z), capacity=float(enc_dict[x][z][y]))
                    sum = sum + float(enc_dict[x][z][y])

    # connect decoder to terminal node
    # num_output_tokens = x + 1
    if layers_dec == 0:  # calculate flow without decoder
        num_input_tokens = len(input_token_list)
        for y in range(0, num_input_tokens):  # connect to end of flow network
            g.add_edge((layers_enc, y), t, capacity=np.inf)
    else:
        if output_token == -1:
            for y in range(0, num_output_tokens):  # connect to end of flow network
                g.add_edge((layers_enc + layers_dec + 1, y + num_input_tokens), t, capacity=np.inf)
        else:
            g.add_edge(
                (layers_enc + layers_dec + 1, output_token + num_input_tokens),
                t,
                capacity=np.inf,
            )

    # decoder self attention
    dec_dict = attention_dict["dec_attn"][head]
    for x in range(0, layers_dec):
        # x coordinate
        for y in range(0, num_output_tokens):  # y coordinate
            test = 0
            for z in range(y, num_output_tokens):
                # counter for y coordinate in successor, but only prev.
                if z == y:
                    test = test + float(dec_dict[x][z][y])
                    g.add_edge(
                        (x + layers_enc + 1, y + num_input_tokens),
                        (x + layers_enc + 2, z + num_input_tokens),
                        capacity=float(dec_dict[x][z][y]) + 1,
                    )
                else:
                    test = test + float(dec_dict[x][z][y])
                g.add_edge(
                    (x + layers_enc + 1, y + num_input_tokens),
                    (x + layers_enc + 2, z + num_input_tokens),
                    capacity=float(dec_dict[x][z][y]),
                )
    # cross attention
    cross_dict = attention_dict["enc_dec_attn"][head]
    for x in range(0, layers_dec):  # x coordinate
        for y in range(0, num_output_tokens):  # y coordinate
            test = 0
            for z in range(0, num_input_tokens):
                g.add_edge(
                    (layers_enc, z),
                    (x + layers_enc + 2, y + num_input_tokens),
                    capacity=float(cross_dict[x][y][z]),
                )
    #pos = {(x, y): (x + 3, -y - 2) for x, y in g.nodes()}
    # if show:
    # nx.draw(g, pos, with_labels=True, node_size=200)
    # nx.draw_networkx_edge_labels(
    #    g, pos, edge_labels=nx.get_edge_attributes(g, "capacity"), font_color="red"
    # )
    # plt.rcParams["figure.figsize"] = [
    #    (layers_enc + layers_dec + 5) * 2,
    #    num_input_tokens + num_output_tokens + 5,
    # ]
    # plt.show()
    result_dict = {}
    result_list = []
    for x in range(0, num_input_tokens):
        g.add_edge(s, (0, x))
        flow_value, _ = nx.maximum_flow(g, s, t, flow_func=flow.edmonds_karp)
        result_list.append(flow_value)
        g.remove_edge(s, (0, x))
    for x in range(0, num_output_tokens):
        flow_value, _ = nx.maximum_flow(
            g, (layers_enc + 1, x + num_input_tokens), t, flow_func=flow.edmonds_karp
        )
        result_list.append(flow_value)
    # normalize encoder connections
    # print("Resultlist")
    # print(result_list)
    for x in range(0, num_input_tokens):
        result_list[x] = result_list[x] * (1 / (num_input_tokens + 1))
    # normalize auto-regression
    for x in range(0, output_token + 1):
        result_list[x + num_input_tokens] = result_list[x + num_input_tokens] * (
            1 / (1 + output_token - x)
        )
    result_sum = 0
    for r in result_list:
        result_sum += r
    all_tokens = input_token_list + output_token_list
    for i, token in enumerate(all_tokens):
        result_dict.update({(all_tokens[i], i): result_list[i]})
    #pos = {(x, y): (x, -y) for x, y in g.nodes()}
    # if plot:
    # data = pd.DataFrame({"attention flow": result_list}, input_token_list)
    # sns.heatmap(data, annot=True, cmap="YlGnBu")
    # plt.show()
    return result_dict


# compute one attention matrix for mulitple heads by summing up the values of all heads
def sum_up_heads(attention_dict):
    new_attention_dict = {}
    for dict in attention_dict:
        attention_layer_dict = {}
        for layer in range(0, len(attention_dict[dict][0])):  # num layers
            player_dict = {}
            for player in range(0, len(attention_dict[dict][0][layer])):  # num players
                attended_player_dict = {}
                for player_attended in range(0, len(attention_dict[dict][0][layer][player])):
                    multi_head_attention = 0
                    for head in range(0, len(attention_dict[dict])):
                        # multi_head_attention = 1  # for normalization checking
                        multi_head_attention = multi_head_attention + float(
                            attention_dict[dict][head][layer][player][player_attended]
                        )
                    attended_player_dict[player_attended] = multi_head_attention
                player_dict[player] = attended_player_dict
            attention_layer_dict[layer] = player_dict
        new_attention_dict[dict] = [attention_layer_dict]
    return new_attention_dict


# calculastes attention flow for every head sepeately and for the sum of all heads for every input/output token pair
def calculate_shapley_sum(
    attention_dict_list,
    input_tokens,
    pred_tokens,
    layers_enc=0,
    layers_dec=0,
    output_token=-1,
    plot=False,
    decoder_only=False,
    encoder_only=False,
):
    shapley_list = []
    num_heads = 0
    if encoder_only:
        num_heads = len(attention_dict_list["enc_attn"])
        all_tokens = input_tokens
    else:
        num_heads = len(attention_dict_list["dec_attn"])
        all_tokens = input_tokens + pred_tokens
    return_list = []
    for head in range(0, num_heads):
        print("Head "+ str(head))
        if decoder_only:
            shapley_list.append(
                calculate_shapley_decoder_only(
                    attention_dict_list,
                    input_tokens,
                    pred_tokens,
                    layers_dec,
                    head,
                    output_token=output_token,
                    show=False,
                    plot=False,
                )
            )
            result_list = []
            i = 0
            for n in all_tokens:
                result_list.append(shapley_list[head][(n, i)])
                i = i + 1
            print(result_list)
            if plot:
                # print(shapley_list)
                # print(shapley_list)
                # print("Head " + str(head))
                del result_list[len(input_tokens) + output_token + 1 :]
                del all_tokens[len(input_tokens) + output_token + 1 :]
                data = pd.DataFrame({"attention flow": result_list}, all_tokens)
                sns.set(font="Times New Roman", font_scale=1.3)
                plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=200)
                sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
                plt.show()
        elif encoder_only:
            shapley_list.append(
                calculate_shapley_encoder_only(
                    attention_dict_list,
                    input_tokens,
                    layers_enc,
                    head,
                    output_token=output_token,
                    show=False,
                    plot=False,
                )
            )
            result_list = []
            i = 0
            for n in all_tokens:
                result_list.append(shapley_list[head][(n, i)])
                i = i + 1
            print(result_list)
            if plot:
                # print(shapley_list)
                # print(shapley_list)
                # print("Head " + str(head))
                del result_list[len(input_tokens) :]
                del all_tokens[len(input_tokens) :]
                data = pd.DataFrame({"attention flow": result_list}, all_tokens)
                sns.set(font="Times New Roman", font_scale=1.3)
                plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=200)
                sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
                plt.show()
        else:
            shapley_list.append(
                calculate_shapley(
                    attention_dict_list,
                    layers_enc,
                    layers_dec,
                    input_tokens,
                    pred_tokens,
                    head,
                    output_token=output_token,
                    show=False,
                    plot=False,
                )
            )
            result_list = []
            i = 0
            for n in all_tokens:
                result_list.append(shapley_list[head][(n, i)])
                i = i + 1
            result_list_1 = result_list[: len(input_tokens)]
            all_tokens_1 = all_tokens[: len(input_tokens)]
            result_list_2 = result_list[
                len(input_tokens) : len(input_tokens) + output_token + 1
            ]
            all_tokens_2 = all_tokens[len(input_tokens) : len(input_tokens) + output_token + 1]
            if plot:
                # print("Head " + str(head))
                data_1 = pd.DataFrame({"attention flow encoder": result_list_1}, all_tokens_1)
                data_2 = pd.DataFrame({"attention flow decoder": result_list_2}, all_tokens_2)
                sns.set(font="Times New Roman", font_scale=1.3)
                plt.subplots(figsize=(3, len(input_tokens) / 2), dpi=200)
                sns.heatmap(data_1, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
                #sns.heatmap(data_2, annot=True, linewidth=0.2, cmap="YlGnBu", cbar=True)
                plt.show()
                plt.subplots(figsize=(3, (output_token + 1) / 2), dpi=200)
                sns.heatmap(data_2, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
                plt.show()
        print("result list in attention flow")
        print(result_list)
        return_list.append(result_list)
    if output_token == -1:
        output_token = len(pred_tokens)
    shapley_list = []
    # print("Sum of all heads")

    attention_dict_list = sum_up_heads(attention_dict_list)
    if decoder_only:
        shapley_list.append(
            calculate_shapley_decoder_only(
                attention_dict_list,
                input_tokens,
                pred_tokens,
                layers_dec,
                0,
                output_token=output_token,
                show=False,
                plot=False,
            )
        )
        result_list = []
        i = 0
        for n in all_tokens:
            result_list.append(shapley_list[0][(n, i)])
            i = i + 1
        if plot:
            del result_list[len(input_tokens) + output_token + 1 :]
            del all_tokens[len(input_tokens) + output_token + 1 :]
            data = pd.DataFrame({"attention flow": result_list}, all_tokens)
            sns.set(font="Times New Roman", font_scale=1.3)
            plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=200)
            sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
            plt.show()
    elif encoder_only:
        shapley_list.append(
            calculate_shapley_encoder_only(
                attention_dict_list,
                input_tokens,
                layers_enc,
                0,
                output_token=output_token,
                show=False,
                plot=False,
            )
        )
        result_list = []
        i = 0
        for n in all_tokens:
            result_list.append(shapley_list[0][(n, i)])
            i = i + 1
        if plot:
            # print(shapley_list)
            # print(shapley_list)
            del result_list[len(input_tokens) :]
            del all_tokens[len(input_tokens) :]
            data = pd.DataFrame({"attention flow": result_list}, all_tokens)
            plt.subplots(figsize=(3, len(all_tokens) / 2), dpi=200)
            sns.set(font="Times New Roman", font_scale=1.3)
            sns.heatmap(data, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
            plt.show()

    else:
        shapley_list.append(
            calculate_shapley(
                attention_dict_list,
                layers_enc,
                layers_dec,
                input_tokens,
                pred_tokens,
                0,
                output_token=output_token,
                show=False,
                plot=False,
            )
        )
        result_list = []
        i = 0
        for n in all_tokens:
            result_list.append(shapley_list[0][(n, i)])
            i = i + 1
        result_list_1 = result_list[: len(input_tokens)]
        all_tokens_1 = all_tokens[: len(input_tokens)]
        result_list_2 = result_list[len(input_tokens) : len(input_tokens) + output_token + 1]
        all_tokens_2 = all_tokens[len(input_tokens) : len(input_tokens) + output_token + 1]
        if plot:
            data_1 = pd.DataFrame({"attention flow encoder": result_list_1}, all_tokens_1)
            data_2 = pd.DataFrame({"attention flow decoder": result_list_2}, all_tokens_2)
            plt.subplots(figsize=(3, len(input_tokens) / 2), dpi=200)
            sns.set(font="Times New Roman", font_scale=1.3)
            sns.heatmap(
                data_1,
                annot=True,
                linewidth=0.2,
                cmap="coolwarm",
                cbar=False,
                fmt=".4f",
            )
            plt.show()
            plt.subplots(figsize=(3, (output_token + 1) / 2), dpi=200)
            sns.heatmap(data_2, annot=True, linewidth=0.2, cmap="crest", cbar=False, fmt=".4f")
            plt.show()
    return return_list
