from gradiend.model import ModelWithGradiend
from gradiend.setups.race.training import WhiteBlackSetup, WhiteAsianSetup, BlackAsianSetup, ChristianJewishSetup, \
    ChristianMuslimSetup, MuslimJewishSetup


def analyze(setup, model_id):
    model = ModelWithGradiend.from_pretrained(model_id)
    output = f'{model_id}/encoded_values.csv'
    setup.analyze_model(model, output)
    print("Analysis complete.")


if __name__ == '__main__':
    setup = WhiteBlackSetup()
    model = 'results/models/race_white_black/distilbert-base-cased-v5'
    model = 'results/models/race_white_asian/bert-base-cased-v2'

    base_models = ['gpt2', 'distilbert-base-cased', 'roberta-large']
    setups = [WhiteBlackSetup(), WhiteAsianSetup(), BlackAsianSetup(), ChristianJewishSetup(), ChristianMuslimSetup(), MuslimJewishSetup()]
    #model = 'results/models/race_white_black/gpt2-v5'
    #analyze(setup, model)
    for base_model in base_models:
        for setup in setups:
            try:
                print(f'Analyzing {setup.id} with base model {base_model}')
                model = f'results/models/{setup.id}/{base_model}-v5'
                analyze(setup, model)
            except Exception as e:
                print(f"Error analyzing {model}: {e}")