import pandas as pd

from utils.constants import CEBAB_CONCEPTS


def create_df_factual_counterfactual(cebab, cebab_counterfactual):
    if 'text' in cebab.columns:
        text_col = 'text'
        version = 0
    else:
        text_col = 'description'
        version = 1
    source_counterfactual_map = {text: {concept: [] for concept in CEBAB_CONCEPTS} for text in
                                 cebab[text_col]}

    for idx in cebab.index:
        example = cebab.loc[idx]
        intervention = example['edit_type']
        if intervention == 'None':
            continue
        text = example[text_col]
        counterfactuals = cebab_counterfactual[cebab_counterfactual['id'] == str(example['original_id'])]
        source_counterfactual_map[text][intervention] = list(counterfactuals['prediction'])
        if version == 0:
            source_counterfactual_map[text][f'{intervention} label'] = example[intervention]
        else:
            source_counterfactual_map[text][f'{intervention} label'] = example[f'{intervention}_aspect_majority']

    df = pd.DataFrame.from_dict(source_counterfactual_map, orient='index')
    df['original_text'] = df.index
    # df.to_csv(path, index=False)
    return df


set_to_match = pd.read_csv(f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/nitagens/test_generations.csv')
set_to_match['original_id'] = set_to_match['id'].apply(lambda x: str(int(x.split('_')[0]) + 2027))
set_to_match['edit_id'] = set_to_match['id'].apply(lambda x: str(int(x.split('_')[1])))
source = pd.read_csv("/home/XXXXXX/MatchingBasedCausalExplanation/sets/sources/test.csv")
create_df_factual_counterfactual(source, set_to_match)
