import re, glob
import pandas as pd
from helpers import create_folder_by_date
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle
import os


# Set the environment variable KMP_DUPLICATE_LIB_OK to TRUE
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

out_dst = r'../results/'

lr_values = []
mm_values = []


# Check if the output directory exists, if not create it
if not os.path.exists(out_dst):
    try:
        os.makedirs(out_dst)
        print(f"Folder '{out_dst}' created successfully.")
    except OSError as e:
        print(f"Error occurred while creating folder: {e}")
else:
    print(f"Folder '{out_dst}' already exists.")

# Define a regular expression pattern to match numbers with optional decimal points
pattern = r'\d+\.\d+'


def plot(df, title_label, out):
    """
    Function to plot a heatmap of the data, highlight the top 10 and top 1 values, and save the plot to a file.

    Parameters:
    df (DataFrame): The data to be plotted.
    title_label (str): The label to be used in the title of the plot.
    out (str): The path where the plot will be saved.

    Returns:
    list: The top 10 and top 1 learning rate and momentum values.
    """
    max_ten = np.sort(df.to_numpy().ravel())[-10:]
    max_val = max(max_ten)

    # Calculate mean and std values
    mean_value = np.mean(max_ten)
    std_value = np.std(max_ten)

    plt.figure()
    plt.rcParams['font.size'] = 8
    ax = sns.heatmap(df, annot=True, fmt=".2f", vmin=0.0, vmax=100.0, linewidth=.5)
    ax.set(xlabel="Learning rate",
           ylabel="Momentum",
           title=f"{title_label} (Mean: {mean_value:.2f}, Std: {std_value:.2f})")

    top_ten_indices = np.unravel_index(np.argsort(df.values, axis=None)[-10:], df.shape)
    top_ten_lr = df.columns[top_ten_indices[1]]
    top_ten_mm = df.index[top_ten_indices[0]]

    max_value = df.values.max()
    max_indices = np.where(df.values == max_value)
    top_one_lr = df.columns[max_indices[1]]
    top_one_mm = df.index[max_indices[0]]

    # Highlight the top-10 values with rectangles
    for lr, mm in zip(top_ten_lr, top_ten_mm):
        col = df.columns.get_loc(lr)
        row = df.index.get_loc(mm)
        ax.add_patch(Rectangle((col, row), 1, 1, fill=False, edgecolor='green', lw=2))

    # Highlight the top-1 values with rectangles
    loc = np.where(df == max_val)
    for (row, col) in zip(loc[0], loc[1]):
        edge_color = 'blue'
        ax.add_patch(Rectangle((col, row), 1, 1, fill=False, edgecolor=edge_color, lw=2))

    plt.savefig(out, dpi=500)
    plt.close()

    # Return both top-10 and top-1 lr and mm values
    return top_ten_lr, top_ten_mm, top_one_lr, top_one_mm


def extract_numbers(text):
    """
    Function to extract all numbers from a given text using regular expressions.

    Parameters:
    text (str): The text from which to extract numbers.

    Returns:
    list: A list of the extracted numbers as floats.
    """
    matches = re.findall(pattern, text)
    return [float(match) for match in matches]


def extract_values_from_filename(filename):
    """
    Function to extract learning rate and momentum values from a filename using regular expressions.

    Parameters:
    filename (str): The filename from which to extract values.

    Returns:
    tuple: A tuple containing the extracted learning rate and momentum values, or None if no match was found.
    """
    # Using regex to extract values from the filename
    match = re.match(r'mnist_lr(\d+\.\d+)momentum(\d+\.\d+)', filename)
    if match:
        lr = float(match.group(1))
        momentum = float(match.group(2))
        return lr, momentum
    return None, None


if __name__ == '__main__':
    """
    Main entry point of the script. It processes a list of items, each representing a different experiment.
    For each experiment, it reads the results from text files, extracts learning rate and momentum values from the filenames,
    and plots a heatmap of the results. It also highlights the top 10 and top 1 values in the heatmap.
    The heatmap is saved to a new folder created in the base output directory.
    """
           
    lst = ['adam_mnist_100', 'adam_mnist_50', 'sgd_mnist_100', 'sgd_mnist_50']


    # Create a new folder in the base output directory with the current date
    new_folder = create_folder_by_date(out_dst)
    print(f"New folder created: {new_folder}")

    # Lists to store the top learning rate and momentum values for all experiments
    all_top_lr, all_top_mm, all_top_one_lr, all_top_one_mm = [], [], [], []

    # Process each experiment
    for item in lst:
        print(item)
        folder_path = '../runs/' + item + '/'

        # Create empty lists to store extracted data
        data = []
        filenames = []

        # Iterate through all files in the folder
        txt_files = glob.glob(folder_path + '/*/mnist_lr*.txt', recursive=True)
        for filename in txt_files:
            if filename.endswith('.txt'):
                with open(filename, 'r') as file:
                    text = file.read()
                    numbers = extract_numbers(text)
                    lr, momentum = extract_values_from_filename(filename.split('\\')[-1])
                    data.append(numbers)
                    filenames.append((filename, lr, momentum))

        # Create a DataFrame from the collected data
        df = pd.DataFrame(data, columns=['test_mnist'])

        # Add columns from filename data
        df['lr'] = [item[1] for item in filenames]
        df['mm'] = [item[2] for item in filenames]

        # Create a pivot table and sort the index in descending order
        name = folder_path.split('/')[-2]
        temp = df.pivot_table(values="test_mnist", index="mm", columns="lr")
        temp = temp.sort_index(ascending=False)
        temp = temp.round(2)

        # Plot the heatmap and get the top 10 and top 1 learning rate and momentum values
        top_ten_lr, top_ten_mm, top_one_lr, top_one_mm = plot(temp, 'Accuracy', new_folder + '/' + name + '_mnist.jpg')

        # Add the top values to the lists
        all_top_lr.extend(top_ten_lr)
        all_top_mm.extend(top_ten_mm)
        all_top_one_lr.extend(top_one_lr)
        all_top_one_mm.extend(top_one_mm)

        # Get the top 1 value and its location
        top1 = np.sort(temp.to_numpy().ravel())[-1]
        loc = np.where(temp == top1)

        # Add the learning rate and momentum values at the top 1 location to the lists
        lr_values.extend(temp.columns[loc[1]])
        mm_values.extend(temp.index[loc[0]])