import os
import numpy as np
import pickle
import json
import glob
from tqdm import tqdm
import pandas as pd 

from langid.langid import LanguageIdentifier, model
from googletrans import Translator
import asyncio
import re 
import numpy as np
from langid.langid import LanguageIdentifier, model


lang_region = {'english': 'USA',
               'hindi': 'India',
               'japanese': 'Japan',
               'swahili': 'Kenya',
               'thai': 'Thailand',
               }

countries = ['USA', 'India', 'Japan', 'Kenya', 'Thailand']

identifier = LanguageIdentifier.from_modelstring(model, norm_probs=True)
identifier.set_languages(['en', 'hi', 'zh', 'ja', 'sw', 'th'])

lang_code = {'english': 'en',
             'hindi': 'hi',
             'japanese': 'ja',
             'swahili': 'sw',
             'thai': 'th',
             }


async def translate_text(text, dest='ja'):
    async with Translator() as translator:
        result = await translator.translate(text, dest)
    return result

async def detect_lang(text):
    async with Translator() as translator:
        result = await translator.detect(text)
        return result

def remove_tokens(text):
    # remove all <*>
    if '<answer>' in text:
        text = text.split('<answer>')[0]
    text = re.sub(r'<[^>]+>', '', text)

    return text

def get_answer(answer, answer_list, verbose=False):
    if type(answer_list) is str:
        answer_list = eval(answer_list)
    complete_answer_list = []
    for ans in answer_list:
        if type(ans) is list:
            if len(ans) == 1:
                ans = ans[0]
            else:
                breakpoint()
        ans = ans.lower().strip()
        if '(' in ans and ')' in ans:
            complete_answer_list.append(ans.split('(')[0].strip().lower())
            # add the part inside ()
            complete_answer_list.append(ans.split('(')[1].split(')')[0].strip().lower())
        complete_answer_list.append(ans)
    answer_list = complete_answer_list
    
    answer = answer.lower()
    if answer in answer_list:
        return 1.0

    if '(' in answer and ')' in answer:
        # check if the part before ( is in answer list
        _answer1 = answer.split('(')[0].strip()
        _answer2 = answer.split('(')[1].split(')')[0].strip()
        if _answer1 in answer_list or _answer2 in answer_list:
            return 1.0
        else:
            if verbose:
                print(answer, answer_list)

            # if any words in answer?
            for ans in answer_list:
                if ans in answer:
#                    print(ans, answer)
                    return 1
            return 0.0
    if verbose:
        print(answer, answer_list)

    for ans in answer_list:
        if ans in answer:
#            print(ans, answer)
            return 1
    return 0.0

def evaluate(data, translate=False, force=False):
    reasoning = []
    answer = []
    lang = []
    if translate:
        prefix = 'translated_'
    else:
        prefix = ''

    if (prefix+'joint_reasoning') in data and not force: # Already Evaluated
        return data
    
    pred_language = []
    language_match = []
    lang_confs = []
    answer_correctness = []
    for datum in data.to_dict(orient='records'):
        llm_generation = datum[prefix + 'llm_generation']
        # measure language accuracy
        p_lang, p_conf = identifier.classify(remove_tokens(llm_generation))
        pred_language.append(p_lang)
        lang_confs.append(p_conf)
        lang_score = float(p_lang == lang_code[datum['language'].lower()])
        language_match.append(lang_score)

        answer_list = datum['answer_list']
        score = get_answer(datum[prefix + 'extracted_answer'], answer_list, verbose=False)
        answer_correctness.append(score)
    data[prefix+'pred_language'] = pred_language
    data[prefix+'language_match'] = language_match
    data[prefix+'lang_conf'] = lang_confs
    data[prefix+'answer_correctness'] = answer_correctness
    reasoning = []
    for eval in data[prefix+'evaluation'].values:
        reasoning.append(float(eval['reasoning_score']))
    data[prefix+'reasoning_score'] = reasoning
    # we use joint reasoning score as the final metric.
    data[prefix+'joint_reasoning'] = data[prefix+'reasoning_score'] * data[prefix+'lang_conf']

    return data



def summarize(data):
    model_name = data['model_name'][0]
    if 'evaluation' in data: # Already Evaluate
        reasoning = []
        answer = []
        lang = []
        for eval in data['evaluation'].values:
            reasoning.append(float(eval['reasoning_score']))
            if 'answer_correctness' in eval:
                answer.append(float(eval['answer_correctness']))
                lang.append(float(eval['language_mismatch']))
        data['reasoning_score'] = reasoning
        if len(answer) > 0:
            data['llm_answer_correctness'] = answer
            data['llm_language_mismatch'] = lang
    # drop unnecessary columns
    data = data.drop(columns=['question', 'answer_list', 'reasoning', 
       'topic', 'llm_generation', 'evaluation', 'answer_type', 'model_name', 'extracted_reasoning', 'pred_language','extracted_answer', 'lang_conf'])

    associate = (data['region'] == data['language'].map(lang_region)).astype(float)
    data['associate'] = associate
    data['language_match'] = data['language_match'].astype(float)
    
    data = data.drop(columns=['region', 'language'])

    data['language_match'] = 100 * data['language_match']
    data['answer_correctness'] = 100 * data['answer_correctness']
    data['reasoning_score'] = 10 * data['reasoning_score']
    if len(answer) > 0:
        data['llm_answer_correctness'] = 100 * data['llm_answer_correctness']
        data['llm_language_mismatch'] = 100 * data['llm_language_mismatch']
    
    if 'translated_llm_generation' in data:
        data = data.drop(columns=['translated_llm_generation', 'translated_extracted_answer'])
        if 'translated_pred_language' in data:
            data = data.drop(columns=['translated_pred_language', 'translated_lang_conf'])
        if 'translated_evaluation' in data:
            data = data.drop(columns=['translated_evaluation'])
            data['translated_language_match'] = 100 * data['translated_language_match']
            data['translated_answer_correctness'] = 100 * data['translated_answer_correctness']
            data['translated_reasoning_score'] = 10 * data['translated_reasoning_score']
       
    # get stats for all, by region, by language_match
    all_mean = data.mean()
    associate_mean = data[data['associate'] == 1].mean()
    nonassociate_mean = data[data['associate'] == 0].mean()
    all_mean = pd.DataFrame(all_mean).T.drop(columns='associate')
    associate_mean = pd.DataFrame(associate_mean).T.drop(columns='associate')
    nonassociate_mean = pd.DataFrame(nonassociate_mean).T.drop(columns='associate')
    all_mean = all_mean.rename(columns=lambda c: f"{c}/all")
    associate_mean = associate_mean.rename(columns=lambda c: f"{c}/assoc")
    nonassociate_mean = nonassociate_mean.rename(columns=lambda c: f"{c}/nonassoc")
    datum = pd.concat([all_mean, associate_mean, nonassociate_mean], axis=1)

    datum['_model_name'] = model_name
    return datum.iloc[0].to_dict()





if __name__ == '__main__':

    root1 = 'geofact_verified/results'
    summary = []
    trans_summary = []
    force = False
    for root in [root1, root2]:
        paths = glob.glob(f'{root}/*.jsonl')
        for path in paths:
            print(path)
            path = path.split('/')[-1].split('.jsonl')[0]
            data = pd.read_json(os.path.join(root, path + '.jsonl'), lines=True)
            data = evaluate(data, force=force)
            if 'translated_evaluation' in data:
                data = evaluate(data, translate=True)
            _sum = summarize(data)
            data.to_json(os.path.join(root, path + '.jsonl'), lines=True, orient='records')
            summary.append(_sum)
    summary = pd.DataFrame(summary)
    summary = summary[sorted(summary.columns, key=str.lower)]
    summary.to_csv('geofact_summary.csv', index=False)

