import wandb
import fire
import os
import pandas as pd

from utils.load_data import load_tabfact_dataset, standardize_dates
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from operations import *
from utils.prompts import *

#### FREDDY
import openai 
from azure.identity import AzureCliCredential

import os
import shutil
from datetime import datetime

# Azure OpenAI Credentials
credential = AzureCliCredential()
openai_token = credential.get_token("https://cognitiveservices.azure.com/.default")
openai.api_key = openai_token.token
############################################# Uncomment for using GPT3.5
openai.api_base = "https://llmopenai.org.net/WS0001037P-exp" #required #alternative https://llm-test-cib-research.openai.azure.com/
# openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"
#############################################
openai.api_type = "azure_ad" # required
openai.api_version = "2024-02-15-preview" # to work till: 2024/04/02: "2023-05-15"

model = LLM

if model == 'GPT4-O':
    n_proc = 1
    chunk_size = 1
    use_subset = True
    model_name = "gpt-4o"
    openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

elif model == 'GPT4':
    n_proc = 1
    chunk_size = 1
    model_name = "gpt-4-turbo"
    openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

else:
    model_name = "gpt-3.5-turbo-0613"
    openai.api_base = "https://llmopenai.org.net/WS0001037P-exp" #required #alternative https://llm-test-cib-research.openai.azure.com/


#######################################################################################################################################


def check_prediction_acceptance(file_name, llm, cot_folder="COT", pos_folder="POS", llm_options=None):
    # Read the content of the HTML files from COT and POS methods
    with open(os.path.join(cot_folder, file_name), 'r') as file:
        cot_content = file.read()
    
    with open(os.path.join(pos_folder, file_name), 'r') as file:
        pos_content = file.read()
    
    # Prepare the prompt for COT explanation
    prompt_cot = f"""
You are given an explanation that describes the reasoning process of a Table Question Answering model in HTML format. This explanation is based on the 'COT' method. Please carefully analyze the explanation and determine whether you accept the prediction made by the Table QA model based on this explanation solely, not the input Table.

COT Explanation:
{cot_content}

Do you accept the prediction based on this explanation? Answer with 'Accept' or 'Reject':
"""
    
    # Prepare the prompt for POS explanation
    prompt_pos = f"""
You are given an explanation that describes the reasoning process of a Table Question Answering model in HTML format. This explanation is based on the 'POS' method. Please carefully analyze the explanation and determine whether you accept the prediction made by the Table QA model based on this explanation solely, not the input Table.

POS Explanation:
{pos_content}

Do you accept the prediction based on this explanation? Answer with 'Accept' or 'Reject':
"""
    
    # Set up LLM options if not provided
    if llm_options is None:
        llm_options = llm.get_model_options()

    # Generate a response for COT explanation
    try:
        response_cot = llm.generate_plus_with_score(prompt_cot, options=llm_options)
        response_pos = llm.generate_plus_with_score(prompt_pos, options=llm_options)

        if response_cot == 'Exceed context length' or response_pos == 'Exceed context length':
            return None

        acceptance_cot = response_cot[0][0].strip()
        acceptance_pos = response_pos[0][0].strip()
        return [acceptance_cot, acceptance_pos]

    except Exception as e:
        print(f"Error generating response from the model: {e}")
        return None


def check_all_predictions(cot_folder="/opt/service/work/instance1/jupyter/tabular_grounding_llms/plan-of-sqls/visualization/common_COT", pos_folder="/opt/service/work/instance1/jupyter/tabular_grounding_llms/plan-of-sqls/visualization/common_POS", llm=None, llm_options=None):
    cot_acceptance_count = 0
    pos_acceptance_count = 0
    total_comparisons = 0

    # Get the list of files in the COT folder
    for subfolder in ['TP', 'TN', 'FP', 'FN']:
        cot_subfolder = os.path.join(cot_folder, subfolder)
        pos_subfolder = os.path.join(pos_folder, subfolder)

        cot_files = set(os.listdir(cot_subfolder))
        pos_files = set(os.listdir(pos_subfolder))

        # Find common files between COT and POS
        common_files = cot_files.intersection(pos_files)
        
        print(f'Processing {subfolder} folder...')

        for idx, file_name in enumerate(common_files):
            if idx % 5 == 0:
                print(idx)

            # Check if the model accepts the predictions based on the explanations
            acceptance = check_prediction_acceptance(file_name, llm, cot_folder=cot_subfolder, pos_folder=pos_subfolder, llm_options=llm_options)

            if acceptance:
                # Increment the acceptance counts based on the model's response
                if acceptance[0] == 'Accept':
                    cot_acceptance_count += 1
                if acceptance[1] == 'Accept':
                    pos_acceptance_count += 1

                total_comparisons += 1

            if idx%5 == 0:
                print('cot_acceptance_count:', cot_acceptance_count)
                print('pos_acceptance_count:', pos_acceptance_count)
                print('total_comparisons:', total_comparisons)

    # Compute the acceptance rates
    cot_acceptance_rate = 100 * cot_acceptance_count / total_comparisons if total_comparisons > 0 else 0
    pos_acceptance_rate = 100 * pos_acceptance_count / total_comparisons if total_comparisons > 0 else 0

    return {
        'COT Acceptance Rate': cot_acceptance_rate,
        'POS Acceptance Rate': pos_acceptance_rate,
        'Total Comparisons': total_comparisons
    }

# Example usage:
gpt_llm = ChatGPT(
    model_name=model_name,
    key=openai.api_key,
)

print(model_name)


acceptance_results = check_all_predictions(llm=gpt_llm)
print(acceptance_results)
