"""
Count tokens using huggingface API
"""

from tqdm import tqdm
from transformers import GPT2Tokenizer

from data import CausalDataset, MoralDataset, Example, JudgmentDatasetSchema
from prompt import CausalJudgmentPrompt, MoralJudgmentPrompt, CausalFactorPrompt, MoralFactorPrompt, CausalAbstractJudgmentPromopt, MoralAbstractJudgmentPrompt
from thought_as_text_translator import CausalTranslator, MoralTranslator

def count_token_for_exp1():
    cd = CausalDataset()
    md = MoralDataset()

    prompt = CausalJudgmentPrompt()
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    total = 0
    response_cnt = 0
    for ex in tqdm(cd):
        q, answer = prompt.apply(ex)
        cnt = len(tokenizer(q)['input_ids'])
        total += cnt
        response_cnt += 3

    prompt = MoralJudgmentPrompt()
    for ex in tqdm(md):
        q, answer = prompt.apply(ex)
        cnt = len(tokenizer(q )['input_ids'])
        total += cnt
        response_cnt += 3

    print(f"total token counts: {total}")  # 41507 tokens
    print(f"total response counts: {response_cnt}")  # 618


def count_tokens_for_exp3():
    cd = CausalDataset()
    md = MoralDataset()

    cfp = CausalFactorPrompt(anno_utils=cd.anno_utils)
    cfp.load_prompts("./prompts/exp3_prompts/causal/")

    mfp = MoralFactorPrompt(anno_utils=md.anno_utils)
    mfp.load_prompts('./prompts/exp3_prompts/moral/')

    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    total = 0
    response_cnt = 0
    for ex in tqdm(cd):
        formatted_prompts = cfp.apply(ex)
        for fp in formatted_prompts:
            cnt = len(tokenizer(fp.prompt)['input_ids'])
            total += cnt
            response_cnt += 3

    for ex in tqdm(md):
        formatted_prompts = mfp.apply(ex)
        for fp in formatted_prompts:
            cnt = len(tokenizer(fp.prompt)['input_ids'])
            total += cnt
            response_cnt += 3

    print(f"total token counts: {total}")  # 102642 tokens
    print(f"total response counts: {response_cnt}")  # 2097

def count_tokens_for_exp2():
    cd = CausalDataset()
    md = MoralDataset()

    prompt = CausalAbstractJudgmentPromopt()
    tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

    total = 0
    response_cnt = 0
    for ex in tqdm(cd):
        if len(ex.annotated_sentences) == 0:
            continue

        assembled, question, answer = CausalTranslator.translate_example(ex)

        q = prompt.apply2(assembled, question, answer).prompt
        cnt = len(tokenizer(q)['input_ids'])
        total += cnt
        response_cnt += 3

    prompt = MoralAbstractJudgmentPrompt()
    for ex in tqdm(md):
        if len(ex.annotated_sentences) == 0:
            continue

        assembled, question, answer = MoralTranslator.translate_example(ex)
        q = prompt.apply2(assembled, question, answer).prompt

        cnt = len(tokenizer(q)['input_ids'])
        total += cnt
        response_cnt += 3

    print(f"total token counts: {total}")  # 15370 tokens
    print(f"total response counts: {response_cnt}")  # 522

if __name__ == '__main__':
    pass
    count_token_for_exp1()
    count_tokens_for_exp2()
    count_tokens_for_exp3()