import os
import pandas as pd

def extract_high_accuracy_results(result_dir, threshold):
    if not os.path.exists(result_dir):
        print(f"Directory not found: {result_dir}")
        return

    # Dictionary to store accuracy: {dataset_name: {dnn_type: max_accuracy}}
    results = {}

    # Iterate through all files in the directory
    for file in os.listdir(result_dir):
        if file.endswith(".csv"):
            file_path = os.path.join(result_dir, file)
            try:
                df = pd.read_csv(file_path)
                
                # Ensure required columns exist
                if 'dataset' in df.columns and 'dnn_type' in df.columns and 'test_accuracy' in df.columns:
                    dataset = df['dataset'].iloc[0]
                    dnn_type = df['dnn_type'].iloc[0]
                    accuracy = df['test_accuracy'].iloc[0]

                    if dataset not in results:
                        results[dataset] = {}
                    
                    # Update max accuracy for this dataset and architecture
                    if dnn_type not in results[dataset] or accuracy > results[dataset][dnn_type]:
                        results[dataset][dnn_type] = accuracy
            except Exception as e:
                print(f"Could not read {file}: {e}")

    # Identify datasets where both FCN and PatchTST > threshold
    matches = []
    print(results)
    for dataset, arch_data in results.items():
        fcn_acc = arch_data.get('FCN', 0)
        patch_acc = arch_data.get('PatchTST', 0)

        if fcn_acc > threshold and patch_acc > threshold:
            matches.append({
                'Dataset': dataset,
                'FCN_Accuracy': fcn_acc,
                'PatchTST_Accuracy': patch_acc
            })

    # Output the results
    if matches:
        final_df = pd.DataFrame(matches)
        print("Datasets with > 60\\% accuracy in both FCN and PatchTST:")
        print(final_df.to_string(index=False))
        
        # Optionally save to a new CSV
        # final_df.to_csv("high_accuracy_comparison.csv", index=False)
    else:
        print(f"No datasets found with both architectures exceeding {threshold*100}%.")

if __name__ == "__main__":
    TARGET_DIR = "/home/40430660@eeecs.qub.ac.uk/InterpretGatedNetwork/result/DNN"
    extract_high_accuracy_results(TARGET_DIR, threshold=0.6)
