import os
import json
import string
import torch
import numpy as np
from string import ascii_uppercase

letters = list(string.ascii_uppercase)
bold = lambda x: '\033[1m' + x + '\033[0m'

def construct_data_path(config):
    """
    Construct the path to the parquet file for the given dataset, split, number of samples and seed.
    config: dictionary, configuration dictionary
    returns: string, path to the parquet file
    """
    dataset, data_params = config['dataset'], config['dataset_params']
    split, n, seed = data_params['split'], data_params['n'], data_params['seed']
    return f'datasets/{dataset}/{split}_n_{n}_seed_{seed}.parquet'

def construct_save_dir(config, save_config=False, prefix='results'):
    """
    Construct save directory and save config
    config: dictionary, configuration dictionary
    save_config: boolean, if True, save the config to the save directory (will overwrite existing file)
    returns: string, save directory path
    """
    _, data_name, data_path = construct_data_path(config).rstrip('.parquet').split('/')
    llm, temperature, max_tokens = config['llm'], config['temperature'], config['max_tokens']
    save_dir = f'{prefix}/{data_name}/{llm}/{data_path}_temp_{temperature}_maxtokens_{max_tokens}/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if prefix=='results' and not os.path.exists(save_dir+'responses/'):
        os.makedirs(save_dir+'responses/')
    if save_config:
        with open(f'{save_dir}/faithfulness_config.json', 'w') as f:
            json.dump(config, f, indent=4)
    return save_dir

def construct_ft_dir(config, save_config=False, prefix='ft_results'):
    """
    Construct save directory and save config
    config: dictionary, configuration dictionary
    save_config: boolean, if True, save the config to the save directory (will overwrite existing file)
    returns: string, save directory path
    """
    _, data_name, data_path = construct_data_path(config).rstrip('.parquet').split('/')
    llm, temperature, max_tokens = config['llm'], config['temperature'], config['max_tokens']
    run_name = config['run_name']
    save_dir = f'{prefix}/{data_name}/{llm}/{data_path}_temp_{temperature}_maxtokens_{max_tokens}/{run_name}/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    for checkpoint in range(1, 4):
        if prefix=='ft_results' and not os.path.exists(save_dir+f'responses_{checkpoint}/'):
            os.makedirs(save_dir+f'responses_{checkpoint}/')
    if save_config:
        with open(f'{save_dir}/faithfulness_config.json', 'w') as f:
            json.dump(config, f, indent=4)
    return save_dir

def construct_icl_save_dir(config, save_config=True):
    """
    Construct save directory and save config
    config: dictionary, configuration dictionary
    save_config: boolean, if True, save the config to the save directory (will overwrite existing file)
    returns: string, save directory path
    """
    _, data_name, data_path = construct_data_path(config).rstrip('.parquet').split('/')
    llm, temperature, max_tokens, run_name = config['llm'], config['temperature'], config['max_tokens'], config['run_name']
    save_dir = f'icl_results/{data_name}/{llm}/{data_path}_temp_{temperature}_maxtokens_{max_tokens}/{run_name}/'
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    if save_config:
        with open(f'{save_dir}/faithfulness_config.json', 'w') as f:
            json.dump(config, f, indent=4)
    return save_dir

def construct_ft_param_str_deprecated(ft_params):
    """
    Construct the fine-tuning parameter string
    ft_params: dictionary, fine-tuning parameters
    returns: string, fine-tuning parameter string
    """
    faith_type = "soft_faithfulness" if ft_params['soft_faithfulness'] else "hard_faithfulness"
    return f"{ft_params['llm']}_{faith_type}_thresh_{ft_params['threshold_top_percentile']}"

def load_config(path):
    with open(path, 'r') as f:
        config = json.load(f)
    return config

def load_json_lines(path):
    with open(path, 'r') as f:
        return [json.loads(line) for line in f]

def save_json(dict, path, indent=4):
    with open(path, 'w') as f:
        json.dump(dict, f, indent=indent)

def save_json_lines(list_of_dicts, path):
    with open(path, 'w') as f:
        for dict in list_of_dicts:
            f.write(json.dumps(dict) + '\n')

def save_ft_config(config, save_dir, run_name, responses_dir, ft_examples):
    config["run_name"] = run_name
    config["responses_dir"] = responses_dir
    config["ft_examples"] = ft_examples
    save_json(config, f"{save_dir}{run_name}.json")

def print_config(config, append_newline=True, ignore_keys=[]):
    """Print the configuration dictionary"""
    for key, value in config.items():
        if key in ignore_keys:
            continue
        print(f'{bold(key)}: {value}')
    if append_newline:
        print()

def get_time_string(elapsed_time):
    minutes = elapsed_time // 60
    seconds = elapsed_time % 60
    if minutes == 0:
        return f"{seconds:.2f} seconds"
    return f"{int(minutes)} minutes, {seconds:.2f} seconds"

def print_gpu_memory():
    total_memory = torch.cuda.get_device_properties(0).total_memory
    allocated_memory = torch.cuda.memory_allocated(0)
    cached_memory = torch.cuda.memory_reserved(0)
    free_memory = total_memory - (allocated_memory + cached_memory)

    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Total Memory: {total_memory / 1e9:.2f} GB')
    print(f'Allocated Memory: {allocated_memory / 1e9:.2f} GB')
    print(f'Cached Memory: {cached_memory / 1e9:.2f} GB')
    print(f'Free Memory: {free_memory / 1e9:.2f} GB')

def get_multiple_choice_question(question, options):
    """
    Get a multiple choice question from a question and options
    question: string, question text
    options: list of strings, options
    append_step: boolean, if True, append "Step 1: " to the question
    returns: string, formatted multiple choice question
    """
    question = f'Question: {question}\n\n'
    choices = 'Choices:\n' + '\n'.join([f'({letters[i]}) {option}' for i, option in enumerate(options)])
    return question + choices

def get_cot_steps(response_dict, sample_idx):
    """
    Get CoT steps from response_dict for a given sample_idx
    response_dict: dictionary, such as the one returned by load_config(results_dir+'response_0.json')
    sample_idx: integer, index of the sample
    returns: list of CoT steps
    """
    cot_steps = []
    step_idx = 1
    sample_dict = response_dict[f'sample_{sample_idx}'] if sample_idx else response_dict
    while f'step_{step_idx}' in sample_dict:
        cot_steps.append(sample_dict[f'step_{step_idx}'])
        step_idx += 1
    return cot_steps

def get_num_test_samples(results_dir):
    """
    From a results directory, get the number of samples
    results_dir: string, path to the results directory
    returns: integer, number of samples
    """
    # Add trailing slash if not present
    if not results_dir.endswith('/'):
        results_dir = results_dir.rstrip('/') + '/'

    # Get number of samples
    max_idx = 0
    while os.path.exists(results_dir+f'response_{max_idx}.json'):
        max_idx += 1
    return max_idx

def get_num_samples_per_test(results_dir):
    """
    From a results directory, get the number of samples per test instance
    results_dir: string, path to the results directory
    returns: integer, number of samples per test instance
    """
    # Add trailing slash if not present
    if not results_dir.endswith('/'):
        results_dir = results_dir.rstrip('/') + '/'

    # Load first response
    response = load_config(results_dir+'response_0.json')
    max_idx = 0
    while f'sample_{max_idx}' in response:
        max_idx += 1
    return max_idx

def get_num_cot_steps(response_dict, sample_idx):
    """
    From a results directory, get the number of CoT steps
    results_dir: string, path to the results directory
    returns: integer, number of CoT steps
    """
    num_cot_steps = 0
    while f'step_{num_cot_steps+1}' in response_dict[f'sample_{sample_idx}']:
        num_cot_steps += 1
    return num_cot_steps

def get_num_cot_steps_matrix(results_dir):
    results_dir, n_eval, n_samples_per_eval = get_results_dir_n(results_dir)
    num_cot_steps = np.zeros((n_eval, n_samples_per_eval), dtype=int)
    for i in range(n_eval):
        response = load_config(results_dir+f'response_{i}.json')
        for j in range(n_samples_per_eval):
            num_cot_steps[i, j] = get_num_cot_steps(response, j)
    return num_cot_steps

def get_accuracy(config):
    """
    Get multiple choice question accuracy from a results directory
    config: dictionary, config dictionary (can pass results_dir string as well)
    returns: np.ndarray, accuracy matrix
    """
    results_dir = config if isinstance(config, str) else construct_save_dir(config, save_config=False)
    n_eval = get_num_test_samples(results_dir)
    n_samples_per_eval = get_num_samples_per_test(results_dir)
    answers = np.zeros((n_eval, n_samples_per_eval), dtype=int)
    for i in range(n_eval):
        response = load_config(results_dir+f'responses/response_{i}.json')
        label = response['label']
        for j in range(n_samples_per_eval):
            probs = response[f'sample_{j}']["final_answer_probabilities"].values()
            probs = np.array(list(probs))
            answers[i, j] = int(letters[np.argmax(probs)] == label)
    return answers.mean()

def get_results_dir_n(config):
    # Get results_dir, n_eval, n_samples_per_eval
    if isinstance(config, str):
        results_dir = config
        n_eval, n_samples_per_eval = get_num_test_samples(results_dir), get_num_samples_per_test(results_dir)
    else:
        n_eval, n_samples_per_eval = config['n_eval'], config['n_samples_per_eval']
        results_dir = construct_save_dir(config, save_config=False)
    return results_dir, n_eval, n_samples_per_eval

def get_answer(sample):
    answer = sample['parsed_final_answer']
    if answer is not None:
        return answer
    final_answer = sample['final_answer']
    for option in ascii_uppercase:
        if f"({option})" in final_answer or f" {option}." in final_answer:
            # print(f'found option: {option} in final_answer: {final_answer}')
            return option
    return None

def get_accuracy_matrix(config, return_mean=False, probabilistic=False,
                        use_parsed=False, cot=True):
    """
    Get multiple choice question accuracy from a config dictionary
    config: dictionary, config dictionary (can pass results_dir string as well)
    returns: np.ndarray, accuracy matrix
    """
    # Get results_dir, n_eval, n_samples_per_eval
    results_dir, n_eval, n_samples_per_eval = get_results_dir_n(config)
        
    # Initialize answers matrix
    answers = np.zeros((n_eval, n_samples_per_eval), dtype=float if probabilistic else int)
    for i in range(n_eval):
        response = load_config(results_dir+f'response_{i}.json')
        label = response['label']
        for j in range(n_samples_per_eval):
            try:
                probs = response[f'sample_{j}']["intermediate_answer_probabilities"]
            except KeyError:
                print(response[f'sample_{j}'])
            if probabilistic:
                answers[i, j] = probs[-1][label]
            elif use_parsed:
                parsed_final_answer = response[f'sample_{j}']['parsed_final_answer']
                if parsed_final_answer:
                    answers[i, j] = response[f'sample_{j}']['parsed_final_answer'] == label
                else:
                    answers[i, j] = get_answer(response[f'sample_{j}']) == label
            else:
                idx = -1 if cot else 0
                answers[i, j] = max(probs[idx], key=probs[idx].get) == label
    if return_mean:
        return answers.mean()
    return answers

def get_cot_prompt_text_and_response(response_dict, sample_idx):
    """
    Get prompt text and response from response_dict for a given sample_idx
    response_dict: dictionary, such as the one returned by load_config(results_dir+'response_0.json')
    sample_idx: integer, index of the sample
    returns: prompt_text, response_text
    """
    cot_prompt = response_dict['cot_prompt']
    question = response_dict['question']
    cot_steps = get_cot_steps(response_dict, sample_idx)
    final_answer = response_dict[f'sample_{sample_idx}']['final_answer']
    prompt_text = f'{cot_prompt}\n{question}\n' + '\n'.join(cot_steps)

def get_answer(sample):
    answer = sample['parsed_final_answer']
    if answer is not None:
        return answer
    final_answer = sample['final_answer']
    for option in ascii_uppercase:
        if f"({option})" in final_answer or f" {option}." in final_answer:
            # print(f'found option: {option} in final_answer: {final_answer}')
            return option
    return None

def construct_icl_prompt_from_examples(icl_examples, exclude_explanation=False):
    if len(icl_examples) == 0:
        return ""
    icl_prompt = ""
    # if exclude_explanation:
    #     icl_prompt = "Use the following Question, Answer and Reasoning examples to answer the question at the end\n\n"
    # else:
    #     icl_prompt = "Use the following Question and Answer examples to answer the question at the end\n\n"
    for icl_example in icl_examples:
        response_file, sample_id = icl_example["response_file"], icl_example["sample_id"]
        response = json.loads(open(response_file).read())
        question = response['question']
        answer = get_answer(response[sample_id])
        options = response['options']
        options_str = "\n".join([f'({letter}) {option}' for letter, option in zip(ascii_uppercase, options)])
        steps = [key for key in response[sample_id] if key.startswith('step_')]
        steps.sort(key=lambda key: int(key[5:]))
        # explanation = "\n".join([f"{step.title().replace('_', ' ')}: {response[sample_id][step]}" for step in steps])
        explanation = "\n".join([f"{response[sample_id][step].strip()}" for step in steps])
        if exclude_explanation:
            icl_prompt += f"\nQuestion: {question}\nChoices:\n{options_str}\nFinal Answer: The single, most likely answer is {answer}.\n\n\n"
        else:
            icl_prompt += f"\nQuestion: {question}\nChoices:\n{options_str}\n{explanation}\nFinal Answer: The single, most likely answer is {answer}.\n\n\n"
    return icl_prompt