import io

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import seaborn as sns
import torch
from PIL import Image
from rdkit import Chem
from rdkit.Chem import Draw
from scipy.stats import pearsonr, spearmanr

from graphsmodel import SUBPLOT_HEIGHT, SUBPLOT_WIDTH, palette

from .text_utils import find_common_keywords


def correlation_plot(preds, y_test, title, ax=None):
    num_nodes = preds.shape[1]
    pear = pearsonr(preds.flatten(), y_test.flatten())
    spear = spearmanr(preds.flatten(), y_test.flatten())

    # Create scatter plot with appropriate color mapping
    ax.scatter(
        preds.flatten(),
        y_test.flatten(),
        c=torch.arange(num_nodes).repeat(preds.shape[0]),
    )

    # Set title
    ax.set_title(f"{title} \nPearson: {pear[0]:.2f} - Spearman: {spear[0]:.2f}")


def plot_smiles(smiles, ax):
    mol = Chem.MolFromSmiles(smiles)
    mol_image = Draw.MolToImage(mol)

    image_bytes = io.BytesIO()
    mol_image.save(image_bytes, format="PNG")
    image_bytes.seek(0)
    img = Image.open(image_bytes)

    ax.imshow(img)
    ax.axis("off")


def plot_ego(
    edge_index,
    selected_node,
    train_mask,
    importance,
    ax,
    overlapping_nodes=None,
    target_node=None,
    node_size=50,
):
    G = nx.Graph()

    edge_list = edge_index.t().tolist()
    G.add_edges_from(edge_list)

    pos = nx.spring_layout(G)

    nx.draw_networkx_edges(G, pos, edge_color=palette["tail:grey"], alpha=0.5, ax=ax)

    node_radius = np.sqrt(node_size * 1e-4 / np.pi)

    for node in G.nodes():
        x, y = pos[node]
        ec = palette["tail:gold"] if node == selected_node else "k"
        lc = 1 if node == selected_node else 0.5
        hatch = "....." if train_mask[node] else ""
        if train_mask[node]:
            color = (
                palette["tail:red"] if importance[node] <= 0 else palette["tail:green"]
            )
        else:
            color = palette["tail:skyblue"]
            if overlapping_nodes is not None and node in overlapping_nodes:
                color = palette["tail:pink"]
        if target_node is not None and node == target_node:
            ax.scatter(
                x,
                y,
                s=node_size * 10,
                marker="*",
                color=color,
                zorder=3,
                linewidths=0.3,
                edgecolors="k",
            )
        else:
            circle = plt.Circle(
                (x, y),
                radius=node_radius,
                color=color,
                ec=ec,
                hatch=hatch,
                linewidth=lc,
            )
            ax.add_patch(circle)

    ax.axis("off")
    ax.axis("equal")


def plot_shortest_paths(selected_node, training_nodes, shortest_paths, importance, ax):

    G = nx.Graph()

    for path in shortest_paths:
        for i in range(len(path) - 1):
            G.add_edge(path[i], path[i + 1])
    pos = nx.spring_layout(G)

    colors = np.full(importance.size, palette["tail:skyblue"])
    colors[training_nodes] = np.where(
        importance[training_nodes] <= 0, palette["tail:red"], palette["tail:green"]
    )
    colors = colors[list(G.nodes())]

    nx.draw_networkx_edges(G, pos, edge_color=palette["tail:grey"], alpha=0.5, ax=ax)
    nx.draw_networkx_nodes(
        G, pos, node_color=colors, node_size=50, edgecolors="k", linewidths=0.3, ax=ax
    )
    nx.draw_networkx_nodes(
        G,
        pos,
        nodelist=[selected_node],
        node_color=palette["tail:skyblue"],
        node_size=50,
        edgecolors=palette["tail:gold"],
        linewidths=1,
        ax=ax,
    )

    test_leg = mpatches.Patch(color=palette["tail:skyblue"], label="test")
    positive_leg = mpatches.Patch(color=palette["tail:green"], label="positive")
    negative_leg = mpatches.Patch(color=palette["tail:red"], label="negative")
    target_leg = mpatches.Patch(
        facecolor="white", edgecolor=palette["tail:gold"], label="ego"
    )

    ax.set_title("Shortest paths")
    ax.legend(handles=[test_leg, positive_leg, negative_leg, target_leg], fontsize=6)
    ax.axis("off")
    ax.axis("equal")


def plot_simple_paths(
    selected_node, training_nodes, shortest_paths, path_counts, importance, ax
):

    G = nx.Graph()
    for path in shortest_paths:
        for i in range(len(path) - 1):
            G.add_edge(path[i], path[i + 1])

    pos = nx.spring_layout(G)

    colors = np.full(importance.size, palette["tail:skyblue"])
    colors[training_nodes] = np.where(
        importance[training_nodes] <= 0, palette["tail:red"], palette["tail:green"]
    )

    nx.draw_networkx_edges(G, pos, edge_color=palette["tail:grey"], alpha=0.5, ax=ax)

    nx.draw_networkx_nodes(
        G,
        pos,
        nodelist=training_nodes,
        node_size=[max(50, 2 * path_counts[n]) for n in training_nodes],
        node_color=colors[training_nodes],
        edgecolors="k",
        linewidths=0.3,
        alpha=0.8,
        ax=ax,
    )
    nx.draw_networkx_nodes(
        G,
        pos,
        nodelist=[selected_node],
        node_color=palette["tail:skyblue"],
        node_size=50,
        edgecolors=palette["tail:gold"],
        linewidths=1,
        ax=ax,
    )

    nx.draw_networkx_labels(G, pos, path_counts, font_size=3, font_color="k", ax=ax)

    positive_leg = mpatches.Patch(color=palette["tail:green"], label="positive")
    negative_leg = mpatches.Patch(color=palette["tail:red"], label="negative")
    target_leg = mpatches.Patch(
        facecolor="white", edgecolor=palette["tail:gold"], label="ego"
    )

    ax.set_title("Number of paths")
    ax.legend(handles=[positive_leg, negative_leg, target_leg], fontsize=6)
    ax.axis("off")
    ax.axis("equal")


def plot_simple_paths_classes(
    selected_node,
    training_nodes,
    shortest_paths,
    path_counts,
    classes,
    idx_to_class,
    ax,
):

    G = nx.Graph()
    for path in shortest_paths:
        for i in range(len(path) - 1):
            G.add_edge(path[i], path[i + 1])

    pos = nx.spring_layout(G)

    unique_classes = list(idx_to_class.keys())
    color_keys = list(palette.keys())
    class_to_color = {
        cls: palette[color_keys[i % len(color_keys)]]
        for i, cls in enumerate(unique_classes)
    }

    colors = np.array([class_to_color[cls.item()] for cls in classes])

    nx.draw_networkx_edges(G, pos, edge_color=palette["tail:grey"], alpha=0.5, ax=ax)

    nx.draw_networkx_nodes(
        G,
        pos,
        nodelist=training_nodes,
        node_size=[max(50, 2 * path_counts[n]) for n in training_nodes],
        node_color=colors[training_nodes],
        edgecolors="k",
        linewidths=0.3,
        alpha=0.8,
        ax=ax,
    )
    nx.draw_networkx_nodes(
        G,
        pos,
        nodelist=[selected_node],
        node_color=colors[selected_node],
        node_size=50,
        edgecolors=palette["tail:gold"],
        linewidths=1,
        ax=ax,
    )

    target_leg = mpatches.Patch(
        facecolor="white", edgecolor=palette["tail:gold"], label="ego"
    )

    ax.set_title("Classes")
    ax.legend(
        handles=[
            mpatches.Patch(color=color, label=f"{idx_to_class[cls]}")
            for cls, color in class_to_color.items()
        ]
        + [target_leg],
        fontsize=6,
        bbox_to_anchor=(1, 1),
    )
    ax.axis("off")
    ax.axis("equal")


def tsne_plot(df, axes, selected_node=None, train=False):
    sns_palette = {color: palette[f"tail:{color}"] for color in df["colors"].unique()}
    unique_styles = df["style"].unique()
    style_order = {style: style for style in unique_styles}

    sns.scatterplot(
        data=df[~df["train"]],
        x="x",
        y="y",
        style="style",
        hue="colors",
        edgecolors="k",
        linewidths=0.3,
        alpha=0.1,
        palette=sns_palette,
        ax=axes,
        style_order=style_order,
    )

    if train:
        sns.scatterplot(
            data=df[df["train"]],
            x="x",
            y="y",
            style="style",
            hue="colors",
            edgecolors="k",
            linewidths=0.3,
            alpha=0.9,
            palette=sns_palette,
            ax=axes,
            style_order=style_order,
            legend=False,
        )

    if selected_node is not None:
        sns.scatterplot(
            x=df["x"][selected_node],
            y=df["y"][selected_node],
            style=[df["style"][selected_node]],
            color=palette["tail:grey"],
            edgecolors=palette["tail:gold"],
            linewidths=1,
            alpha=0.9,
            ax=axes,
            style_order=style_order,
            legend=False,
        )


def plot_common_words_graph(
    training_nodes,
    training_shortest_paths,
    selected_node,
    x,
    idx_to_attr,
    colors,
    ax,
    top_k=10,
):
    common_words_graph = nx.Graph()
    for training_node in training_nodes:
        common_words_graph.add_edge(
            selected_node,
            training_node,
            common_keywords=find_common_keywords(
                selected_node, training_node, x, idx_to_attr, top_k
            ),
        )

    pos = nx.spring_layout(common_words_graph)

    nx.draw_networkx_nodes(
        common_words_graph,
        pos,
        nodelist=training_nodes,
        node_size=50,
        node_color=colors[training_nodes],
        edgecolors="k",
        linewidths=0.3,
        alpha=0.8,
        ax=ax,
    )
    nx.draw_networkx_nodes(
        common_words_graph,
        pos,
        nodelist=[selected_node],
        node_color=colors[selected_node],
        node_size=50,
        edgecolors=palette["tail:gold"],
        linewidths=1,
        ax=ax,
    )

    labels = {
        node: length for node, length in zip(training_nodes, training_shortest_paths)
    }
    nx.draw_networkx_labels(common_words_graph, pos, labels, font_size=3, ax=ax)

    for edge in common_words_graph.edges():
        u, v = edge
        common_keywords = common_words_graph.edges[u, v]["common_keywords"]
        nx.draw_networkx_edges(
            common_words_graph,
            pos,
            edgelist=[(u, v)],
            width=len(common_keywords) * 0.3,
            edge_color="k",
            ax=ax,
        )


def plot_losses(model_id, train_trace, val_trace, best_epoch):
    """
    Plots the training and validation losses over epochs for a given model.

    Args:
        model_id (int): The ID of the model.
        train_trace (list): The list of training losses.
        val_trace (list): The list of validation losses.
        best_epoch (int): The epoch number of the best performance.

    Returns:
        matplotlib.figure.Figure: The generated figure object.
    """
    figsize = get_figsize(ncols=1, nrows=1)
    fig, ax = plt.subplots(figsize=figsize)
    ax.plot(train_trace, label="train")
    ax.plot(val_trace, label="val")
    ax.axvline(best_epoch, color="r", linestyle="--", label="early stopping")
    ax.set_xlabel("epoch")
    ax.set_ylabel("loss")
    ax.set_title(f"Model {model_id+1}")

    ax.legend()

    return fig


def get_figsize(nrows, ncols, width_ratio_per_col=None, height_ratio=1):
    """
    Calculate the figsize based on the number of columns, rows, optional width ratios per column,
    and height ratio to maintain a more square aspect for each subplot.

    Parameters:
    nrows (int): The number of rows.
    ncols (int): The number of columns.
    width_ratio_per_col (list or tuple, optional): The width ratio for each column.
    height_ratio (float, optional): The height ratio to adjust subplot height for a more square appearance.

    Returns:
    tuple: The figsize as a tuple of width and height.
    """
    if width_ratio_per_col is None:
        # If no width_ratio is given, assume equal width for all columns
        width_ratio_per_col = [1] * ncols

    # Ensure the width_ratio_per_col list matches the number of columns
    if len(width_ratio_per_col) != ncols:
        raise ValueError(
            "The length of width_ratio_per_col must match the number of columns"
        )

    # Calculate the total width based on the width ratios
    total_width_ratio = sum(width_ratio_per_col)
    fig_width = total_width_ratio * SUBPLOT_WIDTH
    # Adjust height so that each subplot is more square in appearance
    max_width_ratio = max(width_ratio_per_col)
    fig_height = nrows * (SUBPLOT_HEIGHT * height_ratio) * max_width_ratio

    # Set the figsize based on the calculated values
    figsize = (fig_width, fig_height)

    return figsize


def plot_performance(
    ax, data_keys, data_res, linestyles, title, xlabel, ylabel, ylim=(0, 1)
):
    ax.set_title(title)
    for key in data_keys:
        ax.plot(
            data_res[key]["perf"], label=key, linestyle=linestyles[key.split("_")[0]]
        )
    ax.set_xlabel(xlabel)
    ax.set_ylabel(ylabel)
    ax.set_ylim(*ylim)
