import pandas as pd
import os
import numpy as np
import threading

acc_excel_path = "acc.csv"

_lock = threading.Lock()

def synchronized(func):
    def wrapper(*args, **kwargs):
        with _lock:
            return func(*args, **kwargs)
    return wrapper

@synchronized
def acclog(acc, base_dict, other_dict):
    if not os.path.exists(acc_excel_path):
        # Define column types explicitly to avoid dtype issues
        df = pd.DataFrame(columns=["data_name", "model_name", "method_name", "domain_num", "corruption", "other_dict", "acc_list", "num", "mean", "std"])
        # Set dtypes for numeric columns
        df["num"] = df["num"].astype("int64")
        df["mean"] = df["mean"].astype("float64")
        df["std"] = df["std"].astype("float64")
        
        if acc_excel_path.endswith(".csv"):
            df.to_csv(acc_excel_path, index=False)
        else:
            df.to_excel(acc_excel_path, index=False)
    
    if acc_excel_path.endswith(".csv"):
        df = pd.read_csv(acc_excel_path)
    else:
        df = pd.read_excel(acc_excel_path)
        
    # Ensure numeric columns have correct dtypes after reading from Excel
    if "num" in df.columns:
        df["num"] = df["num"].astype("int64")
    if "mean" in df.columns:
        df["mean"] = df["mean"].astype("float64")
    if "std" in df.columns:
        df["std"] = df["std"].astype("float64")
    
    data_name = base_dict["data_name"]
    model_name = base_dict["model_name"]
    method_name = base_dict["method_name"]
    domain_num = base_dict["domain_num"]
    corruption = base_dict["corruption"]
    other_dict = str(other_dict)

    filter_condition = (df["data_name"] == data_name) & (df["model_name"] == model_name) & \
                    (df["method_name"] == method_name) & (df["domain_num"] == domain_num) & \
                    (df["other_dict"] == other_dict) & (df["corruption"] == corruption)
    
    if not df.loc[filter_condition].empty:
        # Get the current row index and acc_list
        idx = df.loc[filter_condition].index[0]
        current_acc_list = df.at[idx, "acc_list"]
        
        # Convert string representation to actual list if necessary
        if isinstance(current_acc_list, str):
            if current_acc_list.startswith('[') and current_acc_list.endswith(']'):
                current_acc_list = eval(current_acc_list)
            else:
                current_acc_list = [float(current_acc_list)]
        elif not isinstance(current_acc_list, list):
            current_acc_list = [current_acc_list]
        
        # Append the new accuracy
        current_acc_list.append(acc)
        
        # Update the row with explicit type conversion
        df.at[idx, "acc_list"] = current_acc_list
        df.at[idx, "num"] = int(len(current_acc_list))
        df.at[idx, "mean"] = float(np.mean(current_acc_list).round(4))
        df.at[idx, "std"] = float(np.std(current_acc_list).round(4))
    else:
        # Create a new row DataFrame with explicit dtypes
        new_row = pd.DataFrame({
            "data_name": [data_name], 
            "model_name": [model_name], 
            "method_name": [method_name], 
            "domain_num": [domain_num], 
            "corruption": [corruption],
            "other_dict": [other_dict], 
            "acc_list": [[acc]], 
            "num": [1], 
            "mean": [acc], 
            "std": [0.0]  # Explicitly using float
        })
        
        # Ensure new_row doesn't have empty columns before concat
        new_row = new_row.dropna(axis=1, how='all')
        df = pd.concat([df.dropna(axis=1, how='all'), new_row], ignore_index=True)
    
    if acc_excel_path.endswith(".csv"):
        df.to_csv(acc_excel_path, index=False)
    else:
        df.to_excel(acc_excel_path, index=False)

@synchronized
def flash_acc():
    if acc_excel_path.endswith(".csv"):
        df = pd.read_csv(acc_excel_path)
    else:
        df = pd.read_excel(acc_excel_path)
    df["acc_list"] = df["acc_list"].apply(lambda x: eval(x))
    df["num"] = df["acc_list"].apply(lambda x: len(x))
    df["mean"] = df["acc_list"].apply(lambda x: float(np.mean(x)))
    df["std"] = df["acc_list"].apply(lambda x: float(np.std(x)))
    if acc_excel_path.endswith(".csv"):
        df.to_csv(acc_excel_path, index=False)
    else:
        df.to_excel(acc_excel_path, index=False)

if __name__ == "__main__":
    # acclog(0.9, {"data_name": "mnist", "model_name": "cnn", "method_name": "GST", "domain_num": 2}, {"other": 3})
    # acclog(0.95, {"data_name": "mnist", "model_name": "cnn", "method_name": "GST", "domain_num": 2}, {"other": 3})
    # acclog(0.8, {"data_name": "mnist", "model_name": "cnn", "method_name": "GST", "domain_num": 2}, {"other": 3})
    flash_acc()