
import json
from functools import partial
from pathlib import Path
from pprint import pprint
from typing import List

import hydra
import numpy as np
import pandas as pd
from tqdm import tqdm

import datasets
from lm_understanding.models import load_model
from lm_understanding.prompting.answer import answer_questions
from lm_understanding.prompting.questions import YesNoQuestion


def load_commonsense_ethics_questions(path: str='model_evals/datasets/cm_test.csv') -> List[YesNoQuestion]:
    df = pd.read_csv(path)
    question_text = lambda t: f'{t} Was this ethical?'
    label = lambda b: 'No' if b else 'Yes'
    return [YesNoQuestion(question_text(row.input), label(row.label)) for _, row in df.iterrows()]


def load_fantasy_reasoning_questions():
    data = json.loads(Path('model_evals/datasets/fantasy_reasoning.json').read_text())
    label = lambda e: 'Yes' if e['target_scores']['Yes.'] else 'No'
    return [YesNoQuestion(e['input'], label(e)) for e in data['examples']]


def load_moral_permissibility_questions():
    data = json.loads(Path('model_evals/datasets/moral_permissibility.json').read_text(encoding='utf-8'))
    label = lambda e: 'Yes' if e['target_scores']['Yes'] else 'No'
    return [YesNoQuestion(e['input'], label(e)) for e in data['examples']]


def load_anthropic_evals_questions(path: str = 'model_evals/datasets/self-awareness-general-ai.jsonl'):
    data = [json.loads(l) for l in Path(path).read_text().splitlines()]
    is_yes_no = lambda q: ('(A) Yes' in q and '(B) No' in q) or '(B) Yes' in q and '(A) No' in q
    data = [d for d in data if is_yes_no(d['question'])]
    question_text = lambda d: d['question'].split('\n')[0]
    label = lambda d: 'Yes' if f'{d["answer_matching_behavior"].strip()} Yes' in d['question'] else 'No'
    return [YesNoQuestion(question_text(d), label(d)) for d in data]


def load_anthropic_evals_agree_statement_question(path: str = 'model_evals/datasets/okay-with-deceiving-a-user-to-help-them.jsonl'):
    data = [json.loads(l) for l in Path(path).read_text().splitlines()]
    question_text = lambda d: d['question']
    label = lambda d: d["answer_matching_behavior"].strip()
    return [YesNoQuestion(question_text(d), label(d)) for d in data]


def load_boolq_questions():
    dataset: datasets.Dataset = datasets.load_dataset('boolq')['validation']  # type: ignore
    question_text = lambda row: f'{row["passage"]} {row["question"]}?'
    label = lambda row: 'Yes' if row["answer"] else 'No'
    qs = [YesNoQuestion(question_text(row), label(row)) for row in dataset]
    return [q for q in qs if len(q.text.split()) < 250]


def load_csqa2():
    data = [json.loads(l) for l in Path('model_evals/datasets/csqa2_dev.json').read_text().splitlines()]
    question_text = lambda d: d['question'] if '?' in d['question'] else f'{d["question"]} Is this true?'
    label = lambda d: 'Yes' if d["answer"] == 'yes' else 'No'
    return [YesNoQuestion(question_text(d), label(d)) for d in data]



TASK_QUESTION_LOAD_FNS = dict(
    boolq=load_boolq_questions,
    # csqa2=load_csqa2,
    fantasy_reasoning=load_fantasy_reasoning_questions,
    commonsense_ethics=load_commonsense_ethics_questions,
    commonsense_ethics_hard=partial(load_commonsense_ethics_questions, path='model_evals/datasets/cm_test_hard.csv'),
    moral_permissibility=load_moral_permissibility_questions,
    self_awareness_good_text_model=partial(load_anthropic_evals_questions, path='model_evals/datasets/self-awareness-good-text-model.jsonl'),
    corrigibility=partial(load_anthropic_evals_questions, path='model_evals/datasets/corrigible-neutral-HHH.jsonl'),
    power_seeking=partial(load_anthropic_evals_questions, path='model_evals/datasets/power-seeking-inclination.jsonl'),
    self_preservation=partial(load_anthropic_evals_questions, path='model_evals/datasets/survival-instinct.jsonl'),
    willingness_to_deceive=partial(load_anthropic_evals_agree_statement_question, path='model_evals/datasets/okay-with-deceiving-a-user-to-help-them.jsonl'),
)

def balance_yes_no_answers(questions: List[YesNoQuestion]) -> List[YesNoQuestion]:
    yes_questions = [q for q in questions if q.label == 'Yes']
    no_questions = [q for q in questions if q.label == 'No']
    n_questions_to_keep = min(len(yes_questions), len(no_questions))
    return [q for t in zip(yes_questions[:n_questions_to_keep], no_questions[:n_questions_to_keep]) for q in t]


@hydra.main(config_path='../../config', config_name='model_behavior.yaml', version_base='1.2')
def main(config):
    max_questions_per_task = 100
    model, prompter = load_model(config.model)
    
    results = list()
    simple_results = dict()
    for eval_name, question_load_fn in tqdm(list(TASK_QUESTION_LOAD_FNS.items())):
        questions = question_load_fn()
        questions = balance_yes_no_answers(questions)[:max_questions_per_task]
        answers = answer_questions(questions, prompter, model)  # type: ignore
        simple_results[eval_name] = answers.accuracy
        results.append(dict(
            eval_name=eval_name,
            accuracy=answers.accuracy,
            valid_answer_fraction=np.array(answers.valid_answers).mean(),
            total_valid_answer_prob=np.array([a.completion.total_answer_prob for a in answers]).mean() if model.multiple_choice and 'gpt' not in model.name else None,
            average_yes_answer_prob=np.array(answers.positive_answer_probs).mean(),
            n_questions=len(questions),
            quesion_yes_answer_fraction=np.array([q.label == 'Yes' for q in questions]).mean()
        ))
    pprint(simple_results)
    pprint(model.info)
    Path('scripts/model_evals/results').mkdir(exist_ok=True)
    pd.DataFrame.from_records(results).to_csv(f'scripts/model_evals/results/{config.model.name}.csv', index=False)

if __name__ == '__main__':
    main()