#!/usr/bin/env python3
import argparse
import numpy as np
import time
import sys
from pathlib import Path
from typing import Dict, Any, Optional

src_path = Path(__file__).parent / "src"
sys.path.insert(0, str(src_path))

from utils.clustering_handlers import get_clustering_handler
from utils.clustering_handlers.base_handler import BaseClusteringHandler


def cluster(data: np.ndarray, cluster_handler: BaseClusteringHandler) -> np.ndarray:
    return cluster_handler.cluster(data)





def process_single_wsi(features: np.ndarray,
                      wsi_name: str,
                      cluster_handler: BaseClusteringHandler,
                      min_patches_threshold: int = 5) -> Optional[np.ndarray]:
    n_patches = features.shape[0]

    if n_patches < min_patches_threshold:
        print(f"WARNING: WSI {wsi_name}: patch count ({n_patches}) below threshold ({min_patches_threshold}), skipping clustering")
        return None

    try:
        start_time = time.time()
        clustered_features = cluster(features, cluster_handler)
        clustering_time = time.time() - start_time

        clustering_stats = None
        if hasattr(cluster_handler, '_last_clustering_stats'):
            clustering_stats = cluster_handler._last_clustering_stats

        if clustering_stats and 'timing' in clustering_stats:
            timing = clustering_stats['timing']
            print(f"INFO: WSI {wsi_name}: granular ball clustering completed - "
                  f"patches: {clustering_stats['n_patches']}, "
                  f"balls: {clustering_stats['n_clusters']}, "
                  f"noise points: {clustering_stats['n_noise_points']}")
            print(f"      timing analysis: graph construction {timing['graph_construction']:.2f}s, "
                  f"clustering {timing['clustering']:.2f}s, "
                  f"noise reassignment {timing['noise_reassignment']:.2f}s, "
                  f"total {timing['total']:.2f}s")
        elif clustering_stats:
            print(f"INFO: WSI {wsi_name}: granular ball clustering completed - "
                  f"patches: {clustering_stats['n_patches']}, "
                  f"balls: {clustering_stats['n_clusters']}, "
                  f"noise points: {clustering_stats['n_noise_points']}, "
                  f"time: {clustering_time:.2f}s")
        else:
            print(f"INFO: WSI {wsi_name}: clustering completed - "
                  f"patches: {n_patches}, clusters: {clustered_features.shape[0]}, "
                  f"time: {clustering_time:.2f}s")

        return clustered_features

    except Exception as e:
        print(f"ERROR: WSI {wsi_name}: clustering failed - {str(e)}")
        return None


def process_directory(input_dir: str,
                     output_dir: str,
                     cluster_handler: BaseClusteringHandler,
                     file_pattern: str = "*.npy",
                     min_patches_threshold: int = 5) -> Dict[str, Any]:
    input_path = Path(input_dir)
    output_path = Path(output_dir)

    if not input_path.exists():
        raise FileNotFoundError(f"Input directory does not exist: {input_dir}")

    output_path.mkdir(parents=True, exist_ok=True)

    npy_files = list(input_path.glob(file_pattern))
    if not npy_files:
        print(f"WARNING: No matching files found in directory {input_dir}: {file_pattern}")
        return {"total_files": 0, "processed_files": 0, "failed_files": 0, "skipped_files": 0}

    print(f"INFO: Starting to process {len(npy_files)} WSI feature files")
    print(f"INFO: Input directory: {input_dir}")
    print(f"INFO: Output directory: {output_dir}")
    print(f"INFO: Clustering algorithm: {cluster_handler.get_algorithm_name()}")
    print(f"INFO: Clustering mode: adaptive cluster number determination")

    stats = {
        "total_files": len(npy_files),
        "processed_files": 0,
        "failed_files": 0,
        "skipped_files": 0,
        "total_patches": 0,
        "total_clusters": 0,
        "processing_time": 0
    }

    start_time = time.time()

    for i, npy_file in enumerate(npy_files, 1):
        wsi_name = npy_file.stem
        output_file = output_path / npy_file.name

        try:
            print(f"INFO: [{i}/{len(npy_files)}] Processing WSI {wsi_name}...")
            features = np.load(npy_file)

            if features.ndim != 2:
                print(f"ERROR: WSI {wsi_name}: Feature data dimension error, expected 2D, got {features.ndim}D")
                stats["failed_files"] += 1
                continue

            stats["total_patches"] += features.shape[0]

            clustered_features = process_single_wsi(
                features, wsi_name, cluster_handler, min_patches_threshold
            )

            if clustered_features is not None:
                np.save(output_file, clustered_features)
                stats["processed_files"] += 1
                stats["total_clusters"] += clustered_features.shape[0]

                print(f"INFO: WSI {wsi_name}: Clustering results saved to {output_file}")
            else:
                stats["skipped_files"] += 1

        except Exception as e:
            print(f"ERROR: WSI {wsi_name}: Processing failed - {str(e)}")
            stats["failed_files"] += 1

    stats["processing_time"] = time.time() - start_time

    print("=" * 60)
    print("Processing completed! Statistics:")
    print(f"  Total files: {stats['total_files']}")
    print(f"  Successfully processed: {stats['processed_files']}")
    print(f"  Skipped files: {stats['skipped_files']}")
    print(f"  Failed files: {stats['failed_files']}")
    print(f"  Total patches: {stats['total_patches']}")
    print(f"  Total clusters: {stats['total_clusters']}")
    print(f"  Processing time: {stats['processing_time']:.2f}s")

    if stats['total_patches'] > 0:
        avg_clusters_per_wsi = stats['total_clusters'] / stats['processed_files'] if stats['processed_files'] > 0 else 0
        avg_patches_per_wsi = stats['total_patches'] / stats['total_files']
        compression_ratio = stats['total_clusters'] / stats['total_patches'] if stats['total_patches'] > 0 else 0

        print(f"  Average patches per WSI: {avg_patches_per_wsi:.1f}")
        print(f"  Average clusters per WSI: {avg_clusters_per_wsi:.1f}")
        print(f"  Compression ratio: {compression_ratio:.4f}")

    return stats


def copy_features_directly(input_dir: str,
                          output_dir: str,
                          file_pattern: str = "*.npy") -> Dict[str, Any]:
    input_path = Path(input_dir)
    output_path = Path(output_dir)

    if not input_path.exists():
        raise FileNotFoundError(f"Input directory does not exist: {input_dir}")

    output_path.mkdir(parents=True, exist_ok=True)

    npy_files = list(input_path.glob(file_pattern))
    if not npy_files:
        print(f"WARNING: No matching files found in directory {input_dir}: {file_pattern}")
        return {"total_files": 0, "processed_files": 0, "failed_files": 0, "total_patches": 0}

    print(f"INFO: Starting to copy {len(npy_files)} WSI feature files (no clustering)")
    print(f"INFO: Input directory: {input_dir}")
    print(f"INFO: Output directory: {output_dir}")

    stats = {
        "total_files": len(npy_files),
        "processed_files": 0,
        "failed_files": 0,
        "total_patches": 0,
        "processing_time": 0
    }

    start_time = time.time()

    for i, npy_file in enumerate(npy_files, 1):
        wsi_name = npy_file.stem
        output_file = output_path / npy_file.name

        try:
            print(f"INFO: [{i}/{len(npy_files)}] Copying WSI {wsi_name}...")

            features = np.load(npy_file)

            if features.ndim != 2:
                print(f"ERROR: WSI {wsi_name}: Feature data dimension error, expected 2D, got {features.ndim}D")
                stats["failed_files"] += 1
                continue

            stats["total_patches"] += features.shape[0]

            np.save(output_file, features)
            stats["processed_files"] += 1

            print(f"INFO: WSI {wsi_name}: Original features copied to {output_file}")

        except Exception as e:
            print(f"ERROR: WSI {wsi_name}: Copy failed - {str(e)}")
            stats["failed_files"] += 1

    stats["processing_time"] = time.time() - start_time

    print("=" * 60)
    print("Copy completed! Statistics:")
    print(f"  Total files: {stats['total_files']}")
    print(f"  Successfully copied: {stats['processed_files']}")
    print(f"  Failed files: {stats['failed_files']}")
    print(f"  Total patches: {stats['total_patches']}")
    print(f"  Processing time: {stats['processing_time']:.2f}s")

    if stats['total_patches'] > 0:
        avg_patches_per_wsi = stats['total_patches'] / stats['total_files']
        print(f"  Average patches per WSI: {avg_patches_per_wsi:.1f}")

    return stats



def main():
    parser = argparse.ArgumentParser(description="WSI feature clustering tool")

    parser.add_argument(
        '--algorithm',
        type=str,
        default="dbscan",
        help='Clustering algorithm (dbscan, kmeans)'
    )

    parser.add_argument(
        '--eps-percentile',
        type=int,
        default=75,
        choices=range(1, 100),
        help='Percentile for estimating eps parameter in DBSCAN algorithm (1-99, default 70)'
    )

    parser.add_argument(
        '--min-samples',
        type=int,
        default=4,
        help='minPts parameter for DBSCAN algorithm, minimum neighbors for core points (default PCA dimension+1=33)'
    )

    parser.add_argument(
        '--dataset',
        type=str,
        default="camelyon16",
        help='Dataset to use'
    )

    parser.add_argument(
        '--n_clusters',
        type=int,
        default=None,
        help='Manually specify number of clusters, only for kmeans algorithm'
    )

    args = parser.parse_args()

    if args.n_clusters is not None:
        if args.n_clusters <= 1:
            print("Error: --n_clusters must be a positive integer greater than 1")
            return 1

    dataset = args.dataset
    random_state = 42
    min_patches = 5

    input_dir = f"datasets/{dataset}/raw_features0"
    output_dir = f"datasets/{dataset}/raw_features"

    try:
        if args.algorithm is None:
            print("INFO: No clustering algorithm specified, will directly copy original feature data")
            stats = copy_features_directly(
                input_dir=input_dir,
                output_dir=output_dir,
                file_pattern="*.npy"
            )

            if stats['processed_files'] == 0:
                print("Warning: No files were successfully copied")
                return 1

            print(f"Successfully completed WSI feature copying!")
            print(f"   Copied {stats['processed_files']}/{stats['total_files']} files")

        else:
            print(f"INFO: Using {args.algorithm} algorithm for clustering processing")

            if args.algorithm == 'dbscan' and args.n_clusters is not None:
                print(f"WARNING: {args.algorithm} algorithm does not support manually specified cluster numbers, --n_clusters parameter will be ignored")

            clustering_kwargs = {
                'random_state': random_state
            }

            if args.algorithm == 'dbscan':
                clustering_kwargs['eps_percentile'] = getattr(args, 'eps_percentile', 70)
                if args.min_samples is not None:
                    clustering_kwargs['min_samples'] = args.min_samples
            elif args.algorithm == 'kmeans':
                if args.n_clusters is not None:
                    clustering_kwargs['fixed_n_clusters'] = args.n_clusters

            if args.algorithm == 'dbscan':
                min_samples_info = f"minPts={args.min_samples}" if args.min_samples is not None else "minPts=PCA dimension+1(default 33)"
                eps_percentile = getattr(args, 'eps_percentile', 70)
                print(f"INFO: DBSCAN will automatically determine optimal cluster number ({min_samples_info}, eps percentile={eps_percentile})")
            elif args.algorithm == 'kmeans':
                if args.n_clusters is not None:
                    print(f"INFO: {args.algorithm} will use fixed cluster number: {args.n_clusters}")
                else:
                    print(f"INFO: {args.algorithm} will automatically determine optimal cluster number (search range: 2-30)")

            cluster_handler = get_clustering_handler(args.algorithm, **clustering_kwargs)

            stats = process_directory(
                input_dir=input_dir,
                output_dir=output_dir,
                cluster_handler=cluster_handler,
                file_pattern="*.npy",
                min_patches_threshold=min_patches
            )

            if stats['processed_files'] == 0:
                print("Warning: No files were successfully processed")
                return 1

            print(f"Successfully completed WSI feature clustering processing!")
            print(f"   Processed {stats['processed_files']}/{stats['total_files']} files")

        return 0

    except Exception as e:
        print(f"Error occurred during processing: {str(e)}")
        return 1


if __name__ == "__main__":
    exit(main())
