import os
import pandas as pd
import numpy as np
import torch
import time
import logging
import sys
import re

from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.stats import kendalltau
import hydra
import wandb
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from hydra.core.hydra_config import HydraConfig
import rootutils
# Locate the project root (the directory containing .project-root)
root = rootutils.find_root(__file__, ".project-root")
# Set the project root and add it to sys.path for imports
rootutils.setup_root(root, pythonpath=True)

from src.utils.dataprep import dataprep
from src.metrics.correlations import feature_correlations
from src.metrics.feature_importances_embedding import cls_importances, structure_importances, knn_emb_accuracy
from src.metrics.structure_preservation_metrics import compute_mst_geodesics


def parse_value(value):
    """Try to convert a string to float, int, or keep as string if conversion fails."""
    if value == 'None':
        return None
    try:
        if '.' in value:
            return float(value)
        return int(value)
    except ValueError:
        return value  # keep as string


def load_data(data_path, data_name, processing=True, transform='normalize', global_transform=False):
    '''
    Load data from a path. Automatically handles cases where a single CSV file or paired _train and _test CSV files are provided.
    '''
    train_file = os.path.join(data_path, data_name + "_train.csv")
    test_file = os.path.join(data_path, data_name + "_test.csv")
    single_file = os.path.join(data_path, data_name + ".csv")

    if os.path.exists(train_file) and os.path.exists(test_file):
        # Case 1: Paired _train and _test CSVs
        train_data = pd.read_csv(train_file, sep=',')
        test_data = pd.read_csv(test_file, sep=',')
        # Concatenate train and test datasets
        data = pd.concat([train_data, test_data], axis=0).reset_index(drop=True)
        n_train = train_data.shape[0]
    elif os.path.exists(single_file):
        # Case 2: Single CSV file
        data = pd.read_csv(single_file, sep=',')
        n_train = None
    else:
        # Raise an error if neither case is satisfied
        raise FileNotFoundError(f"Neither '{data_name}.csv' nor '{data_name}_train.csv' and '{data_name}_test.csv' found in '{data_path}'.")

    if processing:
        if data_name in ['optdigits', 'landsat', 'sign_mnist_cropped', 'mnist_test', 'fashion_mnist_test','usps', 'bloodmnist', 'organcmnist', 'organsmnist','dermamnist', 'organmnist3d', 'fracturemnist3d']:
            logging.info(f"Applying global scaling for {data_name} dataset...")
            X, y = dataprep(data, label_col_idx=0, transform=transform, global_transform=True, cat_to_numeric=True)
        else:
            logging.info(f"Applying feature-wise scaling for {data_name} dataset...")
            X, y = dataprep(data, label_col_idx=0, transform=transform, global_transform=False, cat_to_numeric=True)

        # Convert X and y to NumPy arrays
        X = X.to_numpy() if isinstance(X, pd.DataFrame) else X
        y = y.to_numpy() if isinstance(y, pd.Series) else y
    else:
        X, y = data.iloc[:, 1:], data.iloc[:, 0]

    return X, y, n_train

    
@hydra.main(version_base=None, config_path="config", config_name="run_model_scores")
def main(cfg: DictConfig):

    # Check if the script is running in SLURM
    if "SLURM_LOCALID" in os.environ:   #using mila_cluster
        # SLURM assigns the GPU based on the local task ID
        gpu_id = int(os.environ["SLURM_LOCALID"])
        device = torch.device(f"cuda:{gpu_id}")
        logging.info(f"Running in SLURM with device: {device}")
    elif os.path.exists("/NOBACKUP/"):   # using Icewindale
        device = cfg.Icewindale
        logging.info(f"Running with manually assigned on device: {device}")
    else:   # Fallback to local execution
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logging.info(f"No SLURM or GPU-specific environment found. Running locally on device: {device}")

    # Load the data
    X, y, n_train = load_data(cfg.data.path, cfg.data.name, transform=cfg.data.transform)


    if cfg.data.name in ['zilionis', 'celegans_dropna', 'samusik']:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=1 - cfg.data.subsample, random_state=cfg.random_state)
        subsample_indices, _ = next(sss.split(X, y))
        X = X[subsample_indices]
        y = y[subsample_indices]
    n_samples = X.shape[0]
    n_features = X.shape[1]
    n_classes = len(np.unique(y))

    logging.info(f"Finish loading and preprocessing data {cfg.data.name}")

    # Table for the results
    results_all = pd.DataFrame(columns=["model", "data", "n_samples", "test_pct", "n_features", "n_classes", "random_state",
                                        "qnx_knn", "trust_knn", "cont_knn", "spear_knn", "stress_knn", "pearson_knn",
                                        "qnx_svm", "trust_svm", "cont_svm", "spear_svm", "stress_svm", "pearson_svm",
                                        "qnx_agg", "trust_agg", "cont_agg", "spear_agg", "stress_agg", "pearson_agg",
                                        "qnx_ens", "trust_ens", "cont_ens", "spear_ens", "stress_ens", "pearson_ens",
                                        "knn_acc", "knn_emb_acc",
                                        "svm_acc", "mlp_acc", "ensemble_acc",
                                        "qnx_unsup", "trust_unsup", "cont_unsup", "spear_unsup", "stress_unsup", "pearson_unsup",
                                        "training_time", "test_time"])

    # Use predefined split or create train/test split the data
    if n_train is None:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=cfg.data.test_size, random_state=cfg.random_state)
        train_index, test_index = next(sss.split(X, y))
    else:
        train_index = np.arange(n_train)
        test_index = np.arange(n_train, n_samples)
        
    x_train, y_train = X[train_index, :], y[train_index]
    x_test, y_test = X[test_index, :], y[test_index]
    test_pct = round(len(test_index) / n_samples, 2)

    # Reduce dimensions with PCA if needed
    if cfg.evaluation.max_features < X.shape[1]:
        logging.info(f"Reducing data to {cfg.evaluation.max_features} features using PCA...")
        pca = PCA(n_components=cfg.evaluation.max_features, random_state=cfg.random_state)
        X = pca.fit_transform(X)
        x_train, x_test = X[train_index, :], X[test_index, :]
        n_features = X.shape[1]

    # Compute feature correlations
    feature_corr_matrix = feature_correlations(X, device=device)

    # Compute cls feature importances using baseline classifiers
    if cfg.evaluation.evaluate:
        logging.info("Starting estimation of classification accuracy AND feature importances...")
        start_time = time.time()
        results = cls_importances(x_train, y_train, x_test, y_test, feature_corr_matrix,
                                        n_neighbors=cfg.evaluation.n_neighbors, n_repeats=cfg.evaluation.n_repeats,
                                        classifiers=cfg.evaluation.baseline_cls,
                                        device=device, random_state=cfg.random_state)
        agg_imp = np.empty(n_features)
        if 'knn' in cfg.evaluation.baseline_cls:
            knn_acc = results['knn'][0]
            knn_imp = results['knn'][1]
            agg_imp = agg_imp + knn_imp
        else:
            knn_acc = np.nan
            knn_imp = np.empty(n_features) * np.nan
        if 'mlp' in cfg.evaluation.baseline_cls:
            mlp_acc = results['mlp'][0]
            mlp_imp = results['mlp'][1]
            agg_imp = agg_imp + mlp_imp
        else:
            mlp_acc = np.nan
            mlp_imp = np.empty(n_features) * np.nan
        if 'svm' in cfg.evaluation.baseline_cls:
            svm_acc = results['svm'][0]
            svm_imp = results['svm'][1]
            agg_imp = agg_imp + svm_imp
        else:
            svm_acc = np.nan
            svm_imp = np.empty(n_features) * np.nan
        if 'ensemble' in results.keys():
            ensemble_acc = results['ensemble'][0]
            ensemble_imp = results['ensemble'][1]
        else:
            ensemble_acc = np.nan
            ensemble_imp = np.empty(n_features) * np.nan
        agg_imp = agg_imp / len(cfg.evaluation.baseline_cls)  # Average the importances across classifiers
        
        logging.info(f"Classification accuracies AND feature importances computed in {time.time() - start_time:.2f} seconds")

        # Save the indices and embeddings and feature importance (classwise, local, global)
        if cfg.save_results:
            save_dir = os.getcwd()

            # Create subfolder paths
            subfolder_path = os.path.join(save_dir, str(cfg.random_state))
            os.makedirs(subfolder_path, exist_ok=True)
            subsubfolder_path = os.path.join(subfolder_path, "feature_importances")
            os.makedirs(subsubfolder_path, exist_ok=True)

            # Save classification importances (with row index = K value, and column index = feature index)
            knn_imp_df = pd.DataFrame(knn_imp)
            knn_imp_df.to_csv(subsubfolder_path + f"/0_knn_cls_importances_{cfg.random_state}.csv", index=False)
            svm_imp_df = pd.DataFrame(svm_imp)
            svm_imp_df.to_csv(subsubfolder_path + f"/0_svm_cls_importances_{cfg.random_state}.csv", index=False)
            mlp_imp_df = pd.DataFrame(mlp_imp)
            mlp_imp_df.to_csv(subsubfolder_path + f"/0_mlp_cls_importances_{cfg.random_state}.csv", index=False)
            agg_imp_df = pd.DataFrame(agg_imp)
            agg_imp_df.to_csv(subsubfolder_path + f"/0_agg_cls_importances_{cfg.random_state}.csv", index=False)
            ensemble_imp_df = pd.DataFrame(ensemble_imp)
            ensemble_imp_df.to_csv(subsubfolder_path + f"/0_ensemble_cls_importances_{cfg.random_state}.csv", index=False)


    for model in cfg.models:
        # Extract model name and parameters
        match = re.match(r"(\w+)\s*\((.*)\)", model)
        if match:
            model = match.group(1)
            params = match.group(2)
            param_dict = {}
            for param in params.split(','):
                key, value = param.split('=')
                param_dict[key.strip()] = parse_value(value.strip())
        else:
            param_dict = None

        model_config_path = os.path.join(root, f"runner/config/model/{model}.yaml")

        if not os.path.exists(model_config_path):
            logging.error(f"Model configuration file not found: {model_config_path}")
            continue

        # Load model configuration
        cfg.model = OmegaConf.load(model_config_path)

        # Update cfg.model with parsed parameters
        if param_dict is not None:
            for param, value in param_dict.items():
                setattr(cfg.model, param, value)
            params_str = ', '.join([f"{k}={v}" for k, v in param_dict.items()])
            cfg.model.name = f"{cfg.model.name} ({params_str})"

        # Add device to the model configuration if the model supports it
        if "device" in cfg.model:
            cfg.model.device = str(device)
        
        if cfg.model.save_type == "checkpoint":  
            
            config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
            out_dir = os.getcwd()
            config["out_dir"] = out_dir
            wandb.init(
                project=cfg.logger.wandb.project,
                entity=cfg.logger.wandb.entity,
                tags=cfg.logger.wandb.tags,
                reinit=True,
                config=config,
                settings=wandb.Settings(start_method="thread"),
            )
            logging.info("WandB initialized for tracking.")

        # Instantiate the model directly from the configuration
        model_config = {key: value for key, value in cfg.model.items() if key not in ["name", "save_type"]}
        if "random_state" in cfg.model:
            model_config['random_state'] = cfg.random_state
        elif "seed" in cfg.model:  # Case of Parametric t-SNE/UMAP ('random_state' is called 'seed')
            model_config['seed'] = cfg.random_state
        model = instantiate(model_config)

        # Train
        logging.info(f"Training {cfg.model.name} model...")
        start_time = time.time()
        emb_train = model.fit_transform(x_train, y_train)
        training_time = time.time() - start_time
        logging.info(f"Training time: {training_time:.2f} seconds")

        # Test
        logging.info(f"Test embedding using {cfg.model.name} model...")
        start_time = time.time()
        emb_test = model.transform(x_test)
        test_time = time.time() - start_time
        logging.info(f"Inference time: {test_time:.2f} seconds")

        if cfg.evaluation.evaluate:
            # Evaluate embedding classification performance using emb train and emb test
            logging.info(f"Evaluating KNN embedding classification accuracies using {cfg.model.name} test embedding...")
            start_time = time.time()
            knn_emb_acc = knn_emb_accuracy(emb_train, y_train, emb_test, y_test, device=device)
            logging.info(f"Embedding classification accuracies computed in {time.time() - start_time:.2f} seconds")


            # Evaluate structure importances
            logging.info(f"Evaluating structural importances for {cfg.model.name} model...")
            start = time.time()
            results = structure_importances(x_train, emb_train, x_test, emb_test, feature_corr_matrix,
                                            n_repeats=cfg.evaluation.n_repeats, device=device, random_state=cfg.random_state)
            qnx_unsup, qnx_imp = results['qnx'][0], results['qnx'][1]
            trust_unsup, trust_imp = results['trust'][0], results['trust'][1]
            cont_unsup, cont_imp = results['cont'][0], results['cont'][1]
            spear_unsup, spear_imp = results['spear'][0], results['spear'][1]
            stress_unsup, stress_imp = results['stress'][0], results['stress'][1]
            pearson_unsup, pearson_imp = results['pearson'][0], results['pearson'][1]
            logging.info(f"Finish local/global structure importances in {time.time() - start:.2f} seconds")


            # Evaluate importance alignment scores
            logging.info(f"Evaluating importance alignment scores for {cfg.model.name} model...")
            start = time.time()
            qnx_knn = kendalltau(knn_imp, qnx_imp).statistic
            qnx_svm = kendalltau(svm_imp, qnx_imp).statistic
            qnx_agg = kendalltau(agg_imp, qnx_imp).statistic
            qnx_ens = kendalltau(ensemble_imp, qnx_imp).statistic

            trust_knn = kendalltau(knn_imp, trust_imp).statistic
            trust_svm = kendalltau(svm_imp, trust_imp).statistic
            trust_agg = kendalltau(agg_imp, trust_imp).statistic
            trust_ens = kendalltau(ensemble_imp, trust_imp).statistic

            cont_knn = kendalltau(knn_imp, cont_imp).statistic
            cont_svm = kendalltau(svm_imp, cont_imp).statistic
            cont_agg = kendalltau(agg_imp, cont_imp).statistic
            cont_ens = kendalltau(ensemble_imp, cont_imp).statistic

            spear_knn = kendalltau(knn_imp, spear_imp).statistic
            spear_svm = kendalltau(svm_imp, spear_imp).statistic
            spear_agg = kendalltau(agg_imp, spear_imp).statistic
            spear_ens = kendalltau(ensemble_imp, spear_imp).statistic

            stress_knn = kendalltau(knn_imp, stress_imp).statistic
            stress_svm = kendalltau(svm_imp, stress_imp).statistic
            stress_agg = kendalltau(agg_imp, stress_imp).statistic
            stress_ens = kendalltau(ensemble_imp, stress_imp).statistic

            pearson_knn = kendalltau(knn_imp, pearson_imp).statistic
            pearson_svm = kendalltau(svm_imp, pearson_imp).statistic
            pearson_agg = kendalltau(agg_imp, pearson_imp).statistic
            pearson_ens = kendalltau(ensemble_imp, pearson_imp).statistic

            logging.info(f"Finish evaluating importance alignment scores in {time.time() - start:.2f} seconds")


        # Save the indices and embeddings
        if cfg.save_results:
            save_dir = os.getcwd()
            # Create subfolder path
            subfolder_path = os.path.join(save_dir, str(cfg.random_state))
            # Create the subfolder if it does not exist
            os.makedirs(subfolder_path, exist_ok=True)

            results_train = pd.DataFrame(emb_train, columns=[f"emb_{i}" for i in range(emb_train.shape[1])])
            results_train["train_index"] = train_index
            results_train.to_csv(subfolder_path + f"/{cfg.model.name}_emb_train_{cfg.random_state}.csv", index=False)

            results_test = pd.DataFrame(emb_test, columns=[f"emb_{i}" for i in range(emb_test.shape[1])])
            results_test["test_index"] = test_index  # Add test indices as a new column
            results_test.to_csv(subfolder_path + f"/{cfg.model.name}_emb_test_{cfg.random_state}.csv", index=False)
            logging.info(f"Save {cfg.model.name} embeddings for random state {cfg.random_state}")

            # Save structural importances
            if cfg.evaluation.evaluate:
                qnx_imp_df = pd.DataFrame(qnx_imp)
                qnx_imp_df.to_csv(subsubfolder_path + f"/{cfg.model.name}_qnx_imp_{cfg.random_state}.csv", index=False)
                trust_imp_df = pd.DataFrame(trust_imp)
                trust_imp_df.to_csv(subsubfolder_path + f"/{cfg.model.name}_trust_imp_{cfg.random_state}.csv", index=False)
                cont_imp_df = pd.DataFrame(cont_imp)
                cont_imp_df.to_csv(subsubfolder_path + f"/{cfg.model.name}_cont_imp_{cfg.random_state}.csv", index=False)
                spear_imp_df = pd.DataFrame(spear_imp)
                spear_imp_df.to_csv(subsubfolder_path + f"/{cfg.model.name}_spear_imp_{cfg.random_state}.csv", index=False)
                stress_imp_df = pd.DataFrame(stress_imp)
                stress_imp_df.to_csv(subsubfolder_path + f"/{cfg.model.name}_stress_imp_{cfg.random_state}.csv", index=False)
                pearson_imp_df = pd.DataFrame(pearson_imp)
                pearson_imp_df.to_csv(subsubfolder_path + f"/{cfg.model.name}_pearson_imp_{cfg.random_state}.csv", index=False)
                logging.info(f"Save {cfg.model.name} local and global structure importances for random state {cfg.random_state}")



        if cfg.evaluation.evaluate:
            results = pd.DataFrame([{
            "model": cfg.model.name,
            "data": cfg.data.name,
            "n_samples": n_samples, 
            "test_pct": test_pct,
            "n_features": n_features,
            "n_classes": n_classes,
            "random_state": cfg.random_state,

            "qnx_knn": qnx_knn,
            "trust_knn": trust_knn,
            "cont_knn": cont_knn,
            "spear_knn": spear_knn,
            "stress_knn": stress_knn,
            "pearson_knn": pearson_knn,

            "qnx_svm": qnx_svm,
            "trust_svm": trust_svm,
            "cont_svm": cont_svm,
            "spear_svm": spear_svm,
            "stress_svm": stress_svm,
            "pearson_svm": pearson_svm,

            "qnx_agg": qnx_agg,
            "trust_agg": trust_agg,
            "cont_agg": cont_agg,
            "spear_agg": spear_agg,
            "stress_agg": stress_agg,
            "pearson_agg": pearson_agg,

            "qnx_ens": qnx_ens,
            "trust_ens": trust_ens,
            "cont_ens": cont_ens,
            "spear_ens": spear_ens,
            "stress_ens": stress_ens,
            "pearson_ens": pearson_ens,
            
            "knn_acc": knn_acc,
            "knn_emb_acc": knn_emb_acc,
            "svm_acc": svm_acc,
            "mlp_acc": mlp_acc,
            "ensemble_acc": ensemble_acc,
            "qnx_unsup": qnx_unsup,
            "trust_unsup": trust_unsup,
            "cont_unsup": cont_unsup,
            "spear_unsup": spear_unsup,
            "stress_unsup": stress_unsup,
            "pearson_unsup": pearson_unsup,
            
            "training_time": training_time,
            "test_time": test_time
            }])
        else:
            results = pd.DataFrame([{
            "model": cfg.model.name,
            "data": cfg.data.name,
            "n_samples": n_samples, 
            "test_pct": test_pct,
            "n_features": n_features,
            "n_classes": n_classes,
            "random_state": cfg.random_state,

            "qnx_knn": np.nan,
            "trust_knn": np.nan,
            "cont_knn": np.nan,
            "spear_knn": np.nan,
            "stress_knn": np.nan,
            "pearson_knn": np.nan,

            "qnx_svm": np.nan,
            "trust_svm": np.nan,
            "cont_svm": np.nan,
            "spear_svm": np.nan,
            "stress_svm": np.nan,
            "pearson_svm": np.nan,

            "qnx_agg": np.nan,
            "trust_agg": np.nan,
            "cont_agg": np.nan,
            "spear_agg": np.nan,
            "stress_agg": np.nan,
            "pearson_agg": np.nan,

            "qnx_ens": np.nan,
            "trust_ens": np.nan,
            "cont_ens": np.nan,
            "spear_ens": np.nan,
            "stress_ens": np.nan,
            "pearson_ens": np.nan,
            
            "knn_acc": np.nan,
            "knn_emb_acc": np.nan,
            "svm_acc": np.nan,
            "mlp_acc": np.nan,
            "ensemble_acc": np.nan,
            "qnx_unsup": np.nan,
            "trust_unsup": np.nan,
            "cont_unsup": np.nan,
            "spear_unsup": np.nan,
            "stress_unsup": np.nan,
            "pearson_unsup": np.nan,
            
            "training_time": training_time,
            "test_time": test_time
            }])

        results_all = pd.concat([results_all, results]) 

        logging.info(f"Finish random state {cfg.random_state}")
        
        # Save the scores        
        save_dir = os.getcwd()
        results_all.to_csv(os.path.join(save_dir, "scores.csv"), index=False)



if __name__ == "__main__":
    main()