import os
from pathlib import Path
import argparse
from sklearn import metrics
import numpy as np
import pandas as pd
from data_utils import load_data, DATA_MAP


def parse_arguments():
    parser = argparse.ArgumentParser(description="Compute evaluation metrics for DiSPaT experiments")
    parser.add_argument("--dataset", type=str, default='wine', 
                       choices=[d.lower() for d in DATA_MAP.keys()],
                       help="Dataset name")
    parser.add_argument("--exp_dir", type=str, default=None,
                       help="Experiment directory path")
    parser.add_argument("--setting", type=str, default='semi_supervised', 
                       choices=['semi_supervised', 'unsupervised'])
    parser.add_argument("--exp_base_dir", type=str, default='exp', 
                       help="Base directory for experiments")
    parser.add_argument("--data_dir", type=str, default='data')
    parser.add_argument("--n_splits", type=int, default=5)
    parser.add_argument("--split_idx", type=int, default=None)
    return parser.parse_args()


def compute_detection_metrics(y_true, y_scores):
    """
    Compute anomaly detection evaluation metrics.
    
    Args:
        y_true: Ground truth labels (0=normal, 1=anomaly)
        y_scores: Anomaly scores (higher = more anomalous)
    
    Returns:
        Tuple of (AUC-ROC, AUC-PR, F1, Precision, Recall)
    """
    n_samples = len(y_true)
    shuffled_idx = np.random.permutation(n_samples)
    y_true_shuffled = y_true[shuffled_idx]
    y_scores_shuffled = y_scores[shuffled_idx]
    
    n_anomalies = len(np.where(y_true_shuffled == 1)[0])
    top_k_indices = np.argpartition(y_scores_shuffled, -n_anomalies)[-n_anomalies:]
    y_pred = np.zeros_like(y_true_shuffled)
    y_pred[top_k_indices] = 1
    
    y_true_shuffled = y_true_shuffled.astype(int)
    precision, recall, f1, _ = metrics.precision_recall_fscore_support(
        y_true_shuffled, y_pred, average='binary'
    )
    
    auc_roc = metrics.roc_auc_score(y_true_shuffled, y_scores_shuffled)
    auc_pr = metrics.average_precision_score(y_true_shuffled, y_scores_shuffled)
    
    return auc_roc, auc_pr, f1, precision, recall


def extract_method_metrics(args, filter_raw=False, filter_normalized=False, filter_ordinal=False):
    """
    Extract and compute metrics for all methods from score files.
    
    Args:
        args: Parsed arguments
        filter_raw: Skip raw score files
        filter_normalized: Only include normalized baseline scores
        filter_ordinal: Only include ordinal baseline scores
    
    Returns:
        Dictionary mapping method names to metric tuples
    """
    X_train, X_test, y_train, y_test = load_data(args)
    if isinstance(y_test, pd.Series):
        y_test = np.array(y_test)
    
    if args.exp_dir is None:
        args.exp_dir = Path(args.exp_base_dir) / args.dataset / args.setting / \
                      f"split{args.n_splits}" / f"split{args.split_idx}"
    else:
        args.exp_dir = Path(args.exp_dir)
    
    score_dir = args.exp_dir / 'scores'
    if not score_dir.exists():
        raise ValueError(f"Score directory {score_dir} does not exist")
    
    method_metrics = {}
    for score_file in score_dir.iterdir():
        if not score_file.name.endswith('.npy'):
            continue
        
        if score_file.name.startswith('raw') and filter_raw:
            continue
        
        if is_baseline_method(score_file.name):
            if filter_normalized and 'normalized' not in score_file.name:
                continue
            if filter_ordinal and 'ordinal' not in score_file.name:
                continue
        
        method_name = '.'.join(score_file.name.split('.')[:-1])
        if method_name == 'rdp':
            continue
        
        scores = np.load(score_file)
        if np.isnan(scores).any() or np.isinf(scores).any():
            print(f"Invalid scores detected for {method_name}, setting metrics to zero")
            method_metrics[method_name] = [0, 0, 0, 0, 0]
        else:
            metrics_tuple = compute_detection_metrics(y_test, scores)
            method_metrics[method_name] = list(metrics_tuple)
    
    rankings = compute_method_rankings(method_metrics)
    print_results_table(method_metrics, rankings)
    
    return method_metrics


def is_baseline_method(filename):
    """Check if filename corresponds to a baseline method (not DiSPaT)."""
    return 'dispat' not in filename.lower() and 'anollm' not in filename.lower()


def compute_method_rankings(method_metrics):
    """Compute rankings for each metric across methods."""
    if not method_metrics:
        return []
    
    n_metrics = len(list(method_metrics.values())[0])
    rankings = []
    method_names = list(method_metrics.keys())
    
    for metric_idx in range(n_metrics):
        metric_scores = [-method_metrics[k][metric_idx] for k in method_names]
        ranking = np.argsort(metric_scores).argsort() + 1
        rankings.append(ranking)
    
    return rankings


def print_results_table(method_metrics, rankings):
    """Print formatted results table with rankings."""
    print("-" * 100)
    method_names = list(method_metrics.keys())
    for idx, (method, metrics) in enumerate(method_metrics.items()):
        auc_roc, auc_pr, f1, precision, recall = metrics
        print(f"{method:30s}: AUC-ROC: {auc_roc:.4f} ({rankings[0][idx]:2d}), "
              f"AUC-PR: {auc_pr:.4f} ({rankings[1][idx]:2d}), "
              f"F1: {f1:.4f} ({rankings[2][idx]:2d}), "
              f"P: {precision:.4f} ({rankings[3][idx]:2d}), "
              f"R: {recall:.4f} ({rankings[4][idx]:2d})")


def filter_method_results(results_dict):
    """Filter and normalize method names in results."""
    filtered = {}
    for method_name, metrics in results_dict.items():
        new_name = method_name
        if not is_baseline_method(method_name):
            if '_lora' in method_name:
                base_name = method_name.replace('_lora', '')
                if base_name in results_dict:
                    continue
                new_name = base_name
        filtered[new_name] = metrics
    return filtered


def aggregate_cross_validation_results(results_list):
    """
    Aggregate metrics across multiple cross-validation splits.
    
    Args:
        results_list: List of result dictionaries, one per split
    
    Returns:
        Tuple of (aggregated_metrics_dict, rankings_dict)
    """
    if not results_list:
        return {}, {}
    
    all_methods = list(results_list[0].keys())
    aggregated = {
        method: {
            'AUC-ROC': [], 'AUC-PR': [], 'F1': [], 'P': [], 'R': []
        }
        for method in all_methods
    }
    
    for split_results in results_list:
        for method in all_methods:
            try:
                aggregated[method]['AUC-ROC'].append(split_results[method][0])
                aggregated[method]['AUC-PR'].append(split_results[method][1])
                aggregated[method]['F1'].append(split_results[method][2])
                aggregated[method]['P'].append(split_results[method][3])
                aggregated[method]['R'].append(split_results[method][4])
            except KeyError:
                print(f"Incomplete results for {method}")
                if method in aggregated:
                    del aggregated[method]
                for split_res in results_list:
                    if method in split_res:
                        del split_res[method]
                break
    
    print("-" * 100)
    rankings = compute_aggregate_rankings(aggregated)
    print_aggregate_table(aggregated, rankings)
    
    return aggregated, rankings


def compute_aggregate_rankings(aggregated_metrics):
    """Compute rankings based on mean aggregated metrics."""
    rankings = {}
    if not aggregated_metrics:
        return rankings
    
    method_names = list(aggregated_metrics.keys())
    for metric_name in ['AUC-ROC', 'AUC-PR', 'F1', 'P', 'R']:
        mean_scores = [-np.mean(aggregated_metrics[k][metric_name]) 
                      for k in method_names]
        ranking = np.argsort(mean_scores).argsort() + 1
        rankings[metric_name] = ranking
    
    return rankings


def print_aggregate_table(aggregated_metrics, rankings):
    """Print formatted aggregate results table."""
    method_names = list(aggregated_metrics.keys())
    for idx, method in enumerate(method_names):
        metrics = aggregated_metrics[method]
        print(f"{method:30s}: "
              f"AUC-ROC: {np.mean(metrics['AUC-ROC']):.4f} ± {np.std(metrics['AUC-ROC']):.4f} ({rankings['AUC-ROC'][idx]:2d}), "
              f"AUC-PR: {np.mean(metrics['AUC-PR']):.4f} ± {np.std(metrics['AUC-PR']):.4f} ({rankings['AUC-PR'][idx]:2d}), "
              f"F1: {np.mean(metrics['F1']):.4f} ± {np.std(metrics['F1']):.4f} ({rankings['F1'][idx]:2d}), "
              f"P: {np.mean(metrics['P']):.4f} ± {np.std(metrics['P']):.4f} ({rankings['P'][idx]:2d}), "
              f"R: {np.mean(metrics['R']):.4f} ± {np.std(metrics['R']):.4f} ({rankings['R'][idx]:2d})")


def main():
    args = parse_arguments()
    if args.split_idx is None:
        results_list = []
        for split_idx in range(args.n_splits):
            args.split_idx = split_idx
            args.exp_dir = None
            split_results = extract_method_metrics(args)
            results_list.append(split_results)
        aggregate_cross_validation_results(results_list)
    else:
        print(args)
        extract_method_metrics(args)


if __name__ == '__main__':
    main()
