from argparse import Namespace
from typing import Optional

import hydra
import pandas as pd

from lm_understanding.data import load_model_behaviors, model_behavior_path
from lm_understanding.explanations import (create_counterfactual_explanations,
                                           create_rationalizations, create_salience, create_wikibabble, create_spamchar)
from lm_understanding.models import load_model
from lm_understanding.prompting import Completer, Prompter
from lm_understanding.question_template import TemplateModelBehavior


def create_explanations(config: Namespace, behavior: TemplateModelBehavior, model: Optional[Completer], prompter: Optional[Prompter]) -> pd.DataFrame:
    print(config)
    if config.explanation.name == 'counterfactual':
        explanations = create_counterfactual_explanations(behavior)
    elif config.explanation.name == 'rationalization':
        assert model is not None and prompter is not None
        explanations = create_rationalizations(behavior, model, prompter)
    elif config.explanation.name == 'salience':
        explanations = create_salience(behavior, config)
    elif config.explanation.name == 'wikibabble':
        explanations = create_wikibabble(behavior)
    elif config.explanation.name == 'spamchar':
        explanations = create_spamchar(behavior)
    else:
        raise NotImplementedError()
    return explanations.to_df()


@hydra.main(config_path='config', config_name='create_explanations.yaml', version_base='1.2')
def main(config):
    model_behaviors, _ = load_model_behaviors(config)
    if config.explanation.requires_model_queries:
        model, prompter = load_model(config.model, free_response=True)
    else:
        model, prompter = None, None
    for model_behavior in model_behaviors:
        assert model_behavior.template_id is not None
        explanations = create_explanations(config, model_behavior, model, prompter)
        save_path = model_behavior_path(config) / model_behavior.template_id / 'explanations' / f'{config.explanation.name}.csv'
        save_path.parent.mkdir(exist_ok=True, parents=True)
        explanations.to_csv(save_path, index=False)


if __name__ == '__main__':
    main()