from matplotlib import colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
from typing import List
from html2image import Html2Image


def plot_attribution_heatmap(attributions: np.array, tokens: List[str], roi_names: str, output_file: Path):
    plt.figure(figsize=(10, 6))
    sns.heatmap(attributions, annot=True, fmt=".2f", cmap="viridis",
                xticklabels=roi_names, yticklabels=tokens)
    plt.xlabel("ROI")
    plt.ylabel("Words")
    plt.title("Attribution Heatmap")
    plt.savefig(output_file)

def plot_token_attr(words_attrs: np.array, tokens: List[str], roi_names: str, output_file: Path):
    """
    Generate a matplotlib plot for visualising the attribution
    of the output tokens.

    Args:
        show (bool): whether to show the plot directly or return the figure and axis
            Default: False
    """

    # maximum absolute attribution value
    # used as the boundary of normalization
    # always keep 0 as the mid point to differentiate pos/neg attr
    max_abs_attr_val = np.max(np.abs(words_attrs))

    fig, ax = plt.subplots()

    # Hide the grid
    ax.grid(False)

    fig.set_size_inches(
        max(words_attrs.shape[1] * 1.3, 6.4), max(words_attrs.shape[0] / 2.5, 4.8)
    )
    colors = [
        "#93003a",
        "#d0365b",
        "#f57789",
        "#ffbdc3",
        "#ffffff",
        "#a4d6e1",
        "#73a3ca",
        "#4772b3",
        "#00429d",
    ]

    im = ax.imshow(
        words_attrs,
        vmax=max_abs_attr_val,
        vmin=-max_abs_attr_val,
        cmap=mcolors.LinearSegmentedColormap.from_list(
            name="colors", colors=colors
        ),
        aspect="auto",
    )
    fig.set_facecolor("white")

    # Create colorbar
    cbar = fig.colorbar(im, ax=ax)  # type: ignore
    cbar.ax.set_ylabel("Word Attribution", rotation=-90, va="bottom")

    # Show all ticks and label them with the respective list entries.
    ax.set_xticks(np.arange(words_attrs.shape[1]), labels=roi_names)
    ax.set_yticks(np.arange(words_attrs.shape[0]), labels=tokens)

    # Let the horizontal axes labeling appear on top.
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    # Rotate the tick labels and set their alignment.
    plt.setp(ax.get_xticklabels(), rotation=-30, ha="right", rotation_mode="anchor")

    # Loop over the data and create a `Text` for each "pixel".
    # Change the text's color depending on the data.
    for i in range(words_attrs.shape[0]):
        for j in range(words_attrs.shape[1]):
            val = words_attrs[i, j]
            color = "black" if 0.2 < im.norm(val) < 0.8 else "white"
            im.axes.text(
                j,
                i,
                "%.4f" % val,
                horizontalalignment="center",
                verticalalignment="center",
                color=color,
            )

    plt.xlabel("ROI")
    plt.ylabel("Words")
    plt.title("Attribution Heatmap")
    plt.tight_layout()
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()

def save_highlighted_html(attributions, tokens, roi_names, output_file):
    """
    Saves HTML with attributions highlighted for all ROIs on the same page.
    Uses a continuous color scale from -max_abs_attr (blue) to max_abs_attr (red).
    Adjusted for tighter horizontal spacing.

    Args:
        attributions (np.array): Shape (L_ext, R) - normalized word attributions per ROI.
        tokens (List[str]): List of words in the sentence.
        roi_names (List[str]): Names of each ROI (length R).
        output_file (str or Path): Path to save the HTML file.
    """
    max_abs_attr = np.max(np.abs(attributions))  # Get max absolute attribution value

    html_str = f"""<html>
    <head>
        <style>
            body {{ font-family: Arial, sans-serif; font-size: 16px; }} /* Increased base font size */
            .roi-section {{ margin-bottom: 10px; }} /* Reduce spacing between ROIs */
            .roi-name {{ font-size: 18px; font-weight: bold; color: #2c3e50; margin-bottom: 3px; }} /* Larger ROI titles */
            .word {{ padding: 2px; margin: 0px; border-radius: 2px; display: inline-block; font-size: 18px; }} /* Larger words, tighter spacing */
            .sentence {{ line-height: 1.6; margin-bottom: 3px; }} /* Reduce spacing between lines */
            .legend-container {{ margin-top: 15px; text-align: center; font-size: 14px; }} /* Adjust legend size */
            .gradient-bar {{
                width: 80%; height: 15px; background: linear-gradient(to right, blue, white, red);
                margin: auto; border: 1px solid #000;
            }}
            .legend-labels {{
                display: flex; justify-content: space-between; width: 80%; margin: auto;
            }}
        </style>
    </head>
    <body>"""

    # Loop through all ROIs and create a section for each
    for roi_idx, roi_name in enumerate(roi_names):
        html_str += f"<div class='roi-section'><div class='roi-name'>{roi_name}</div>"
        html_str += "<p class='sentence'>"

        # Highlight words for each ROI using a continuous color scale
        for token, attr in zip(tokens, attributions[:, roi_idx]):
            color = f"rgb({255 if attr > 0 else 0}, 0, {255 if attr < 0 else 0}, {abs(attr) / max_abs_attr})"
            
            html_str += f"<span class='word' style='background-color:{color};'>{token}</span> "

        html_str += "</p></div>"

    # Add continuous color scale legend at the bottom
    html_str += f"""<div class='legend-container'>
        <h4>Attribution Color Scale</h4>
        <div class="gradient-bar"></div>
        <div class="legend-labels">
            <span>-{max_abs_attr:.2f}</span>
            <span>0</span>
            <span>+{max_abs_attr:.2f}</span>
        </div>
    </div>"""

    html_str += "</body></html>"

    # Save the generated HTML file
    with open(output_file, "w", encoding="utf-8") as f:
        f.write(html_str)


def plot_grouped_bar_attributions_per_feature(discourse_features, brain_regions, total_attrs, positive_attrs, negative_attrs, output_file):
    """
    Creates grouped bar plots for discourse features showing total, positive, and negative attributions
    for all brain regions.

    Parameters:
    - discourse_features (list):    List of discourse feature names (e.g., ['NNSE', 'Speech', 'Emotion']).
    - brain_regions (list):         List of brain region names (e.g., ['Region1', 'Region2', 'Region3']).
    - total_attrs (np.array):       Array of shape (S, F, R) with total attributions per brain region for a 
                                    certain discourse feature for a given subject.
    - positive_attrs (np.array):    Array of shape (S, F, R) with positive attributions per brain region for a 
                                    certain discourse feature for a given subject.
    - negative_attrs (np.array):    Array of shape (S, F, R) with negative attributions per brain region for a 
                                    certain discourse feature for a given subject.
    """
    # Compute mean and standard deviation across subjects
    total_attrs_mean = np.mean(total_attrs, axis=0)
    positive_attrs_mean = np.mean(positive_attrs, axis=0)
    negative_attrs_mean = np.mean(negative_attrs, axis=0)
    total_attrs_std = np.std(total_attrs, axis=0)
    positive_attrs_std = np.std(positive_attrs, axis=0)
    negative_attrs_std = np.std(negative_attrs, axis=0)

    print(total_attrs_mean.shape, positive_attrs_mean.shape, negative_attrs_mean.shape)

    # Define bar width and positions
    bar_width = 0.2  # Width of each bar group
    x = np.arange(len(brain_regions))  # x positions for bars
    num_rows = len(discourse_features) // 2 + len(discourse_features) % 2

    # Create a plot for each discourse feature
    fig, ax = plt.subplots(num_rows, 2, figsize=(14, num_rows * 5))
    for feature_idx, feature in enumerate(discourse_features):
        row = feature_idx // 2
        col = feature_idx % 2
        ax[row, col].bar(x-bar_width, total_attrs_mean[feature_idx], bar_width, yerr=total_attrs_std[feature_idx], label='Total Attributions', color='lightgreen', capsize=5)
        ax[row, col].bar(x, positive_attrs_mean[feature_idx], bar_width*0.5, yerr=positive_attrs_std[feature_idx], label='Positive Attributions', color='lightblue', capsize=5)
        ax[row, col].bar(x+bar_width, negative_attrs_mean[feature_idx], bar_width*0.5, yerr=negative_attrs_std[feature_idx], label='Negative Attributions', color='lightcoral', capsize=5)

        # Customize the plot
        ax[row, col].set_xlabel('Brain Regions', fontsize=12)
        ax[row, col].set_ylabel('Attribution', fontsize=12)
        ax[row, col].set_title(f'Attribution Values for {feature}', fontsize=14)
        ax[row, col].set_xticks(x)
        ax[row, col].set_xticklabels(brain_regions, rotation=45, ha='right')
        ax[row, col].legend(fontsize=10)

        # Add gridlines and improve layout
        ax[row, col].grid(axis='y', linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.savefig(output_file)
    plt.close()