import os
import glob
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from datetime import datetime
import torch
import math
import random

def get_latest_folder(folder_names):
    # Convert folder names to datetime objects for comparison
    folder_dates = []
    for folder in folder_names:
        folder = os.path.basename(folder)
        # try:
        folder_time = datetime.strptime(folder, 'testing-%Y-%m-%d %H:%M:%S.%f')
        folder_dates.append((folder, folder_time))
        # except ValueError:
        #     continue

    # Sort folders by timestamp in descending order
    sorted_folders = sorted(folder_dates, key=lambda x: x[1], reverse=True)

    # Return the name of the most recently created folder
    return sorted_folders[0][0] if sorted_folders else None

# model_name = "cifar_c_cnn_ligru_act_0.0002"
model_names = [
    # "cifar100_c_resnet_26_width_128",
    # "cifar100_c_resnet_26_width_256",
    # "cifar100_c_cnn_gru_width_128_act_0.1",
    # "logs_cifar100_c_vit",
    # "cifar100_c_resnet_26_width_256",
    # "cifar100_c_resnet_cifar_width_128_depth_30",
    # "cifar100_c_cnn_ligru_width_128",
    "cifar100_c_cnn_gru_gn_width_128",
]
# model_names = [
#     "tiny_imagenet_c_resnet_tiny_imagenet",
#     "tiny_imagenet_c_cnn_gru_gn",
# ]

model_names = [
    # "resnet_ttt",
    # "cifar_c_resnet_26_width_128",
    # "cifar_c_dt_net_recall_2d_alpha_0.0",
    # "cifar_c_cnn_gru_width_128_act_0.5",
    "cifar_c_cnn_gru_gn_width_128_act_0.5",
    "cifar_c_cnn_gru_gn_width_128_act_0.01",
    # "cifar_c_cnn_gru_width_128_alpha_0.1",
    # "cifar_c_cnn_gru_width_128_alpha_0.5",
    # "cifar100_c_cnn_ligru_width_128_iter_100",
    # "cifar_c_cnn_gru_width_128_alpha_0.1",
    # "cifar_c_cnn_gru_width_128_alpha_0.5",
    # "cifar100_c_cnn_gru_width_256",
    # "cifar100_c_cnn_gru_gn_width_128",
    # "cifar_c_cnn_gru_width_128_self_verify",
    # "cifar_c_cnn_gru_width_128_alpha_0.5"
    # "cifar_c_cnn_ligru_act_0.0002",
    # "cifar_c_cnn_ligru_alpha_0.0",
    # "cifar_c_cnn_gru_width_128",
    "cifar_c_cnn_gru_width_128_thresh_0.03",
    # "cifar_c_cnn_gru_gn_width_128_thresh_0.01",
    # "cifar_c_cnn_gru_width_128_thresh_0.003",
    # "cifar_c_cnn_gru_width_128_thresh_0.001",
    # "cifar_c_cnn_gru_width_128_thresh_0.2",
    # "logs_cifar10_c_vit",
    # "cifar_c_resnet_26_width_128",
    # "cifar_c_resnet_cifar_width_128_depth_30",
    # "cifar_c_cnn_ligru_width_128",
    "cifar_c_cnn_gru_width_128",
    # "cifar_c_resnet_cifar_width_128_depth_30_randomseed",
    # "cifar_c_cnn_ligru_width_128_depth_30_randomseed",
    # "cifar_c_cnn_gru_gn_width_128_depth_30_randomseed",
]

corruptions = [
    # 'Total Noise',
    'Gaussian Noise',
    'Shot Noise',
    'Impulse Noise',
    'Defocus Blur',
    'Glass Blur',
    'Motion Blur',
    'Zoom Blur',
    'Snow',
    'Frost',
    'Fog',
    'Brightness',
    'Contrast',
    'Elastic Transform',
    'Pixelate',
    'JPEG Compression',
    # "stl"
]
round_level = 1

def estimate_t_opt(test_ssh_acc, test_acc=None, acc_thresh=100):
    max_ssh_acc = 0
    max_ssh_iter = 0
    for iter, acc in test_ssh_acc.items():
        if round(acc, round_level) > round(max_ssh_acc, round_level) + 0.1:
            max_ssh_iter = iter
            max_ssh_acc = acc
            if test_acc is not None and acc_thresh < 100:
                if test_acc[iter] > acc_thresh:
                    break

    return max_ssh_iter

def estimate_max_acc(test_acc, acc_thresh=100):
    max_acc = 0
    t_max = 0
    for iter, acc in test_acc.items():
        if round(acc, round_level) > round(max_acc, round_level):
            max_acc = acc
            t_max = iter
            if acc_thresh < 100:
                if max_acc > acc_thresh:
                    break
    return max_acc, t_max

def create_table():
    model_names = [
        "logs_cifar10_c_vit",
        "cifar_c_resnet_26_width_128",
        "cifar_c_cnn_gru_width_128",
    ]
    # model_names = [
    #     "cifar100_c_resnet_cifar_alpha_0.0",
    #     "cifar100_c_cnn_gru_gn_alpha_0.0",
    #     "cifar100_c_cnn_ligru_alpha_0.0",
    # ]
    round_level = 1
    accuracy_data = np.zeros((len(corruptions), len(model_names)))
    avg_acc_list = []
    for i, model_name in enumerate(model_names):
        print("model name: ", model_name)
        avg_acc = 0
        avg_topt = 0
        t_opts = []
        for j, corruption in enumerate(corruptions):
            # print("Corruption types: ", corruption)
            test_dir = "test_" + model_name + "_" + corruption
            # print(test_dir)
            test_dir = os.path.join("outputs", test_dir, get_latest_folder(glob.glob(os.path.join("outputs", test_dir, "test*"))))
            res_json = os.path.join(test_dir, "stats.json")
            stats = json.load(open(res_json))
            test_acc = stats["test_acc"]
            test_ssh_acc = stats["test_ssh_acc"]
            
            t_opt = estimate_t_opt(test_ssh_acc=test_ssh_acc)
            avg_topt += int(t_opt)
            t_opts.append(int(t_opt))
            # print("test iter: ", t_opt)
            # print("_" * 50)
            accuracy_data[j, i] = test_acc[t_opt]
            avg_acc += test_acc[t_opt]
        avg_acc_list.append(str(round(avg_acc/len(corruptions), round_level)))
        # print("Avg topt: ", round(avg_topt / len(corruptions), round_level))
        # print("Avg acc: ", round(avg_acc/len(corruptions), round_level))
        t_opts = np.array(t_opts)
        mean_value = np.mean(t_opts)
        variance = np.var(t_opts, ddof=1)
        print("mean: ", mean_value)
        print("var: ", variance)
        print("_" * 50)
    
    model_names = [name.replace("cifar100_c_", "").replace("_alpha_0.0", "").replace("_", "") for name in model_names]
    df = pd.DataFrame(accuracy_data, columns=model_names, index=corruptions)

    # Add corruption type as a column for better readability (optional)
    df.reset_index(inplace=True)
    df.rename(columns={'index': 'Corruption Type'}, inplace=True)

    # Display the table
    # print(df)

    # Save to a CSV file (optional)
    df.to_csv('model_corruption_comparison.csv', index=False)
    
    # Convert the DataFrame to LaTeX table
    latex_table = df.to_latex(index=False, float_format="%.2f", column_format="|l|c|c|c|c|c|c|c|")

    # Display the LaTeX table
    # print(latex_table)
    
# create_table()

def create_bar_chart():
    acc_list = {}
    acc_list_round = {}
    acc_max_list = []
    for i, model_name in enumerate(model_names):
        if model_name == "resnet_ttt":
            acc_list[model_name] = read_ttt()
            avg_acc = sum(acc_list[model_name]) / len(acc_list[model_name])
            print("Model name {}, Avg acc: {}, Max acc: {}".format(model_name, str(avg_acc), str(avg_acc)))
            continue
        if model_name.find("log") >= 0:
            acc_list[model_name] = read_feedforward(model_name)
            acc_list_round[model_name] = [round(x, 1) for x in acc_list[model_name]]
            avg_acc = sum(acc_list[model_name]) / len(acc_list[model_name])
            print("Model name {}, Avg acc: {}, Max acc: {}".format(model_name, str(avg_acc), str(avg_acc)))
            continue
        acc_list[model_name] = []
        acc_avg = 0
        iter_avg = 0
        acc_max_avg = 0
        acc_list_round[model_name] = []
        
        for j, corruption in enumerate(corruptions):
            test_dir = "test_" + model_name + "_" + corruption
            # print(test_dir)
            test_dir = os.path.join("outputs", test_dir, get_latest_folder(glob.glob(os.path.join("outputs", test_dir, "test*"))))
            res_json = os.path.join(test_dir, "stats.json")
            stats = json.load(open(res_json))
            test_acc = stats["test_acc"]
            test_ssh_acc = stats["test_ssh_acc"]
            # if model_name.find("thresh") >= 0 or model_name.find("act") >= 0 or model_name.find("resnet") >= 0:
            if model_name.find("resnet") >= 0: 
            # if model_name.find("thresh") >= 0:
                t_opt = str(max([int(x) for x in test_acc.keys()]))
                t_max = t_opt
                max_acc = test_acc[t_opt]
            else:
                t_opt = estimate_t_opt(test_ssh_acc=test_ssh_acc, test_acc=test_acc, acc_thresh=100)
                max_acc, t_max = estimate_max_acc(test_acc)
                # max_acc = test_acc[t_opt]
                acc_max_list.append(round(estimate_max_acc(test_acc)[0], 1))
            iter_avg += int(t_opt)
            acc_max_avg += max_acc
            acc_list[model_name].append(max_acc)
            acc_list_round[model_name].append(round(max_acc, 1))
            # print(f"corruption: {corruption} est_acc: {round(test_acc[t_opt], 1)} max_acc: {round(max_acc, 1)} miss: {round(max_acc - test_acc[t_opt], 1)} t_max: {t_max}")
            
            acc_avg += test_acc[t_opt]
        print("Model name {}, Avg acc: {}, Max acc: {}, Iter Avg: {}".format(model_name, str(acc_avg / len(corruptions)), str(acc_max_avg / len(corruptions)), str(iter_avg / len(corruptions))))
    
    x = np.arange(len(corruptions))  # X-axis locations for bars
    width = 0.25 * 0.75  # Width of each bar
    
    # Create a bar chart
    plt.figure(figsize=(15, 6))
    # plt.bar(x - width, acc_list[model_names[0]], width, label='epsilon=0.03')
    # plt.bar(x, acc_list[model_names[1]], width, label='epsilon=0.003')
    # plt.bar(x + width, acc_list[model_names[2]], width, label='epsilon=0.0')
    
    plt.bar(x - width, acc_list[model_names[0]], width, label='ViT')
    plt.bar(x, acc_list[model_names[1]], width, label='Resnet-30')
    plt.bar(x + width, acc_list[model_names[2]], width, label='Conv-LiGRU')
    # plt.bar(x + width * 2, acc_list[model_names[3]], width, label='Conv-GRU')
    
    # plt.bar(x - width, acc_list[model_names[0]], width, label='R-Resnet-30')
    # plt.bar(x, acc_list[model_names[1]], width, label='Conv-GRU')
    

    # Add details to the plot   
    plt.xlabel("Corruption Types")
    plt.ylabel("Accuracy (%)")
    plt.title("Accuracy on Common Corruptions Benchmark")
    plt.xticks(x, corruptions, rotation=45, ha='right') 
    # plt.ylim(0, 100)  # <-- cố định trục y từ 0 đến 100
    plt.legend(loc="upper center", bbox_to_anchor=(0.4, 1.0))
    # plt.tight_layout()
    plt.tight_layout()
    plt.savefig(f"bar_chart.pdf", dpi=300, bbox_inches='tight')
    
    gap_accs = []
    for model_name in model_names:
        # acc_arr = np.array(acc_list[model_name])
        # mean_value = np.mean(acc_arr)
        # variance = np.var(acc_arr, ddof=1)
        print(model_name)
        # print("mean: ", mean_value)
        # print("var: ", np.sqrt(variance))
        # print("_" * 50)
        acc_latx = ""
        acc_arr = acc_list_round[model_name]
        for i, acc in enumerate(acc_arr):
            acc_ests = [acc_list_round[name][i] for name in model_names]
            if not acc == max(acc_ests):
                acc_latx += f" & {acc}"
            else:
                acc_latx += " & \\textbf{" + f"{acc}" +"}"
            if not model_name.find("thresh") >= 0:
                gap_accs.append(acc_max_list[i] - acc)
                
        print(acc_latx)
        print("-" * 50)
    acc_latx = ""
    acc_arr = acc_max_list
    for acc in acc_arr:
        acc_latx += f" & {acc}"
    print(acc_latx)
    print("-" * 50)
    
    gap_accs = np.array(gap_accs)
    print(np.mean(gap_accs))
    print(np.var(gap_accs))
    # import rich; from rich import inspect; from rich import print as rprint; import ipdb; ipdb.set_trace()

def create_bar_chart_iterations():
    t_max_info = {}
    for i, model_name in enumerate(model_names):
        print("model name: ", model_name)
        acc_max_list = []
        acc_est_list = []
        t_max_list = []
        t_est_list = []
        acc_avg = 0
        acc_max_avg = 0
        for j, corruption in enumerate(corruptions):
            test_dir = "test_" + model_name + "_" + corruption
            test_dir = os.path.join("outputs", test_dir, get_latest_folder(glob.glob(os.path.join("outputs", test_dir, "test*"))))
            res_json = os.path.join(test_dir, "stats.json")
            stats = json.load(open(res_json))
            test_acc = stats["test_acc"]
            test_ssh_acc = stats["test_ssh_acc"]
            # if model_name.find("thresh") >= 0:
            if model_name.find("resnet") >= 0:
                t_opt = str(max([int(x) for x in test_acc.keys()]))
            else:
                t_opt = estimate_t_opt(test_ssh_acc=test_ssh_acc, test_acc=test_acc, acc_thresh=100)
            max_acc, t_max = estimate_max_acc(test_acc, acc_thresh=100)
            acc_max_avg += max_acc
            acc_est_list.append(test_acc[t_opt])
            acc_max_list.append(max_acc)
            t_max_list.append(int(t_max))
            t_est_list.append(int(t_opt))
            print(f"corruption: {corruption} est_acc: {round(test_acc[t_opt], 2)} max_acc: {round(max_acc, 2)} miss: {round(max_acc - test_acc[t_opt], 2)} t_max: {t_max}, t_opt: {t_opt}")
            
            acc_avg += test_acc[t_opt]
        t_max_info[model_name] = t_max_list
        t_max_list = np.array(t_max_list)
        mean_value = np.mean(t_max_list)
        variance = np.var(t_max_list, ddof=1)
        # mean_value = np.mean(t_est_list)
        # variance = np.var(t_est_list, ddof=1)
        print("mean: ", mean_value)
        print("var: ", variance)
        print("Model name {}, Avg acc: {}, Max acc: {}".format(model_name, str(acc_avg / len(corruptions)), str(acc_max_avg / len(corruptions))))
        print("_" * 50)
        
    x = np.arange(len(corruptions))  # X-axis locations for bars
    width = 0.25  # Width of each bar
    
    # Create a bar chart
    plt.figure(figsize=(15, 6))
    for i, model_name in enumerate(model_names):
        plt.bar(x + width * i, t_max_info[model_name], width, label=" ".join(model_name.split("_")[2:4]))

    # Add details to the plot
    plt.xlabel("Corruption Types")
    plt.ylabel("Num of thinking steps")
    plt.title("Number of thinking steps at Test Time on Common Corruptions Benchmark")
    plt.xticks(x, corruptions, rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"bar_chart.pdf", dpi=300, bbox_inches='tight')
    

def create_bar_chart_to_check_overthinking():
    acc_list = {}
    acc_final_list = {}
    for i, model_name in enumerate(model_names):
        acc_list[model_name] = []
        acc_final_list[model_name + "_final"] = []
        acc_avg = 0
        acc_max_avg = 0
        for j, corruption in enumerate(corruptions):
            test_dir = "test_" + model_name + "_" + corruption
            test_dir = sorted(glob.glob(os.path.join("outputs", test_dir, "test*")))[-1]
            res_json = os.path.join(test_dir, "stats.json")
            stats = json.load(open(res_json))
            test_acc = stats["test_acc"]
            test_ssh_acc = stats["test_ssh_acc"]
            if model_name.find("resnet") >= 0:
                t_opt = "4"
            else:
                t_opt = estimate_t_opt(test_ssh_acc=test_ssh_acc)
            max_acc = estimate_max_acc(test_acc)
            acc_max_avg += max_acc
            acc_list[model_name].append(test_acc[t_opt])
            
            acc_final_list[model_name + "_final"].append(test_acc[str(max([int(x) for x in test_acc.keys()]))])
            acc_avg += test_acc[t_opt]
        print("Model name {}, Avg acc: {}, Max acc: {}".format(model_name, str(acc_avg / len(corruptions)), str(acc_max_avg / len(corruptions))))
    
    x = np.arange(len(corruptions))  # X-axis locations for bars
    width = 0.25  # Width of each bar
    
    # Create a bar chart
    for i in range(len(model_names)):
        plt.figure(figsize=(15, 6))
        plt.bar(x - width, acc_final_list[model_names[i] + "_final"], width, label='baseline')
        plt.bar(x, acc_list[model_names[i]], width, label='proposal')
        # plt.bar(x, acc_list[model_names[1]], width, label='Cnn-GRU')
        # plt.bar(x + width, acc_list[model_names[2]], width, label='Conv-GRU sv')

        # Add details to the plot
        plt.xlabel("Corruption Types")
        plt.ylabel("Accuracy (%)")
        plt.title("Accuracy on Common Corruptions Benchmark")
        plt.xticks(x, corruptions, rotation=45, ha='right')
        plt.legend()
        plt.tight_layout()
        plt.savefig(f"bar_chart_{model_names[i]}.pdf", dpi=300, bbox_inches='tight')



# Sample data for demonstration
# Replace these with your actual data
def acc_during_iter(model_name, corruptions=corruptions):

    # Calculate the number of rows and columns based on the number of corruptions
    n_corruptions = len(corruptions)
    n_cols = min(n_corruptions, 5)  # Fixed number of columns
    n_rows = math.ceil(n_corruptions / n_cols)  # Calculate the required number of rows
    
    # Create the figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(int(24 * (n_rows / 3)), int(12 * (n_cols / 5))))  # Adaptive rows, 5 columns
    # fig, axes = plt.subplots(n_rows, n_cols, figsize=(int(12 * (n_rows / 3)), int(24 * (n_cols / 5))))  # For STL-10
    if n_corruptions == 1:
        axes = [axes]
    else:
        axes = axes.flatten()

    round_level = 2

    sum_acc = 0
    sum_peak_acc = 0

    for i in range(len(corruptions)):
        ax = axes[i]
        corruption = corruptions[i]
        # print(corruption)
        test_dir = "test_" + model_name + "_" + corruption
        test_dir = os.path.join("outputs", test_dir, get_latest_folder(glob.glob(os.path.join("outputs", test_dir, "test*"))))
        res_json = os.path.join(test_dir, "stats.json")
        stats = json.load(open(res_json))
        test_acc = stats["test_acc"]
        test_ssh_acc = stats["test_ssh_acc"]
        
        x = range(1, len(test_acc.values()) + 1)
        # Plot both lists on the same graph
        ax.plot(x, test_acc.values(), marker='o', label='Main task', linestyle='-', color='blue')
        ax.plot(x, test_ssh_acc.values(), marker='s', label='Auxiliary task', linestyle='--', color='green')

        # Add labels, title, and legend
        if corruption == "stl":
            corruption = "Original" 
        ax.set_title(corruption, fontsize=15)
        ax.set_xlabel('Iterations', fontsize=14)
        ax.set_ylabel('Accuracy', fontsize=14)
        
        # Adjust the size of tick labels on both axes
        ax.tick_params(axis='x', labelsize=14)  # X-axis tick size
        ax.tick_params(axis='y', labelsize=14)  # Y-axis tick size
        
        ax.legend(fontsize=14)
        ax.grid(True)

        # Calculate max accuracy and best iteration for the main task
        max_ssh_acc = 0
        max_ssh_iter = 0
        for iter, acc in test_ssh_acc.items():
            if round(acc, round_level) > round(max_ssh_acc, round_level):
                max_ssh_iter = iter
                max_ssh_acc = acc
        sum_acc += float(test_acc[max_ssh_iter])
        
        max_acc = 0
        max_iter = 0
        for iter, acc in test_acc.items():
            if round(acc, round_level) > round(max_acc, round_level):
                max_iter = iter
                max_acc = acc
        
        sum_peak_acc += max_acc
        print("max iter ori: ", max_iter)
        print("max ssh iter ori: ", max_ssh_iter)
        
        # Highlight the area before overthinking and show max_ssh_iter on the x-axis
        ax.axvspan(1, max_ssh_iter, color='darkblue', alpha=0.5, label='Estimated iteration')

        # Add the max_ssh_iter value on the x-axis at the position of max_ssh_iter
        # import rich; from rich import inspect; from rich import print as rprint; import ipdb; ipdb.set_trace()
        # ax.text(max_ssh_iter, 0.05, f'{max_ssh_iter}', color='red', fontsize=12, ha='center', va='bottom')
        # x_max = len(test_acc.keys())
        # # Add vertical line for iteration estimate based on self-supervised task and best accuracy iteration
        # ax.axvline(x=(max_ssh_iter / len(test_ssh_acc)) * x_max, color='red', linestyle='--', label='Iter estimate (SSH)')
        # ax.axvline(x=(max_iter / len(test_ssh_acc)) * x_max, color='purple', linestyle='-.', label='Best accuracy iter')
        print("max iter: ", max_iter)
        print("max ssh iter: ", max_ssh_iter)
        
    # Adjust layout and remove extra subplots if any
    plt.tight_layout()
    plt.savefig(os.path.join(f'{model_name}_visualize_acc_iter_overthinking.pdf'), dpi=300, bbox_inches='tight')
        
    print(f"avg acc = {sum_acc / len(corruptions)}")  
    print(f"avg peak acc = {sum_peak_acc / len(corruptions)}")
    print("model name: ", model_name)
    
def check_acc_during_iteretate_on_multi_model(model_names=model_names, corruptions=corruptions):

    # Calculate the number of rows and columns based on the number of corruptions
    n_corruptions = len(corruptions)
    n_cols = min(n_corruptions, 5)  # Fixed number of columns
    n_rows = math.ceil(n_corruptions / n_cols)  # Calculate the required number of rows
    
    # Create the figure
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(int(24 * (n_rows / 3)), int(12 * (n_cols / 5))))  # Adaptive rows, 5 columns
    axes = axes.flatten()  # Flatten axes for easy iteration

    sum_acc = 0
    sum_peak_acc = 0

    for i in range(len(corruptions)):
        ax = axes[i]
        corruption = corruptions[i]
        # print(corruption)
        colors = ["red", "green", "blue"]
        for i, model_name in enumerate(model_names):
            test_dir = "test_" + model_name + "_" + corruption
            print(test_dir)
            test_dir = os.path.join("outputs", test_dir, get_latest_folder(glob.glob(os.path.join("outputs", test_dir, "test*"))))
            res_json = os.path.join(test_dir, "stats.json")
            stats = json.load(open(res_json))
            test_acc = stats["test_acc"]
            test_ssh_acc = stats["test_ssh_acc"]
        
            x = range(1, len(test_acc.values()) + 1)
            # Plot both lists on the same graph
            if model_name.find("cifar100") >= 0:
                model_name = "cifar100-c"
            else:
                model_name = "cifar10-c"
            # if model_name.find("thresh") > -1:
            #     model_name = "epsilon = 0.03"
            # else:
            #     model_name = "epsilon = 0.0"
            color = colors[i]
            ax.plot(x, test_acc.values(), marker='o', label=model_name, linestyle='-', color=color)
        # ax.plot(x, test_ssh_acc.values(), marker='s', label='Auxiliary task', linestyle='--', color='green')

        # Add labels, title, and legend
        ax.set_title(corruption, fontsize=15)
        ax.set_xlabel('Iterations', fontsize=14)
        ax.set_ylabel('Accuracy', fontsize=14)
        
        # Adjust the size of tick labels on both axes
        ax.tick_params(axis='x', labelsize=14)  # X-axis tick size
        ax.tick_params(axis='y', labelsize=14)  # Y-axis tick size
        
        ax.legend(fontsize=14)
        ax.grid(True)
        
    # Adjust layout and remove extra subplots if any
    plt.tight_layout()
    plt.savefig(os.path.join(f'visualize_acc_iter_overthinking.pdf'), dpi=300, bbox_inches='tight')
        
    print(f"avg acc = {sum_acc / len(corruptions)}")  
    print(f"avg peak acc = {sum_peak_acc / len(corruptions)}")
    
def draw_acc_across_iterations(model_name, corruption, task="test_acc", color='blue'):
    test_dir = "test_" + model_name + "_" + corruption
    test_dir = glob.glob(os.path.join("outputs", test_dir, "test*"))[0]
    res_json = os.path.join(test_dir, "stats.json")
    stats = json.load(open(res_json))
    test_acc = stats[task]
    
    x = range(1, len(test_acc.values()) + 1)
    # Plot both lists on the same graph
    plt.plot(x, test_acc.values(), marker='o', label='main task', linestyle='-', color=color)

    # Add labels, title, and legend
    # plt.title(corruption)
    plt.xlabel('Iterations', fontsize=45)
    plt.ylabel('Accuracy', fontsize=45)
    
    # Xóa các giá trị trên trục
    plt.xticks([])  # Xóa giá trị trục X
    plt.yticks([])  # Xóa giá trị trục Y
    
    # Adjust the size of tick labels on both axes
    # plt.tick_params(axis='x', labelsize=28)  # X-axis tick size
    # plt.tick_params(axis='y', labelsize=28)  # Y-axis tick size
    
    # plt.legend(fontsize=28)
    # plt.grid(True)

        
        
    # Adding a title
    plt.tight_layout()
    plt.savefig(os.path.join("outputs", model_name, f'draw_{task}_cross_{corruption}.pdf'), dpi=300, bbox_inches='tight') 

def read_json(test_dir):
    res_json = os.path.join(test_dir, "stats.json")
    stats = json.load(open(res_json))
    test_acc = stats["test_acc"]
    ssh_acc = stats["test_ssh_acc"]
    return test_acc, ssh_acc

def plot_adaptive_ability_gru_resnet(hard_result_path, medium_result_path, resnet_hard_path=None, resnet_medium_path=None, num_iterations=31):
    hard_result, _ = read_json(hard_result_path)
    medium_result, _ = read_json(medium_result_path)
    resnet_hard_result, _ = read_json(resnet_hard_path)
    resnet_medium_result, _ = read_json(resnet_medium_path)
    iterations = [x for x in range(1, num_iterations)]
    # Data for plotting
    test_time_iterations = np.array([x for x in range(1, num_iterations)])
    recurrent_40bit = np.array([medium_result[str(x)] for x in iterations])
    recurrent_44bit = np.array([hard_result[str(x)] for x in iterations])
    feedforward_40bit = np.array([resnet_medium_result[str(x)] for x in iterations])
    feedforward_44bit = np.array([resnet_hard_result[str(x)] for x in iterations])

    # Plot the shaded training regime
    plt.axvspan(1, 30, color='lightblue', alpha=0.3, label='Number of Renset Layers')

    # Plot each data line with shaded standard deviation
    plt.plot(test_time_iterations, recurrent_40bit, marker='x', label='Conv-GRU, Gaussian Noise level 3', color='purple')
    # plt.fill_between(test_time_iterations, recurrent_40bit - std_recurrent_40bit, recurrent_40bit + std_recurrent_40bit, color='purple', alpha=0.2)

    plt.plot(test_time_iterations, recurrent_44bit, marker='^', label='Conv-GRU, Gaussian Noise level 5', color='red')
    # plt.fill_between(test_time_iterations, recurrent_44bit - std_recurrent_44bit, recurrent_44bit + std_recurrent_44bit, color='red', alpha=0.2)

    plt.plot(test_time_iterations, feedforward_40bit, linestyle='dashed', label='Resnet-30, Gaussian Noise level 3', color='purple')
    plt.plot(test_time_iterations, feedforward_44bit, linestyle='dotted', label='Resnet-30, Gaussian Noise level 5', color='red')

    # Add labels, legend, and title
    plt.xlabel('Test-Time Iterations')
    plt.ylabel('Accuracy (%)')
    plt.title('Models Trained With 30 Iterations')
    plt.legend()
    plt.grid(alpha=0.3)

    # Show the plot
    plt.savefig("visualize_adaptive.pdf", dpi=300, bbox_inches='tight')

def plot_adaptive_compute(
                        result_paths, \
                        model_names, \
                        num_iterations=31, \
                        title="Models Trained With 30 Iterations", \
                        label_layer="Number of Renset Layers", \
                        iteration_shade=25, \
                        save_path="visualize_adaptive.pdf", \
                        maker_list=['o', 's', '^', 'x', 'D', '*'], \
                        color_list=['red', 'green', 'blue', 'purple', 'orange', 'brown', 'pink', 'gray'],
                    ):
    results = [read_json(path)[0] for path in result_paths]
    iterations = [x for x in range(1, num_iterations)]
    # Data for plotting
    test_time_iterations = np.array([x for x in range(1, num_iterations)])
    result_arrays = [np.array([result[str(x)] for x in iterations]) for result in results]

    # Plot the shaded training regime
    plt.axvspan(1, iteration_shade, color='lightblue', alpha=0.3)
    
    for i, result_array in enumerate(result_arrays):
        # plt.plot(test_time_iterations, result_array, marker=maker_list[i], label=model_names[i], color=color_list[i])
        plt.plot(test_time_iterations, result_array, marker="*", label=model_names[i], color=color_list[i])
        print("model name: ", model_names[i])
        if model_names[i].find("thresh") >= 0:
            print("max acc: ", result_array[-1])
        else:
            print("max acc: ", max(result_array))
            print("est acc:", result_array[iteration_shade])
        print("-" * 50)

    # Add labels, legend, and title
    plt.xlabel('Test-Time Iterations')
    plt.ylabel('Accuracy (%)')
    plt.title(title)
    plt.legend()
    plt.grid(alpha=0.3)

    # Show the plot
    plt.savefig(save_path, dpi=300, bbox_inches='tight')

def draw_training_loss_acc(log_file):
    import re
    import matplotlib.pyplot as plt

    # Đọc file log
    log_file = log_file  # Đặt tên file log ở đây

    # Dữ liệu để lưu trữ
    epochs = []
    train_accuracy = []
    train_loss = []
    train_ssh_loss = []
    train_ssh_accuracy = []

    # Đọc và phân tích file log
    with open(log_file, 'r') as f:
        for line in f:
            # Tìm kiếm và trích xuất các giá trị
            epoch_match = re.search(r'epoch (\d+):', line)
            if epoch_match:
                epoch = int(epoch_match.group(1))
                epochs.append(epoch)
            
            epochs = [x for x in range(1, 201)]
            
            # Tìm kiếm các giá trị liên quan đến loss và accuracy
            if 'Training accuracy' in line:
                accuracy_match = re.search(r'Training accuracy at epoch (\d+): (\d+\.\d+)', line)
                if accuracy_match:
                    train_accuracy.append(float(accuracy_match.group(2)))

            # if 'Training loss' in line:
            #     loss_match = re.search(r'Training loss at epoch (\d+): (\d+\.\d+)', line)
            #     if loss_match:
            #         train_loss.append(float(loss_match.group(2)))

            # if 'Training SSH loss' in line:
            #     ssh_loss_match = re.search(r'Training SSH loss at epoch (\d+): (\d+\.\d+)', line)
            #     if ssh_loss_match:
            #         train_ssh_loss.append(float(ssh_loss_match.group(2)))

            if 'Training SSH accuracy' in line:
                ssh_accuracy_match = re.search(r'Training SSH accuracy at epoch (\d+): (\d+\.\d+)', line)
                if ssh_accuracy_match:
                    train_ssh_accuracy.append(float(ssh_accuracy_match.group(2)))

    # Vẽ đồ thị
    plt.figure(figsize=(10, 6))

    # Plot các đồ thị
    # plt.plot(epochs, train_accuracy, label='Train Accuracy', color='blue')
    # plt.plot(epochs, train_loss, label='Train Loss', color='red')
    # plt.plot(epochs, train_ssh_loss, label='Train SSH Loss', color='green')
    plt.plot(epochs, train_ssh_accuracy, label='Train SSH Accuracy', color='orange')

    # Thêm tiêu đề và nhãn cho các trục
    # plt.title('Training Accuracy over Epochs')
    plt.xlabel('Epoch', fontsize=45)
    plt.ylabel('Accuracy', fontsize=45)
    # plt.legend()
    plt.xticks([])  # Xóa giá trị trục X
    plt.yticks([])  # Xóa giá trị trục Y

    # Lưu đồ thị thành file PDF
    plt.savefig('training_ssh_task.pdf')

def visualize_loss():
    import matplotlib.pyplot as plt
    import numpy as np

    # Simulated loss data for three models (Replace with actual loss data)
    n_iterations = 100
    models = {}
    for model_name in model_names:
        read_loss_if = json.load(open(os.path.join("outputs", model_name, "testing-test_sample", "loss.json")))
        model_name = model_name.replace("alpha_0.0", "").replace("cifar_c", "").replace("_", " ").replace("gn", "").strip()
        models[model_name] = read_loss_if
        
    # models = {
    #     "Model A": {
    #         "cls": np.random.uniform(0.2, 1.0, n_iterations),
    #         "ssh": np.random.uniform(0.1, 0.8, n_iterations)
    #     },
    #     "Model B": {
    #         "cls": np.random.uniform(0.2, 1.0, n_iterations),
    #         "ssh": np.random.uniform(0.1, 0.8, n_iterations)
    #     },
    #     "Model C": {
    #         "cls": np.random.uniform(0.2, 1.0, n_iterations),
    #         "ssh": np.random.uniform(0.1, 0.8, n_iterations)
    #     }
    # }

    # Create subplots
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))

    # Plot Classification Loss
    axes[0].set_title("Classification Loss", fontsize=20)
    axes[0].set_xlabel("Iterations", fontsize=18)
    axes[0].set_ylabel("Loss", fontsize=18)
    for model_name, losses in models.items():
        axes[0].plot(losses["cls"], label=model_name)
    axes[0].legend(fontsize=15)
    axes[0].grid()

    # Plot Self-Supervision Loss
    axes[1].set_title("Self-Supervision Loss", fontsize=20)
    axes[1].set_xlabel("Iterations", fontsize=18)
    axes[1].set_ylabel("Loss", fontsize=18)
    for model_name, losses in models.items():
        axes[1].plot(losses["ssh"], label=model_name)
    axes[1].legend()
    axes[1].grid()
    plt.tick_params(axis='x', labelsize=13)  # X-axis tick size
    plt.tick_params(axis='y', labelsize=13)  # Y-axis tick size

    plt.legend(fontsize=15)

    # Show plots
    plt.tight_layout()
    # plt.show()
    plt.savefig("visualize_loss.pdf")
    plt.close()
    
def visualize_norm_accuracy():
    # Simulated norm differences for Conv-GRU model (Replace with actual values)
    model_name = "cifar_c_cnn_gru_gn_width_128"  # Model for Conv-GRU
    model_real_name = "Conv-GRU"
    corruption = "outputs/test_cifar_c_cnn_gru_width_128_Glass Blur/testing-2025-04-22 21:52:22.254495"
    second_model_name = "cifar_c_cnn_ligru_width_128"  # Model for Conv-LiGRU
    second_model_real_name = "Conv-LiGRU"
    second_corruption = "outputs_iclr/test_cifar_c_cnn_ligru_width_128_Glass Blur/testing-2025-09-19 09:23:29.975224"
    
    # Load the data
    read_loss_if = json.load(open(os.path.join("outputs", model_name, "testing-test_sample", "diff_norms.json")))[:40]
    test_acc_data = json.load(open(os.path.join(corruption, "stats.json")))
    test_acc = list(test_acc_data["test_acc"].values())[:40]
    
    # Load data for the second model
    read_loss_if_second = json.load(open(os.path.join("outputs", second_model_name, "testing-test_sample", "diff_norms.json")))[:40]
    test_acc_data_second = json.load(open(os.path.join(second_corruption, "stats.json")))
    test_acc_second = list(test_acc_data_second["test_acc"].values())[:40]

    # Create plot
    fig, ax1 = plt.subplots(figsize=(8, 5))

    # Plot norm differences on the primary y-axis
    ax1.set_xlabel("Iterations (t)", fontsize=18)
    ax1.set_ylabel("Norm Difference", fontsize=18)
    norm_line, = ax1.plot(read_loss_if, label=f"{model_real_name} Norm Difference", color='r')
    norm_line_second, = ax1.plot(read_loss_if_second, label=f"{second_model_real_name} Norm Difference", color='b')

    # Create a secondary y-axis for accuracy
    ax2 = ax1.twinx()
    ax2.set_ylabel("Accuracy", fontsize=18)

    # Plot accuracy on the secondary y-axis
    accuracy_line, = ax2.plot(test_acc, label=f"{model_real_name} Accuracy", color='r', marker="x")
    accuracy_line_second, = ax2.plot(test_acc_second, label=f"{second_model_real_name} Accuracy", color='b', marker="x")
    
    # Plot a verical line at iteration 25
    iter_peak_acc1 = ax1.axvline(x=26, color='blue', linestyle='--', label='Peak Acc Iter (Conv-LiGRU)')
    iter_peak_acc2 = ax1.axvline(x=15, color='red', linestyle='--', label='Peak Acc Iter (Conv-GRU)')

    # Set tick parameters
    ax1.tick_params(axis='x', labelsize=13)  # X-axis tick size
    ax1.tick_params(axis='y', labelsize=13)  # Y-axis tick size for norm difference
    ax2.tick_params(axis='y', labelsize=13)  # Y-axis tick size for accuracy

    # Add legends
    # ax1.legend(fontsize=12, loc="upper left")
    # ax2.legend(fontsize=12, loc="lower left")
    # Add legend for both norm and accuracy (combine labels in a single legend)
    ax1.legend(handles=[norm_line, accuracy_line, norm_line_second, accuracy_line_second, iter_peak_acc1, iter_peak_acc2], fontsize=12, loc="center right")
    # ax1.legend(handles=[norm_line_second, accuracy_line_second], fontsize=12, loc="upper right")

    # ax1.grid()

    # Save the figure
    plt.savefig("visualize_norm_with_accuracy_convgru.pdf")

def visualize_norm():
    import matplotlib.pyplot as plt
    import numpy as np

    # Simulated norm differences for two models (Replace with actual values)
    n_iterations = 100
    model_names = [
        "cifar_c_cnn_gru_gn_width_128",
        "cifar_c_cnn_ligru_width_128",
        ]
    model_real_names = [
        "Conv-LiGRU",
        "Conv-GRU",
        ]
    corruption = "Gaussian Noise"
    models = {}
    for model_name, name in zip(model_names, model_real_names):
        read_loss_if = json.load(open(os.path.join("outputs", model_name, "testing-test_sample", "diff_norms.json")))
        models[name] = read_loss_if
        result_inference = json.load(open(os.path.join("outputs", f"test_{model_name}_{corruption}", "stats.json")))
        models[name + "_test_acc"] = result_inference["test_acc"]

    # Create plot
    plt.figure(figsize=(8, 5))
    # plt.title("Norm Differences (h_t - h_{t-1})", fontsize=20)
    plt.xlabel("Iterations (t)", fontsize=18)
    plt.ylabel("Norm Difference", fontsize=18)

    # Plot data for each model
    for model_name, norm_diffs in models.items():
        plt.plot(norm_diffs, label=model_name)
        
    plt.tick_params(axis='x', labelsize=13)  # X-axis tick size
    plt.tick_params(axis='y', labelsize=13)  # Y-axis tick size

    plt.legend(fontsize=15)
    plt.grid()
    # plt.show()
    plt.savefig("visualize_norm.pdf")

def read_ttt(log_dir="data/C10C_layer2_slow_gn_expand"):
    res_list = []
    for corrup_type in corruptions:
        corrup_type = corrup_type.lower().replace(" ", "_") + "_5_ada.pth"
        res = torch.load(os.path.join(log_dir, corrup_type), weights_only=False)
        res_list.append((1 - res["cls_adapted"]) * 100)
    return res_list

def read_feedforward(log_dir):
    res_list = []
    for corrup_type in corruptions:
        corrup_type = corrup_type.lower().replace(" ", "_") + ".log"
        lines = open(os.path.join("outputs", log_dir, corrup_type), "r").readlines()
        res = lines[-1].split(" ")[-1].replace("%", "")
        res_list.append(float(res))
    return res_list
    
    
def plot_accuracy(models, accuracies):
    """
    Plot a bar chart showing the accuracy of 3 models on STL-10 dataset.
    
    Parameters:
    - models: List of model names (strings)
    - accuracies: List of accuracy values (floats)
    """
    # Ensure the lists have the same length
    if len(models) != len(accuracies):
        print("Error: The number of models and accuracies do not match.")
        return

    # Create the bar chart
    plt.figure(figsize=(8, 6))
    plt.bar(models, accuracies, color=['blue', 'green', 'red'])

    # Add labels and title
    plt.xlabel('Models', fontsize=14)
    plt.ylabel('Accuracy (%)', fontsize=14)
    plt.title('Model Accuracy on STL-10 Dataset', fontsize=16)

    # Display the accuracy values on top of each bar
    for i, accuracy in enumerate(accuracies):
        plt.text(i, accuracy + 1, f'{accuracy:.2f}%', ha='center', fontsize=12)

    # Show the plot
    plt.tight_layout()
    plt.savefig("stl_bar_chart.pdf")

# Example usage
    
# for model_name in model_names:
#     print(model_name)
#     acc_during_iter(model_name=model_name)
#     print("-" * 100)
# acc_during_iter(model_name="cifar_c_cnn_gru_width_128")

# check_acc_during_iteretate_on_multi_model(model_names, corruptions=corruptions)

# plot_adaptive_ability_gru_resnet("outputs/test_cifar_c_cnn_gru_width_128_level_5_Gaussian Noise/testing-2025-04-30 11:33:49.526349", 
#                       "outputs/test_cifar_c_cnn_gru_width_128_level_3_Gaussian Noise/testing-2025-04-30 11:34:18.443551",
#                       "outputs/test_cifar_c_resnet_cifar_width_128_depth_30_Gaussian Noise/testing-2025-05-10 18:11:33.538084",
#                       "outputs/test_cifar_c_resnet_cifar_width_128_depth_30_level_3_Gaussian Noise/testing-2025-05-13 10:25:44.576215",
#                       )

# plot_adaptive_compute([
#                     "outputs_iclr/test_cifar_c_cnn_gru_gn_width_128_thresh_0.03_stl/testing-2025-09-22 09:57:21.084896",
#                     # "outputs_iclr/test_cifar_c_cnn_gru_gn_width_128_thresh_0.003_stl/testing-2025-09-22 08:40:24.814893",
#                     # "outputs/test_cifar_c_cnn_gru_width_128_stl/testing-2025-05-05 15:52:33.241376", 
#                     "outputs_iclr/test_cifar_c_cnn_gru_gn_width_128_stl/testing-2025-09-22 09:48:15.621751",
#                     #   "outputs_iclr/test_cifar_c_cnn_gru_gn_width_128_act_0.5_stl/testing-2025-09-22 09:34:05.988144",
#                     "outputs_iclr/test_cifar_c_cnn_gru_gn_width_128_act_0.5_stl/testing-2025-09-22 09:40:26.298545",
#                       ],
#                     ["Conv-GRU norm threshold", "Conv-GRU unhalted", "Conv-GRU with ACT",
#                          ],
#                         title="Conv-GRU Models on STL-10",
#                         save_path="visualize_adaptive_convgru_stl.pdf",
#                         num_iterations=30,
#                       )

plot_adaptive_compute([
                      "outputs/test_cifar_c_cnn_gru_width_128_thresh_0.03_Glass Blur/testing-2025-05-09 15:44:50.034544",
                      "outputs/test_cifar_c_cnn_gru_width_128_Glass Blur/testing-2025-04-22 21:52:22.254495",
                      "outputs_iclr/test_cifar_c_cnn_gru_gn_width_128_act_0.5_Glass Blur/testing-2025-09-22 10:35:32.320985",
                      ],
                        ["Conv-GRU norm threshold", "Conv-GRU unhalted", "Conv-GRU with ACT"],
                        title="Conv-GRU Models on glass blur corruption level 5",
                        save_path="visualize_adaptive_convgru_glass_blur.pdf",
                        iteration_shade=20,
                        num_iterations=30,
                      )

# create_bar_chart()
# create_bar_chart_iterations()
# print(read_res50())
# draw_acc_across_iterations("cifar_c_cnn_ligru_alpha_0.0", corruption="None", task="test_acc", color="red")
# draw_training_loss_acc("/home/tranhieu/workdir/research/deep-thinking/outputs/cifar_c_cnn_ligru_alpha_0.0/training-2025-01-10 15:33:38.798950/train.log")
# visualize_loss()
# visualize_norm()
# visualize_norm_accuracy()
# create_table()
# plot_accuracy(["CNN model", "Conv-LiGRU", "Conv-GRU"], [36.4, 40.38, 44.67])