from gradiend.evaluation.analyze_decoder import compute_lms
from gradiend.model import ModelWithGradiend
from gradiend.setups.race.training import WhiteBlackSetup

id = 'religion_christian_jewish'
id = 'race_white_black'


gradiend_model_id = f'results/models/{id}/bert-base-cased'
gradiend_model_id = f'results/experiments/gradiend/race_white_black/distilbert-base-cased/0'

gradiend = ModelWithGradiend.from_pretrained(gradiend_model_id)

changed_model = gradiend.modify_model(0.1, 0)
output = f'results/changed_models/{id}/distilbert-base-cased-{id}'

changed_model.save_pretrained(output)
# Save the tokenizer as well
gradiend.tokenizer.save_pretrained(output)


setup = WhiteBlackSetup()
texts = setup.create_eval_data(gradiend, max_size=100)
texts = texts['texts']

lms = compute_lms(changed_model, gradiend.tokenizer, texts, ignore=setup.non_neutral_terms)

print(lms)

# 0, 0: 0.62 (SEAT 0.5975)
# 0.01, 0:
# 0.1, 0: 0.57 (SEAT 0.59)
# 1, 0: 0.34
# 10, 0: 0.03 (SEAT 0.18)

# - 0.1, 0: 0.576 (SEAT 0.6000)


# jewish christian base 0.6196606202457577 (SEAT 0.41285463039133397)

# base cj 0.3100783697523169,
# 0.310342986376854

# Base 0.4379
# 0.1: 0.4369