import numpy as np
import pandas as pd
from sklearn.datasets import make_classification, make_blobs
from sklearn.model_selection import train_test_split, GridSearchCV, KFold, StratifiedKFold, StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from tqdm import tqdm
# from sklearn.feature_selection import mutual_info_classif, f_classif
# from skfeature.function.similarity_based import fisher_score, lap_score, reliefF, SPEC
# from skfeature.function.statistical_based import gini_index, t_score
# from skfeature.utility.construct_W import construct_W
import os
import scipy.io
from joblib import Parallel, delayed
from itertools import product
from Functions_gpu import Kernel_matrix, LG_sym, calc_differential_vec
from ELVES_gpu import Differential_method, Shared_space, Multiple_latent_variables
import cupy as cp
from cuml.svm import SVC as cuSVC
from cuml.model_selection import GridSearchCV as cuGridSearchCV
from cupyx.scipy.spatial.distance import pdist
from cuml.preprocessing import MinMaxScaler
import json
from ast import literal_eval
import gc   

# ================== Cache management module ==================
import hashlib
import pickle
from filelock import FileLock

CACHE_DIR = "/Prostate/Cache"  
os.makedirs(CACHE_DIR, exist_ok=True)

def get_train_cache_key(params, seeds, outer_iter, fold_idx):
    """generate training phase cache key (including fold information)"""
    key_str = f"train_{params}_{seeds}_iter{outer_iter}_fold{fold_idx}"
    return hashlib.md5(key_str.encode()).hexdigest()

def get_test_cache_key(params, seeds, outer_iter):
    """generate test cache key"""
    key_str = f"test_{params}_{seeds}_iter{outer_iter}"
    return hashlib.md5(key_str.encode()).hexdigest()

def save_cache(scores, cache_key):
    """save cache to file"""
    cache_path = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
    with FileLock(cache_path + ".lock"):
        with open(cache_path, "wb") as f:
            pickle.dump(cp.asnumpy(scores), f)  

def load_cache(cache_key):
    """loading cache from file"""
    cache_path = os.path.join(CACHE_DIR, f"{cache_key}.pkl")
    if not os.path.exists(cache_path):
        return None
    with FileLock(cache_path + ".lock"):
        with open(cache_path, "rb") as f:
            return cp.array(pickle.load(f))  


def task_to_hashable(task):
    """convert task arguments to a hashable tuple"""
    params, seeds, n_feat, outer_iter = task
    params = tuple(float(p) if isinstance(p, cp.ndarray) else p for p in params)
    seeds = (seeds['seed1'], seeds['seed3'])
    return (params, seeds, n_feat, outer_iter)

def save_checkpoint(results, processed_tasks, result_path, checkpoint_path):
    """save the results and remove duplicates"""
    columns = ['n_features', 'params', 'seeds', 'outer_iter', 'test_accuracy', 'test_error', 'best_svm_params']
    
    for res in results:
        for col in columns:
            res[col] = res.get(col, np.nan)
    
    new_df = pd.DataFrame(results, columns=columns)
    
    if os.path.exists(result_path):
        existing_df = pd.read_csv(result_path)
        
        # generate unique composite key
        key_cols = ['n_features', 'params', 'seeds', 'outer_iter']
        existing_df['_key'] = existing_df[key_cols].astype(str).agg('|'.join, axis=1)
        new_df['_key'] = new_df[key_cols].astype(str).agg('|'.join, axis=1)
        
        new_df = new_df[~new_df['_key'].isin(existing_df['_key'])]
        new_df = new_df.drop(columns=['_key'])
    
    write_mode = 'a' if os.path.exists(result_path) else 'w'
    write_header = not os.path.exists(result_path)
    new_df.to_csv(result_path, mode=write_mode, header=write_header, index=False)
    
    # save Checkpoint
    with open(checkpoint_path, 'w') as f:
        serializable_tasks = [
            [list(task[0]), list(task[1]), task[2], task[3]]  
            for task in processed_tasks
        ]
        json.dump(serializable_tasks, f)



def ELVES(X1, X2, N=2, K=200, k=500, k0=400, w1=0.5):
    """ iteratively calculate the differential vector: take the larger value at the corresponding position as the final score"""
    delta1 = Multiple_latent_variables(X1, X2, N=N, K=K, k=k, k0=k0)
    delta2 = Multiple_latent_variables(X2, X1, N=N, K=K, k=k, k0=k0)

    # ===== intra-category aggregation =====
    # calculate score1 and score2
    score1 = delta1[:, -1] ** 2
    score2 = delta2[:, -1] ** 2

    # ===== aggregation across categories =====
    # for score1 and score2, take the larger value in the corresponding position
    score = np.maximum(score1, score2)

    return score            # delta1, delta2, score1, score2, score


# ==================Result processing function ==================
def analyze_results(df):
    """
    Average performance of multiple outer layer iterations by number of features and parameter combinations
    """
    grouped = df.groupby(['n_features', 'params', 'seeds']).agg({    # 'seeds'
        'test_accuracy': ['mean', 'std'],
        'test_error': ['mean', 'std']
    }).reset_index()

    grouped.columns = ['_'.join(col).strip('_') for col in grouped.columns.values]

    qualified = grouped[
        (grouped['test_accuracy_mean'] > 95)
        ]

    return grouped, qualified


# ================== Nested Cross Validation ==================
def process_iteration(params, seeds, n_features, outer_iter):
    """GPU full-process computing, data does not leave the video memory"""
    cp._default_memory_pool.free_all_blocks()
    cp.cuda.Device(0).use()
    
    # Global data preloaded to GPU
    global X_gpu, y_gpu  
    
    train_idx, test_idx = next(
        StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=seeds['seed1'] + outer_iter)   # random_state=42+outer_iter
        .split(cp.zeros(len(y_gpu)), y_gpu.get())  
    )
    
    X_train_gpu = X_gpu[train_idx]
    y_train_gpu = y_gpu[train_idx]
    X_test_gpu = X_gpu[test_idx]
    y_test_gpu = y_gpu[test_idx]

            
    # ===== Test set evaluation =====
    test_acc = cp.nan
    test_error = cp.nan
    best_svm_params = None
    cv_splits = None
    try:
        test_cache_key = get_test_cache_key(
        params, 
        seeds['seed1'], 
        outer_iter
        )

        cached_scores_full = load_cache(test_cache_key)
        if cached_scores_full is not None:
            scores_full = cached_scores_full
            # print(f"Reusing the test cache: {test_cache_key}")
        else:
            X1_full = X_train_gpu[y_train_gpu == 0]
            X2_full = X_train_gpu[y_train_gpu == 1]
            scores_full = ELVES(X1_full, X2_full, *params)
            save_cache(scores_full, test_cache_key)
            del X1_full, X2_full
            cp.get_default_memory_pool().free_all_blocks()  
            gc.collect()
        
        selected_full = cp.argsort(scores_full)[-n_features:]  
        y_train_cpu = cp.asnumpy(y_train_gpu).astype(np.int32)

        cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=seeds['seed3'])    # random_state=42
        cv_splits = list(cv.split(X_train_gpu[:, selected_full].get(), y_train_cpu))  

        del scores_full, y_train_cpu
        cp.get_default_memory_pool().free_all_blocks()  
        cp.get_default_pinned_memory_pool().free_all_blocks()
        
        # SVM parameter tuning
        grid_search = cuGridSearchCV(
            cuSVC(kernel='rbf'),
            svm_param_grid,
            cv=cv_splits,  
            # verbose=1  
        )
        grid_search.fit(X_train_gpu[:, selected_full], y_train_gpu)
        test_acc = grid_search.score(X_test_gpu[:, selected_full], y_test_gpu)
        test_error = 1.0 - test_acc if not cp.isnan(test_acc) else cp.nan   
        # print(f"test_accuracy: {test_acc}")

        best_svm_params = {
            k: v.item() if isinstance(v, cp.ndarray) else v
            for k, v in grid_search.best_params_.items()
        }

        del grid_search
        cp.get_default_memory_pool().free_all_blocks()
        cp.get_default_pinned_memory_pool().free_all_blocks()
    
    except Exception as e:
        print(f"Testing phase error: {str(e)}")
        test_error = cp.nan
        import traceback
        traceback.print_exc()  

    try:
        del X_train_gpu, X_test_gpu, y_train_gpu, y_test_gpu, cv_splits, selected_full
    except NameError:
        pass
    cp.get_default_memory_pool().free_all_blocks()
    cp.get_default_pinned_memory_pool().free_all_blocks()
    cp._default_memory_pool.free_all_blocks()
    gc.collect()
    
    return {
        'n_features': n_features,
        'params': params,
        'seeds': str(seeds),
        'outer_iter': outer_iter,
        'test_accuracy': test_acc,
        'test_error': test_error,
        'best_svm_params': str(best_svm_params)  
    }


# ================== Performing parallel computations ==================
if __name__ == "__main__":
    mat_data = scipy.io.loadmat('Prostate_GE.mat')
    X_gpu = cp.asarray(mat_data['X'], dtype=cp.float32)  
    y_gpu = cp.asarray(mat_data['Y'].ravel(), dtype=cp.int32)
    y_gpu = (y_gpu < 2).astype(cp.int32)

    # experimental parameters
    param_configs = [
        {
            'n_features': [120],    
            'elves_params': {
                'N': [5], 'K': [500], 'k': [5800], 'k0': [5700], 'w1': [0.5]
            },
            'seeds': {
                'seed1': [56],   
                'seed3': [42]
            }
        },
    ]
    
    
    svm_param_grid = {
        'C': [2 ** i for i in [-5, -2, 1, 4, 7, 10, 13]],
        'gamma': [2 ** i for i in [-15, -12, -9, -6, -3, 0, 3]]
    }
    n_outer_iter = 30  # outer loop times (30 independent divisions)  
    n_inner_folds = 10  # inner cross validation folds  

    # generate all task combinations
    tasks = []
    for config in param_configs:
        n_feature_list = config['n_features']
        elves_param_grid = config['elves_params']
        seeds_grid = config['seeds']
        elves_combinations = list(product(*[
            [float(p) if isinstance(p, cp.ndarray) else p for p in values]  
            for values in elves_param_grid.values()
        ]))
        seeds_combinations = [
            {'seed1': s1, 'seed3': s3} 
            for s1, s3 in product(
                seeds_grid['seed1'],
                seeds_grid['seed3']
            )
        ]
        
        for n_feat in n_feature_list:
            for params in elves_combinations:
                for seeds in seeds_combinations:
                    for iter in range(n_outer_iter):
                        tasks.append((tuple(params), seeds, n_feat, iter))

    cp.cuda.Device(0).use()

    # ================== Breakpoint resume logic ==================
    result_path = '/Prostate/partial_results.csv'
    checkpoint_path = '/Prostate/processed_tasks.json'
    # Loading processed tasks
    processed_tasks = set()
    if os.path.exists(checkpoint_path):
        with open(checkpoint_path, 'r') as f:
            processed = json.load(f)
            processed_tasks = {
                (tuple(task[0]), tuple(task[1]), task[2], task[3])  
                for task in processed
            }
    # generate a list of remaining tasks
    remaining_tasks = [
        task for task in tasks 
        if task_to_hashable(task) not in processed_tasks
    ]
    print(f"Number of remaining tasks：{len(remaining_tasks)}/{len(tasks)}")
    
    # ================== Batch Processing Logic ==================
    batch_size = 30
    for i in range(0, len(remaining_tasks), batch_size):
        batch_tasks = remaining_tasks[i:i+batch_size]
        current_batch_results = []
        try:
            with Parallel(n_jobs=4, prefer="processes", verbose=10) as parallel:
                batch_results = parallel(
                    delayed(process_iteration)(params, seeds, n_feat, iter)
                    for params, seeds, n_feat, iter in tqdm(batch_tasks)
                )
                current_batch_results.extend(batch_results)
            batch_hashes = {task_to_hashable(task) for task in batch_tasks}
            processed_tasks.update(batch_hashes)  

            save_checkpoint(
                current_batch_results,
                processed_tasks,  
                result_path,
                checkpoint_path
            )
        except KeyboardInterrupt:
            print("Server or user interruption, saving processed tasks...")
            save_checkpoint(
                current_batch_results,
                processed_tasks,  
                result_path,
                checkpoint_path
            )
            exit()
        finally:
            cp.get_default_memory_pool().free_all_blocks()
            cp.get_default_pinned_memory_pool().free_all_blocks()
            gc.collect()


    # ================== Final result ==================
    final_df = pd.read_csv(result_path)
    grouped_df, qualified_df = analyze_results(final_df)
    final_df.to_csv('/Prostate/full_results.csv', index=False)
    grouped_df.to_csv('/Prostate/grouped_results.csv', index=False)

   

