import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import re
import matplotlib.colors as mcolors
import matplotlib.cm as cm
PERCENTS = [20,30,40,60,80,100]
def parse_text(text):
    # Use regex to extract the interto value, percentage, and mse
    pattern = r"traffic_(\d+).*?percentIs(\d*).*?mse:(\d+\.\d+)"
    matches = re.findall(pattern, text, re.DOTALL)
    final_list = [(int(m[0]), int(m[1]), float(m[2])) for m in matches]
    new_final_list = []
    for answer_tuple in final_list:
        if answer_tuple[1] in PERCENTS:
        # if answer_tuple[0] >= 128 and answer_tuple[0] <= 1600 and answer_tuple[1] != 80 and answer_tuple[1] != 50 and answer_tuple[0] != 336  and answer_tuple[1] >= 11 and (answer_tuple[1] >= 20 or answer_tuple[0] <= 2000) and (answer_tuple[1] > 11 or answer_tuple[0] <= 1000):
            new_final_list.append(answer_tuple)
    # print(new_final_list)
    return new_final_list

def read_data(file_path):
    with open(file_path, 'r') as file:
        text = file.read()
    results = parse_text(text)
    return pd.DataFrame(results, columns=['Interto', 'Percentage', 'MSE'])

def plot_data(df):
    # Setup the colormap
    norm = mcolors.LogNorm(vmin=df['Percentage'].min(), vmax=df['Percentage'].max())
    scalar_map = cm.ScalarMappable(norm=norm, cmap=cm.viridis)

    plt.figure(figsize=(10, 6))
    grouped = df.groupby('Percentage')

    for percentage, group in grouped:
        sorted_group = group.sort_values('Interto')
        mean_mse = sorted_group.groupby('Interto')['MSE'].mean()
        # std_mse = sorted_group.groupby('Interto')['MSE'].std()
        intertos = mean_mse.index
        color = scalar_map.to_rgba(percentage)
        plt.errorbar(intertos, mean_mse, yerr=0.0, fmt='o-', label=f'{percentage}%', color=color)

    plt.title('MSE by Horizon. Model: iTransformer. Dataset: traffic. Pred len: 192.')
    plt.xlabel('Horizon')
    # plt.xscale('log')
    plt.ylabel('MSE')
    # plt.xlim(96,1152)
    plt.ylim(0.365,0.45)
    # plt.ylim(0.21, 0.26)
    plt.xticks([128, 256, 512,768,1024], ['128', '256', '512', '768','1024'])

    # Colorbar with custom ticks
    cbar = plt.colorbar(scalar_map, label='Percent of Training Data Used')
    tick_locs = np.unique(df['Percentage'])  # Unique percentage values
    cbar.set_ticks(tick_locs)
    cbar.set_ticklabels(tick_locs)
    
    plt.grid(True)
    plt.legend(title='', loc='upper right')
    plt.savefig("newresult_iTF_traffic_horizonXDatascaling___byhorizon.png")
    
    

if __name__ == "__main__":
    file_path = 'newresult_iTF_traffic_HorizonXData.txt'
    df = read_data(file_path)
    plot_data(df)



# import matplotlib.pyplot as plt
# import pandas as pd
# import numpy as np
# import re
# import matplotlib.colors as mcolors
# import matplotlib.cm as cm

# def parse_text(text):
#     # Use regex to extract the interto value, percentage, and mse
#     pattern = r"interto(\d+).*?_p(\d+).*?mse:(\d+\.\d+)"
#     matches = re.findall(pattern, text, re.DOTALL)
#     final_list = [(int(m[0]), int(m[1]), float(m[2])) for m in matches]
#     new_final_list = []
#     for answer_tuple in final_list:
#         if answer_tuple[0] >= 128 and answer_tuple[0] <= 2000 and answer_tuple[1] >= 11 and answer_tuple[1] != 80 and (answer_tuple[1] >= 20 or answer_tuple[0] <= 1000) and (answer_tuple[1] >= 12 or answer_tuple[0] <= 800):
#         # if answer_tuple[0] >= 32 and answer_tuple[0] <= 768 and answer_tuple[0] != 3 and answer_tuple[0] != 6 and answer_tuple[0] != 336 and answer_tuple[0] != 512:
#             new_final_list.append(answer_tuple)
    
    
    
#     return new_final_list

# def read_data(file_path):
#     with open(file_path, 'r') as file:
#         text = file.read()
#     results = parse_text(text)
#     return pd.DataFrame(results, columns=['Interto', 'Percentage', 'MSE'])

# def plot_data(df):
#     # Setup the colormap
#     norm = mcolors.LogNorm(vmin=df['Percentage'].min(), vmax=df['Percentage'].max())
#     scalar_map = cm.ScalarMappable(norm=norm, cmap=cm.viridis)

#     plt.figure(figsize=(10, 6))
#     grouped = df.groupby('Percentage')
    
#     for percentage, group in grouped:
#         group.sort_values('Interto', inplace=True)
#         color = scalar_map.to_rgba(percentage)
#         plt.plot(group['Interto'], group['MSE'], label=f'Percentage {percentage}', marker='o', markersize=4, color=color)
    
#     plt.title('MSE by Percentage and horizon: etth1 Horizon -> 336')
#     plt.xlabel('horizon')
#     # use log scale for x-axis
#     plt.xscale('log')
#     plt.ylabel('MSE')
    
#     # Colorbar with custom ticks
#     cbar = plt.colorbar(scalar_map, label='Percentage of Data Used')
#     tick_locs = np.unique(df['Percentage'])  # Unique percentage values
#     cbar.set_ticks(tick_locs)
#     cbar.set_ticklabels(tick_locs)
    
#     plt.grid(True)
#     plt.legend(title='Percentage', loc='upper right')
#     plt.savefig("MSE_vs_horizon_ETTh1.png")

# if __name__ == "__main__":
#     file_path = 'newresult_ETTh1_horizon.txt'
#     df = read_data(file_path)
#     plot_data(df)


