import random

import numpy as np
import torch
import wandb
import sys
from explainers.matching import Matching
from models.bert import Bert
from models.matching_rep import MatchingRepresentation
from models.roberta import Roberta
from models.sentence_transformer import SentenceTrans
from utils.constants import CEBAB_CONCEPTS
from utils.data_utils import load_source_sets, load_generations_sets, get_intervention_pairs, batchify_eval_cebab, \
    batchify_train_cebab, load_edits_sets, batchify_test_cebab, batchify_train_adding_prompt
from utils.model_utils import get_fine_tuned_model
from utils.results_utils import make_dir, get_results_path


def set_seed(seed):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)


wandb.init()
config = wandb.config
wandb_on = True
loss_version = config.loss_version

results_dir = get_results_path(config)
config.path = results_dir

set_seed(config.seed)

generations = load_generations_sets(filter_level=config.filter_level)
sets = load_source_sets(seed=config.seed)
edited_sets = load_edits_sets(filter_level=config.filter_level, seed=config.seed)

model_to_fine_tune = config.model_to_fine_tune
if model_to_fine_tune == 'bert':
    model = Bert()
    rep_model = model.get_representation_model()
    tokenizer = model.get_tokenizer()
elif model_to_fine_tune == 'roberta':
    model = Roberta()
    rep_model = model.get_representation_model()
    tokenizer = model.get_tokenizer()
elif model_to_fine_tune == 'sentence_transformer':
    model = SentenceTrans()
    rep_model = model.get_representation_model()
    tokenizer = model.get_tokenizer()
else:
    raise ValueError('model_to_fine_tune should be either bert or roberta')

if config.filter_level in ['no_filter', 'dry_filter']:
    edit = False
elif config.filter_level in ['all_filter', 'wet_filter']:
    edit = True

train_batches = []
if config.treatment == 'all':
    for t in CEBAB_CONCEPTS:
        train_batch = batchify_train_adding_prompt(full_set=edited_sets[f'full_train_set_{config.seed}'], treatment=t,
                                                   tokenizer=tokenizer, generation_index=config.generation_index,
                                                   edit=edit)
        train_batches = train_batches + train_batch
else:
    train_batch = batchify_train_cebab(full_set=edited_sets[f'full_train_set_{config.seed}'],
                                       treatment=config.treatment,
                                       tokenizer=tokenizer, generation_index=config.generation_index, edit=edit)

# train_batch = batchify(df_generations=generations['train_generations'], df_source=sets[f'train_set_{config.seed}'],
#                        treatment=config.treatment, tokenizer=tokenizer)
# validation_batch = batchify_eval(df_generations=generations['validation_generations'], df_source=sets['validation'],
#                                  treatment=config.treatment, tokenizer=tokenizer)

validation_batches = []
if config.treatment == 'all':
    for t in CEBAB_CONCEPTS:
        validation_batch = batchify_train_adding_prompt(full_set=edited_sets[f'full_validation_set'], treatment=t,
                                                        tokenizer=tokenizer, generation_index=config.generation_index,
                                                        edit=edit)

        validation_batches = validation_batches + validation_batch
else:
    validation_batch = batchify_train_cebab(full_set=edited_sets[f'full_validation_set'], treatment=config.treatment,
                                            tokenizer=tokenizer, generation_index=config.generation_index)

# initilaize matching rep
matching_rep = MatchingRepresentation(tokenizer=tokenizer, pretrained_model=rep_model, group=model.get_model_group())
set_to_match = edited_sets[f'expanded_matching_set_{config.seed}']
if config.treatment == 'all':
    matching_explainer_1 = Matching(set_to_match=set_to_match,
                                    representation_model=matching_rep, assign=False, top_k=1, threshold=0,
                                    adding_prompt=True, description='matching')
else:
    matching_explainer_1 = Matching(set_to_match=set_to_match,
                                    representation_model=matching_rep, assign=False, top_k=1, threshold=0,
                                    adding_prompt=False, description='matching')
explainers = [matching_explainer_1]
pairs_test = get_intervention_pairs(df=sets['test'], dataset_type="5-way", verbose=1)
pairs_validation = get_intervention_pairs(df=sets['validation'], dataset_type="5-way", verbose=1)
# pairs_test = None
if config.treatment == 'all':
    test_batch = None
else:
    test_batch = batchify_test_cebab(pairs=pairs_test, df_source=sets['test'], treatment=config.treatment,
                                     tokenizer=tokenizer)
# test_batch = None
if config.model_to_explain == 'bert':
    model_to_explain = get_fine_tuned_model('bert')
# cebab_eval = {
#     'model_to_explain': model_to_explain,
#     'explainers': explainers,
#     'pairs_validation': pairs_validation,
#     'pairs_test': pairs_test,
#     'matching_description': matching_explainer_1.get_explainer_description()
# }
cebab_eval = None
# cebab_eval = None

if config.treatment == 'all':
    train_batch = train_batches
    validation_batch = validation_batches
# train matching rep
matching_rep.train(config, results_dir=results_dir, wandb_on=wandb_on, train_batch=train_batch,
                   valid_batch=validation_batch,
                   test_batch=test_batch, mean_loss=config.mean_loss,
                   cebab_eval=cebab_eval, loss_version=loss_version)
