# python run_wikitq_pos.py --load_dataset True --use_subset True > A_DEBUG_WIKITQ.txt

import fire
import os
import pandas as pd
import random
# Set the random seed for reproducibility
random.seed(42)
import json
import collections
import numpy as np

from utils.load_data import *
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from utils.wikitq_eval import *
from operations import *

#### FREDDY
import openai 
from azure.identity import AzureCliCredential

import os
import shutil
from datetime import datetime

# Azure OpenAI Credentials
credential = AzureCliCredential()
openai_token = credential.get_token("https://cognitiveservices.azure.com/.default")
openai.api_key = openai_token.token

openai.api_type = "azure_ad" # required
openai.api_version = "2024-02-15-preview" # to work till: 2024/04/02: "2023-05-15"

def check_denotation(target_values, predicted_values):
    """Return True if the predicted denotation is correct.

    Args:
        target_values (list[Value])
        predicted_values (list[Value])
    Returns:
        bool
    """
    # Check size
    if len(target_values) != len(predicted_values):
        return False
    # Check items
    for target in target_values:
        if not any(target.match(pred) for pred in predicted_values):
            return False
    return True

#########################################################
targetted_indices = [i for i in range(0, 10)]
#########################################################

POS_DEBUG = False
if POS_DEBUG:
    pos_wrongs = ['test-51', 'test-1652', 'test-255', 'test-161', 'test-1130', 'test-189', 'test-704', 'test-447', 'test-407', 'test-1466', 'test-1330', 'test-1436', 'test-1751', 'test-1774', 'test-919', 'test-1988', 'test-1563', 'test-1409', 'test-1402', 'test-1573', 'test-1300', 'test-1794', 'test-1342', 'test-2006', 'test-1197', 'test-877', 'test-1043', 'test-334', 'test-234', 'test-1812', 'test-1099', 'test-788', 'test-781', 'test-1083', 'test-1133', 'test-1979', 'test-2022', 'test-601', 'test-6', 'test-611', 'test-1915', 'test-1561', 'test-330', 'test-1793', 'test-1560', 'test-217', 'test-1782', 'test-1280', 'test-1963', 'test-629', 'test-1086', 'test-1000', 'test-1557', 'test-1090', 'test-1568', 'test-262', 'test-411', 'test-1460', 'test-1859', 'test-1995', 'test-897', 'test-1205', 'test-1991', 'test-866', 'test-1949', 'test-924', 'test-247', 'test-460', 'test-43', 'test-570', 'test-1756', 'test-1481', 'test-968', 'test-1323', 'test-124', 'test-824', 'test-193', 'test-882', 'test-287', 'test-1719', 'test-4', 'test-119', 'test-1426', 'test-1496', 'test-996', 'test-607', 'test-832', 'test-1196', 'test-322', 'test-1218', 'test-139', 'test-826', 'test-643', 'test-19', 'test-1339', 'test-167', 'test-1860', 'test-1157', 'test-647', 'test-436', 'test-1588', 'test-714', 'test-1035', 'test-431']
    pos_wrong = [int(x.split('-')[1]) for x in pos_wrongs]

def main(
        dataset_path: str = "data/tabfact/test.jsonl",
        raw2clean_path: str = "data/tabfact/raw2clean.jsonl",
        # model_name: str = "gpt-3.5-turbo-16k-0613",
        model: str = LLM,
        result_dir: str = "results/wikitq",
        first_n: int = 4344,  # Can specify a subset or use None for all data
        use_subset: bool = False,  # Determines whether to use a subset of samples
        subset_indices: list = targetted_indices,  # Indices of the samples to use if use_subset is True; for select_row
        n_proc: int = 10, # 1, 
        chunk_size: int = 10, # 1, 
        load_dataset: bool = False,
):

    if model.upper() == 'GPT4-O' or model.upper() == 'GPT4O':
        n_proc = 1
        chunk_size = 1
        use_subset = True
        model_name = "gpt-4o"
        openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

    elif model.upper() == 'GPT-4' or model.upper() == 'GPT4':
        n_proc = 1
        chunk_size = 1
        model_name = "gpt-4-turbo"
        openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

    else:
        model_name = "gpt-3.5-turbo-0613"
        openai.api_base = "https://llmopenai.org.net/WS0001037P-exp" #required #alternative https://llm-test-cib-research.openai.azure.com/

    print(subset_indices)

    print(model_name)
    if K_plans > 1:
        n_proc = 3
        chunk_size = 3
        print(n_proc, chunk_size)

    gpt_llm = ChatGPT(
        model_name=model_name,
        key=openai.api_key,
    )
    
    # Load dataset
    
    dataset_raw = []

    with open('data/wikitq/test.jsonl') as f:
        lines = f.readlines()
        for idx, line in enumerate(lines):
            dic = json.loads(line)
            dic['id'] = idx
            dataset_raw.append(dic)
            
    dataset = dataset_raw[:first_n]
    dataset = [preprocess_entry(entry) for entry in dataset]
    dataset = dataset if not use_subset else [dataset[i] for i in subset_indices]

    print('Model name:', model_name)
    print('The number of samples being tested:', len(dataset))

    os.makedirs(result_dir, exist_ok=True)
    
    proc_samples, _ = wikitq_dynamic_chain_exec_with_cache_mp(
            dataset,
            llm=gpt_llm,
            llm_options=gpt_llm.get_model_options(
                temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0
            ),
            strategy="top",
            cache_dir=os.path.join(result_dir, "cache"),
            n_proc=n_proc,
            chunk_size=chunk_size,
        )

    print(proc_samples)
    
    ############# From DATER paper

    # ID string --> list[Value]
    target_values_map = {}
    tagged_dataset_path = 'data/wikitq/data'
    for filename in os.listdir(tagged_dataset_path):
        if filename[0]=='.':
            continue
        filename = os.path.join(tagged_dataset_path, filename)
        print('Reading dataset from', filename)
        with open(filename, 'r', 'utf8') as fin:
            header = fin.readline().rstrip('\n').split('\t')
            for line in fin:
                stuff = dict(zip(header, line.rstrip('\n').split('\t')))
                ex_id = stuff['id']
                original_strings = tsv_unescape_list(stuff['targetValue'])
                canon_strings = tsv_unescape_list(stuff['targetCanon'])

                target_values_map[ex_id] = to_value_list(
                    original_strings, canon_strings)

    st2id = {}

    with open(os.path.join('data/wikitq/test_lower.jsonl')) as f:
        lines = f.readlines()
        for l in lines:
            dic = json.loads(l)
            st = dic['statement']
            ids = dic['ids']
            st2id[st] = ids
    with open('data/wikitq/gloc_wtq_end2end_wikitq_test.json', 'r') as f:
        dic = json.load(f)

    deno_acc = 0
    execs = 0

    fall_back_crt = 0
    fb_count = 0

    pos_crt = 0
    pos_count = 0 
    ######### col filed################
    # sort the dict by key
    proc_samples = dict(sorted(proc_samples.items()))
    false_log_files = []


    # process samples to get denotation accuracy
    for sample_index, res in proc_samples.items():
        # breakpoint()
        res_table = res['input']['table_text']
        res_st = res['input']['statement']
        res_preds = res['answer']
        fall_back = res['fallback_LLM']

        if fall_back is True:
            fb_count += 1
        else:
            pos_count += 1

        if res['is_sql_executable'] is False:
            continue
        if len(res_preds) == 0:
            continue

        if len(res_preds[0]) == 0:
            continue
        execs += 1

        for key in dic:
            to_union = collections.defaultdict(float)
            it = dic[key]
            table = it['data_item']['table_text']
            st = it['data_item']['statement']
            

            if res_st in st:
                preds = []
                answs = []
                for ans in res_preds:
                    answs.append(str(ans[0]))
                answs = ' SEP_TOKEN '.join(answs)
                preds.append(answs)
                preds.append(0)
                preds.append(0)

                preds = [preds]
                break

        # preds = it['generations'] #### getting predictions
        
        for pred in preds:
            log_prob_mean = pred[2]
            pred = pred[0]

            # pred = pred.split('therefore,the answer is :')[-1]
            
            key = pred
            to_union[key] += np.exp(log_prob_mean)
        d_ordered = sorted(to_union.items(),key=lambda x:x[1],reverse=True)
        try:
            pred_answer = d_ordered[0][0].split('\n')[0].strip()
        except Exception:
            pred_answer = 'error'

        st = st.split('\n')[0]
        target_values = target_values_map[st2id[st]]

        # if sample_index == 48:
        #     breakpoint()

        pred_answer = pred_answer.split(' SEP_TOKEN ')
        pred_answer = to_value_list(pred_answer)

        flag = False
        # breakpoint()
        if check_denotation(target_values,pred_answer):
            deno_acc +=1
            flag = True

        if flag is True and fall_back is True:
            fall_back_crt += 1
        elif flag is True and fall_back is False:
            pos_crt += 1
        
        if flag is False:
            # if flag is False:
            if fall_back is True:
                # print(f'### FALLBACK - Sample {sample_index}: {flag}, {pred_answer}, {target_values}')
                pass
            else:
                print(f'### POS - Sample {sample_index}: {flag}, {pred_answer}, {target_values}')

        false_log_files.append(f'log_{sample_index}.txt')

    combine_files_from_directory(wikitq_planning_log_path, false_log_files)

    print(f'Executability: {execs}/{len(proc_samples)}')
    print('Executability Rate:', 100*execs/len(proc_samples))
    print('\n')

    print('Fall-back Rate:', 100*fb_count/len(proc_samples))
    print('\n')

    print('Fall-back Acc:', 100*fall_back_crt/fb_count)
    print('\n')

    print('PoS Rate:', 100*pos_count/len(proc_samples))
    print('\n')

    print('PoS Acc:', 100*pos_crt/pos_count)
    print('\n')

    print(f'Denotation: {deno_acc}/{len(proc_samples)}')
    print('Denotation Accuracy:', 100*deno_acc/len(proc_samples))

    
    #############
        
if __name__ == "__main__":
    fire.Fire(main)

