#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import glob
import json
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import math
from collections import defaultdict

# Set style for plots
plt.style.use('ggplot')
sns.set_style("whitegrid")

def extract_priors_from_response(response):
    """
    Extract all priors from a model response.
    Returns a list of dictionaries with probabilities for each option (A, B, C, D).
    """
    if not response or not isinstance(response, str):
        return []
    
    # Find all prior tags in the response
    prior_pattern = r'<prior_(\d+|final)>\s*(<([A-D])>\s*([\d\.]+)\s*</\3>)+\s*</prior_\1>'
    prior_matches = re.finditer(prior_pattern, response, re.DOTALL)
    
    # Extract options and probabilities
    priors = []
    for prior_match in prior_matches:
        prior_text = prior_match.group(0)
        prior_num = prior_match.group(1)  # Will be digit or "final"
        
        # Find all option tags within this prior
        option_pattern = r'<([A-D])>\s*([\d\.]+)\s*</\1>'
        option_matches = re.finditer(option_pattern, prior_text)
        
        # Build dictionary of options and probabilities
        prior_dict = {'prior_num': prior_num}
        for option_match in option_matches:
            option = option_match.group(1)
            prob = float(option_match.group(2))
            prior_dict[option] = prob
        
        # Ensure all options have values (set to 0 if missing)
        for option in ['A', 'B', 'C', 'D']:
            if option not in prior_dict:
                prior_dict[option] = 0.0
        
        priors.append(prior_dict)
    
    # Sort priors by number (with "final" at the end)
    def prior_key(p):
        if p['prior_num'] == 'final':
            return float('inf')
        return int(p['prior_num'])
    
    priors.sort(key=prior_key)
    
    return priors

def get_highest_prob_option(prior):
    """Return the option with the highest probability in the prior."""
    options = ['A', 'B', 'C', 'D']
    return max(options, key=lambda opt: prior[opt])

def calculate_log_odds(prior, correct_option):
    """Calculate log odds of the correct option in the prior."""
    prob = prior.get(correct_option, 0.0001)  # Use small value to avoid log(0)
    # Bound probability to avoid extreme values
    prob = max(min(prob, 0.9999), 0.0001)
    return math.log(prob / (1 - prob))

def analyze_priors_for_model(json_path):
    """
    Analyze priors from all responses in a JSON file.
    Returns a dictionary with various metrics.
    """
    with open(json_path, 'r') as f:
        results = json.load(f)
    
    all_priors = []
    all_answer_changes = []
    all_log_odds = []
    all_accuracies = []
    
    num_responses_with_priors = 0
    skipped = 0
    
    for res in results:
        if res.get('skipped', False):
            skipped += 1
            continue
        
        response = res.get('response', '')
        priors = extract_priors_from_response(response)
        
        if not priors:
            continue
        
        num_responses_with_priors += 1
        all_priors.append(priors)
        
        # Track answer changes across priors
        answers = [get_highest_prob_option(prior) for prior in priors]
        answer_changes = []
        for i in range(1, len(answers)):
            answer_changes.append(1 if answers[i] != answers[i-1] else 0)
        all_answer_changes.append(answer_changes)
        
        # Get the correct answer
        resolution = res.get('resolution')
        correct_option = None
        for opt in ['A', 'B', 'C', 'D']:
            if res.get(opt) == resolution:
                correct_option = opt
                break
        
        if correct_option:
            # Calculate log odds and accuracy for each prior
            log_odds = [calculate_log_odds(prior, correct_option) for prior in priors]
            all_log_odds.append(log_odds)
            
            accuracies = [1 if get_highest_prob_option(prior) == correct_option else 0 for prior in priors]
            all_accuracies.append(accuracies)
    
    # Aggregate results
    results = {
        'model_name': os.path.basename(json_path).split('_')[0],
        'total_responses': len(results),
        'responses_with_priors': num_responses_with_priors,
        'skipped': skipped,
        'all_priors': all_priors,
        'all_answer_changes': all_answer_changes,
        'all_log_odds': all_log_odds,
        'all_accuracies': all_accuracies
    }
    
    return results

def plot_model_priors(model_results, output_dir='plots'):
    """Create visualizations for a model's prior analysis."""
    os.makedirs(output_dir, exist_ok=True)
    model_name = model_results['model_name']
    
    # 1. Plot average probability assigned to each option across priors
    if model_results['all_priors']:
        max_priors = max(len(priors) for priors in model_results['all_priors'])
        
        # Initialize arrays to hold average probabilities
        option_probs = {
            'A': np.zeros(max_priors),
            'B': np.zeros(max_priors),
            'C': np.zeros(max_priors),
            'D': np.zeros(max_priors)
        }
        counts = np.zeros(max_priors)
        
        # Sum up probabilities across all responses
        for priors in model_results['all_priors']:
            for i, prior in enumerate(priors):
                for option in ['A', 'B', 'C', 'D']:
                    option_probs[option][i] += prior[option]
                counts[i] += 1
        
        # Calculate averages
        for option in option_probs:
            for i in range(max_priors):
                if counts[i] > 0:
                    option_probs[option][i] /= counts[i]
        
        # Create plot
        plt.figure(figsize=(10, 6))
        for option, color in zip(['A', 'B', 'C', 'D'], ['blue', 'orange', 'green', 'red']):
            plt.plot(range(1, max_priors + 1), option_probs[option], label=option, marker='o', color=color)
        
        plt.xlabel('Prior Number')
        plt.ylabel('Average Probability')
        plt.title(f'Average Probability Assigned to Each Option Across Priors\n{model_name}')
        plt.legend()
        plt.grid(True)
        plt.xticks(range(1, max_priors + 1))
        plt.savefig(f'{output_dir}/{model_name}_option_probs.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # 2. Plot answer changes across priors
    if model_results['all_answer_changes']:
        max_changes = max(len(changes) for changes in model_results['all_answer_changes'])
        avg_changes = np.zeros(max_changes)
        counts = np.zeros(max_changes)
        
        for changes in model_results['all_answer_changes']:
            for i, change in enumerate(changes):
                avg_changes[i] += change
                counts[i] += 1
        
        for i in range(max_changes):
            if counts[i] > 0:
                avg_changes[i] /= counts[i]
        
        plt.figure(figsize=(10, 6))
        plt.bar(range(1, max_changes + 1), avg_changes, color='skyblue')
        plt.xlabel('Prior Transition')
        plt.ylabel('Probability of Answer Change')
        plt.title(f'Probability of Answer Change Between Consecutive Priors\n{model_name}')
        plt.grid(True, axis='y')
        plt.xticks(range(1, max_changes + 1))
        plt.savefig(f'{output_dir}/{model_name}_answer_changes.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # 3. Plot average log odds of correct answer
    if model_results['all_log_odds']:
        max_steps = max(len(log_odds) for log_odds in model_results['all_log_odds'])
        avg_log_odds = np.zeros(max_steps)
        counts = np.zeros(max_steps)
        
        for log_odds in model_results['all_log_odds']:
            for i, lo in enumerate(log_odds):
                if not np.isnan(lo) and not np.isinf(lo):
                    avg_log_odds[i] += lo
                    counts[i] += 1
        
        for i in range(max_steps):
            if counts[i] > 0:
                avg_log_odds[i] /= counts[i]
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, max_steps + 1), avg_log_odds, marker='o', color='purple')
        plt.xlabel('Prior Number')
        plt.ylabel('Average Log Odds of Correct Answer')
        plt.title(f'Average Log Odds of Correct Answer Across Priors\n{model_name}')
        plt.grid(True)
        plt.xticks(range(1, max_steps + 1))
        plt.savefig(f'{output_dir}/{model_name}_log_odds.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # 4. Plot accuracy across priors
    if model_results['all_accuracies']:
        max_steps = max(len(accs) for accs in model_results['all_accuracies'])
        avg_accuracies = np.zeros(max_steps)
        counts = np.zeros(max_steps)
        
        for accuracies in model_results['all_accuracies']:
            for i, acc in enumerate(accuracies):
                avg_accuracies[i] += acc
                counts[i] += 1
        
        for i in range(max_steps):
            if counts[i] > 0:
                avg_accuracies[i] /= counts[i]
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, max_steps + 1), avg_accuracies * 100, marker='o', color='green')
        plt.xlabel('Prior Number')
        plt.ylabel('Accuracy (%)')
        plt.title(f'Accuracy Across Priors\n{model_name}')
        plt.grid(True)
        plt.xticks(range(1, max_steps + 1))
        plt.savefig(f'{output_dir}/{model_name}_accuracy.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    # 5. Additional analysis: Entropy/uncertainty across priors
    if model_results['all_priors']:
        max_priors = max(len(priors) for priors in model_results['all_priors'])
        avg_entropy = np.zeros(max_priors)
        counts = np.zeros(max_priors)
        
        for priors in model_results['all_priors']:
            for i, prior in enumerate(priors):
                # Calculate entropy: -sum(p * log(p))
                entropy = 0
                for option in ['A', 'B', 'C', 'D']:
                    p = prior[option]
                    if p > 0:
                        entropy -= p * np.log(p)
                
                avg_entropy[i] += entropy
                counts[i] += 1
        
        for i in range(max_priors):
            if counts[i] > 0:
                avg_entropy[i] /= counts[i]
        
        plt.figure(figsize=(10, 6))
        plt.plot(range(1, max_priors + 1), avg_entropy, marker='o', color='brown')
        plt.xlabel('Prior Number')
        plt.ylabel('Average Entropy')
        plt.title(f'Average Entropy (Uncertainty) Across Priors\n{model_name}')
        plt.grid(True)
        plt.xticks(range(1, max_priors + 1))
        plt.savefig(f'{output_dir}/{model_name}_entropy.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    return

def main():
    # Define the directory where JSON results are stored
    retrieval_dir = "/fast/XXXX-3/forecasting/evals/manual/manifold_mcq/"
    
    # Get list of JSON files
    json_files = glob.glob(os.path.join(retrieval_dir, "*.json"))
    
    if not json_files:
        print(f"No JSON files found in {retrieval_dir}")
        return
    
    all_model_results = []
    for json_file in tqdm(json_files, desc="Analyzing models"):
        model_results = analyze_priors_for_model(json_file)
        all_model_results.append(model_results)
        
        # Generate plots for each model
        plot_model_priors(model_results)
    
    # Optional: Generate comparative plots across models
    if len(all_model_results) > 1:
        # Compare log odds trends
        plt.figure(figsize=(12, 8))
        for result in all_model_results:
            if not result['all_log_odds']:
                continue
                
            model_name = result['model_name']
            max_steps = max(len(log_odds) for log_odds in result['all_log_odds'])
            avg_log_odds = np.zeros(max_steps)
            counts = np.zeros(max_steps)
            
            for log_odds in result['all_log_odds']:
                for i, lo in enumerate(log_odds):
                    if not np.isnan(lo) and not np.isinf(lo):
                        avg_log_odds[i] += lo
                        counts[i] += 1
            
            for i in range(max_steps):
                if counts[i] > 0:
                    avg_log_odds[i] /= counts[i]
            
            plt.plot(range(1, max_steps + 1), avg_log_odds, marker='o', label=model_name)
        
        plt.xlabel('Prior Number')
        plt.ylabel('Average Log Odds of Correct Answer')
        plt.title('Comparison of Log Odds Across Models')
        plt.legend()
        plt.grid(True)
        plt.savefig('plots/comparative_log_odds.png', dpi=300, bbox_inches='tight')
        plt.close()
        
        # Compare accuracy trends
        plt.figure(figsize=(12, 8))
        for result in all_model_results:
            if not result['all_accuracies']:
                continue
                
            model_name = result['model_name']
            max_steps = max(len(accs) for accs in result['all_accuracies'])
            avg_accuracies = np.zeros(max_steps)
            counts = np.zeros(max_steps)
            
            for accuracies in result['all_accuracies']:
                for i, acc in enumerate(accuracies):
                    avg_accuracies[i] += acc
                    counts[i] += 1
            
            for i in range(max_steps):
                if counts[i] > 0:
                    avg_accuracies[i] /= counts[i]
            
            plt.plot(range(1, max_steps + 1), avg_accuracies * 100, marker='o', label=model_name)
        
        plt.xlabel('Prior Number')
        plt.ylabel('Accuracy (%)')
        plt.title('Comparison of Accuracy Across Models')
        plt.legend()
        plt.grid(True)
        plt.savefig('plots/comparative_accuracy.png', dpi=300, bbox_inches='tight')
        plt.close()
    
    print(f"Analysis complete. Plots saved to the 'plots' directory.")

if __name__ == "__main__":
    main() 