import random

import numpy as np
import torch
import pandas as pd

pd.options.mode.chained_assignment = None

from explainers.matching import Matching
from models.matching_rep import MatchingRepresentation
from models.sentence_transformer import SentenceTrans
from utils.data_utils import batchify_stance, \
    load_stance_detection
from utils.results_utils import get_results_path
from configs import config_stance as config
from utils.constants import STANCE_CONFOUNDERS
import wandb


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)


confounders_cols = [c for c in STANCE_CONFOUNDERS if c != config.treatment]
wandb.init()
config = wandb.config
wandb_on = True
loss_version = config.loss_version

results_dir = get_results_path(config, save_config=True)
config.path = results_dir

set_seed(config.seed)

data = load_stance_detection(filtered=False, prediction_include=True)
domains = ['Climate Change', 'Feminism']
for key in data.keys():
    # take only the examples with relevant domain
    data[key] = data[key][data[key]['domain_text'].isin(domains)]
    print(f'{key} size: {len(data[key])}')
data['train_base'] = data['train_base']
data['dev_base'] = data['dev_base']
data['test_base'] = data['test_base'].sample(1000)

# data['test_base'] = data['test_base'][data['test_base']['label_text'] != data['test_base']['edit_label_text']]

# the same for cfs
model = SentenceTrans()
rep_model = model.get_representation_model()
tokenizer = model.get_tokenizer()

batches = batchify_stance(base_set=data['train_base'], cfs_set=data['train_cfs'], treatment_col=config.treatment,
                          confounders_cols=confounders_cols, tokenizer=tokenizer)

validation_batches = batchify_stance(base_set=data['dev_base'], cfs_set=data['dev_cfs'],
                                     treatment_col=config.treatment, confounders_cols=confounders_cols,
                                     tokenizer=tokenizer)

# initilaize matching rep
matching_rep = MatchingRepresentation(tokenizer=tokenizer, pretrained_model=rep_model, group=model.get_model_group(),
                                      setup_name=config.setup_name)
set_to_match = data['matching_train']

matching_explainer_1 = Matching(set_to_match=set_to_match,
                                representation_model=matching_rep, assign=False, top_k=1, threshold=0,
                                adding_prompt=False, description='matching', setup_name=config.setup_name)
explainers = [matching_explainer_1]

test_batch = batchify_stance(base_set=data['test_base'], cfs_set=data['test_cfs'], treatment_col=config.treatment,
                             confounders_cols=confounders_cols, tokenizer=tokenizer)

model_to_explain = config.model_to_explain

cebab_eval = {
    'model_to_explain': model_to_explain,
    'explainers': explainers,
    'pairs_validation': None,
    'pairs_test': data['test_base'],
    'matching_description': matching_explainer_1.get_explainer_description()
}

# train matching rep
matching_rep.train(config, results_dir=results_dir, wandb_on=wandb_on, train_batch=batches,
                   valid_batch=validation_batches,
                   test_batch=test_batch, mean_loss=config.mean_loss,
                   cebab_eval=cebab_eval, loss_version=loss_version)
