import os
import re
import pandas as pd
from scipy.stats import spearmanr
import seaborn as sns
import matplotlib.pyplot as plt

# === Initial Setup ===
directory = './summary'  # Path to the analysis files (in data.zip) (note they are in alphabetic ascending order)
attack_success_rates = [
    [0.1625, 0.2375, 0.2125],  # Benign ASRs
    [0.1875, 0.2375, 0.225],   # Cybersecurity ASRs
    [0.15, 0.2375, 0.225],     # Engineering ASRs
    [0.3500, 0.5000, 0.4250],  # H2 ASRs (LAT Harmful)
    [0.5625, 0.7000, 0.5875],   # Harmful ASRs (CB Harmful)
    [0.1875, 0.2375, 0.225],   # Legal ASRs
]

# === Flatten attack success rates list for 18 values (6 models * 3 ASRs) ===
flat_attack_success_rates = [asr for model_asrs in attack_success_rates for asr in model_asrs]

# === Function to parse one file ===
def parse_stats_file(filepath):
    with open(filepath, 'r') as file:
        content = file.read()

    # Match "Metric_name:\n  Mean: value"
    pattern = r'(\w+):\n(?:.*?\n)*?\s*Mean:\s*([-+]?[0-9]*\.?[0-9]+)'
    matches = re.findall(pattern, content)

    stats = {}
    for metric, mean in matches:
        stats[metric] = float(mean)

    ld_q = re.search(r'Lexical Diversity \(Questions\):\s*([0-9.]+)', content)
    ld_r = re.search(r'Lexical Diversity \(Responses\):\s*([0-9.]+)', content)
    if ld_q: stats['lexical_diversity_q'] = float(ld_q.group(1))
    if ld_r: stats['lexical_diversity_r'] = float(ld_r.group(1))

    return stats

# === Parse all files ===
data = []
files = sorted([f for f in os.listdir(directory) if f.endswith('.txt')])

# Ensure there are 6 files corresponding to 6 models
if len(files) != 6:
    raise ValueError("There should be exactly 6 files in the directory corresponding to 5 models.")

for f in files:
    full_path = os.path.join(directory, f)
    stats = parse_stats_file(full_path)
    data.append(stats)

# === Begin Correlational Analysis ===
expanded_data = []

for i, model_data in enumerate(data):
    for asr in attack_success_rates[i]:
        model_data_copy = model_data.copy() 
        model_data_copy['attack_success'] = asr 
        expanded_data.append(model_data_copy) 

# === DataFrame ===
df_final = pd.DataFrame(expanded_data)

# === Spearman Correlation with p-values for ASRs ===
correlation_matrix, p_value_matrix = spearmanr(df_final, axis=0)

correlation_df = pd.DataFrame(correlation_matrix, columns=df_final.columns, index=df_final.columns)
p_value_df = pd.DataFrame(p_value_matrix, columns=df_final.columns, index=df_final.columns)

correlations_with_asr = correlation_df['attack_success'].drop('attack_success')
p_values_with_asr = p_value_df['attack_success'].drop('attack_success')

# Print out the correlations and p-values with 'attack_success'
print("\nCorrelations with attack success:")
print(correlations_with_asr)
print("\nP-values for correlations with attack success:")
print(p_values_with_asr)

print("File order:", files)
print("ASRs:", attack_success_rates)

# === Heatmap for Correlation with ASRs ===
correlation_with_asr_df = correlation_df[['attack_success']]

plt.figure(figsize=(8, 5))
sns.heatmap(correlation_with_asr_df, annot=True, cmap='coolwarm', center=0)
plt.title('Spearman Correlation with Attack Success Rate')
plt.tight_layout()
plt.show()

# === Heatmap for P-values with ASRs ===
p_values_with_asr_df = p_value_df[['attack_success']]

plt.figure(figsize=(8, 5))
sns.heatmap(p_values_with_asr_df, annot=True, cmap='coolwarm', center=0)
plt.title('P-Value Matrix for Spearman Correlations with Attack Success Rate')
plt.tight_layout()
plt.show()
