import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import numpy as np
import torch
from tqdm import tqdm

#Create Plot for a result-dict
def create_mse_plot_picture(result_path, result_dict, result_dict_method, trashhold_value=0, x_axis=0):
        #Achtung hier über alle, so später nicht einbauen
        mean = torch.mean(result_dict["anomalie_score"])
        std = torch.std(result_dict["anomalie_score"])
        #trashhold_value = mean + 2*std
        trashhold_value = result_dict_method["trashhold_value"]
        print("Mean: ", mean, " Std: ", std)
        if x_axis == 0:
            x = range(len(result_dict["anomalie_score"]))
        else:
            x = x_axis
        plt.figure(figsize=(8, 5))  # Größe des Plots festlegen
        plt.plot(x, result_dict["anomalie_score"], label="Werte", linestyle="-", color="b")  # Blaue Linie mit Punkten
        plt.axhline(y=trashhold_value, color="r", linestyle="--", label=f"Threshold = {trashhold_value}")  # Rote gestrichelte Linie
        plt.xlabel("Index")
        plt.ylabel("Wert")
        # plt.ylim(0,50)
        plt.title("Liniendiagramm mit Schwellenwert")
        plt.legend()
        plt.savefig(result_path + result_dict["dataset"] + "_" + str(result_dict["save_epoch"]) + "_line_plot.png", dpi=300, bbox_inches="tight")
        plt.clf()
        print("Bild gespeichert'")


def create_plot_interactive(result_path, result_dict, trashhold_value=0, x_axis=0):
        # Create Figure
        if x_axis == 0:
            x = list(range(len(result_dict["anomalie_score"])))
        else:
            x = x_axis
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=x, y=result_dict["anomalie_score"], mode='lines', name='Line'))

        # Add Titles
        fig.update_layout(title="Interactive Line Chart",
                          xaxis_title="X Axis",
                          yaxis_title="Y Axis")

        # Save as HTML
        fig.write_html(result_path + result_dict["dataset"] + "_" + str(result_dict["save_epoch"]) + "_interactive_plot.html")

def create_plot_heatmap(result_path, result_dict):
    data = result_dict["anomalie_score_all_features"]
    #data_log = np.log1p(data)
    data_transposed = data.T
    data = data_transposed
    sns.heatmap(data, cmap="gray_r")
    file_path = result_path + result_dict["dataset"] + "_" + str(result_dict["save_epoch"]) + "_heatmap.png"
    plt.savefig(file_path, dpi=300, bbox_inches="tight")
    plt.clf()

