# load all the data
import numpy as np
import pandas as pd

from models.distil_bert import Distil_Bert
from models.roberta import Roberta
from utils.constants import STANCE_CONFOUNDERS, CONFOUNDERS_MAPS
from utils.training_utils import train_concept_stance_data
from utils.data_utils import load_stance_detection


# train all the necessary models
def train_all_models(source_to_target, backbone='roberta'):
    data = load_stance_detection()
    for s_t in source_to_target:
        # train text to speaker model
        if 'label' in s_t[1]:
            if backbone == 'roberta':
                model = Roberta(num_labels=3)
            else:
                model = Distil_Bert(num_labels=3)
        else:
            if backbone == 'roberta':
                model = Roberta(num_labels=len(CONFOUNDERS_MAPS[s_t[1]]))
            else:
                model = Distil_Bert(num_labels=len(CONFOUNDERS_MAPS[s_t[1]]))
        train_concept_stance_data(classifier=model, concept=f'{s_t[0]}_to_{s_t[1]}',
                                  df_train=data['train_base'], df_dev=data['dev_base'],
                                  df_test=data['test_base'].sample(100), label_column=s_t[1],
                                  text_column=s_t[0])
        del model

    print('done')


def load_original_source(prediction_included=False, filtered=False):
    path_base = '/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/data.csv'
    path_zero = '/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/zero_shot.csv'

    if prediction_included:
        path_base = path_base.replace('.csv', '_preds_included.csv')
        path_zero = path_zero.replace('_shot.csv', '_preds_included.csv')
    elif filtered:
        path_base = path_base.replace('.csv', '_filtered.csv')
        path_zero = path_zero.replace('_shot.csv', '_filtered.csv')
    base = pd.read_csv(path_base)
    zero = pd.read_csv(path_zero)
    return base, zero


def save_predictions(df, source_to_target, path='data_preds_included.csv', backbone='roberta'):
    df.dropna(subset=['text', 'instruction'], inplace=True)
    for s_t in source_to_target:
        if 'label' in s_t[1]:
            num_labels = 3
        else:
            num_labels = len(CONFOUNDERS_MAPS[s_t[1]])
        if backbone == 'roberta':
            instruction_to_label_model = Roberta(num_labels=num_labels,
                                                 pretrained_model_path=f'/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/stance_setup/{s_t[0]}_to_{s_t[1]}/roberta-base')
        elif backbone == 'distil':
            instruction_to_label_model = Distil_Bert(num_labels=num_labels,
                                                     pretrained_model_path=f'/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/stance_setup/{s_t[0]}_to_{s_t[1]}/distilbert-base-uncased')
        outputs_base = instruction_to_label_model.get_predictions(list(df[s_t[0]].values),
                                                                  return_both=True)
        # outputs_zero = instruction_to_label_model.get_predictions(list(zero['instruction'].values),
        #                                                           return_both=True)
        if backbone == 'roberta':
            df[f'{s_t[0]}_to_{s_t[1]}_preds'] = outputs_base[1]
            # zero[f'instruction_to_label_preds'] = outputs_zero[1]
            df[f'{s_t[0]}_to_{s_t[1]}_probs'] = outputs_base[0]
            # zero[f'instruction_to_label_probs'] = outputs_zero[0]
        elif backbone == 'distil':
            df[f'{s_t[0]}_to_{s_t[1]}_distil_preds'] = outputs_base[1]
            # zero[f'instruction_to_label_preds'] = outputs_zero[1]
            df[f'{s_t[0]}_to_{s_t[1]}_distil_probs'] = outputs_base[0]
            # zero[f'instruction_to_label_probs'] = outputs_zero[0]

        del instruction_to_label_model
        print('finish save predictions to ', s_t[0], ' to ', s_t[1])

        # save the predictions
        df.to_csv(f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/{path}.csv',
                  index=False)
    # zero.to_csv(
    #     '/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/zero_predictions_included.csv',
    #     index=False)
    print('done')
    return df


def save_edit_predictions(df, source_to_target, path='data_preds_included', backbone='roberta'):
    for s_t in source_to_target:
        # get predictions for edit text, if it is nan, assign nan
        model = f'{s_t[0]}_to_{s_t[1]}'.replace('edit_', '')
        if 'label' in s_t[1]:
            num_labels = 3
        else:
            num_labels = len(CONFOUNDERS_MAPS[s_t[1]])
        if backbone == 'roberta':
            # do the same for instruction_to_label
            instruction_to_label_model = Roberta(num_labels=num_labels,
                                                 pretrained_model_path=f'/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/stance_setup/{model}/roberta-base')
        elif backbone == 'distil':
            instruction_to_label_model = Distil_Bert(num_labels=num_labels,
                                                     pretrained_model_path=f'/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/stance_setup/{model}/distilbert-base-uncased')
        base_not_null = df[~df[s_t[0]].isnull()]
        outputs_base = instruction_to_label_model.get_predictions(list(base_not_null[s_t[0]].values),
                                                                  return_both=True)
        # assign 1 to nan values, keep the indices
        base_null_indices = df[df[s_t[0]].isnull()].index
        if backbone == 'roberta':
            df.loc[base_null_indices, f'{s_t[0]}_to_{s_t[1]}_preds'] = 1
            df.loc[base_null_indices, f'{s_t[0]}_to_{s_t[1]}_probs'] = 1
            # assign the predictions to the data
            base_not_null[f'{s_t[0]}_to_{s_t[1]}_preds'] = outputs_base[1]
            base_not_null[f'{s_t[0]}_to_{s_t[1]}_probs'] = outputs_base[0]
        elif backbone == 'distil':
            df.loc[base_null_indices, f'{s_t[0]}_to_{s_t[1]}_distil_preds'] = 1
            df.loc[base_null_indices, f'{s_t[0]}_to_{s_t[1]}_distil_probs'] = 1
            # assign the predictions to the data
            base_not_null[f'{s_t[0]}_to_{s_t[1]}_distil_preds'] = outputs_base[1]
            base_not_null[f'{s_t[0]}_to_{s_t[1]}_distil_probs'] = outputs_base[0]

        # merge the data
        df = pd.concat([base_not_null, df[df[s_t[0]].isnull()]], axis=0)
        del instruction_to_label_model
        print('finish save predictions to ', s_t[0], ' to ', s_t[1])
        df.to_csv(
            f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/{path}.csv',
            index=False)
    print('done')
    return df


def save_predictions_special_models(df, source_to_target, path='data_preds_included'):
    df.dropna(subset=['text', 'instruction'], inplace=True)
    for s_t in source_to_target:
        if 'label' in s_t[1]:
            num_labels = 3
        else:
            num_labels = len(CONFOUNDERS_MAPS[s_t[1]])
        instruction_to_label_model = Roberta(num_labels=num_labels,
                                             pretrained_model_path=f'/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/stance_setup/{s_t[2]}/roberta-base')
        outputs_base = instruction_to_label_model.get_predictions(list(df[s_t[0]].values),
                                                                  return_both=True)
        # outputs_zero = instruction_to_label_model.get_predictions(list(zero['instruction'].values),
        #                                                           return_both=True)

        df[f'{s_t[3]}_{s_t[0]}_to_{s_t[1]}_preds'] = outputs_base[1]
        # zero[f'instruction_to_label_preds'] = outputs_zero[1]
        df[f'{s_t[3]}_{s_t[0]}_to_{s_t[1]}_probs'] = outputs_base[0]
        # zero[f'instruction_to_label_probs'] = outputs_zero[0]

        del instruction_to_label_model
        print('finish save predictions to ', s_t[0], ' to ', s_t[1])

        # save the predictions
        df.to_csv(f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/{path}.csv',
                  index=False)
    # zero.to_csv(
    #     '/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/zero_predictions_included.csv',
    #     index=False)
    print('done')
    return df


def save_edit_predictions_special_models(df, source_to_target, path='data_preds_included'):
    for s_t in source_to_target:
        # get predictions for edit text, if it is nan, assign nan
        if 'label' in s_t[1]:
            num_labels = 3
        else:
            num_labels = len(CONFOUNDERS_MAPS[s_t[1]])
        # do the same for instruction_to_label
        instruction_to_label_model = Roberta(num_labels=num_labels,
                                             pretrained_model_path=f'/home/XXXXXX/MatchingBasedCausalExplanation/saved_models/stance_setup/{s_t[2]}/roberta-base')
        base_not_null = df[~df[s_t[0]].isnull()]
        outputs_base = instruction_to_label_model.get_predictions(list(base_not_null[s_t[0]].values),
                                                                  return_both=True)
        # assign 1 to nan values, keep the indices
        base_null_indices = df[df[s_t[0]].isnull()].index
        df.loc[base_null_indices, f'{s_t[0]}_to_{s_t[1]}_preds'] = 1
        df.loc[base_null_indices, f'{s_t[0]}_to_{s_t[1]}_probs'] = 1
        # assign the predictions to the data
        base_not_null[f'{s_t[3]}_{s_t[0]}_to_{s_t[1]}_preds'] = outputs_base[1]
        base_not_null[f'{s_t[3]}_to_{s_t[1]}_probs'] = outputs_base[0]
        # merge the data
        df = pd.concat([base_not_null, df[df[s_t[0]].isnull()]], axis=0)
        del instruction_to_label_model
        print('finish save predictions to ', s_t[0], ' to ', s_t[1])
    df.to_csv(
        f'/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/{path}.csv',
        index=False)
    print('done')
    return df


def filtering(base, zero):
    base_temp = base.copy()
    zero_temp = zero.copy()
    for concept in ['age', 'domain', 'job']:
        base_len_before_filtering = len(base_temp)
        zero_len_before_filtering = len(zero_temp)
        # encode preds to values
        # create oposite map from CONFOUNDERS_MAPS[concept]
        oposite_map = {f'tensor({v})': k for k, v in CONFOUNDERS_MAPS[concept].items()}
        base_temp[f'edit_text_to_{concept}_preds_text'] = base_temp[f'edit_text_to_{concept}_preds'].apply(
            lambda x: oposite_map[x] if x != '1.0' else x)
        zero_temp[f'edit_text_to_{concept}_preds_text'] = zero_temp[f'edit_text_to_{concept}_preds'].apply(
            lambda x: oposite_map[x])

        # drop values with edit_type = concept, and edit_goal different from the text_to_concept_preds
        base_temp = base_temp[(base_temp['edit_type'] != concept) | ((base_temp['edit_type'] == concept) & (
                base_temp['edit_goal'] == base_temp[f'edit_text_to_{concept}_preds_text']))]

        zero_temp = zero_temp[(zero_temp['edit_type'] != concept) | ((zero_temp['edit_type'] == concept) & (
                zero_temp['edit_goal'] == zero_temp[f'edit_text_to_{concept}_preds_text']))]

        base_len_after_filtering = len(base_temp)
        print('dropped base:', base_len_before_filtering - base_len_after_filtering, ' from ', concept)
        zero_len_after_filtering = len(zero_temp)
        print('dropped zero:', zero_len_before_filtering - zero_len_after_filtering, ' from ', concept)
    base = base_temp
    zero = zero_temp

    base.to_csv(
        '/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/data_filtered.csv',
        index=False)
    zero.to_csv(
        '/home/XXXXXX/MatchingBasedCausalExplanation/sets/stance_detection/zero_filtered.csv',
        index=False)


# train_all_models(source_to_target=[
#     ('text', 'age'),
#     ('text', 'gender'),
#     ('text', 'domain'),
#     ('text', 'job')])
train_all_models(source_to_target=[
    ('instruction', 'label'),
    #           ('original_instruction', 'label'),
    #                                    ('original_instruction', 'original_label'),
    # ('instruction', 'original_label')
], backbone='roberta')
# train_all_models(source_to_target=[
#     ('instruction', 'label'),
#     #           ('original_instruction', 'label'),
#     #                                    ('original_instruction', 'original_label'),
#     # ('instruction', 'original_label')
# ], backbone='roberta', domains=['Climate Change', 'Feminism'])
filtered = False
base, zero = load_original_source(filtered=filtered, prediction_included=True)

df = 'base'
if df == 'zero':
    df = zero
    print('zero')
    if filtered:
        p = 'zero_filtered'
    else:
        p = 'zero_preds_included'
else:
    print('base')
    df = base
    if filtered:
        p = 'data_filtered'
    else:
        p = 'data_preds_included'
# #
df = save_predictions(df, source_to_target=[
    ('instruction', 'label_Feminism_Climate Change'),
    # ('original_instruction', 'label'),
    # ('original_instruction', 'original_label'),
    # ('instruction', 'original_label'),
    # ('text', 'age'),
    # ('text', 'gender'),
    # ('text', 'domain'),
    # ('text', 'job')
], path=p, backbone='roberta')
# # # # #
df = save_edit_predictions(df, source_to_target=[
    ('edit_instruction', 'label_Feminism_Climate Change'),
    # ('edit_instruction', 'original_label'),
    # ('edit_text', 'age'),
    # ('edit_text', 'gender'),
    # ('edit_text', 'domain'),
    # ('edit_text', 'job')
], path=p, backbone='roberta')
# #
# # # df = save_predictions_special_models(df, source_to_target=[
# # #     ('instruction', 'label', 'original_instruction_to_label', 'original_to_label'),
# # #     ('instruction', 'original_label', 'original_instruction_to_label', 'original_to_label'),
# # # ], path=p)
# # # df = save_edit_predictions_special_models(df, source_to_target=[
# # #     ('edit_instruction', 'label', 'original_instruction_to_label', 'original_to_label'),
# # #     ('edit_instruction', 'original_label', 'original_instruction_to_label', 'original_to_label'),
# # # ], path=p)
# #
# # # filtering(base, zero)
