#!/usr/bin/env python3

import argparse
import sys
from pathlib import Path
import numpy as np
import random
import os

GLOBAL_SEED = 42

def set_global_seed(seed=GLOBAL_SEED):
    print(f"Setting global random seed: {seed}")

    random.seed(seed)
    np.random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
        print(f"PyTorch random seed set")
    except ImportError:
        print(f"  PyTorch not installed, skipping PyTorch random seed setup")

    print(f"Global random seed setup completed")

set_global_seed()

src_path = Path(__file__).parent / "src"
sys.path.insert(0, str(src_path))

from data import set_active_dataset, WSIDataProcessor
from classifiers import get_classifier, ClassifierFactory
from config import get_classifier_config





def prepare_bags_and_labels(data_processor: WSIDataProcessor,
                           feature_type: str) -> tuple:
    print(f"Loading data and features...")

    train_labels_dict, val_labels_dict, test_labels_dict = data_processor.load_labels_from_excel()
    features_data = data_processor.load_features(feature_type=feature_type)

    train_bags = []
    train_labels = []

    for wsi_name, label in train_labels_dict.items():
        if wsi_name in features_data:
            features = features_data[wsi_name]
            train_bags.append((features, wsi_name))
            train_labels.append(label)

    val_bags = []
    val_labels = []

    for wsi_name, label in val_labels_dict.items():
        if wsi_name in features_data:
            features = features_data[wsi_name]
            val_bags.append((features, wsi_name))
            val_labels.append(label)

    test_bags = []
    test_labels = []

    for wsi_name, label in test_labels_dict.items():
        if wsi_name in features_data:
            features = features_data[wsi_name]
            test_bags.append((features, wsi_name))
            test_labels.append(label)

    train_val_labels = train_labels + val_labels
    unique_labels = sorted(set(train_val_labels))
    class_names = unique_labels
    label_to_idx = {label: idx for idx, label in enumerate(unique_labels)}

    train_labels_idx = [label_to_idx[label] for label in train_labels]
    val_labels_idx = [label_to_idx[label] for label in val_labels]

    test_labels_original = test_labels

    print(f"Data loading completed:")
    print(f"  Training: {len(train_bags)} bags")
    print(f"  Validation: {len(val_bags)} bags")
    print(f"  Testing: {len(test_bags)} bags")
    print(f"  Classes: {class_names}")

    return (train_bags, train_labels_idx, val_bags, val_labels_idx,
            test_bags, test_labels_original, class_names)


def run_single_experiment(dataset_name: str, classifier_type: str,
                         feature_type: str, use_cv: bool = True,
                         n_folds: int = 5, gpu_id: int = None) -> dict:

    print(f"\n{'='*60}")
    print(f"Starting experiment")
    print(f"  Dataset: {dataset_name}")
    print(f"  Classifier: {classifier_type}")
    print(f"  Feature type: {feature_type}")
    print(f"  Cross validation: {use_cv} ({'folds: ' + str(n_folds) if use_cv else ''})")
    print(f"{'='*60}")

    set_active_dataset(dataset_name)
    data_processor = WSIDataProcessor(dataset_name)

    classifier_config = get_classifier_config(classifier_type)

    if gpu_id is not None:
        try:
            import torch
            if torch.cuda.is_available():
                if gpu_id >= torch.cuda.device_count():
                    print(f"Warning: GPU {gpu_id} does not exist, system has {torch.cuda.device_count()} GPUs")
                    print(f"   Using default device settings")
                else:
                    classifier_config['device'] = f'cuda:{gpu_id}'
                    print(f"Using GPU device: cuda:{gpu_id}")
            else:
                print(f"Warning: CUDA not available, ignoring GPU setting, will use CPU")
        except ImportError:
            print(f"Warning: PyTorch not installed, ignoring GPU setting")

    (train_bags, train_labels, val_bags, val_labels,
     test_bags, test_labels, class_names) = prepare_bags_and_labels(data_processor, feature_type)



    print(f"Starting experiment: {dataset_name} - {classifier_type} - {feature_type}")
    print(f"Data statistics: Training({len(train_bags)}) + Validation({len(val_bags)}) + Testing({len(test_bags) if test_bags else 0})")
    print(f"Number of classes: {len(class_names)}")
    print(f"Class names: {class_names}")
    print("=" * 80)
    
    if use_cv:
        from sklearn.model_selection import StratifiedKFold

        all_bags = train_bags + val_bags
        all_labels = train_labels + val_labels

        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=None)
        cv_results = []
        best_model = None
        best_accuracy = 0.0

        print(f"\nStarting {n_folds}-fold cross validation...")

        global_kde_features = None
        if classifier_type == 'kde_quantile_pooling':
            print(f"Precomputing KDE features for all WSIs for KDE classifier...")
            temp_classifier = get_classifier(classifier_type, classifier_config)
            temp_classifier._dataset_name = dataset_name
            temp_classifier._feature_type = feature_type
            temp_classifier.precompute_all_kde_features(all_bags)
            global_kde_features = temp_classifier._global_kde_features

        for fold, (train_idx, val_idx) in enumerate(skf.split(all_bags, all_labels)):
            print(f"\n--- Fold {fold + 1}/{n_folds} ---")

            fold_train_bags = [all_bags[i] for i in train_idx]
            fold_train_labels = [all_labels[i] for i in train_idx]
            fold_val_bags = [all_bags[i] for i in val_idx]
            fold_val_labels = [all_labels[i] for i in val_idx]

            classifier = get_classifier(classifier_type, classifier_config)
            if classifier_type in ['cluster_sampling', 'flexible_cluster_sampling', 'kde_quantile_pooling', 'cluster_covariance']:
                classifier._dataset_name = dataset_name
                classifier._feature_type = feature_type

            if classifier_type == 'kde_quantile_pooling' and global_kde_features is not None:
                classifier._global_kde_features = global_kde_features

            training_result = classifier.fit(
                fold_train_bags, fold_train_labels,
                fold_val_bags, fold_val_labels,
                class_names
            )
            
            cv_results.append({
                'fold': fold + 1,
                'val_accuracy': training_result.final_accuracy,
                'val_auc': training_result.final_auc,
                'val_f1': training_result.final_f1,
                'training_time': training_result.training_time,
                'best_epoch': training_result.best_epoch
            })

            if training_result.final_accuracy > best_accuracy:
                best_accuracy = training_result.final_accuracy
                best_model = classifier
                print(f"Fold {fold + 1} validation accuracy: {training_result.final_accuracy:.4f}, AUC: {training_result.final_auc:.4f}, F1: {training_result.final_f1:.4f} (best)")
            else:
                print(f"Fold {fold + 1} validation accuracy: {training_result.final_accuracy:.4f}, AUC: {training_result.final_auc:.4f}, F1: {training_result.final_f1:.4f}")

        cv_accuracies = [result['val_accuracy'] for result in cv_results]
        cv_aucs = [result['val_auc'] for result in cv_results]
        cv_f1s = [result['val_f1'] for result in cv_results]
        mean_cv_accuracy = np.mean(cv_accuracies)
        std_cv_accuracy = np.std(cv_accuracies)
        mean_cv_auc = np.mean(cv_aucs)
        std_cv_auc = np.std(cv_aucs)
        mean_cv_f1 = np.mean(cv_f1s)
        std_cv_f1 = np.std(cv_f1s)

        print(f"\nCross validation results:")
        print(f"  Fold accuracies: {[f'{acc:.4f}' for acc in cv_accuracies]}")
        print(f"  Mean accuracy: {mean_cv_accuracy:.4f} ± {std_cv_accuracy:.4f}")
        print(f"  Fold AUCs: {[f'{auc:.4f}' for auc in cv_aucs]}")
        print(f"  Mean AUC: {mean_cv_auc:.4f} ± {std_cv_auc:.4f}")
        print(f"  Fold F1s: {[f'{f1:.4f}' for f1 in cv_f1s]}")
        print(f"  Mean F1: {mean_cv_f1:.4f} ± {std_cv_f1:.4f}")

        final_classifier = best_model if best_model is not None else classifier
        print(f"\nUsing best model (accuracy: {best_accuracy:.4f}) for subsequent analysis")
        
    else:
        print(f"\nStarting single training...")

        classifier = get_classifier(classifier_type, classifier_config)
        if classifier_type in ['cluster_sampling', 'flexible_cluster_sampling', 'kde_quantile_pooling', 'cluster_covariance']:
            classifier._dataset_name = dataset_name
            classifier._feature_type = feature_type

        if classifier_type == 'kde_quantile_pooling':
            all_bags = train_bags + val_bags
            classifier.precompute_all_kde_features(all_bags)

        training_result = classifier.fit(
            train_bags, train_labels,
            val_bags, val_labels,
            class_names
        )
        
        mean_cv_accuracy = training_result.final_accuracy
        mean_cv_auc = training_result.final_auc
        mean_cv_f1 = training_result.final_f1
        std_cv_accuracy = 0.0
        std_cv_auc = 0.0
        std_cv_f1 = 0.0
        cv_results = []
        final_classifier = classifier

        print(f"Validation accuracy: {training_result.final_accuracy:.4f}, AUC: {training_result.final_auc:.4f}, F1: {training_result.final_f1:.4f}")




    if test_bags and len(test_bags) > 0:
        print(f"\nTest set prediction...")
        test_wsi_names = [metadata for _, metadata in test_bags]
        test_result = final_classifier.predict(test_bags, test_wsi_names)

        if all(isinstance(label, str) for label in test_labels):
            label_to_idx = {label: idx for idx, label in enumerate(class_names)}
            test_labels_idx = [label_to_idx.get(label, -1) for label in test_labels]
            valid_mask = np.array(test_labels_idx) >= 0
            if np.sum(valid_mask) > 0:
                test_accuracy = np.mean(test_result.predictions[valid_mask] == np.array(test_labels_idx)[valid_mask])
                print(f"Test accuracy: {test_accuracy:.4f} (based on {np.sum(valid_mask)}/{len(test_labels)} valid samples)")
            else:
                test_accuracy = None
                print("No valid labels in test set for accuracy calculation")
        else:
            test_accuracy = np.mean(test_result.predictions == np.array(test_labels))
            print(f"Test accuracy: {test_accuracy:.4f}")
    else:
        test_result = None
        test_accuracy = None
        print("No test data (predefined split test data moved to validation set)")
    
    result = {
        'dataset': dataset_name,
        'classifier': classifier_type,
        'feature_type': feature_type,
        'use_cv': use_cv,
        'n_folds': n_folds if use_cv else 1,
        'mean_cv_accuracy': mean_cv_accuracy,
        'mean_cv_auc': mean_cv_auc,
        'mean_cv_f1': mean_cv_f1,
        'std_cv_accuracy': std_cv_accuracy,
        'std_cv_auc': std_cv_auc,
        'std_cv_f1': std_cv_f1,
        'test_accuracy': test_accuracy,
        'cv_results': cv_results,
        'test_predictions': test_result,
        'classifier_config': classifier_config,
        'training_history': getattr(final_classifier, 'training_history', None)
    }



    return result


def main():
    parser = argparse.ArgumentParser(description='MIL classification experiment system')

    parser.add_argument('--dataset', type=str, default='camelyon16',
                       help='Dataset name (hupo_cancer, camelyon16)')

    parser.add_argument('--classifier', type=str, default='homil',
                       help='Classifier type (homil, attention, mean_pooling, max_pooling, clam_sb, clam_mb, s4mil, transmil)')
    parser.add_argument('--feature_type', type=str, default='original',
                       help='Feature type (original, pca_64d, ae_64d)')
    parser.add_argument('--experiment_mode', type=str, default='single',
                       choices=['single', 'comparison', 'sweep'],
                       help='Experiment mode (single: single experiment, comparison: classifier comparison, sweep: parameter sweep)')

    parser.add_argument('--no_cv', action='store_true',
                       help='Disable cross validation')
    parser.add_argument('--n_folds', type=int, default=5,
                       help='Number of cross validation folds')
    parser.add_argument('--seed', type=int, default=GLOBAL_SEED,
                       help=f'Random seed (default: {GLOBAL_SEED})')




    parser.add_argument('--gpu', type=int, default=None,
                       help='Specify GPU device ID to use (e.g.: 0, 1, 2), if not specified use default device settings')

    args = parser.parse_args()

    if args.seed != GLOBAL_SEED:
        set_global_seed(args.seed)

    print("Available options:")
    print(f"  Classifiers: {ClassifierFactory.list_available()}")

    try:
        if args.experiment_mode == 'single':
            result = run_single_experiment(
                dataset_name=args.dataset,
                classifier_type=args.classifier,
                feature_type=args.feature_type,
                use_cv=not args.no_cv,
                n_folds=args.n_folds,
                gpu_id=args.gpu
            )

            print(f"\nExperiment completed!")
            print(f"Final results: CV Accuracy = {result['mean_cv_accuracy']:.4f}, CV AUC = {result['mean_cv_auc']:.4f}, CV F1 = {result['mean_cv_f1']:.4f}")
            if result['test_accuracy'] is not None:
                print(f"Test accuracy = {result['test_accuracy']:.4f}")

        elif args.experiment_mode == 'comparison':
            classifiers_to_compare = ClassifierFactory.list_available()
            comparison_results = []

            print(f"\nStarting classifier comparison experiment...")
            for classifier_type in classifiers_to_compare:
                print(f"\n--- Testing classifier: {classifier_type} ---")
                try:
                    result = run_single_experiment(
                        dataset_name=args.dataset,
                        classifier_type=classifier_type,
                        feature_type=args.feature_type,
                        use_cv=not args.no_cv,
                        n_folds=args.n_folds,
                        gpu_id=args.gpu
                    )
                    comparison_results.append(result)
                except Exception as e:
                    print(f"Error: Classifier {classifier_type} failed: {e}")

            print(f"\nClassifier comparison results:")
            for result in comparison_results:
                print(f"  {result['classifier']}: CV Acc={result['mean_cv_accuracy']:.4f}±{result['std_cv_accuracy']:.4f}, "
                      f"CV AUC={result['mean_cv_auc']:.4f}±{result['std_cv_auc']:.4f}, "
                      f"CV F1={result['mean_cv_f1']:.4f}±{result['std_cv_f1']:.4f}")

            print(f"\nComparison completed for {len(comparison_results)} classifiers")

        elif args.experiment_mode == 'sweep':
            print("Please use auto_train_v2.py for parameter sweep mode")

    except Exception as e:
        print(f"Error: Experiment failed: {e}")
        import traceback
        traceback.print_exc()
        return 1
    
    return 0


if __name__ == "__main__":
    exit(main())
