import pandas as pd

from eval_pipeline.utils import metric_utils, get_intervention_pairs, get_train_singles_and_pairs
from eval_pipeline.explainers import ExclusiveExplainer, InclusiveExplainer

def cebab_pipeline(model, explainer, train_dataset, dev_dataset, seed, k, dataset_type='5-way', shorten_model_name=False, train_setting='exclusive', approximate=False):
    # TODO: add inclusive
    ## k training pairs (sample or get them from a pre-loaded sampled file?)
    ### How? k pairs == 2*k samples, with a maximum of k u's
    ## n-k training singles

    if train_setting in ['inclusive', 'approximate']:
        # NOTE: this can be moved to an outer loop for speed optimization.
        # NOTE: this should be done before the runs and just saved in some files
        # TODO: approx true
        train_dataset, train_pairs_dataset = get_train_singles_and_pairs(train_dataset[0], train_dataset[1], seed, k, dataset_type=dataset_type, approximate=approximate)
    elif train_setting == 'exclusive':
        pass
 
    # NOTE: we will only work with models that are fitted
    # fit model
    model.fit(train_dataset)
    
    # get predictions on train and dev
    train_predictions, _ = model.predict_proba(train_dataset)
    dev_predictions, dev_report = model.predict_proba(dev_dataset)

    if train_setting in ['inclusive', 'approximate']:
        # TODO: add predictions to dataset
        # TODO: get the model predictions in a pair format for inclusive explainers
        predictions = pd.DataFrame(data=zip(train_dataset['id'].to_numpy(), train_predictions), columns=['id', 'prediction'])

        train_dataset = train_dataset.merge(predictions, on='id')

        predictions_base = predictions.rename(lambda x: x+'_base', axis=1) 
        predictions_counterfactual = predictions.rename(lambda x: x+'_counterfactual', axis=1) 

        train_pairs_dataset = train_pairs_dataset.merge(predictions_base, on='id_base')
        train_pairs_dataset = train_pairs_dataset.merge(predictions_counterfactual, on='id_counterfactual')

    # append predictions to datasets
    # train_dataset['prediction'] = list(train_predictions)
    dev_dataset['prediction'] = list(dev_predictions)

    # get intervention pairs
    # TODO: approx false
    pairs_dataset = get_intervention_pairs(dev_dataset, dataset_type=dataset_type)  # TODO why is the index not unique here?

    # fit explainer
    # TODO: add inclusive
    if train_setting in ['inclusive', 'approximate']:
        if isinstance(explainer, ExclusiveExplainer):
            explainer.fit(train_dataset, train_predictions, model, dev_dataset=None)
        else:
            explainer.fit(train_pairs_dataset, train_dataset, model, dev_dataset=pairs_dataset)

    elif train_setting == 'exclusive':
        explainer.fit(train_dataset, train_predictions, model, dev_dataset=dev_dataset)

    # mitigate possible data leakage
    allowed_columns = [
        'description_base',
        'review_majority_base',
        'food_aspect_majority_base',
        'service_aspect_majority_base',
        'noise_aspect_majority_base',
        'ambiance_aspect_majority_base',
        'intervention_type',
        'intervention_aspect_base',
        'intervention_aspect_counterfactual',
        'opentable_metadata_base',
        'prediction_base'
    ]

    pairs_dataset_no_leakage = pairs_dataset.copy()[allowed_columns]

    # get explanations
    explanations = explainer.estimate_icace(pairs_dataset_no_leakage)

    # append explanations to the pairs
    pairs_dataset['EICaCE'] = explanations

    pairs_dataset = metric_utils._calculate_ite(pairs_dataset)  # effect of crowd-workers on other crowd-workers (no model, no explainer)
    pairs_dataset = metric_utils._calculate_icace(pairs_dataset)  # effect of concept on the model (with model, no explainer)
    pairs_dataset = metric_utils._calculate_estimate_loss(pairs_dataset)  # l2 CEBaB Score (model and explainer)

    # only keep columns relevant for metrics
    CEBaB_metrics_per_pair = pairs_dataset[[
        'intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual', 'ITE', 'ICaCE', 'EICaCE', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff']].copy()
    CEBaB_metrics_per_pair['count'] = 1

    # get CEBaB tables
    metrics = ['count', 'ICaCE', 'EICaCE']

    groupby_aspect_direction = ['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual']

    CaCE_per_aspect_direction = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, metrics)
    CaCE_per_aspect_direction.columns = ['count', 'CaCE', 'ECaCE']
    CaCE_per_aspect_direction = CaCE_per_aspect_direction.set_index(['count'], append=True)
    
    ACaCE_per_aspect = metric_utils._aggregate_metrics(CaCE_per_aspect_direction.abs(), ['intervention_type'], ['CaCE', 'ECaCE'])
    ACaCE_per_aspect.columns = ['ACaCE', 'EACaCE']

    CEBaB_metrics_per_aspect_direction = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff'])
    CEBaB_metrics_per_aspect_direction.columns = ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff']
    CEBaB_metrics_per_aspect_direction = CEBaB_metrics_per_aspect_direction.set_index(['count'], append=True)

    CEBaB_metrics_per_aspect = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, ['intervention_type'], ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff'])
    CEBaB_metrics_per_aspect.columns = ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff']
    CEBaB_metrics_per_aspect = CEBaB_metrics_per_aspect.set_index(['count'], append=True)

    CEBaB_metrics = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, [], ['ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff'])

    # get ATE table
    ATE = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, ['count', 'ITE'])
    ATE.columns = ['count', 'ATE']

    # add model and explainer information
    if shorten_model_name:
        model_name = str(model).split('.')[0]
    else:
        model_name = str(model)

    # deal with some idiosyncrasies
    if "checkpoint" in model_name:
        model_name = model.model_path.split('__')[-1].split('/')[0]

    CaCE_per_aspect_direction.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) if col != 'CaCE' else (model_name, '', col) for col in CaCE_per_aspect_direction.columns])
    ACaCE_per_aspect.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) if col != 'ACaCE' else (model_name, '', col) for col in ACaCE_per_aspect.columns])
    CEBaB_metrics_per_aspect_direction.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) for col in CEBaB_metrics_per_aspect_direction.columns])
    CEBaB_metrics_per_aspect.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) for col in CEBaB_metrics_per_aspect.columns])
    CEBaB_metrics.index = pd.MultiIndex.from_product([[model_name], [str(explainer)], CEBaB_metrics.index])
    
    # performance report
    performance_report_index = ['macro-f1', 'accuracy']
    performance_report_data = [dev_report['macro avg']['f1-score'], dev_report['accuracy']]
    performance_report_col = [model_name]
    performance_report = pd.DataFrame(data=performance_report_data, index=performance_report_index, columns=performance_report_col)

    return ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, ACaCE_per_aspect, performance_report

def run_pipelines(models, explanators, train, dev, seed, k, dataset_type='5-way', shorten_model_name=False, train_setting='exclusive',approximate=False):
    # run all (model, explainer) pairs
    results_ATE = []
    results_CEBaB_metrics = []
    results_CEBaB_metrics_per_aspect_direction = []
    results_CEBaB_metrics_per_aspect = []
    results_CaCE_per_aspect_direction = []
    results_ACaCE_per_aspect = []
    results_performance_report = []

    for model, explainer in zip(models, explanators):
        print(f'Now running {explainer}')
        if train_setting == 'exclusive':
            train_dataset = train.copy()
        elif train_setting in ['inclusive', 'approximate']:
            train_dataset = (train[0].copy(), train[1].copy())
        dev_dataset = dev.copy()

        ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, ACaCE_per_aspect, performance_report = cebab_pipeline(
            model, explainer, train_dataset, dev_dataset, seed, k, dataset_type=dataset_type, shorten_model_name=shorten_model_name, train_setting=train_setting, approximate=approximate)

        results_ATE.append(ATE)
        results_CEBaB_metrics.append(CEBaB_metrics)
        results_CEBaB_metrics_per_aspect_direction.append(CEBaB_metrics_per_aspect_direction)
        results_CEBaB_metrics_per_aspect.append(CEBaB_metrics_per_aspect)
        results_CaCE_per_aspect_direction.append(CaCE_per_aspect_direction)
        results_ACaCE_per_aspect.append(ACaCE_per_aspect)
        results_performance_report.append(performance_report)

    # concat the results
    final_ATE = results_ATE[0]
    final_CEBaB_metrics = pd.concat(results_CEBaB_metrics, axis=0)
    final_CEBaB_per_aspect_direction = pd.concat(results_CEBaB_metrics_per_aspect_direction, axis=1)
    final_CEBaB_per_aspect = pd.concat(results_CEBaB_metrics_per_aspect, axis=1)
    final_CaCE_per_aspect_direction = pd.concat(results_CaCE_per_aspect_direction, axis=1)
    final_ACaCE_per_aspect = pd.concat(results_ACaCE_per_aspect, axis=1)
    final_report = results_performance_report[0]

    # drop duplicate ICaCE columns
    final_CaCE_per_aspect_direction = final_CaCE_per_aspect_direction.loc[:,~final_CaCE_per_aspect_direction.columns.duplicated()]
    final_ACaCE_per_aspect = final_ACaCE_per_aspect.loc[:,~final_ACaCE_per_aspect.columns.duplicated()]


    return final_ATE, final_CEBaB_metrics, final_CEBaB_per_aspect_direction, final_CEBaB_per_aspect, final_CaCE_per_aspect_direction, final_ACaCE_per_aspect, final_report
