import os
import json
import pandas as pd
from pathlib import Path
import argparse

def get_all_dataset_info(segmented_dir="segmented_datasets"):
    """
    Returns all dataset information from the segmented datasets directory.
    
    Args:
        segmented_dir (str): Path to the directory containing the segmented datasets
    
    Returns:
        A list of dictionaries containing all information about each dataset
    """
    segmented_path = Path(segmented_dir)
    
    # Check if directory exists
    if not segmented_path.exists():
        print(f"Error: Directory '{segmented_dir}' does not exist")
        return []
    
    all_datasets = []
    
    # Iterate through all dataset directories
    for dataset_dir in segmented_path.iterdir():
        if not dataset_dir.is_dir():
            continue
        
        # Check if metadata.json exists
        metadata_path = dataset_dir / "metadata.json"
        if not metadata_path.exists():
            continue
        
        # Load metadata
        try:
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            
            # Collect ALL dataset information
            dataset_info = {
                "dataset_name": metadata.get("dataset_name", dataset_dir.name),
                "original_name": metadata.get("original_name", "Unknown"),
                "task_id": metadata.get("task_id", "Unknown"),
                "num_segments": metadata.get("num_segments", 0),
                "num_features": metadata.get("num_features", 0),
                "num_instances": metadata.get("num_instances", 0),
                "target_type": metadata.get("target_type", "Unknown"),
                "segments_dir": str(dataset_dir / "segments"),
                "cumulative_dir": str(dataset_dir / "cumulative"),
                "segment_size": metadata.get("segment_size", None),
                "num_classes": metadata.get("num_classes", None),
                "class_balance": metadata.get("class_balance", None),
                "instances_per_segment": metadata.get("instances_per_segment", None),
                "cumulative_instances": metadata.get("cumulative_instances", None),
                "feature_types": metadata.get("feature_types", None),
                "categorical_features": metadata.get("categorical_features", None),
                "numerical_features": metadata.get("numerical_features", None),
                # Add any other fields that might be in metadata
                **{k: v for k, v in metadata.items() if k not in [
                    "dataset_name", "original_name", "task_id", "num_segments",
                    "num_features", "num_instances", "target_type", "segment_size",
                    "num_classes", "class_balance", "instances_per_segment",
                    "cumulative_instances", "feature_types", "categorical_features",
                    "numerical_features"
                ]}
            }
            
            all_datasets.append(dataset_info)
            
        except Exception as e:
            print(f"Error processing {dataset_dir}: {e}")
    
    # Sort by task_id for consistency
    all_datasets.sort(key=lambda x: str(x["task_id"]))
    
    return all_datasets

def get_summary_statistics(datasets):
    """
    Calculate summary statistics from the dataset information.
    
    Args:
        datasets: List of dataset dictionaries
    
    Returns:
        Dictionary with summary statistics
    """
    if not datasets:
        return {}
    
    # Extract numeric values
    instances = [d["num_instances"] for d in datasets if d["num_instances"]]
    features = [d["num_features"] for d in datasets if d["num_features"]]
    segments = [d["num_segments"] for d in datasets if d["num_segments"]]
    segment_sizes = []
    
    for d in datasets:
        if d.get("segment_size"):
            segment_sizes.append(d["segment_size"])
        elif d.get("instances_per_segment"):
            # Get average segment size if we have per-segment data
            sizes = d["instances_per_segment"]
            if isinstance(sizes, list) and sizes:
                avg_size = sum(sizes) / len(sizes)
                segment_sizes.append(int(avg_size))
    
    classes = [d["num_classes"] for d in datasets if d.get("num_classes")]
    
    import numpy as np
    
    def calc_stats(data, name):
        if not data:
            return {}
        return {
            "name": name,
            "count": len(data),
            "min": int(np.min(data)),
            "max": int(np.max(data)),
            "median": float(np.median(data)),
            "mean": float(np.mean(data)),
            "std": float(np.std(data))
        }
    
    return {
        "num_datasets": len(datasets),
        "instances": calc_stats(instances, "Number of instances"),
        "features": calc_stats(features, "Number of features"),
        "segments": calc_stats(segments, "Number of segments"),
        "segment_sizes": calc_stats(segment_sizes, "Segment size"),
        "classes": calc_stats(classes, "Number of classes")
    }

def main():
    parser = argparse.ArgumentParser(description="Get all dataset information")
    parser.add_argument("--segmented_dir", type=str, default="segmented_datasets",
                        help="Directory containing segmented datasets")
    parser.add_argument("--output_format", type=str, choices=["json", "csv", "summary", "all"], 
                        default="all", help="Output format")
    parser.add_argument("--output_file", type=str, default=None,
                        help="File to save output")
    
    args = parser.parse_args()
    
    # Get all dataset information
    datasets = get_all_dataset_info(args.segmented_dir)
    
    if args.output_format == "json":
        output = json.dumps(datasets, indent=2)
    elif args.output_format == "csv":
        df = pd.DataFrame(datasets)
        output = df.to_csv(index=False)
    elif args.output_format == "summary":
        stats = get_summary_statistics(datasets)
        output = json.dumps(stats, indent=2)
    else:  # all
        result = {
            "datasets": datasets,
            "summary": get_summary_statistics(datasets)
        }
        output = json.dumps(result, indent=2)
    
    if args.output_file:
        with open(args.output_file, 'w') as f:
            f.write(output)
        print(f"Output saved to {args.output_file}")
    else:
        print(output)
    
    # Always return the datasets for programmatic use
    return datasets

if __name__ == "__main__":
    # Run main and also make datasets available
    datasets = main()