import json
import pprint
import ipdb
import os
import tiktoken

domains = [
    'translate',
    'qa',
    'chat',
    'summary',
    'harmlessness',
    'math_cot',
    'math_pot',
    'code_exec',
    'code_not_exec'
]

folders = [
    'sub_dev_data',
    'sub_test_data'
]

encoding = tiktoken.get_encoding("cl100k_base")

def compute_correction(data):
    in_tokens, out_tokens = [], []
    for sample in data:
        if domain == 'summary':
            question = sample['question'] + sample['article']
        else:
            question = sample['question']
        question += sample['feedback']
        response = sample['correction']
        in_token = len(encoding.encode(question))
        out_token = len(encoding.encode(response))
        in_tokens.append(in_token)
        out_tokens.append(out_token)
    return sum(in_tokens), sum(out_tokens)


def compute_feedback(data):
    in_tokens, out_tokens = [], []
    for sample in data:
        if domain == 'summary':
            question = sample['question'] + sample['article']
        else:
            question = sample['question']
        response = sample['generation']
        in_token = len(encoding.encode(question))
        out_token = len(encoding.encode(response))
        in_tokens.append(in_token)
        out_tokens.append(out_token)

        # add gpt-4 critique
        in_token += out_token
        out_token = len(encoding.encode(sample['feedback']))
        #in_tokens.append(in_token)
        #out_tokens.append(out_token)

    return sum(in_tokens), sum(out_tokens)


def compute_comp_feedback(data):
    in_tokens, out_tokens = [], []
    for sample in data:
        if domain == 'summary':
            question = sample['question'] + sample['article']
        else:
            question = sample['question']
        response_a = sample['sub']['generation_a']
        response_b = sample['sub']['generation_b']
        response_a_l, response_b_l = len(encoding.encode(response_a)), len(encoding.encode(response_b))
        in_token = len(encoding.encode(question))
        out_token = (response_a_l + response_b_l)/2
        #in_tokens.append(in_token)
        #out_tokens.append(out_token)

        # add the gpt-4 critiques
        in_token += response_a_l + response_b_l
        out_token = len(encoding.encode(sample['sub']['feedback']))
        #in_tokens.append(in_token)
        #out_tokens.append(out_token)

    return sum(in_tokens), sum(out_tokens)


rest = {}

for folder in folders:
    if folder not in rest:
        rest[folder] = {}
    for domain in domains:
        if domain not in rest[folder]:
            rest[folder][domain] = {'feedback': None, 'comparison': None, 'correction': None}

        # feedback and correction
        path = f'{folder}/{domain}_feedback_correction.json'
        with open(path) as f:
            data = json.load(f)
        in_sum_, out_sum_ = compute_feedback(data)
        rest[folder][domain]['feedback'] = {
            'in_token': in_sum_,
            'out_token': out_sum_
        }
        try:
            in_sum_c, out_sum_c = compute_correction(data)
            rest[folder][domain]['correction'] = {
                'in_token': in_sum_c,
                'out_token': out_sum_c
            }
        except:
            print(path)
            continue
        # comparsion 
        path = f'{folder}/{domain}_comp_feedback.json'
        with open(path) as f:
            data = json.load(f)
        in_sum_, out_sum_ = compute_comp_feedback(data)
        rest[folder][domain]['comp_feedback'] = {
            'in_token': in_sum_,
            'out_token': out_sum_
        }

price = {
    'in': 10/1000000,
    'out': 30/1000000
}

final_cost = 0
for folder in folders:
    overall_in_token, overall_out_token = 0, 0
    for domain in rest[folder]:
        for dimension in ['feedback', 'correction', 'comp_feedback']:
            try:
                overall_in_token += rest[folder][domain][dimension]['in_token']
                overall_out_token += rest[folder][domain][dimension]['out_token']
            except:
                continue
    print('folder:', overall_in_token * price['in'] + overall_out_token * price['out'])
    final_cost += overall_in_token * price['in'] + overall_out_token * price['out']
print(final_cost/9)
        
