import numpy as np
import pandas as pd
from sklearn.datasets import make_classification, make_blobs
from sklearn.model_selection import train_test_split, GridSearchCV, KFold, StratifiedKFold, StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from tqdm import tqdm
from sklearn.feature_selection import mutual_info_classif, f_classif
from skfeature.function.similarity_based import fisher_score, lap_score, reliefF, SPEC
from skfeature.function.statistical_based import gini_index, t_score
from skfeature.utility.construct_W import construct_W
import os
import scipy.io
from joblib import Parallel, delayed
from itertools import product
from Functions_gpu import Kernel_matrix, LG_sym, calc_differential_vec
from ELVES_gpu import Differential_method, Shared_space, Multiple_latent_variables
import cupy as cp
from cuml.svm import SVC as cuSVC
from cuml.model_selection import GridSearchCV as cuGridSearchCV
from cupyx.scipy.spatial.distance import pdist
from cuml.preprocessing import MinMaxScaler
import json
from ast import literal_eval
import gc   
from filelock import FileLock
from ManiFeSt_gpu import ManiFeSt_gpu



def calculate_sigma(X, percentile):
    """ Calculate the specified percentile of the Euclidean distance as sigma """
    if isinstance(X, cp.ndarray):
        X = X.get()
    distances = pdist(X, metric='euclidean')
    return np.percentile(distances, percentile)


mem_pool = cp.cuda.MemoryPool()
cp.cuda.set_allocator(mem_pool.malloc)

# ================== Cache management module ==================
import hashlib
import pickle

CACHE_DIR = "/Prostate/baseline/Cache_Lap"  
os.makedirs(CACHE_DIR, exist_ok=True)

def get_train_cache_key(params, seeds, outer_iter, fold_idx):
    key_str = f"train_{params}_{seeds}_iter{outer_iter}_fold{fold_idx}"
    return hashlib.md5(key_str.encode()).hexdigest()

def get_test_cache_key(params, seeds, outer_iter):
    key_str = f"test_{params}_{seeds}_iter{outer_iter}"
    return hashlib.md5(key_str.encode()).hexdigest()

def save_cache(scores, cache_key):
    cache_path = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
    with FileLock(cache_path + ".lock"):
        with open(cache_path, "wb") as f:
            pickle.dump(cp.asnumpy(scores), f)  

def load_cache(cache_key):
    cache_path = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
    if not os.path.exists(cache_path):
        return None
    with FileLock(cache_path + ".lock"):
        with open(cache_path, "rb") as f:
            return cp.array(pickle.load(f))  

def task_to_hashable(task):
    n_feat, seeds, outer_iter = task[:3]
    seeds_tuple = tuple(sorted(seeds.items()))
    return (n_feat, seeds_tuple, outer_iter)




def evaluate_params(params, X_train, y_train, n_feat, seeds, outer_iter, inner_cv):
    try:
        fold_scores = []
        for fold_idx, (train_idx, val_idx) in enumerate(inner_cv.split(X_train.get(), y_train.get())):
            cache_key = get_train_cache_key(
                params, 
                tuple(sorted(seeds.values())), 
                outer_iter,
                fold_idx
            )
            
            X_tr = X_train[train_idx]
            y_tr = y_train[train_idx]

            cached_scores = load_cache(cache_key)
            if cached_scores is not None:
                scores = cached_scores
                del cached_scores  
                # print(f"Reuse training cache: {cache_key}")
            else: 
                # feature selection: manifest
                # scores = ManiFeSt_gpu(X_tr, y_tr, *params)

                X_tr_cpu = X_tr.get()
                y_tr_cpu = y_tr.get()
                
                # feature selection: ReliefF  
                # scores = reliefF.reliefF(X_tr_cpu, y_tr_cpu, k=params[0])  

                # feature selection: IG
                # scores = mutual_info_classif(X_tr_cpu, y_tr_cpu, n_neighbors=params[0])

                # feature selection: Laplacian Score
                sigma = calculate_sigma(X_tr_cpu, params[0])
                W = construct_W(X_tr_cpu, weight_mode='heat_kernel', metric='euclidean', t=float(sigma))
                scores = lap_score.lap_score(X_tr_cpu, W=W)
                scores = -scores  
                
                scores = cp.asarray(scores)
                save_cache(scores, cache_key)  
                del X_tr_cpu, y_tr_cpu  
            selected = cp.argsort(scores)[-n_feat:]
            
            svm = cuSVC(kernel='rbf', C=1.0, gamma='scale')
            svm.fit(X_tr[:, selected], y_tr)
            
            acc = svm.score(X_train[val_idx][:, selected], y_train[val_idx])
            fold_scores.append(acc)

            del X_tr, y_tr, svm, scores, selected
            cp.get_default_memory_pool().free_all_blocks()
        
        return np.mean(fold_scores)
    except Exception as e:
        print(f"Parameter evaluation failed: {str(e)}")
        return 0.0

def process_iteration(n_feat, seeds, outer_iter, baseline_combinations):
    try:
        cp.cuda.Device(0).use()
        cp.cuda.set_allocator(mem_pool.malloc)
        
        splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.1, 
                                       random_state=seeds['seed1'] + outer_iter)
        train_idx, test_idx = next(splitter.split(cp.zeros(len(y_gpu)), y_gpu.get()))
        
        X_train = X_gpu[train_idx]
        y_train = y_gpu[train_idx]
        X_test = X_gpu[test_idx]
        y_test = y_gpu[test_idx]

        inner_cv = StratifiedKFold(n_splits=n_inner_folds, shuffle=True, random_state=seeds['seed2'])
        with Parallel(n_jobs=16, verbose=0) as parallel:
            scores = parallel(
                delayed(evaluate_params)(params, X_train, y_train, n_feat, seeds, outer_iter, inner_cv)
                for params in baseline_combinations
            )
        
        best_idx = np.argmax(scores)
        best_score = scores[best_idx]
        best_params = baseline_combinations[best_idx]

        test_cache_key = get_test_cache_key(
        best_params, 
        seeds['seed1'], 
        outer_iter
        )

        cached_scores_full = load_cache(test_cache_key)
        if cached_scores_full is not None:
            scores_full = cached_scores_full
            del cached_scores_full  
            # print(f"Reuse the test cache: {test_cache_key}")
        else:
            # feature selection: manifest
            # scores_full = ManiFeSt_gpu(X_train, y_train, *best_params)

            X_train_cpu = X_train.get()
            y_train_cpu = y_train.get()
            # feature selection: ReliefF  
            # scores_full = reliefF.reliefF(X_train_cpu, y_train_cpu, k=best_params[0])  
            
            # feature selection: IG
            # scores_full = mutual_info_classif(X_train_cpu, y_train_cpu, n_neighbors=best_params[0])    

            # feature selection: Laplacian Score
            sigma = calculate_sigma(X_train_cpu, best_params[0])
            W = construct_W(X_train_cpu, weight_mode='heat_kernel', metric='euclidean', t=float(sigma))
            scores_full = lap_score.lap_score(X_train_cpu, W=W)
            scores_full = -scores_full  
            
            scores_full = cp.asarray(scores_full)
            save_cache(scores_full, test_cache_key)
            del X_train_cpu, y_train_cpu  
        selected = cp.argsort(scores_full)[-n_feat:]

        y_train_cpu = cp.asnumpy(y_train).astype(np.int32)

        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=seeds['seed3'])    # random_state=42
        cv_splits = list(cv.split(X_train[:, selected].get(), y_train_cpu))  
        
        grid_search = cuGridSearchCV(
            cuSVC(kernel='rbf'),
            svm_param_grid,
            cv=cv_splits
        )
        grid_search.fit(X_train[:, selected], y_train)
        test_acc = grid_search.score(X_test[:, selected], y_test)
        test_error = 1.0 - test_acc if not cp.isnan(test_acc) else cp.nan   

        best_svm_params = {
            k: v.item() if isinstance(v, cp.ndarray) else v
            for k, v in grid_search.best_params_.items()
        }

        del grid_search, X_train, y_train, X_test, y_test, scores_full, selected
        cp.get_default_memory_pool().free_all_blocks()

        return {
            'n_features': n_feat,
            'best_params': str(best_params),
            'seeds': json.dumps(seeds),
            'outer_iter': outer_iter,
            'val_accuracy': float(best_score),
            'test_accuracy': float(test_acc),
            'test_error': float(test_error),
            'best_svm_params': str(best_svm_params)
        }
        
    except Exception as e:
        print(f"Iteration Exception: {str(e)}")
        import traceback
        traceback.print_exc()  
        return None
    finally:
        mem_pool.free_all_blocks()
        cp._default_memory_pool.free_all_blocks()
        cp.get_default_memory_pool().free_all_blocks()
        cp.get_default_pinned_memory_pool().free_all_blocks()
        gc.collect()
        # cp.cuda.Stream.null.synchronize()

if __name__ == "__main__":
    mat_data = scipy.io.loadmat('Prostate_GE.mat')
    X_gpu = cp.asarray(mat_data['X'], dtype=cp.float32)
    y_gpu = cp.asarray(mat_data['Y'].ravel().astype(np.int32))
    y_gpu = (y_gpu < 2).astype(cp.int32)

    param_config = {
        'n_features': [8,10,20]+list(range(50,201,50))+list(range(300,501,100)),       
        'elves_params': {
                # 'k': [(x,) for x in [1,3,5,10,15,20,30,50,100]]  # Make sure parameters are passed as tuples Parameter grid for IG and ReliefF
                'percentile': [(x,) for x in [1,5,10,30,50,70,90,95,99]]   # Parameter grid of Laplacian Score
        },
        # 'elves_params': {
        #         'kernel_scale_factor': [1.0], 
        #         'use_spsd': [True],
        #         'percentile': [5, 10, 30, 50, 70, 90, 95],    # Parameter grid of ManiFeSt
        # },
        'seeds': {
            'seed1': [42], 
            'seed2': [42],
            'seed3': [42]
        }
    }
    svm_param_grid = {
        'C': [2 ** i for i in [-5, -2, 1, 4, 7, 10, 13]],
        'gamma': [2 ** i for i in [-15, -12, -9, -6, -3, 0, 3]]
    }
    n_outer_iter = 30  # outer loop times (30 independent divisions)
    n_inner_folds = 10  # inner cross validation folds

    result_path = '/Prostate/baseline/results.csv'
    checkpoint_path = '/Prostate/baseline/checkpoint.json'

    os.makedirs(os.path.dirname(result_path), exist_ok=True)

    tasks = []
    elves_combos = param_config['elves_params']['percentile']
    for n_feat in param_config['n_features']:
        for seeds in [dict(zip(param_config['seeds'].keys(), vals)) 
                    for vals in product(*param_config['seeds'].values())]:
            for iter in range(n_outer_iter):
                tasks.append((n_feat, seeds, iter, elves_combos))

    processed = set()
    if os.path.exists(checkpoint_path):
        try:
            with open(checkpoint_path, 'r') as f:
                processed_data = json.load(f)
                processed = set(tuple(literal_eval(h)) for h in processed_data)
                print(f"{len(processed)} checkpoints loaded")
        except Exception as e:
            print(f"Loading checkpoint failed: {str(e)}，will restart")
            processed = set()

    pending_tasks = [task for task in tasks if task_to_hashable(task) not in processed]
    print(f"Total number of tasks: {len(tasks)} | Remaining tasks: {len(pending_tasks)}")

    for task in tqdm(pending_tasks, desc="Processing progress", total=len(pending_tasks)):
        task_hash = task_to_hashable(task)
        
        result = process_iteration(*task)
        if not result:
            continue
            
        try:
            with FileLock(result_path + ".lock"):
                write_header = not os.path.exists(result_path) or os.stat(result_path).st_size == 0
                with open(result_path, 'a') as f:
                    df = pd.DataFrame([result])
                    df.to_csv(f, header=write_header, index=False)
            
            with FileLock(checkpoint_path + ".lock"):
                current_processed = set()
                if os.path.exists(checkpoint_path):
                    with open(checkpoint_path, 'r') as f:
                        current_processed = set(tuple(literal_eval(h)) for h in json.load(f))
                current_processed.add(task_hash)

                temp_checkpoint = checkpoint_path + ".tmp." + str(os.getpid())
                with open(temp_checkpoint, 'w') as f:
                    json.dump([str(h) for h in current_processed], f)
                os.replace(temp_checkpoint, checkpoint_path)
                
            processed = current_processed
            
        except Exception as e:
            print(f"Save failed: {task_hash}，error: {str(e)}")
            if os.path.exists(temp_checkpoint):
                os.remove(temp_checkpoint)
        finally:
            if os.path.exists(temp_checkpoint):
                os.remove(temp_checkpoint)

    # Results Analysis
    try:
        final_df = pd.read_csv(result_path)
        summary = (
            final_df.groupby(['n_features', 'seeds'])
            .agg({
                'val_accuracy': ['mean', 'std'],
                'test_accuracy': ['mean', 'std'],
                'test_error': ['mean', 'std']
            })
            .reset_index()
        )
        print("\nThe average validation accuracy and test accuracy under each combination:")
        print(summary.to_string(index=False))
        summary.to_csv('/Prostate/baseline/analysis.csv', index=False)
    except FileNotFoundError:
        print("No result file found, please check the run log")

