import torch
import os
import pandas as pd
import numpy as np
import yaml
import mne
import random

from sklearn.metrics import mean_absolute_error, r2_score, balanced_accuracy_score, roc_auc_score
from sklearn.metrics import precision_recall_curve, auc, hamming_loss
from sklearn.model_selection import train_test_split, StratifiedKFold
from models.models import *
from datasets.datasets import *
from copy import deepcopy

def score(ys_true, ys_pred, test_ids, n_classes, convert_to_subject=True, logits=False):
    """ys_true and ys_pred should be 2-dimensional.
    Unilabel: (n_subjects, 1).
    Multilabel: (n_subjects, n_classes)"""

    if ys_true.ndim==1:
        ys_true = ys_true.reshape(-1, 1)
    if ys_pred.ndim==1:
        ys_pred = ys_pred.reshape(-1, 1)

    metrics = []
    metrics_temp = {}
    metrics_temp["1"] = []
    
    n_subjects = len(np.unique(test_ids))
    dims = 1 if n_classes < 3 else n_classes
    sub_ys_true = np.empty((n_subjects, dims))
    sub_ys_pred = np.empty((n_subjects, dims))

    for label in range(dims):
        if convert_to_subject:
            if logits:
                    ys_pred = torch.sigmoid(torch.tensor(ys_pred, dtype=torch.float)).numpy()
            
            # average per subject (either regression target or, in case of classification, probabilities)
            df = pd.DataFrame({"y_true": ys_true[:,label], "y_pred": ys_pred[:,label], "subject_id": test_ids})
            df_grouped = df.groupby("subject_id")
            df_mean = df_grouped.mean()
            sub_ys_true[:,label] = df_mean["y_true"].values
            sub_ys_pred[:,label] = df_mean["y_pred"].values
        else:
            # sub_ys_true[:,label] = ys_true[:,label]
            # sub_ys_pred[:,label] = ys_pred[:,label]
            sub_ys_true = ys_true
            sub_ys_pred = ys_pred

        if n_classes == 1:
            if np.isnan(np.array(sub_ys_pred)).any():
                print("Scoring: NANs")
                metrics.append(0.)
                metrics.append(0.)
            else:
                metrics.append(mean_absolute_error(sub_ys_true[:, label], sub_ys_pred[:, label]))
                metrics.append(r2_score(sub_ys_true[:, label], sub_ys_pred[:, label]))
        elif n_classes == 2:
             metrics.append(balanced_accuracy_score(sub_ys_true[:, label], (sub_ys_pred[:, label] > 0.5).astype(float)))
             metrics.append(roc_auc_score(sub_ys_true[:, label], sub_ys_pred[:, label]))
        else:
            prec, rec, _ = precision_recall_curve(sub_ys_true[:, label].astype(int), sub_ys_pred[:, label])
            auc_precision_recall = auc(rec, prec)
            metrics_temp["1"].append(auc_precision_recall)
            if label == (dims-1):
                metrics.append(np.mean(metrics_temp["1"]))
                metrics.append(hamming_loss(sub_ys_true, (sub_ys_pred > 0.5).astype(float)))
        
    return sub_ys_true, sub_ys_pred, metrics
    

def best_hp(path: str, ncv_i: int, fold: int, n_train: int) -> dict:
    """Returns the best hyperparameters for a given fold."""

    # file path to score file
    path = os.path.dirname(path.rstrip("/"))
    file_name = f"{path}/hp_ncv-{ncv_i}_fold-{fold}_ntrain-{n_train}.csv"

    df = pd.read_csv(file_name)

    # get the best hyperparameters and turn into dict
    min_idx = df["val_loss"].idxmin()
    df = df.drop(columns=["val_loss", "val_metric"])
    best_dict = df.loc[min_idx].to_dict()

    return best_dict

def set_hp(cfg: dict, hp_key: dict, ncv_i: int, fold: int, n_train: int) -> dict:
    """Adds the hyperparameters to the config file."""

    for k in hp_key:

        if k in cfg["model"]:
            cfg["model"][k] = hp_key[k]

        elif k in cfg["training"]:
            cfg["training"][k] = hp_key[k]
            
        elif k in cfg["model"]["ELM"]:
            cfg["model"]["ELM"][k] = f"reports/{hp_key[k]}.json"

        else:
            raise ValueError("Hyper-grid contains unknown parameters.")

    cfg["training"]["ncv"] = ncv_i
    cfg["training"]["fold"] = fold
    cfg["training"]["n_train"] = n_train
    cfg["training"]["hp_key"] = hp_key

    return cfg

def update_score_file(val_loss: int, val_metric: int, hp_key: dict, ncv_i: int, fold: int, n_train: int, path: str) -> None:
    """Updates the score file with the new tested hyperparameters and associated validation loss."""

    path = os.path.dirname(path.rstrip("/"))
    file_name = f"{path}/hp_ncv-{ncv_i}_fold-{fold}_ntrain-{n_train}.csv"

    if os.path.exists(file_name):
        df = pd.read_csv(file_name)
    else:
        df = pd.DataFrame()
        
    fold_df = pd.DataFrame({**hp_key, "val_loss": val_loss, "val_metric": val_metric}, index=[0])

    df = pd.concat([df, fold_df], ignore_index=True)

    df.to_csv(file_name, index=False)

    return 



def dict_from_yaml(file_path: str) -> dict:
    
    with open(file_path, 'r') as file:
        yaml_data = yaml.safe_load(file)
        
    if "convert_to_TF" not in yaml_data["model"]:
        yaml_data["model"]["convert_to_TF"] = False

    return yaml_data

def split_indices_and_prep_dataset(
        cfg, subjects, dataset, test_dataset, n_train, n_val, n_test, setting, world_size, n_folds, fold, ncv_i):
    
    train_ss = cfg["dataset"]["train_subsample"]
    val_ss = cfg["dataset"]["val_subsample"]
    test_ss = cfg["dataset"]["test_subsample"]
    target = cfg["training"]["target"]
    salt = cfg["training"]["random_seed"] + 4999*ncv_i
    
    to_stratify = get_stratification_vector(dataset, target, n_train, subset=subjects)
    
    if setting in ["SSL_PRE", "GEN_EMB"]: # Do not subsample and use complete training set.
        train_ind, val_ind, test_ind = subjects, np.array([1]), np.array([1])
        
    elif val_ss: # validation set is provided manually: Skip folding data.
        train_ind = subjects
        ind_path = os.path.join(cfg['dataset']['path'], 'indices', f"{val_ss}_indices.npy") 
        val_ind = np.load(ind_path)
        
    elif n_folds==1: # Skip Cross-Validation
        train_ind, val_ind = train_test_split(subjects, test_size=n_val, stratify=to_stratify, 
                                              random_state=9*n_train + salt)
        
    else: # Do Stratified-K-Fold, with state=n_train to recreate splits.
        skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=salt)
        for i, (train_index, val_index) in enumerate(skf.split(np.arange(len(to_stratify)), to_stratify)):
            if i == fold: # Grab [fold]
                train_ind, val_ind = train_index, val_index

        if n_train < len(train_ind): # If necessary, subsample training set.
            to_stratify_train = to_stratify[train_ind] 
             # avoid n=1 issues by replacing uniquely occurring strings 
            unique_str, counts = np.unique(to_stratify_train, return_counts=True)
            more_than_once = unique_str[counts > 1]
            only_once = unique_str[counts == 1]
            replacement_dict = {key: np.random.choice(more_than_once) for key in only_once}
            to_stratify_train = np.array([replacement_dict.get(i, i) for i in to_stratify_train])
            try: 
                train_ind, _ = train_test_split(train_ind, train_size=n_train, stratify=to_stratify_train,
                                                random_state=99*fold + 9*n_train + salt)
            except:
                train_ind = np.random.choice(train_ind, replace=False, size=n_train)
            
        # If requested, subsample for test-set (rather than using a predefined hold-out test-set).
        if test_ss or test_dataset:
            test_ind = np.array([1]) 
        else:
            assert (n_val+n_test) <= len(val_ind), "Reduce n_val+n_test; too large for number of folds!"
            val_ind, test_ind = train_test_split(val_ind, train_size=n_val, test_size=n_test, stratify=to_stratify[val_ind], 
                                                 random_state=99*fold + 9*n_train + salt)
        
        # From indices (np.arange) to subject IDs
        train_ind = subjects[train_ind]
        val_ind = subjects[val_ind]
        test_ind = subjects[test_ind]

    if test_ss or test_dataset: # In case we have a seperate test dataset or test subsample.
        ind_path = os.path.join(cfg['dataset']['path'], 'indices', f"{test_ss}_indices.npy")        
        test_ind = np.sort(np.load(ind_path))

    dataset.set_epoch_indices(train_ind, val_ind, test_ind)
    sub_ids = dataset.get_subject_ids(world_size)

    if test_dataset: 
        test_dataset.test_ind = test_ind
        test_dataset.set_epoch_indices(np.arange(1), np.arange(1), test_ind)
        test_sub_ids = test_dataset.get_subject_ids(world_size)
        sub_ids["test"] = test_sub_ids["test"]
        
    print(len(train_ind), len(val_ind), len(test_ind))

    return train_ind, val_ind, test_ind, dataset, test_dataset, sub_ids

def get_stratification_vector(dataset, target: list, n_train: int, subset: np.ndarray=np.array([])):
    
    if len(subset)>0:
        matches = np.isin(dataset.subject_ids, subset)
        df = pd.DataFrame({"PAT": dataset.pathology[matches].astype(int).squeeze(), "subject_id": dataset.subject_ids[matches]})
    else:
        df = pd.DataFrame({"PAT": dataset.pathology.astype(int).squeeze(), "subject_id": dataset.subject_ids})
    df_grouped = df.groupby("subject_id")
    df = df_grouped.mean()
    return np.array(df.PAT.values.astype(str))


def save_cv_results(setting, cfg, ys_true, ys_pred, test_metric, test_loss, hp, n_train, fold, ncv_i):

    results = {
        "MAE/BACC": test_metric[0],
        "R2/AUC": test_metric[1],
        "fold": fold,
        "ncv_i": ncv_i,
        "best_hp": str(hp),
        "test_loss": test_loss
    }
    for i in range(ys_true.shape[1]):
        results["ys_true_sub_l" + str(i)] = ys_true[:,i]
        results["ys_pred_sub_l" + str(i)] = ys_pred[:,i]

    rp = cfg['training']['results_save_path'] + "/" + setting
    if not os.path.exists(rp):
        os.makedirs(rp)

    df = pd.DataFrame(results)
    df.to_csv(f"{rp}/{cfg['model']['model_name']}_ncv_{ncv_i}_fold_{fold}_ntrain_{n_train}.csv")


def load_DDP_state_dict(model, path, device, DDP=False):

    state_dict = torch.load(path, device)
    model.load_state_dict(state_dict)

    return model

def load_data(cfg, setting):

    def find_correct_dataset(cfg, setting):
        print(cfg["training"]["loss_function"])
        if setting in ["SSL_PRE", "GEN_EMB"]: # SSL channel-wise Pretraining: [n_epochs*n_channels, n_EEG_samples]
            if "MIL" in cfg["training"]["loss_function"]:
                dataset = H5_MIL(cfg, setting)
            elif "ELM" in cfg["training"]["loss_function"]:
                dataset = H5_ELM(cfg, setting)
        elif setting in ["SSL_FT", "SV"]: # Finetune or Supervise: [n_epochs, n_channels, n_EEG_samples]
            dataset = TUAB_H5(cfg, setting)
        elif setting in ["SSL_NL"]: # Nonlinear eval: [n_epochs, n_channels, n_embedding_samples]
            dataset = TUAB_H5_features(cfg, setting)
        elif setting in ["SSL_LIN"]:
            dataset = []
        return dataset
    
    dataset = find_correct_dataset(cfg, setting)

    # Check whether a seperate training dataset is used. 
    if cfg["dataset"]["test_name"]:
        cfg_test = deepcopy(cfg)
        cfg_test["dataset"]["name"] = cfg_test["dataset"]["test_name"]
        test_dataset = find_correct_dataset(cfg_test, setting)
    else:
        test_dataset = None

    return dataset, test_dataset

def filter_data(x: torch.Tensor, freq: int=200, bands: list=[(1,7), (8, 30), (31, 49)], axis: int=1,
                norm_stats: list=[0.2, 0.085, 0.045]) -> torch.Tensor:
    # Unoptimized band-pass filtering. Don't use this for more than an experiment.
    x_np = x.numpy().astype(np.float64)
    f = [mne.filter.filter_data(x_np, freq, band[0], band[1], verbose="critical", n_jobs=1)/norm_stats[i] for i, band in enumerate(bands)]
    return torch.from_numpy(np.stack(f, axis=axis)).to(dtype=x.dtype)

def set_seeds(random_seed):
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    random.seed(random_seed)
    
