# python run_tabfact_pos.py --use_subset True --load_dataset True > A_DEBUG_TABFACT.txt

import fire
import os
import pandas as pd
import random
# Set the random seed for reproducibility
random.seed(42)

from utils.load_data import *
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from operations import *

#### FREDDY
import openai 
from azure.identity import AzureCliCredential

import os
import shutil
from datetime import datetime

# Azure OpenAI Credentials
credential = AzureCliCredential()
openai_token = credential.get_token("https://cognitiveservices.azure.com/.default")
openai.api_key = openai_token.token

openai.api_type = "azure_ad" # required
openai.api_version = "2024-02-15-preview" # to work till: 2024/04/02: "2023-05-15"

targetted_indices = [i for i in range(0,2024)]

print('Samples tested:', targetted_indices)

POS_DEBUG = False
if POS_DEBUG:
    pos_wrongs = ['test-51', 'test-1652', 'test-255', 'test-161', 'test-1130', 'test-189', 'test-704', 'test-447', 'test-407', 'test-1466', 'test-1330', 'test-1436', 'test-1751', 'test-1774', 'test-919', 'test-1988', 'test-1563', 'test-1409', 'test-1402', 'test-1573', 'test-1300', 'test-1794', 'test-1342', 'test-2006', 'test-1197', 'test-877', 'test-1043', 'test-334', 'test-234', 'test-1812', 'test-1099', 'test-788', 'test-781', 'test-1083', 'test-1133', 'test-1979', 'test-2022', 'test-601', 'test-6', 'test-611', 'test-1915', 'test-1561', 'test-330', 'test-1793', 'test-1560', 'test-217', 'test-1782', 'test-1280', 'test-1963', 'test-629', 'test-1086', 'test-1000', 'test-1557', 'test-1090', 'test-1568', 'test-262', 'test-411', 'test-1460', 'test-1859', 'test-1995', 'test-897', 'test-1205', 'test-1991', 'test-866', 'test-1949', 'test-924', 'test-247', 'test-460', 'test-43', 'test-570', 'test-1756', 'test-1481', 'test-968', 'test-1323', 'test-124', 'test-824', 'test-193', 'test-882', 'test-287', 'test-1719', 'test-4', 'test-119', 'test-1426', 'test-1496', 'test-996', 'test-607', 'test-832', 'test-1196', 'test-322', 'test-1218', 'test-139', 'test-826', 'test-643', 'test-19', 'test-1339', 'test-167', 'test-1860', 'test-1157', 'test-647', 'test-436', 'test-1588', 'test-714', 'test-1035', 'test-431']
    pos_wrong = [int(x.split('-')[1]) for x in pos_wrongs]

def main(
        dataset_path: str = "data/tabfact/test.jsonl",
        raw2clean_path: str = "data/tabfact/raw2clean.jsonl",
        # model_name: str = "gpt-3.5-turbo-16k-0613",
        model: str = LLM,
        result_dir: str = "results/tabfact",
        first_n: int = 2024,  # Can specify a subset or use None for all data
        use_subset: bool = False,  # Determines whether to use a subset of samples
        subset_indices: list = targetted_indices,  # Indices of the samples to use if use_subset is True; for select_row
        n_proc: int = 10, # 1, 
        chunk_size: int = 10, # 1, 
        load_dataset: bool = False,
):

    if model == 'GPT4-o':
        n_proc = 1
        chunk_size = 1
        use_subset = True
        model_name = "gpt-4o"
        openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

    elif model == 'GPT4':
        n_proc = 1
        chunk_size = 1
        model_name = "gpt-4-turbo"
        openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

    else:
        model_name = "gpt-3.5-turbo-0613"
        openai.api_base = "https://llmopenai.org.net/WS0001037P-exp" #required #alternative https://llm-test-cib-research.openai.azure.com/

    print(subset_indices)

    print(model_name)
    if K_plans > 1:
        n_proc = 3
        chunk_size = 3
        print(n_proc, chunk_size)

    gpt_llm = ChatGPT(
        model_name=model_name,
        key=openai.api_key,
    )
    
    # Load processed dataset if needed
    if load_dataset is True:
        print('Loading preprocessed dataset...')
        # dataset = load_dataset_from_pkl("processed_dataset.pkl") # 200 first samples in TabFact
        dataset_raw = load_tabfact_dataset(dataset_path, raw2clean_path, first_n=-1)
        
        reformatted_dataset_raw = {}
        for raw_sample in dataset_raw:
            id = raw_sample['id']
            id = int(id.split('-')[1])
            reformatted_dataset_raw[id] = raw_sample

        dataset = load_dataset_from_pkl('data/tabfact/gpt-3.5-turbo-0613_TabFact_processed.pkl')

        final_dataset = []

        # For these specific tables, GPT3.5 cannot preprocess the date correctly, we manually do that
        for sample in dataset:
            id = sample['id']
            id = int(id.split('-')[1])
            if id > 330 and id < 341:
                final_dataset.append(reformatted_dataset_raw[id])
            else:
                final_dataset.append(sample)

        if POS_DEBUG:
            final_dataset = []
            for sample in dataset:
                id = sample['id']
                id = int(id.split('-')[1])
                if id in pos_wrong:
                    final_dataset.append(sample)
            
            dataset = final_dataset  
            dataset = dataset[:20]

        dataset = final_dataset    
    else:
        print('Standardizing dataset...')

        # Load dataset
        dataset_raw = load_tabfact_dataset(dataset_path, raw2clean_path, first_n=-1)
        indices = list(range(len(dataset_raw)))

        # process the whole tabfact
        subset_size = len(dataset_raw)

        dataset = dataset_raw[:subset_size]

        dataset = standardize_dates(dataset, llm=gpt_llm)
        # This standardized dataset has been generated using gpt3.5
        save_dataset_to_pkl(dataset, f"{model_name}_TabFact_date_processed.pkl")

        return


    dataset = dataset if not use_subset else [dataset[i] for i in subset_indices]
    dataset = [preprocess_entry(entry) for entry in dataset]
    
    print('Model name:', model_name)
    print('The number of samples being tested:', len(dataset))

    os.makedirs(result_dir, exist_ok=True)
    
    if NATURAL_LANGUAGE_PLANNING is True:
        proc_samples, _ = dynamic_chain_exec_with_cache_mp(
                dataset,
                llm=gpt_llm,
                llm_options=gpt_llm.get_model_options(
                    temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0
                ),
                strategy="top",
                cache_dir=os.path.join(result_dir, "cache"),
                n_proc=n_proc,
                chunk_size=chunk_size,
            )
        print('Final computation:\n')

        for idx, sample in proc_samples.items():
            if sample['answer'] != sample['groundtruth'] and sample['is_sql_executable'] is True:
                print(idx)
        
        accuracy = tabfact_compute_accuracy(proc_samples)
        
    elif OTG_PLANNING is True: 
        if USING_SQL_FOR_FINAL_QUERY is True:
            proc_samples, dynamic_chain_log_list = dynamic_chain_exec_with_cache_mp(
                dataset,
                llm=gpt_llm,
                llm_options=gpt_llm.get_model_options(
                    temperature=0.0, per_example_max_decode_steps=200, per_example_top_p=1.0
                ),
                strategy="top",
                cache_dir=os.path.join(result_dir, "cache"),
                n_proc=n_proc,
                chunk_size=chunk_size,
            )
            
            fixed_chain = [
                (
                    "simpleQuery_SQL_fewshot",
                    simple_query_sql,
                    dict(use_demo=True),
                    dict(
                        temperature=0, per_example_max_decode_steps=200, per_example_top_p=1.0
                    ),
                ),
            ]

            final_result, history = fixed_chain_exec_mp(gpt_llm, proc_samples, fixed_chain)


            cot_cnt = 0
            for i in range (len(final_result)):
                if final_result[i]["chain"][-1]['Final_query_SQL_executable'] is False or final_result[i]['is_sql_executable'] is False:
                    cot_cnt += 1
                    final_result[i] = cot_result[i]
        else:
            final_result = cot_result
            history = None
        
        print(f'\nNumber of samples where the (record 2) final query is processed by SQL: {len(dataset) - cot_cnt} queries')

        acc = tabfact_match_func_for_samples(final_result)


        SQL_done = 0
        for i in range (len(final_result)):
            print(f'Visualizing chain of SQLs for {i}-th sample')
            if final_result[i]['is_sql_executable'] is True and ('Final_query_SQL_executable' in final_result[i]['chain'][-1] and final_result[i]['chain'][-1]['Final_query_SQL_executable'] is True):
                SQL_done += 1
            process_final_result(final_result[i])

        print(f'\nNumber of samples that are entirely (record 3) processed by SQL: {SQL_done} queries')

        pickle.dump(final_result, open(os.path.join(result_dir, "final_result.pkl"), "wb"))
        pickle.dump(history, open(os.path.join(result_dir, "cotable_log.pkl"), "wb"))

        pickle.dump(dynamic_chain_log_list, open(os.path.join(result_dir, "dynamic_chain_log_list.pkl"), "wb"))


if __name__ == "__main__":
    fire.Fire(main)
