import argparse
import numpy as np
import pandas as pd
from data_utils import DATA_MAP, MIXED, ODDS
from compute_metrics import extract_method_metrics, aggregate_cross_validation_results, filter_method_results


def parse_arguments():
    parser = argparse.ArgumentParser(description="Aggregate results across multiple datasets")
    parser.add_argument("--dataset", type=str, default='all', 
                       choices=['all', 'mixed', 'odds'],
                       help="Dataset subset to process")
    parser.add_argument("--setting", type=str, default='semi_supervised', 
                       choices=['semi_supervised', 'unsupervised'])
    parser.add_argument("--exp_dir", type=str, default=None)
    parser.add_argument("--metric", type=str, choices=["AUC-ROC", "F1", "AUC-PR"], 
                       default='AUC-ROC')
    parser.add_argument("--data_dir", type=str, default='data')
    parser.add_argument("--n_splits", type=int, default=5)
    parser.add_argument("--only_normalized", action='store_true', default=False)
    parser.add_argument("--only_ordinal", action='store_true', default=False)
    return parser.parse_args()


def select_datasets(dataset_arg):
    """Select dataset list based on argument."""
    if dataset_arg == 'all':
        return [k for k in DATA_MAP.keys()]
    elif dataset_arg == 'odds':
        return ODDS
    elif dataset_arg == 'mixed':
        return MIXED
    return []


def process_all_datasets(args):
    """Process all datasets and aggregate results."""
    datasets = sorted(select_datasets(args.dataset))
    
    metric_scores = {}
    ranking_dict = {}
    std_scores = {}
    all_metric_values = {}
    
    for dataset_idx, dataset in enumerate(datasets):
        try:
            print("*" * 100)
            print(dataset)
            args.split_idx = None
            args.dataset = dataset
            args.exp_dir = None
            
            results_list = []
            for split_idx in range(args.n_splits):
                args.split_idx = split_idx
                split_results = extract_method_metrics(
                    args, 
                    filter_normalized=args.only_normalized,
                    filter_ordinal=args.only_ordinal
                )
                results_list.append(split_results)
            
            aggregated_metrics, rankings = aggregate_cross_validation_results(results_list)
            filtered_metrics = filter_method_results(aggregated_metrics)
            
            for method_name in filtered_metrics.keys():
                if method_name not in all_metric_values:
                    all_metric_values[method_name] = np.zeros((len(datasets), args.n_splits))
                
                for split_idx in range(args.n_splits):
                    if method_name in results_list[split_idx]:
                        all_metric_values[method_name][dataset_idx, split_idx] = \
                            results_list[split_idx][method_name][
                                ['AUC-ROC', 'AUC-PR', 'F1'].index(args.metric)
                            ]
            
            metric_scores[dataset] = {
                k: np.mean(filtered_metrics[k][args.metric]) 
                for k in filtered_metrics.keys()
            }
            ranking_dict[dataset] = {
                k: int(rankings[args.metric][idx]) 
                for idx, k in enumerate(filtered_metrics.keys())
            }
            std_scores[dataset] = {
                k: np.std(filtered_metrics[k][args.metric]) 
                for k in filtered_metrics.keys()
            }
        except Exception as e:
            print(f"Error processing dataset {dataset}: {e}")
            continue
    
    return metric_scores, ranking_dict, std_scores, all_metric_values, datasets


def save_results_tables(metric_scores, std_scores, all_metric_values, datasets, args):
    """Save aggregated results to CSV files."""
    df_means = pd.DataFrame(metric_scores).T
    avg_row = df_means.mean(axis=0)
    df_means.loc['avg'] = avg_row
    df_means = df_means.round(3)
    print(df_means)
    df_means.to_csv(f'exp/{args.setting}_avg_{args.metric}.csv')
    
    df_stds = pd.DataFrame(std_scores).T
    avg_std = []
    for col in df_stds.columns:
        if col not in all_metric_values:
            avg_std.append(0)
            continue
        std = np.std(np.mean(all_metric_values[col], axis=0))
        avg_std.append(std)
    df_stds.loc['avg'] = avg_std
    df_stds = df_stds.round(3)
    print(df_stds)
    df_stds.to_csv(f'exp/{args.setting}_std_{args.metric}.csv')


def main():
    args = parse_arguments()
    metric_scores, ranking_dict, std_scores, all_metric_values, datasets = \
        process_all_datasets(args)
    save_results_tables(metric_scores, std_scores, all_metric_values, datasets, args)


if __name__ == '__main__':
    main()
