import logging
import math
import numpy as np
import matplotlib.pyplot as plt
import torch
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, \
    f1_score, roc_auc_score
from tabulate import tabulate


def calculate_accuracy(outputs, labels, threshold=0.5):
    """
    Calculate the accuracy for multilabel classification.
    outputs: Model output logits [batch_size, num_labels]
    labels: Ground truth labels [batch_size, num_labels]
    threshold: Threshold to convert probabilities to binary predictions
    """
    # Convert logits to probabilities
    probs = torch.sigmoid(outputs)  # [batch_size, num_labels]

    # Convert probabilities to binary predictions
    threshold = torch.mean(probs).item()
    # threshold = 0.5
    preds = (probs > threshold).float()  # [batch_size, num_labels]

    # Compare predictions with labels
    correct = (preds == labels).float()  # [batch_size, num_labels]

    # Calculate accuracy per label
    accuracy_per_label = correct.mean(dim=0)  # [num_labels]

    # Calculate overall accuracy (mean across all labels)
    overall_accuracy = correct.mean().item()

    return overall_accuracy, accuracy_per_label, preds, probs


def retain_first_jump(label_token):

    batch_size, seq_len, num_features = label_token.shape

    first_one_pos = (label_token == 1).to(torch.int).cumsum(dim=1) == 1

    processed_label = (label_token * first_one_pos).float()

    all_zero_or_all_one = (label_token.sum(dim=1) == 0) | (label_token.sum(dim=1) == seq_len)  
    all_zero_or_all_one = all_zero_or_all_one.unsqueeze(1).expand(-1, seq_len, -1)  

    processed_label[all_zero_or_all_one] = 0.0  

    return processed_label



def retain_first_jump_v2(label_token):
   
    batch_size, seq_len, num_features = label_token.shape


    processed_label = label_token.clone()


    padding_mask = (label_token == -1) 


    feature_mask = torch.ones(num_features, dtype=torch.bool, device=label_token.device) 
    feature_mask[[0, 1, 15, 16]] = False 


    valid_mask = (label_token != -1) 
    first_one_pos = ((label_token == 1) & valid_mask).to(torch.int).cumsum(dim=1) == 1


    processed_label[:, :, feature_mask] = (label_token * first_one_pos)[:, :, feature_mask]


    processed_label[padding_mask] = -1

    return processed_label

def plot_lines(lines_data,
                    xlabel="",
                    ylabel_left="",
                    ylabel_right=None,
                    title="",
                    figsize=(8, 6),
                    label_size=14,
                    tick_size=12,
                    title_size=16,
                    legend_size=15,
                    dpi=500,
                    save_string="xts_plot_lines.jpg",
                    show_plot=True):
    """
        Plots line charts with optional dual y-axes, allowing customization for each line's style, label, and marker size.

        Parameters:
        - lines_data: A list of dictionaries, each representing a series of data points to be plotted. Each dictionary should
                      include 'x' (x-axis data points), 'y' (y-axis data points), 'style' (optional, matplotlib style string
                      such as 'o-g' for green circles), 'label' (optional, label for the line in legend),
                      'y_axis' (optional, 'left' or 'right' to specify which y-axis to use), and
                      'marker_size' (optional, size of the markers).
e.g.,for single y axis:
 lines_data = [
    {'x': range(1, 4), 'y': [5, 9, 4], 'label': 'Data Set 1', 'style': 'o-r', 'marker_size': 5},
    {'x': range(1, 6), 'y': [3, 5, 8, 3, 5], 'label': 'Data Set 2', 'style': 'x--b', 'marker_size': 6},
    {'x': range(1, 6), 'y': [2, 3, 2, 4, 3], 'label': 'Data Set 3', 'style': 's:g', 'marker_size': 4},
]
xts_plot_lines(lines_data, xlabel="X Axis", ylabel_left="Y Axis", title="Complex Line Chart", save_string="complex_line_chart.jpg")

for dual y axis: (make sure the y_axis is specified, and the value can only be 'left' or 'right', and in the xts_plot_lines function make sure ylabel_right is not None!)
lines_data = [
    {'x': range(1, 4), 'y': [5, 9, 4], 'label': 'Data Set 1', 'style': 'o-r', 'marker_size': 5, 'y_axis': 'left'},
    {'x': range(1, 6), 'y': [3, 5, 8, 3, 5], 'label': 'Data Set 2', 'style': 'x--b', 'marker_size': 6, 'y_axis': 'right'},
    {'x': range(1, 6), 'y': [2, 3, 2, 4, 3], 'label': 'Data Set 3', 'style': 's:g', 'marker_size': 4, 'y_axis': 'left'},
]
xts_plot_lines(lines_data, xlabel="X Axis", ylabel_left="Y Axis", ylabel_right="Y test", title="Complex Line Chart", save_string="complex_line_chart.jpg")

        - xlabel: Label for the x-axis.
        - ylabel_left: Label for the primary (left) y-axis.
        - ylabel_right: Label for the secondary (right) y-axis. If provided, enables dual y-axis feature.
        - title: Title of the plot.
        - figsize: Tuple specifying the width and height of the figure in inches.
        - label_size: Font size for the axis labels.
        - tick_size: Font size for the axis tick labels.
        - title_size: Font size for the plot title.
        - legend_size: Font size for the legend.
        - save_string: File path or name to save the plot image.
        - show_plot: If True, displays the plot window.

        Returns:
        None. The function saves the plot to the specified path and optionally displays it.
        """
    # Create the figure and primary y-axis
    fig, ax_left = plt.subplots(figsize=figsize)
    # Create secondary y-axis if specified
    if ylabel_right:
        ax_right = ax_left.twinx()
        ax_right.set_ylabel(ylabel_right, size=label_size)
    else:
        ax_right = None

    # Plot lines and collect handles and labels for the legend
    handles, labels = [], []
    for line in lines_data:
        # Determine the appropriate y-axis for the current line
        ax = ax_left if line.get('y_axis', 'left') == 'left' else ax_right
        # Plot the line using the specified or default style and marker size
        ln, = ax.plot(line['x'], line['y'], line.get('style', '*-r'), label=line.get('label', ''),
                      markersize=int(line.get('marker_size', 5)))
        # line.get('style', '*-r') is equivalent to line['style'] if 'style' in line else '*-r'
        # Collect the handle and label for the legend
        handles.append(ln)
        labels.append(line.get('label', ''))

    # Set axis labels and tick label sizes
    ax_left.set_xlabel(xlabel, size=label_size)
    ax_left.set_ylabel(ylabel_left, size=label_size)
    ax_left.tick_params(axis='x', labelsize=tick_size)
    ax_left.tick_params(axis='y', labelsize=tick_size)
    if ax_right:
        ax_right.tick_params(axis='y', labelsize=tick_size)

    # Using ax_left for the legend. Matplotlib will automatically place it in the best location
    ax_left.legend(handles, labels, fontsize=legend_size)

    # Set the plot title and save the figure to file
    plt.title(title, size=title_size)
    plt.savefig(save_string, dpi=dpi, bbox_inches="tight")

    # Show the plot if requested
    if show_plot:
        plt.show()
