import json
import wandb
import fire
import os
import pandas as pd

import sys
import os
import argparse

# Add the parent directory to the system path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '../../..')))

from utils.load_data import load_tabfact_dataset, standardize_dates
from utils.llm import ChatGPT
from utils.helper import *
from utils.evaluate import *
from utils.chain import *
from operations import *
from utils.prompts import *

#### FREDDY
import openai 
from azure.identity import AzureCliCredential

import os
import shutil
from datetime import datetime

# Azure OpenAI Credentials
credential = AzureCliCredential()
openai_token = credential.get_token("https://cognitiveservices.azure.com/.default")
openai.api_key = openai_token.token

openai.api_base = "https://llmopenai.org.net/WS0001037P-exp"
openai.api_type = "azure_ad"
openai.api_version = "2024-02-15-preview"

model = LLM

if model == 'GPT4-o':
    n_proc = 1
    chunk_size = 1
    use_subset = True
    model_name = "gpt-4o"
    openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

elif model == 'GPT4':
    n_proc = 1
    chunk_size = 1
    model_name = "gpt-4-turbo"
    openai.api_base = "https://llmopenai-02.org.net/WS0001037P-exp-use2/"

else:
    model_name = "gpt-3.5-turbo-0613"
    openai.api_base = "https://llmopenai.org.net/WS0001037P-exp"

#######################################################################################################################################

def load_processed_results(result_file_path):
    if os.path.exists(result_file_path):
        with open(result_file_path, 'r') as f:
            return json.load(f)
    return {}

def save_processed_results(result_file_path, results):
    with open(result_file_path, 'w') as f:
        json.dump(results, f, indent=4)

def check_prediction_acceptance(file_name, llm, method_a_folder, method_b_folder, llm_options=None, method_a_metadata=None, method_b_metadata=None, method_a_name="Method A", method_b_name="Method B"):
    # Read the content of the HTML files from method A and method B
    with open(os.path.join(method_a_folder, file_name), 'r') as file:
        method_a_content = file.read()
    
    with open(os.path.join(method_b_folder, file_name), 'r') as file:
        method_b_content = file.read()

    # Get metadata for the current file
    method_a_key = f'{method_a_name}_{file_name}'
    method_a_data = method_a_metadata.get(method_a_key)
    method_b_key = f'{method_b_name}_{file_name}'
    method_b_data = method_b_metadata.get(method_b_key)

    # Prepare the prompt for Method A explanation
    prompt_a = f"""
The Table Question Answering model is working on Table Fact Verification task (TabFact dataset), verifying if a given Statement is TRUE or FALSE on a given input Table.
You are given an explanation that describes the reasoning process of the Table Question Answering model in HTML format. This explanation is based on the '{method_a_name}' method. Please carefully analyze the explanation and determine whether the prediction is correct or wrong.

{method_a_name} Explanation:
{method_a_content}

Is the prediction correct? Answer with 'Yes' or 'No':
"""
    
    # Prepare the prompt for Method B explanation
    prompt_b = f"""
The Table Question Answering model is working on Table Fact Verification task (TabFact dataset), verifying if a given Statement is TRUE or FALSE on a given input Table.
You are given an explanation that describes the reasoning process of the Table Question Answering model in HTML format. This explanation is based on the '{method_b_name}' method. Please carefully analyze the explanation and determine whether the prediction is correct or wrong.

{method_b_name} Explanation:
{method_b_content}

Is the prediction correct? Answer with 'Yes' or 'No':
"""
    
    # Set up LLM options if not provided
    if llm_options is None:
        llm_options = llm.get_model_options()


    # Generate a response for Method A and Method B explanations
    try:
        response_a = llm.generate_plus_with_score(prompt_a, options=llm_options)
        response_b = llm.generate_plus_with_score(prompt_b, options=llm_options)

        if response_a == 'Exceed context length' or response_b == 'Exceed context length':
            return None, None, None, None, None

        acceptance_a = response_a[0][0].strip()
        acceptance_b = response_b[0][0].strip()

        # Check if the prediction acceptance matches the correctness of the actual prediction
        hit_a = (method_a_data['answer'].upper() == method_a_data['prediction'].upper() and acceptance_a == 'Yes') or \
                (method_a_data['answer'].upper() != method_a_data['prediction'].upper() and acceptance_a == 'No')
        
        hit_b = (method_b_data['answer'].upper() == method_b_data['prediction'].upper() and acceptance_b == 'Yes') or \
                (method_b_data['answer'].upper() != method_b_data['prediction'].upper() and acceptance_b == 'No')

        return [hit_a, hit_b], acceptance_a, acceptance_b, method_a_data, method_b_data

    except Exception as e:
        print(f"Error generating response from the model: {e}")
        return None, None, None, None, None

import os
import json
import random

def check_all_predictions(method_a_folder, method_b_folder, result_file_path="llm-as-a-judge_gpt4.json", llm=None, llm_options=None, method_a_name="Method A", method_b_name="Method B"):
    method_a_hits = 0
    method_b_hits = 0
    total_comparisons = 0

    # Set random seed for reproducibility
    random.seed(42)

    # Load already processed results to resume processing
    results = load_processed_results(result_file_path)

    if method_b_name not in results:
        results[method_b_name] = {}
    if method_a_name not in results:
        results[method_a_name] = {}

    # Load metadata for Method A and Method B
    with open(os.path.join(f'{project_directory}/plan-of-sqls/visualization', f"Tabular_LLMs_human_study_vis_6_{method_a_name}.json"), 'r') as file:
        method_a_metadata = json.load(file)

    with open(os.path.join(f'{project_directory}/plan-of-sqls/visualization', f"Tabular_LLMs_human_study_vis_6_{method_b_name}.json"), 'r') as file:
        method_b_metadata = json.load(file)

    all_common_files = []

    # Collect files from all subfolders
    for subfolder in ['TN', 'FP', 'FN', 'TP']:
    
        method_a_subfolder = os.path.join(f'{project_directory}/plan-of-sqls/xai_study/llm-judge/scripts', method_a_name, subfolder)
        method_b_subfolder = os.path.join(f'{project_directory}/plan-of-sqls/xai_study/llm-judge/scripts', method_b_name, subfolder)

        method_a_files = set(os.listdir(method_a_subfolder))
        method_b_files = set(os.listdir(method_b_subfolder))

        # Find common files between Method A and Method B
        common_files = method_a_files.intersection(method_b_files)

        # Store the full path to the file along with its subfolder for later processing
        all_common_files.extend([(subfolder, file_name) for file_name in common_files])

    # Shuffle the list to mix files from different subfolders
    random.shuffle(all_common_files)
    all_common_files = all_common_files[:500]

    # Process the mixed list of files
    for idx, (subfolder, file_name) in enumerate(all_common_files):
        method_a_subfolder = os.path.join(f'{project_directory}/plan-of-sqls/xai_study/llm-judge/scripts', method_a_name, subfolder)
        method_b_subfolder = os.path.join(f'{project_directory}/plan-of-sqls/xai_study/llm-judge/scripts', method_b_name, subfolder)

        # Check if the model hits based on the explanations
        hit, acceptance_a, acceptance_b, method_a_data, method_b_data = check_prediction_acceptance(
            file_name, llm, method_a_folder=method_a_subfolder, method_b_folder=method_b_subfolder, 
            llm_options=llm_options, method_a_metadata=method_a_metadata, method_b_metadata=method_b_metadata,
            method_a_name=method_a_name, method_b_name=method_b_name
        )

        if subfolder not in results[method_b_name]:
            results[method_b_name][subfolder] = {}
        results[method_b_name][subfolder][file_name] = {
            'LLM_act': acceptance_b,
            'ref': method_b_data
        }

        if subfolder not in results[method_a_name]:
            results[method_a_name][subfolder] = {}
        results[method_a_name][subfolder][file_name] = {
            'LLM_act': acceptance_a,
            'ref': method_a_data
        }

        if hit:
            # Increment the hit counts based on the model's response
            if hit[0]:
                method_a_hits += 1
            if hit[1]:
                method_b_hits += 1

            total_comparisons += 1

        # Save results after processing each file
        save_processed_results(result_file_path, results)

        # if idx % 5 == 0:
    print(f'{method_a_name} Hits:', method_a_hits)
    print(f'{method_b_name} Hits:', method_b_hits)
    print('Total Comparisons:', total_comparisons)
    print('---------------------------------------')

    # Compute the accuracy rates
    method_a_accuracy = 100 * method_a_hits / total_comparisons if total_comparisons > 0 else 0
    method_b_accuracy = 100 * method_b_hits / total_comparisons if total_comparisons > 0 else 0

    return {
        f'{method_a_name} Accuracy': method_a_accuracy,
        f'{method_b_name} Accuracy': method_b_accuracy,
        'Total Comparisons': total_comparisons
    }


def main():
    parser = argparse.ArgumentParser(description='Compare predictions acceptance between two methods using an LLM.')
    parser.add_argument('--method_a', type=str, required=True, help='The first explanation method folder (e.g., COT).')
    parser.add_argument('--method_b', type=str, required=True, help='The second explanation method folder (e.g., POS).')

    args = parser.parse_args()

    print(f'Comparing {args.method_a} vs. {args.method_b}...')

    gpt_llm = ChatGPT(
        model_name=model_name,
        key=openai.api_key,
    )

    results = check_all_predictions(
        method_a_folder=args.method_a,
        method_b_folder=args.method_b,
        result_file_path=f"../results/llm-as-a-judge_decision-making-{LLM}_{args.method_a}_{args.method_b}.json",
        llm=gpt_llm,
        method_a_name=args.method_a,
        method_b_name=args.method_b
    )
    print(results)

if __name__ == '__main__':
    main()
