import re
import os
import glob
import math
import numpy as np
import pandas as pd
from chatgpt import *


###########################################
## function 
###########################################
def is_eng(ch):
    if u'a' <= ch <= u'z' or u'A' <= ch <= u'Z':
        return True
    else:
        return False

def is_chi(ch):
    if '\u4e00' <= ch <= '\u9fff':
        return True
    else:
        return False

def func_str_all_eng(str):
    for ch in str:
        if is_chi(ch):
            return False
    return True

def func_str_all_chi(str):
    for ch in str:
        if is_eng(ch):
            return False
    return True


# "['aa', 'bb', 'cc', 'dd']" => ['aa', 'bb', 'cc', 'dd']
def string_to_list(str):
    
    if isinstance(str, list):
        return str
    
    if str == '':
        str = []
    elif pd.isna(str):
        str = []
    else:
        if str[0] == '[':  str = str[1:]
        if str[-1] == ']': str = str[:-1]
        str = [item.strip() for item in re.split('[\'\",]', str) if item.strip() not in ['', ',']]
    return str


# "[['aa', 'bb'], ['cc', 'dd'], ['ee']]"  => [['aa', 'bb'], ['cc', 'dd'], ['ee']]
def listlist_to_list(str):
    results = []

    multi_lists = [item for item in re.split(r"[\[\]]", str) if item.strip() not in ['', ',']]
    for one_list in multi_lists:
        one_list = [item for item in re.split('[\'\"]', one_list) if item.strip() not in ['', ',']]
        results.append(one_list)
    
    return results


# map each label into its group ID
def func_map_label_to_synonym(mlist, mapping):
    new_mlist, errors = [], 0
    for label in mlist:
        if label in mapping:
            new_mlist.append(mapping[label])
        else:
            new_mlist.append(label)
            errors += 1
    if errors > 0:
        raise Exception('some labels does not in the mapping!')
    return new_mlist


# read *.npy -> map
def func_get_name2reason(reason_root):
    name2reason = {}
    for reason_npy in glob.glob(reason_root + '/*.npy'):
        name = os.path.basename(reason_npy)[:-4]
        reason = np.load(reason_npy).tolist()
        name2reason[name] = reason
    return name2reason


# judge whether the extraction process is complete
def func_judge_extract_complete(extract_root):
    print (f'test on {extract_root}')

    if os.path.exists(extract_root):
        names = os.listdir(extract_root)
        if len(names) == samplenum:
            for name in names:
                filepath = os.path.join(extract_root, name)
                val = np.load(filepath).tolist()
                if val != '':
                    continue
                else:
                    return False
        else:
            return False
    else:
        return False    
    return True


###########################################
## Main process 
###########################################
# translate openset labels into another language
def translate_openset(openset_root, gptmodel, source='chi', target='eng', version='hard'):

    name2openset = func_get_name2reason(openset_root)

    store_root = openset_root + '-translate'
    if not os.path.exists(store_root):
        os.makedirs(store_root)

    for name in name2openset:
        openset = name2openset[name]
        openset = string_to_list(openset)

        ## openset filter
        if source == 'chi' and target == 'eng':
            openset = [label for label in openset if func_str_all_chi(label)]
        elif source == 'eng' and target == 'chi':
            openset = [label for label in openset if func_str_all_eng(label)]
        label_num = len(openset)

        # main process
        for splitnum in [10, 1]:
            openset_translate = []
            for ii in range(math.ceil(label_num/splitnum)):
                openset_subset = openset[ii*splitnum:(ii+1)*splitnum]
                openset_subset = "~".join(openset_subset)
                all_eng_flag = func_str_all_eng(openset_subset)
                all_chi_flag = func_str_all_chi(openset_subset)
                assert all_eng_flag or all_chi_flag, f'two lang in {openset_subset}'

                if source == 'chi' and target == 'eng':
                    assert all_chi_flag
                    openset_subset_translate = get_translate_chi2eng(openset_subset, model=gptmodel)
                elif source == 'eng' and target == 'chi':
                    assert all_eng_flag
                    openset_subset_translate = get_translate_eng2chi(openset_subset, model=gptmodel)
                else:
                    raise Exception('no support source/target lang')
                openset_subset_translate = [item.strip() for item in re.split('[~,.]', openset_subset_translate) if item.strip() not in ['', ',']]
                openset_translate.extend(openset_subset_translate)
            if label_num == len(openset_translate):
                break
        
        try:
            assert label_num == len(openset_translate)
            openset_translate = [item.lower() for item in openset_translate]
            openset_translate = list(set(openset_translate))
            print (openset_translate)
            save_path = os.path.join(store_root, f'{name}.npy')
            np.save(save_path, openset_translate)
        except:
            if version == 'hard':
                raise Exception(f"Error translate on openset: {openset}!!")
            elif version == 'soft': 
                save_path = os.path.join(store_root, f'{name}.npy')
                np.save(save_path, openset)


# reason -> openset labels
def generate_openset_from_reason(reason_root, store_root, gptmodel):

    if not os.path.exists(store_root):
        os.makedirs(store_root)

    name2reason = func_get_name2reason(reason_root)
    lang = reason_root[-3:]
    assert lang in ['chi', 'eng']

    for name in name2reason:
        multi_reason = name2reason[name]
        response = get_reason_to_openset(multi_reason, model=gptmodel, lang=lang)
        save_path = os.path.join(store_root, f'{name}.npy')
        np.save(save_path, response)


# calculate synonym
def generate_openset_synonym(gt_root, openset_root, synonym_root, gptmodel):

    ## read gt openset
    name2gt = func_get_name2reason(gt_root)
    if len(name2gt) != samplenum:
        raise Exception("Incorrect openset gt folder!!")
    
    ## read pred openset
    name2pred = func_get_name2reason(openset_root)
    if len(name2pred) != samplenum:
        raise Exception("Incorrect number of predictions!!")
    
    ## define store root
    if not os.path.exists(synonym_root):
        os.makedirs(synonym_root)

    ## main process
    for name in name2gt:
        list1 = string_to_list(name2gt[name])
        list2 = string_to_list(name2pred[name])
        # all convert to lower
        list1 = [item.lower() for item in list1]
        list2 = [item.lower() for item in list2]
        # find synonym
        response = get_openset_synonym(list1, list2, model=gptmodel)
        save_path = os.path.join(synonym_root, f'{name}.npy')
        np.save(save_path, response)


# calculate accuracy and recall
def calculate_openset_overlap_rate(gt_root, openset_root, synonym_root):

    name2gt      = func_get_name2reason(gt_root)
    name2pred    = func_get_name2reason(openset_root)
    name2mapping = func_get_name2reason(synonym_root)
    if len(name2pred) != samplenum or len(name2mapping) != samplenum:
        raise Exception("Incorrect number!!")
    
    # main
    gt_error, pred_error = 0, 0
    accuracy, recall = [], []
    for name in name2mapping:

        # => synonym_map
        synonym_map = {}
        mapping = name2mapping[name]
        multi_lists = listlist_to_list(mapping)
        for one_list in multi_lists:
            for ii in range(len(one_list)):
                synonym_map[one_list[ii]] = one_list[0]
        
        # map into group ID
        gt = string_to_list(name2gt[name])
        gt = [item.lower() for item in gt]
        try:
            gt = set(func_map_label_to_synonym(gt, synonym_map))
        except:
            gt_error += 1
            continue

        pred = string_to_list(name2pred[name])
        pred = [item.lower() for item in pred]
        try:
            pred = set(func_map_label_to_synonym(pred, synonym_map))
        except:
            pred_error += 1
            continue 

        if len(pred) == 0:
            accuracy.append(0)
            recall.append(0)
        else:
            accuracy.append(len(gt & pred)/len(pred))
            recall.append(len(gt & pred)/len(gt))
    print ('gt_error:', gt_error, '  pred_error: ', pred_error)
    print ('process number (after filter): ', len(accuracy))
    return np.mean(accuracy), np.mean(recall)


if __name__ == '__main__':
    
    # root
    main_root = './'
    step_root = './manual_results_step10'
    samplenum = 332

    # gpt models
    gptmodel = 'gpt-3.5-turbo-16k-0613'

    # we take 'Chat-UniVi-main/output-reason-7b-chi' for example
    # => output-reason-7b-chi (S1); 
    # => output-reason-7b-nosubtitle-chi (S0); 
    # => output-reason-7b-nosubtitle-addsub-chi (S2)
    output_roots = [
        (os.path.join(main_root, 'Chat-UniVi-main/output-reason-7b-chi'),                      'Chat-UniVi(vt) \\cite{jin2023chat}  & $\\surd$   & $\\surd$  & $\\times$'),
        (os.path.join(main_root, 'Chat-UniVi-main/output-reason-7b-nosubtitle-chi'),           'Chat-UniVi(v) \\cite{jin2023chat}   & $\\times$  & $\\surd$  & $\\times$'),
        (os.path.join(main_root, 'Chat-UniVi-main/output-reason-7b-nosubtitle-addsub-chi'),    'Chat-UniVi(v) \\cite{jin2023chat}   & $\\surd$   & $\\surd$  & $\\times$'),
    ]
  

    ## descriptions -> open-set labels
    for chi_root, _ in output_roots:
            
            for run_time in [0, 1]:

                ## ------ process for chi openset extraction C -> c -> c_t ------
                reason_root = chi_root[:-4] + '-chi' 
                store_root = reason_root + f'-run{run_time}-openset'
                generate_openset_from_reason(reason_root, store_root, gptmodel)
                print (func_judge_extract_complete(store_root))
                    
                save_root = reason_root + f'-run{run_time}-openset-translate'
                translate_openset(store_root, gptmodel, source='chi', target='eng', version='soft')  # translate c -> c_t
                print (func_judge_extract_complete(save_root))

                ## ------ process for chi openset extraction E -> e ------
                reason_root = chi_root[:-4] + '-eng'
                store_root = reason_root + f'-run{run_time}-openset'
                generate_openset_from_reason(reason_root, store_root, gptmodel)
                print (func_judge_extract_complete(store_root))



    # set-level accuracy and recall calculation
    gt_root = os.path.join(step_root, 'check-result-openset-chi-translate') # checked label sets
    for chi_root, _ in output_roots:
        
            for run_time in [0, 1]:
                
                ## ------ process for chi openset extraction C -> c -> c_t ------
                reason_root = chi_root[:-4] + '-chi'
                print (reason_root)
                openset_root = reason_root + f'-run{run_time}-openset-translate'
                synonym_root = reason_root + f'-run{run_time}-openset-translate-synonym'
                generate_openset_synonym(gt_root, openset_root, synonym_root, gptmodel)
                print (func_judge_extract_complete(synonym_root))
                accuracy, recall = calculate_openset_overlap_rate(gt_root, openset_root, synonym_root)
                print (f'accuracy: {accuracy}; recall: {recall}')

                ## ------ process for eng openset extraction E -> e ------
                reason_root = chi_root[:-4] + '-eng' 
                print (reason_root)
                openset_root = reason_root + f'-run{run_time}-openset'
                synonym_root = reason_root + f'-run{run_time}-openset-synonym'
                generate_openset_synonym(gt_root, openset_root, synonym_root, gptmodel)
                print (func_judge_extract_complete(synonym_root))
                accuracy, recall = calculate_openset_overlap_rate(gt_root, openset_root, synonym_root)
                print (f'accuracy: {accuracy}; recall: {recall}')
