import pandas as pd
import glob
import json
import sys
import os
import random
import copy
from utils.util import construct_save_dir, load_config
from string import ascii_uppercase
from utils.faithfulness import parse_mcq_answer, get_faith_aoc

assert len(sys.argv) == 3, "please pass dataset name, model name, and faithfulness type (soft/hard) as command line args!"
dataset_name = sys.argv[1]
model_name = sys.argv[2]
filter_correct_examples = True

faithfulness_column = 'soft_faithfulness'

config = load_config('faithfulness_config.json')
config.update({'dataset': dataset_name, 'max_tokens': 512, 'temperature': 0.0, 'llm': model_name})
config['dataset_params'].update({'n': 400})
responses_dir = construct_save_dir(config, save_config=False)
response_files = glob.glob(f"{responses_dir}responses/response_*.json")

def get_perturbed_explanation_response_file(response_file):
    assert "_temp_0.0_" in response_file
    return response_file.replace("_temp_0.0_", "_temp_0.3_")

def get_answer(sample):
    answer = sample['parsed_final_answer']
    if answer is not None:
        return answer
    final_answer = sample['final_answer']
    for option in ascii_uppercase:
        if f"({option})" in final_answer or f" {option}." in final_answer:
            # print(f'found option: {option} in final_answer: {final_answer}')
            return option
    return None


def get_faithfulness(response_dict, sample_key):
    if faithfulness_column == 'soft_faithfulness':
        return response_dict[sample_key][faithfulness_column]
    elif faithfulness_column == 'hard_faithfulness':
        response_text, final_answer_str = response_dict[sample_key]['full_response'], response_dict['final_answer_str']
        parsed_cot_answer = parse_mcq_answer(response_text, final_answer_str)
        answers_probs = response_dict['sample_0']['intermediate_answer_probabilities']
        _, hard_faith_aoc = get_faith_aoc(answers_probs, parsed_cot_answer)
        return hard_faith_aoc
    else:
        raise Exception


def get_most_faithful_sample_id(response_file):
    response_file = get_perturbed_explanation_response_file(response_file)
    file = open(response_file)
    response_dict = json.loads(file.read())
    file.close()
    sample_keys = [key for key in response_dict if key.startswith('sample_')]
    assert len(sample_keys) == 5
    sample_keys.sort(key=lambda sample_key: get_faithfulness(response_dict, sample_key), reverse=True)
    return sample_keys[0]

def is_correct(response_file):
    file = open(response_file)
    response_dict = json.loads(file.read())
    file.close()
    return response_dict['label'] == get_answer(response_dict['sample_0'])

icl_config = {
    "dataset": dataset_name,
    "dataset_params": {
        "split": "test",
        "n": 100,
        "seed": 42
    },
    "llm": model_name,
    "temperature": 0.0,
    "max_tokens": 1024,
    "n_eval": 100,
    "n_samples_per_eval": 1,
    "n_probs": 20,
    "add_final_answer": True,
    "exclude_explanation": False,
}

all_response_files = copy.deepcopy(response_files)

if filter_correct_examples:
    all_count = len(response_files)
    response_files = list(filter(is_correct, response_files))
    correct_count = len(response_files)
    print(f'Filtered {correct_count} correct samples from {all_count} samples.')

random.seed(42)
random_k_icl_examples = random.sample(response_files, k=10)
top_icl_examples = sorted(
    response_files, 
    key=lambda response_file: get_faithfulness(json.loads(open(response_file).read()), 'sample_0'), 
    reverse=True
)
top_from_all_icl_examples = sorted(
    all_response_files, 
    key=lambda response_file: get_faithfulness(json.loads(open(response_file).read()), 'sample_0'), 
    reverse=True
)
random.seed(42)
random_k_from_all_icl_examples = random.sample(all_response_files, k=10)
correct_top_icl_examples = list(filter(is_correct, top_icl_examples))
config_dir = construct_save_dir(icl_config, prefix='icl_configs')

# Baseline 2 Config: ICL with 10 Random {Q, A} pairs
config = copy.deepcopy(icl_config)
config["run_name"] = f"baseline_2"
config["icl_examples"] = [{"response_file": response_file, "sample_id": "sample_0"} for response_file in random_k_icl_examples]
config["exclude_explanation"] = True
f = open(os.path.join(config_dir, f'baseline_2.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Baseline 3 Config: ICL with 10 Random {Q, A, E} triplets
config = copy.deepcopy(icl_config)
config["run_name"] = f"baseline_3"
config["icl_examples"] = [{"response_file": response_file, "sample_id": "sample_0"} for response_file in random_k_from_all_icl_examples]
f = open(os.path.join(config_dir, f'baseline_3.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Baseline 3 C Config: ICL with 10 Random {Q, A, E} triplets
config = copy.deepcopy(icl_config)
config["run_name"] = f"baseline_3_c"
config["icl_examples"] = [{"response_file": response_file, "sample_id": "sample_0"} for response_file in random_k_icl_examples]
f = open(os.path.join(config_dir, f'baseline_3_c.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Approach 1 Config: ICL with Top-10 Most Faithful {Q, A, E} Examples
config = copy.deepcopy(icl_config)
config["run_name"] = f"approach_1"
config["icl_examples"] = [{"response_file": response_file, "sample_id": "sample_0"} for response_file in top_from_all_icl_examples[:10]]
f = open(os.path.join(config_dir, f'approach_1.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Approach 1 C Config: ICL with Top-10 Most Faithful {Q, A, E} Examples
config = copy.deepcopy(icl_config)
config["run_name"] = f"approach_1_c"
config["icl_examples"] = [{"response_file": response_file, "sample_id": "sample_0"} for response_file in top_icl_examples[:10]]
f = open(os.path.join(config_dir, f'approach_1_c.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Approach 2 Config: ICL with Random-10 {Q, A, E} Examples and Most Faithful Instance
config = copy.deepcopy(icl_config)
config["run_name"] = f"approach_2"
config["icl_examples"] = [
    {
        "response_file": get_perturbed_explanation_response_file(response_file), 
        "sample_id": get_most_faithful_sample_id(response_file)
    } for response_file in random_k_from_all_icl_examples
]
f = open(os.path.join(config_dir, f'approach_2.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Approach 2 C Config: ICL with Random-10 {Q, A, E} Examples and Most Faithful Instance
config = copy.deepcopy(icl_config)
config["run_name"] = f"approach_2_c"
config["icl_examples"] = [
    {
        "response_file": get_perturbed_explanation_response_file(response_file), 
        "sample_id": get_most_faithful_sample_id(response_file)
    } for response_file in random_k_icl_examples
]
f = open(os.path.join(config_dir, f'approach_2_c.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Approach 3 Config: ICL with Top-10 Most Faithful {Q, A, E} Examples
config = copy.deepcopy(icl_config)
config["run_name"] = f"approach_3"
config["icl_examples"] = [{
        "response_file": get_perturbed_explanation_response_file(response_file), 
        "sample_id": get_most_faithful_sample_id(response_file)
    } for response_file in top_from_all_icl_examples[:10]]
f = open(os.path.join(config_dir, f'approach_3.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()

# Approach 3 C Config: ICL with Top-10 Most Faithful {Q, A, E} Examples
config = copy.deepcopy(icl_config)
config["run_name"] = f"approach_3_c"
config["icl_examples"] = [{
        "response_file": get_perturbed_explanation_response_file(response_file), 
        "sample_id": get_most_faithful_sample_id(response_file)
    } for response_file in top_icl_examples[:10]]
f = open(os.path.join(config_dir, f'approach_3_c.json'), 'w')
f.write(json.dumps(config, indent=4))
f.close()
