import os
import re
import pandas as pd
import matplotlib.pyplot as plt


PROJECT_ROOT = os.path.dirname(os.path.abspath(__file__))

DATA_DIR = os.path.join(PROJECT_ROOT,'data') 
FIGURES_DIR = os.path.join(PROJECT_ROOT, "figures") 
os.makedirs(FIGURES_DIR, exist_ok=True)             


plt.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans']  
plt.rcParams['axes.unicode_minus'] = False                  
plt.rcParams['font.size'] = 12
plt.rcParams['figure.figsize'] = (10, 6)    
plt.rcParams['lines.linewidth'] = 2         
plt.rcParams['lines.markersize'] = 4        


ALG_STYLE = {

    "RTCFRPlus(finetune)":   ('#796c5b', '-', 'p'), 
    "RTPCFRPlus(finetune)":  ('#ffbb00', '-', '<'), 
    "RTPDCFRPlus(finetune)": ('#4b0082', '-', 'h'), 
    

    "RTCFRPlus":             ('#1f77b4', '-', 'o'), 
    "RTPCFRPlus":            ('#2e8b57', '-', 's'), 
    "RTPDCFRPlus":           ('#ff7f0e', '-', 'd'), 
    

    "CFRPlus":               ('#d62728', '-', '*'), 
    "DCFR":                  ('#001a6f', '-', 'x'), 
    "PCFRPlus":              ('#8e44ad', '-', 'D'), 
    "Reg-CFR":               ('#ff69b4', '-', 'P'), 
    "OMWU":                  ('#e74c3c', '-', 'H'), 
    "OGDA":                  ('#00ced1', '-', 'X'), 
}

ALG_NAME_MAP = {
    "RTCFRPlus": "RTCFR+", 
    "RTPCFRPlus": "RTPCFR+", 
    "RTPDCFRPlus": "RTPDCFR+",
    "RTCFRPlus(finetune)": "RTCFR+ (fine-tuned)",
    "RTPCFRPlus(finetune)": "RTPCFR+ (fine-tuned)",
    "RTPDCFRPlus(finetune)": "RTPDCFR+ (fine-tuned)",
    "CFRPlus": "CFR+",
    "DCFR": "DCFR",
    "PCFRPlus": "PCFR+",
    "Reg-CFR": "Reg-CFR",
    "OMWU": "OMWU",
    "OGDA": "OGDA"
}



def load_data_by_game():
    game_alg_dict = {}
    file_pattern = re.compile(r'^(?P<alg>[^_]+)_(?P<game>.+?)\.csv$')

    
    for filename in os.listdir(DATA_DIR):
        if filename.endswith(".csv"):
            match = file_pattern.match(filename)
            if match:
                alg_name = match.group("alg")
                game_name = match.group("game")
                csv_path = os.path.join(DATA_DIR, filename)

                
                try:
                    df = pd.read_csv(csv_path, usecols=["Iteration", "Exploitability"])
                    df = df.dropna().reset_index(drop=True)  
                    if df.shape[0] == 0:
                        
                        continue
                except Exception as e:
                    
                    continue

                if game_name not in game_alg_dict:
                    game_alg_dict[game_name] = {}
                game_alg_dict[game_name][alg_name] = df
    return game_alg_dict


def plot_game_curve(game_name, alg_data):
    
    plt.figure() 

    
    TOP_ALG = "RTPCFRPlusGraph(finetune)"
    
    for alg_cls_name, df in alg_data.items():
        if alg_cls_name != TOP_ALG:
            color, linestyle, marker = ALG_STYLE.get(alg_cls_name, ('#000000', '-', 'x'))
            show_name = ALG_NAME_MAP.get(alg_cls_name, alg_cls_name)
            plt.plot(
                df["Iteration"],
                df["Exploitability"],
                color=color,
                linestyle=linestyle,
                marker=marker,
                markevery=5,  
                label=show_name
            )
    
    if TOP_ALG in alg_data:
        df_top = alg_data[TOP_ALG]
        color, linestyle, marker = ALG_STYLE.get(TOP_ALG, ('#000000', '-', 'x'))
        show_name = ALG_NAME_MAP.get(TOP_ALG, TOP_ALG)
        plt.plot(
            df_top["Iteration"],
            df_top["Exploitability"],
            color=color,
            linestyle=linestyle,
            marker=marker,
            markevery=5,
            label=show_name
        )
    
    plt.title(f'{game_name} - Exploitability of Different Algorithms', fontsize=14, fontweight='bold')
    plt.xlabel('Training Iterations', fontsize=12)
    plt.ylabel('Exploitability', fontsize=12)
    plt.yscale('log')  
    plt.grid(True, alpha=0.3, linestyle='--') 
    plt.legend(loc='upper right', fontsize=11) 
    
    
    save_png = os.path.join(FIGURES_DIR, f"{game_name}.png")
    plt.savefig(save_png, dpi=300, bbox_inches='tight')
    
    plt.close() 

if __name__ == "__main__":
    
    game_data = load_data_by_game()
    if not game_data:
        pass
    else:
        
        for game_name, alg_data in game_data.items():
            print(f"\n start: {game_name}")
            plot_game_curve(game_name, alg_data)
        print("\n  save to figures/ ")