# System imports
from pathlib import Path
from typing import Tuple

# Data science imports
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold
from sklearn.utils import shuffle
from catboost import CatBoostClassifier

from src.utils import set_seed, save_results_to_pickle
from src import logger
 
from src.dataset_inference.preprocessing import create_df, create_grouped_split, clean_outliers, create_split, scale_data
from src.dataset_inference.metrics import get_p_values, rank_candidates
from src.dataset_inference.attacks import get_mia

# Other imports
import glob

import traceback
import numpy as np

from pathlib import Path
import os
import glob
import hydra
from omegaconf import DictConfig, OmegaConf
from dotenv import load_dotenv
import numpy as np
from src import logger
from copy import deepcopy
from src import get_roc_auc
import shap
import warnings
import torch
from joblib import Parallel, delayed

# Initialize
load_dotenv()
set_seed(11)

warnings.filterwarnings("ignore", category=FutureWarning, module="sklearn")


def get_feature_importance(model):
    if hasattr(model, "feature_importances_"):
        return model.feature_importances_
    elif hasattr(model, "coef_"):
        return np.abs(model.coef_[0] if model.coef_.ndim > 1 else model.coef_)
    elif hasattr(model, 'theta_'):
        return np.abs(model.theta_[0] - model.theta_[1]) / np.sqrt(model.var_[0] + model.var_[1])
    elif hasattr(model, "calibrated_classifiers_"):
        return sum([get_feature_importance(a) for a in model.calibrated_classifiers_], 0) / len(model.calibrated_classifiers_)
    elif hasattr(model, "estimator"):
        return get_feature_importance(model.estimator)
    else:
        logger.warning(f'No feature importance found: {model}')
        return None

def train_and_evaluate(model, X_train, y_train, X_test, y_test, do_shap=False, features_names=None):
    
    X_train = X_train.numpy() if type(X_train) == torch.Tensor else X_train
    X_test = X_test.numpy() if type(X_test) == torch.Tensor else X_test
    y_train = y_train.numpy() if type(y_train) == torch.Tensor else y_train
    y_test = y_test.numpy() if type(y_test) == torch.Tensor else y_test
    
    cur_model = deepcopy(model)
    cur_model.fit(X_train, y_train)
    y_pred = cur_model.predict_proba(X_test)[..., 1]

    if do_shap:
        explainer = shap.Explainer(cur_model, X_train, feature_names=features_names)
        feature_importance = explainer(X_test)
        feature_importance = {
                "mean_abs_values": float(np.mean(np.abs(feature_importance.values))), 
                "num_features": len(feature_importance.feature_names) if hasattr(feature_importance, "feature_names") else None
         }
    else:
        feature_importance = get_feature_importance(cur_model)
        feature_importance = dict(zip(features_names, map(float, feature_importance)))
        feature_importance = dict(sorted(feature_importance.items(), key=lambda item: item[1], reverse=True))

    return feature_importance, y_pred


def perform_dataset_inference(
    model, 
    X, 
    y, 
    set_ids, 
    n_splits, 
    features_names, 
    ) -> Tuple[np.ndarray, np.ndarray, list, list, list]:

    unique_groups = np.unique(set_ids)    
    feature_importance_list = []

    if len(unique_groups) == 1:
        set_ids = np.arange(len(set_ids))
        unique_groups = set_ids
        copyright_traps = True
    else:
        copyright_traps = False
    

    skf = KFold(n_splits=min(n_splits, len(unique_groups)), shuffle=True, random_state=42)
    
    y_oof = np.zeros(len(y))
    for fold, (train_idx, test_idx) in enumerate(skf.split(unique_groups)):
        test_mask = np.zeros(len(set_ids), dtype=bool)
        test_mask = np.isin(set_ids, unique_groups[test_idx])
        train_mask = ~test_mask
        X_train, X_test = X[train_mask], X[test_mask]
        y_train, y_test = y[train_mask], y[test_mask]
        

        logger.info(f"Fold {fold}: X_train {X_train.shape}, y_train {y_train.shape}, X_test {X_test.shape}, y_test {y_test.shape}")
        
        X_train, X_test, scaler = scale_data(X_train, X_test, scaler_type="standard")
        
        X_train, y_train = shuffle(X_train, y_train,  random_state=42)

        feature_importance, y_pred = train_and_evaluate(model, X_train, y_train, X_test, y_test, features_names=features_names, do_shap=False)
        feature_importance_list.append(feature_importance)

        y_oof[test_mask] = y_pred
    

    if copyright_traps:
        unique_groups = [0]
        set_ids = np.zeros_like(set_ids)

    logger.info("Computing ranks")
    rankings_list = []
    for g in unique_groups:
        test_mask = np.where(set_ids == g)[0]
        y_group = y[test_mask]
        y_group_prob = y_oof[test_mask]
        ranks, max_rank = rank_candidates(y_group, y_group_prob)
        rankings_list += ranks    
    
    
    if not copyright_traps:
        y_0 = []
        for group_id in np.unique(set_ids):
            cnt_1 = int(y[(set_ids==group_id)&(y==1)].sum())
            y_0 += list(y_oof[(y==0)&(set_ids==group_id)][:cnt_1])
        y_0 = np.array(y_0)
        y_1 = y_oof[y==1]
    else:
        y_0 = y_oof[y==0]
        y_1 = y_oof[y==1]
    
    auc, fpr, tpr = get_roc_auc(y, y_oof)
    p_values = get_p_values([10, 20, 30, 40, 50, 60, 70, 80, 90, 100], y_1, y_0, ranks=rankings_list, max_rank=max_rank)
    
    return auc, fpr, tpr, feature_importance_list, p_values, rankings_list


def run_eval(
    X, 
    y, 
    set_ids,
    n_splits, 
    features_names, 
    pickle_path, 
    info_to_log, 
    model_params,
    ):
    
    Path(pickle_path).parent.mkdir(parents=True, exist_ok=True)
    model_type = model_params['model_type']
    del model_params['model_type']
    params = model_params
    if model_type == "LogisticRegression":
        model = LogisticRegression(**params, random_state = 42, n_jobs=64, max_iter=1000, solver='newton-cholesky', class_weight='balanced')
    elif model_type == "CatBoostClassifier":
        model = CatBoostClassifier(**params, random_state = 42, verbose=500, thread_count=64, l2_leaf_reg=1000.0)
    else:
        raise ValueError(f"Unsupported model type: {model_type}")
    


    auc, fpr, tpr, feature_importance_list, p_values, rankings_list = perform_dataset_inference(
        model=model, 
        X=X,  
        y=y, 
        set_ids=set_ids,
        n_splits=n_splits,
        features_names=features_names, 
    )
    info_to_log.update({
        "auc": auc,
        "fpr": fpr,
        "tpr": tpr,
        "feature_importance_list": feature_importance_list,
        "p_values": p_values,
        "rankings_list": rankings_list,
    })
    logger.info(f"Saving results to {pickle_path}")
    
    save_results_to_pickle(pickle_path, info_to_log)

def merge_dicts(dicts):
    d = {}
    for k in dicts[0].keys():
        if type(dicts[0][k]) == torch.Tensor:
            d[k] = torch.cat([a[k] for a in dicts])
        else:
            d[k] = np.concatenate([a[k] for a in dicts])
    return d

def main(model_list, files, files_nonmembers, number_of_nonmembers, number_of_members, key, filter_features, n_splits, clf, outliers, split_data, out_path):
    run_name = f"{key}.{clf.lower()}.{outliers}.{'_'.join(model_list).replace('/', '_')}.{number_of_members}.{number_of_nonmembers}.{n_splits}.{split_data}"

    out_path.mkdir(parents=True, exist_ok=True)

    if (out_path / f"{run_name}_results.pkl").exists():
        logger.warning(100*'-')
        logger.warning(100*'!')
        logger.warning(f"Results already exist for {run_name}")
        logger.warning(100*'!')
        logger.warning(100*'-')
        return
    logger.info(f"Starting dataset inference ({key}, {clf}, {outliers}). Files: {len(files)}")
    if len(files) == 0:
        logger.warning(f"No files found for {run_name}")
        return

    members_scores_string, nonmembers_scores_string = create_df(files)
    if len(files_nonmembers) > 0:
        logger.info(f"Nonmembers files: {len(files_nonmembers)}")
        _, nonmembers_scores_string, _ = create_df(files_nonmembers)

        # group_id, secret_type, dataset_name, model_name
        members_scores_string = {
            (0, 'all', 'all', model_name): merge_dicts([v for (group_id, secret_type, dataset_name, model_name), v in members_scores_string.items() if model_name in model_list])
            for model_name in model_list
        }
        
        nonmembers_scores_string = {
            (0, 'all', 'all', model_name): merge_dicts([v for (group_id, secret_type, dataset_name, model_name), v in nonmembers_scores_string.items() if model_name in model_list])
            for model_name in model_list
        }
    
    if len(nonmembers_scores_string) == 0:
        logger.warning(f"No data found for {run_name}")
        return
    
    logger.info(f"dataframe: {len(nonmembers_scores_string)}") 
    
    dataset = create_grouped_split(model_list, members_scores_string, nonmembers_scores_string, filter_features, number_of_nonmembers)
    
    logger.info(f"Dataset: {len(dataset)} with {len(dataset.keys())} groups")
    
    features_names = [f'{model}_{attack}' for model in model_list for attack in get_mia(next(iter(members_scores_string.values()))).keys() if filter_features(attack)]
    
    x, y, set_ids = create_split(dataset)
    logger.info(f"Initial x: {x.shape}, y: {y.shape}, set_ids: {set_ids.shape}")
    
    
    # Split data
    if split_data == 'members':
        x = x[y == 1]
        y = (torch.randperm(len(x)) < number_of_nonmembers).long()
        set_ids = torch.zeros_like(y)
    elif split_data == 'nonmembers':
        x = x[y == 0]
        y = (torch.randperm(len(x)) < number_of_nonmembers).long()
        set_ids = torch.zeros_like(y)
    
    # remove duplicates
    unique_rows, unique_indices, unique_counts = np.unique(x, axis=0, return_index=True, return_counts=True)
    x = x[unique_indices]
    y = y[unique_indices]
    set_ids = set_ids[unique_indices]    
    logger.info(f"After cleaning and removing duplicates. x: {x.shape}, y: {y.shape}, set_ids: {set_ids.shape}")
    logger.info(f"[Before limit] Members: {(y==1).long().sum()}, Groups: {len(torch.unique(set_ids))}")
    
    # limit the number of members
    unique_ids = np.unique(set_ids, axis=0)
    if len(unique_ids) == 1:
        mask = y == 1
        order = np.random.permutation(len(y[mask]))[:number_of_members]
        x = torch.cat([x[mask][order], x[~mask]])
        set_ids = torch.cat([set_ids[mask][order], set_ids[~mask]])
        y = torch.cat([y[mask][order], y[~mask]])
        cnt_members = y.sum()
    else:
        unique_ids = shuffle(unique_ids)
        if y.sum() > number_of_members:
            cnt_members = 0
            pos = 0
            new_x = []
            new_y = []
            new_set_ids = []
            while cnt_members < number_of_members and pos < len(unique_ids):
                mask = set_ids == unique_ids[pos]
                new_members = (y[mask] == 1).sum()
                if cnt_members + new_members <= number_of_members:
                    cnt_members += new_members
                    new_x.append(x[mask])
                    new_y.append(y[mask])
                    new_set_ids.append(set_ids[mask])
                pos += 1
            x = torch.cat(new_x)
            y = torch.cat(new_y)
            set_ids = torch.cat(new_set_ids)
        else:
            cnt_members = y.sum()
        
    if cnt_members != number_of_members:
        logger.warning(f"Number of members is {cnt_members}, but should be {number_of_members}")

    logger.info(f"Members: {(y==1).long().sum()}, Groups: {len(torch.unique(set_ids))}")
    
    # remove outliers
    x = clean_outliers(x, remove_frac=0.05, outliers=outliers)


    clf_name = {
        'lr': 'LogisticRegression',
        'catboost': 'CatBoostClassifier',
    }[clf]
    
    kwargs = {
        "X": x,
        "y": y,
        "set_ids": set_ids,
        "n_splits": n_splits, 
        "features_names": features_names, 
        "info_to_log": {
            "model_list": model_list,
            "key": key, 
            "number_of_nonmembers": number_of_nonmembers, 
            "clf": clf,
            "outliers": outliers,
        },
        "pickle_path": out_path / f"{run_name}_results.pkl", 
        "model_params": {
            "model_type": clf_name,
        },
    }
    
    run_eval(
        **kwargs,
    )


def enumerate_experiments(cfg):
    # Iterate over classifiers and experiment configs
    for clf in cfg.experiment.classifiers:
        for outliers in cfg.experiment.outliers:
            logger.info(f'Running for clf: {clf}, outliers: {outliers}')
            for config in cfg.experiment.configs:
                # logger.info(f'Using config: {config}')
                # logger.info(f"Configuration:\n{OmegaConf.to_yaml(cfg)}")
                
                filter_func = eval(f"{config['file_filter_condition']}")
                if 'file_filter_condition_nonmembers' in config:
                    filter_func_nonmembers = eval(f"{config['file_filter_condition_nonmembers']}")
                else:
                    filter_func_nonmembers = lambda x: False
                
                if 'filter_features' in config:
                    filter_features = eval(f"{config['filter_features']}")
                else:
                    filter_features = lambda x: True
                    
                split_data = config.get('split_data', 'none')
                assert split_data in ['none', 'members', 'nonmembers'], f"split_data should be one of ['none', 'members', 'nonmembers']. Given: {split_data}"
                
                secret_types = cfg.experiment.secret_types
                
                blind_models = config.get('blind_baseline', [])
                
                all_files = glob.glob(os.path.join(cfg.experiment.file_dir, '*.pkl'))
                all_files = [file_path for file_path in all_files if any([' ' + model.split('/')[-1].lower() + '.pkl' in file_path.split('/')[-1].lower() for model in config['model_list'] + blind_models])]
                for secret_type in secret_types:
                    key = f"{config['key']}_{secret_type}"
                    secret_type = secret_type.split(',')
                    files = [
                            file_path
                            for file_path in all_files
                            if filter_func(file_path) and 
                            (any(s.lower() in file_path.lower() for s in secret_type) or 'all' in secret_type)
                    ]
                    files_nonmembers = [
                            file_path
                            for file_path in all_files
                            if filter_func_nonmembers(file_path) and 
                            (any(s.lower() in file_path.lower() for s in secret_type) or secret_type == ['all'])
                    ]
                    if len(files) == 0:
                        logger.warning(f"No files found for {key} {secret_type}")
                        continue
                    try:
                        yield dict(
                            key= key,
                            model_list = config['model_list'],
                            blind_models = blind_models,
                            number_of_nonmembers = config['number_of_nonmembers'],
                            number_of_members = config['number_of_members'],
                            clf = clf,
                            files = sorted(files),
                            files_nonmembers = sorted(files_nonmembers),
                            out_path = Path(cfg.experiment.out_dir), 
                            n_splits=cfg.experiment.n_splits,
                            outliers=outliers,
                            filter_features=filter_features,
                            split_data=split_data,
                        )
                    except Exception as e:
                        logger.error(f"Error processing {config['key']} {clf}: {e}")
                        traceback.print_exc()
                        pass


   
@hydra.main(config_path="config", config_name="gen_config", version_base=None)
def run_experiment(cfg: DictConfig):
    list_experiments = list(enumerate_experiments(cfg))
    Parallel(n_jobs=min(cfg.experiment.get('n_jobs', 1), len(list_experiments)), verbose=15, backend='threading')(delayed(main)(**exp) for exp in list_experiments)
    logger.info("Experiment finished.")


if __name__ == "__main__":
    run_experiment()
