import os
import pandas as pd
import numpy as np
import torch
import logging
import re
import wandb

from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.neighbors import KNeighborsClassifier
import hydra
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from hydra.core.hydra_config import HydraConfig
import rootutils
from src.metrics.feature_importances_embedding import knn_emb_accuracy

# Setup root path
root = rootutils.find_root(__file__, ".project-root")
rootutils.setup_root(root, pythonpath=True)

# Local imports
from src.utils.dataprep import dataprep

def parse_value(value):
    if value == 'None':
        return None
    try:
        return float(value) if '.' in value else int(value)
    except ValueError:
        return value

def load_data(data_path, data_name, processing=True, transform='normalize'):
    train_file = os.path.join(data_path, data_name + "_train.csv")
    test_file = os.path.join(data_path, data_name + "_test.csv")
    single_file = os.path.join(data_path, data_name + ".csv")

    if os.path.exists(train_file) and os.path.exists(test_file):
        train_data = pd.read_csv(train_file, sep=',')
        test_data = pd.read_csv(test_file, sep=',')
        data = pd.concat([train_data, test_data], axis=0).reset_index(drop=True)
        n_train = train_data.shape[0]
    elif os.path.exists(single_file):
        data = pd.read_csv(single_file, sep=',')
        n_train = None
    else:
        raise FileNotFoundError(f"Could not find dataset files.")

    if processing:
        if data_name in ['optdigits', 'landsat', 'sign_mnist_cropped', 'mnist_test', 'fashion_mnist_test', 'usps',
                         'bloodmnist', 'organcmnist', 'organsmnist', 'dermamnist', 'organmnist3d', 'fracturemnist3d']:
            X, y = dataprep(data, label_col_idx=0, transform=transform, global_transform=True, cat_to_numeric=True)
        else:
            X, y = dataprep(data, label_col_idx=0, transform=transform, global_transform=False, cat_to_numeric=True)

        X = X.to_numpy() if isinstance(X, pd.DataFrame) else X
        y = y.to_numpy() if isinstance(y, pd.Series) else y
    else:
        X, y = data.iloc[:, 1:], data.iloc[:, 0]

    return X, y, n_train

@hydra.main(version_base=None, config_path="config", config_name="run_noisy_tree")
def main(cfg: DictConfig):
    if "SLURM_LOCALID" in os.environ:
        device = torch.device(f"cuda:{int(os.environ['SLURM_LOCALID'])}")
    elif os.path.exists("/NOBACKUP/"):
        device = cfg.Icewindale
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load original data
    X, y, n_train = load_data(cfg.data.path, cfg.data.name, transform=cfg.data.transform)
    X = X[:, :40]  # Remove noise features
    n_samples, n_features_orig = X.shape

    if n_train is None:
        sss = StratifiedShuffleSplit(n_splits=1, test_size=cfg.data.test_size, random_state=cfg.random_state)
        train_idx, test_idx = next(sss.split(X, y))
    else:
        train_idx = np.arange(n_train)
        test_idx = np.arange(n_train, n_samples)

    y_train = y[train_idx]
    y_test = y[test_idx]

    # Define desired SNR levels
    desired_snrs = [np.inf, 10, 1, 0.1, 0.01, 0.001]
    informative_features = n_features_orig

    results_all = pd.DataFrame()

    for snr in desired_snrs:
        if snr == np.inf:
            X_new = X           # no noise
        else:
            # --- draw uniform noise on [0, 1] --------------------
            target_noise_features = int(round(informative_features / snr))

            np.random.seed(cfg.random_state)
            new_noise = np.random.uniform(
                low=0.0,
                high=1.0,
                size=(n_samples, target_noise_features)
            )
            # ----------------------------------------------------------------

            # Concatenate informative + noise
            X_new = np.hstack([X, new_noise])

        for model in cfg.models:
            match = re.match(r"(\w+)\s*\((.*)\)", model)
            if match:
                model_name = match.group(1)
                params = match.group(2)
                param_dict = {}
                for param in params.split(','):
                    key, value = param.split('=')
                    param_dict[key.strip()] = parse_value(value.strip())
            else:
                model_name = model
                param_dict = None

            model_config_path = os.path.join(root, f"runner/config/model/{model_name}.yaml")
            if not os.path.exists(model_config_path):
                logging.error(f"Model configuration file not found: {model_config_path}")
                continue

            cfg.model = OmegaConf.load(model_config_path)

            if param_dict is not None:
                for param, value in param_dict.items():
                    setattr(cfg.model, param, value)
                params_str = ', '.join([f"{k}={v}" for k, v in param_dict.items()])
                cfg.model.name = f"{cfg.model.name} ({params_str})"

            if "device" in cfg.model:
                cfg.model.device = str(device)

            model_config = {key: value for key, value in cfg.model.items() if key not in ["name", "save_type"]}
            if "random_state" in cfg.model:
                model_config['random_state'] = cfg.random_state
            elif "seed" in cfg.model:
                model_config['seed'] = cfg.random_state

            model_inst = instantiate(model_config)

            # WandB init
            if cfg.model.save_type == "checkpoint":
                config = OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
                out_dir = os.getcwd()
                config["out_dir"] = out_dir
                wandb.init(
                    project=cfg.logger.wandb.project,
                    entity=cfg.logger.wandb.entity,
                    tags=cfg.logger.wandb.tags,
                    reinit=True,
                    config=config,
                    settings=wandb.Settings(start_method="thread"),
                )
                logging.info("WandB initialized for tracking.")

            x_train = X_new[train_idx]
            x_test = X_new[test_idx]

            emb_train = model_inst.fit_transform(x_train, y_train)
            emb_test = model_inst.transform(x_test)

            # Compute kNN embedding accuracy
            knn_emb_acc = knn_emb_accuracy(emb_train, y_train, emb_test, y_test, device=device)

            # Save embeddings and training losses
            if cfg.save_results:
                subfolder_path = os.path.join(os.getcwd(), str(cfg.random_state))
                os.makedirs(subfolder_path, exist_ok=True)

                results_train = pd.DataFrame(emb_train, columns=[f"emb_{i}" for i in range(emb_train.shape[1])])
                results_train["train_index"] = train_idx
                results_train.to_csv(subfolder_path + f"/{cfg.model.name}_emb_train_SNR{snr}.csv", index=False)

                results_test = pd.DataFrame(emb_test, columns=[f"emb_{i}" for i in range(emb_test.shape[1])])
                results_test["test_index"] = test_idx
                results_test.to_csv(subfolder_path + f"/{cfg.model.name}_emb_test_SNR{snr}.csv", index=False)

                if hasattr(model_inst, 'epoch_losses_emb'):
                    training_losses = pd.DataFrame(model_inst.epoch_losses_emb, columns=["loss"])
                    training_losses["epoch"] = np.arange(len(model_inst.epoch_losses_emb))
                    training_losses.to_csv(subfolder_path + f"/{cfg.model.name}_train_loss_SNR{snr}.csv", index=False)

                logging.info(f"Saved {cfg.model.name} embeddings (SNR={snr}) for random state {cfg.random_state}")

            # Save kNN embedding accuracy
            results_all = pd.concat([results_all, pd.DataFrame([{
                "model": cfg.model.name, "data": cfg.data.name, "SNR": snr,
                "n_samples": n_samples, "n_features": x_train.shape[1],
                "knn_emb_acc": knn_emb_acc, "random_state": cfg.random_state,
            }])], ignore_index=True)

        # Save all results
        results_all.to_csv(os.path.join(os.getcwd(), "scores.csv"), index=False)

if __name__ == "__main__":
    main()