import re
import os
import numpy as np
import glob
import itertools
import argparse
from config import config
import matplotlib.pyplot as plt

parser = argparse.ArgumentParser(prog='Generates Results & Graphs for one set of Runs of a Paritcular Problem')
parser.add_argument('-e', '--experiment_dir', type=str, required=True, dest="e", help="Path to the Directory containing the Runs and Scores")
parser.add_argument('-d', '--prompting_depth', type=int, required=True, dest="d", help="Prompting Depth of the Experiment")
args = parser.parse_args() 

experiments_dir = args.e
prompting_depth = args.d

def generate_selections(N, M):
    numbers = list(range(1, N+1))
    selections = list(itertools.combinations(numbers, M))
    return selections

def parse_stats_file(filename):
    with open(filename, 'r') as f:
        content = f.read()

    # Extract directory, dataset file, and number of testing samples
    directory_match = re.search(r'Code Directory: (.*), Dataset File: (.*)', content)
    samples_match = re.search(r'Number of Testing Samples: (\d+)', content)

    if not directory_match or not samples_match:
        raise ValueError("Failed to parse the header information.")

    directory = directory_match.group(1)
    dataset_file = directory_match.group(2)
    num_samples = int(samples_match.group(1))

    # Extract stats for each code file
    code_stats_pattern = r'(?P<filename>code-\d+\.py)\nCorrect: (?P<correct>\d+)\nRuntime-Error: (?P<runtime>\d+)\nTimeout-Error: (?P<timeout>\d+)\nVerification-Error: (?P<verification>\d+)'
    matches = re.finditer(code_stats_pattern, content)

    code_stats = []
    for match in matches:
        code_stats.append({
            'filename': match.group('filename'),
            'correct': int(match.group('correct')),
            'runtime-error': int(match.group('runtime')),
            'timeout-error': int(match.group('timeout')),
            'verification-error': int(match.group('verification')),
        })

    return {
        'directory': directory,
        'dataset_file': dataset_file,
        'num_samples': num_samples,
        'code_stats': code_stats
    }

def find_scores_files(path):
    pattern = os.path.join(path, 'scores-*')
    files = [f for f in glob.glob(pattern) if os.path.isfile(f)]
    return files

def find_run_folders(path):
    pattern = os.path.join(path, 'run-*')    
    folders = [f for f in glob.glob(pattern) if os.path.isdir(f)]
    return folders


print(f"Experiment Directory: {experiments_dir}")
run_folders = find_run_folders(path=experiments_dir)
print("Run Folders:")
print("\n".join(run_folders))

test_total_stats = dict() ### run-number --> prompting_depth ---> dict(accuracy)
val_total_stats = dict() ### run-number --> prompting_depth ---> dict(accuracy)
test_last_program_stats = dict() ### run-number --> dict(accuracy)
val_last_program_stats = dict() ### run-number --> dict(accuracy)
number_of_runs = len(run_folders)


for run_folder in run_folders:
    run_number = int(run_folder.split("/")[-1].split("-")[-1])
    test_total_stats[run_number] = dict() ### depth ---> stats
    val_total_stats[run_number] = dict() ### depth ---> stats
    scores_files = find_scores_files(run_folder)
    assert len(scores_files) == 2 ### one validation and one test file
    for scores_file in scores_files:
        if "validation" in scores_file:
            val_scores_file = scores_file
        elif "test" in scores_file:
            test_scores_file = scores_file
    print(val_scores_file, test_scores_file)
    
    test_stats = parse_stats_file(test_scores_file)['code_stats']
    for code_stat in test_stats:
        depth_level = int(code_stat['filename'].split(".")[0].split("-")[-1])
        code_stat['accuracy'] = code_stat['correct'] / (code_stat['correct'] + code_stat['runtime-error'] + code_stat['timeout-error'] + code_stat['verification-error'])
        test_total_stats[run_number][depth_level] = code_stat
    test_last_program_stats[run_number] = test_total_stats[run_number][max(test_total_stats[run_number].keys())]


    val_stats = parse_stats_file(val_scores_file)['code_stats']
    for code_stat in val_stats:
        depth_level = int(code_stat['filename'].split(".")[0].split("-")[-1])
        code_stat['accuracy'] = code_stat['correct'] / (code_stat['correct'] + code_stat['runtime-error'] + code_stat['timeout-error'] + code_stat['verification-error'])
        val_total_stats[run_number][depth_level] = code_stat
    val_last_program_stats[run_number] = val_total_stats[run_number][max(val_total_stats[run_number].keys())]

test_accuracy_data = np.zeros(shape=(number_of_runs, prompting_depth))
val_accuracy_data = np.zeros(shape=(number_of_runs, prompting_depth))
max_depth_reached = dict() ### run_number --> depth

for run_number in range(1, number_of_runs+1):
    program_indicies = sorted(test_total_stats[run_number].keys())
    max_depth_reached[run_number] = max(program_indicies)
    for program_index in program_indicies:
        test_accuracy_data[run_number-1][program_index-1] = test_total_stats[run_number][program_index]['accuracy']
        val_accuracy_data[run_number-1][program_index-1] = val_total_stats[run_number][program_index]['accuracy']

fig, ax = plt.subplots()
cax = ax.imshow(test_accuracy_data, cmap='Greens', origin='upper', vmin=0, vmax=1)

# Add a colorbar to help interpret the colors
cbar = fig.colorbar(cax)

# Set title and labels
ax.set_title(f"{config['problem_name']} Accuracy")
ax.set_xlabel("Feedback Depth")
ax.set_ylabel("Independent Runs")
cbar.set_label('Accuracy', rotation=270, labelpad=15)

# Draw horizontal lines to separate rows
num_rows, num_cols = test_accuracy_data.shape
for y in range(num_rows - 1):
    ax.axhline(y + 0.5, color='white', linestyle='-', linewidth=10)
    
for run_number, depth in max_depth_reached.items():
    ax.axvline(max_depth_reached[run_number]- 0.5, ymin=((num_rows-run_number))/num_rows+0.1/num_rows, ymax=((num_rows-run_number)+1)/num_rows-0.1/num_rows, color='red', linestyle='-', linewidth=3)

# Annotate each cell with its corresponding number
for y in range(num_rows):
    for x in range(num_cols):
        if x >= max_depth_reached[y+1]:
            test_accuracy_data[y, x] = test_accuracy_data[y, max_depth_reached[y+1]-1]
            val_accuracy_data[y, x] = val_accuracy_data[y, max_depth_reached[y+1]-1]
        else:
            value = test_accuracy_data[y, x]
            ax.text(x, y, "{:.1f}".format(value*100), ha='center', va='center', color='orange', fontweight='bold')

plt.savefig("Results.png")
plt.clf()


with open('results.txt', 'w') as results_file:
    results_file.write(f'Problem Name: {config["problem_name"]}\n\n')
    for run_number in sorted(max_depth_reached.keys()):
        results_file.write(f'Run-{run_number} Depth: {max_depth_reached[run_number]}\n')
    results_file.write('\n')
    results_file.write('Validation Accuracy:\n')
    for y in range(num_rows):
        for x in range(num_cols):
            results_file.write(f'{val_accuracy_data[y, x]:.3f} ')
        results_file.write('\n')
    results_file.write('\n')
    results_file.write('Test Accuracy:\n')
    for y in range(num_rows):
        for x in range(num_cols):
            results_file.write(f'{test_accuracy_data[y, x]:.3f} ')
        results_file.write('\n')
    results_file.write('\n')

print('Validation Accuracy')
print(val_accuracy_data)
print('\nTest Accuracy')
print(test_accuracy_data)
print('\nMax Depth')
print(max_depth_reached)


average_scores_per_depth = []
### effect of prompting depth
for depth in range(prompting_depth):
    average_score = np.mean(test_accuracy_data[:, depth])*100
    average_scores_per_depth.append(average_score)
    
plt.plot(average_scores_per_depth, color="red", marker="X", lw=2, ls="--")
plt.grid(True)
plt.xlabel("Feedback Depth")
plt.ylabel("Average Score")
plt.title(f"Accuracy vs Feedback Depth ({config['problem_name']})")
plt.savefig("Results-Feedback.png")
plt.clf()


### effect of multiple runs
average_scores_per_run_count = []
for run_count in range(1, number_of_runs+1): ### k
    all_permutations = generate_selections(number_of_runs, run_count)
    accumulator = 0
    for run_indices in all_permutations: ### all possible samples of k programs
        ### for a particular sample find the max val accuracy
        run_val_accuraracies = [(val_last_program_stats[run_index]['accuracy'], run_index) for run_index in run_indices]
        run_val_accuraracies.sort(key=lambda x: x[0], reverse=True)
        best_run = run_val_accuraracies[0][1]
        accumulator += test_last_program_stats[best_run]['accuracy']
    average_scores_per_run_count.append(accumulator/len(all_permutations))
plt.plot(average_scores_per_run_count, color="red", marker="X", lw=2, ls="--")
plt.grid(True)
plt.xlabel("Number of Runs")
plt.ylabel("Average Score")
plt.title(f"Accuracy vs Number of Runs ({config['problem_name']})")
plt.savefig("Results-Runs.png")
plt.clf()