import os
import yaml
import pandas as pd
from pathlib import Path
from collections import defaultdict

def load_class_names():
    with open('../config/constants.yml', 'r') as f:
        constants = yaml.safe_load(f)
    
    class_names = {
        'ENZYMES': constants['ENZYMES_CLASS_NAMES'],
        'COLLAB': constants['COLLAB_CLASS_NAMES'],
        'Motif': constants['Motif_CLASS_NAMES']
    }
    return class_names

def count_graphs():
    base_dir = Path('../graphs')
    if not base_dir.exists():
        print("No 'graphs' directory found!")
        return
    
    class_names = load_class_names()
    counts = defaultdict(lambda: defaultdict(int))
    
    # Count graphs for each dataset and class
    for strategy in os.listdir(base_dir):
        strategy_path = base_dir / strategy
        if not strategy_path.is_dir():
            continue
            
        for dataset in os.listdir(strategy_path):
            dataset_path = strategy_path / dataset
            if not dataset_path.is_dir():
                continue
                
            for class_dir in os.listdir(dataset_path):
                class_path = dataset_path / class_dir
                if not class_path.is_dir():
                    continue
                
                # Look for nested dataset directory
                nested_dataset_path = class_path / dataset
                if nested_dataset_path.is_dir():
                    # Look for nested class directory
                    nested_class_path = nested_dataset_path / class_dir
                    if nested_class_path.is_dir():
                        # Count .pt files in the nested class directory
                        num_graphs = len([f for f in os.listdir(nested_class_path) if f.endswith('.pt')])
                        key = f"{strategy}/{dataset}"
                        counts[key][class_dir] = num_graphs
    
    # Create DataFrame for each strategy/dataset combination
    dfs = []
    for key, class_counts in counts.items():
        strategy, dataset = key.split('/')
        
        # Get class names mapping for this dataset
        dataset_class_names = class_names.get(dataset, {})
        
        # Create index using class names where available
        index = []
        for class_idx in class_counts.keys():
            if '_' in class_idx:  # Handle boundary case (e.g., "0_1")
                c1, c2 = class_idx.split('_')
                name = f"{dataset_class_names.get(int(c1), c1)}_{dataset_class_names.get(int(c2), c2)}"
            else:
                name = dataset_class_names.get(int(class_idx), class_idx)
            index.append(name)
        
        df = pd.DataFrame({
            'Strategy': strategy,
            'Dataset': dataset,
            'Graphs': [class_counts[c] for c in class_counts.keys()]
        }, index=index)
        dfs.append(df)
    
    # Combine all DataFrames
    if dfs:
        final_df = pd.concat(dfs)
        
        # Save to CSV
        output_file = 'graph_counts.csv'
        final_df.to_csv(output_file)
        
        # Print summary
        print("\nGraph Count Summary:")
        print("===================")
        for key, class_counts in counts.items():
            strategy, dataset = key.split('/')
            print(f"\n{strategy} - {dataset}:")
            total = 0
            for class_idx, count in class_counts.items():
                if dataset in class_names:
                    if '_' in class_idx:
                        c1, c2 = class_idx.split('_')
                        class_name = f"{class_names[dataset][int(c1)]}_{class_names[dataset][int(c2)]}"
                    else:
                        class_name = class_names[dataset][int(class_idx)]
                    print(f"  {class_name}: {count} graphs")
                else:
                    print(f"  Class {class_idx}: {count} graphs")
                total += count
            print(f"  Total: {total} graphs")
        
        print(f"\nDetailed counts saved to {output_file}")
    else:
        print("No graphs found!")

if __name__ == "__main__":
    count_graphs()
