import json
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import argparse

def load_results(filepath):
    """Load the JSON results file."""
    with open(filepath, 'r') as f:
        return json.load(f)

def plot_lambda_distributions(results, output_dir):
    """Plot distributions of lambda values for robust vs vulnerable samples."""
    detailed = results['detailed_results']
    lambda_vals_clean = np.array(detailed['lambda_values_clean'])
    lambda_vals_adv = np.array(detailed['lambda_values_adv'])
    is_robust = np.array(detailed['is_robust'])
    is_vulnerable = np.array(detailed['is_vulnerable'])
    
    # Plot clean lambda distributions
    plt.figure(figsize=(10, 6))
    plt.hist(lambda_vals_clean[is_robust], bins=50, alpha=0.5, label='Robust', density=True)
    plt.hist(lambda_vals_clean[is_vulnerable], bins=50, alpha=0.5, label='Vulnerable', density=True)
    plt.xlabel('Lambda Values (Clean)')
    plt.ylabel('Density')
    plt.title('Distribution of Clean Lambda Values: Robust vs Vulnerable Samples')
    plt.legend()
    plt.savefig(output_dir / 'lambda_distribution_clean.png')
    plt.close()
    
    # Plot adversarial lambda distributions
    plt.figure(figsize=(10, 6))
    plt.hist(lambda_vals_adv[is_robust], bins=50, alpha=0.5, label='Robust', density=True)
    plt.hist(lambda_vals_adv[is_vulnerable], bins=50, alpha=0.5, label='Vulnerable', density=True)
    plt.xlabel('Lambda Values (Adversarial)')
    plt.ylabel('Density')
    plt.title('Distribution of Adversarial Lambda Values: Robust vs Vulnerable Samples')
    plt.legend()
    plt.savefig(output_dir / 'lambda_distribution_adv.png')
    plt.close()

def plot_class_statistics(results, output_dir):
    """Plot class-wise lambda values and robustness."""
    class_stats = results['class_stats']
    
    # Convert to arrays for plotting
    class_ids = list(class_stats.keys())
    mean_lambdas_clean = [stats['mean_lambda_clean'] for stats in class_stats.values()]
    mean_lambdas_adv = [stats['mean_lambda_adv'] for stats in class_stats.values()]
    robust_accs = [stats['robust_accuracy'] for stats in class_stats.values()]
    
    # Plot clean lambda vs robust accuracy
    plt.figure(figsize=(10, 6))
    plt.scatter(mean_lambdas_clean, robust_accs, alpha=0.5, label='Clean')
    plt.scatter(mean_lambdas_adv, robust_accs, alpha=0.5, label='Adversarial')
    plt.xlabel('Mean Lambda Value')
    plt.ylabel('Robust Accuracy')
    plt.title('Class-wise Mean Lambda vs Robust Accuracy')
    plt.legend()
    
    # Add trend lines
    for values, style in [(mean_lambdas_clean, 'r--'), (mean_lambdas_adv, 'b--')]:
        z = np.polyfit(values, robust_accs, 1)
        p = np.poly1d(z)
        plt.plot(values, p(values), style, alpha=0.8)
    
    # Add correlation coefficients
    corr_clean = np.corrcoef(mean_lambdas_clean, robust_accs)[0,1]
    corr_adv = np.corrcoef(mean_lambdas_adv, robust_accs)[0,1]
    plt.text(0.05, 0.95, f'Clean Correlation: {corr_clean:.3f}\nAdv Correlation: {corr_adv:.3f}', 
             transform=plt.gca().transAxes)
    
    plt.savefig(output_dir / 'class_lambda_vs_robustness.png')
    plt.close()

def analyze_embedding_similarity(results, output_dir):
    """Analyze relationship between embedding similarity and robustness."""
    detailed = results['detailed_results']
    similarities = np.array(detailed['embedding_similarity'])
    is_robust = np.array(detailed['is_robust'])
    lambda_vals_clean = np.array(detailed['lambda_values_clean'])
    lambda_vals_adv = np.array(detailed['lambda_values_adv'])
    
    # Plot embedding similarity distributions
    plt.figure(figsize=(10, 6))
    plt.hist(similarities[is_robust], bins=50, alpha=0.5, label='Robust', density=True)
    plt.hist(similarities[~is_robust], bins=50, alpha=0.5, label='Non-robust', density=True)
    plt.xlabel('Embedding Similarity')
    plt.ylabel('Density')
    plt.title('Distribution of Embedding Similarities: Robust vs Non-robust')
    plt.legend()
    plt.savefig(output_dir / 'embedding_similarity.png')
    plt.close()
    
    # Plot lambda change vs embedding similarity
    plt.figure(figsize=(10, 6))
    lambda_change = lambda_vals_adv - lambda_vals_clean
    
    # Remove any NaN or infinite values
    mask = ~(np.isnan(similarities) | np.isnan(lambda_change) | 
             np.isinf(similarities) | np.isinf(lambda_change))
    
    similarities_clean = similarities[mask]
    lambda_change_clean = lambda_change[mask]
    
    plt.scatter(similarities_clean, lambda_change_clean, alpha=0.5)
    plt.xlabel('Embedding Similarity')
    plt.ylabel('Lambda Change (Adv - Clean)')
    plt.title('Embedding Similarity vs Lambda Change')
    
    # Add trend line if we have valid data
    if len(similarities_clean) > 0:
        try:
            # Compute polynomial fit
            z = np.polyfit(similarities_clean, lambda_change_clean, 1)
            p = np.poly1d(z)
            
            # Plot trend line
            x_range = np.linspace(np.min(similarities_clean), np.max(similarities_clean), 100)
            plt.plot(x_range, p(x_range), "r--", alpha=0.8)
            
            # Add correlation coefficient
            corr = np.corrcoef(similarities_clean, lambda_change_clean)[0,1]
            if not np.isnan(corr):
                plt.text(0.05, 0.95, f'Correlation: {corr:.3f}', 
                        transform=plt.gca().transAxes)
        except Exception as e:
            print(f"Warning: Could not compute trend line: {str(e)}")
    
    plt.savefig(output_dir / 'similarity_vs_lambda_change.png')
    plt.close()

def print_summary_statistics(results):
    """Print key summary statistics."""
    stats = results['summary_stats']
    print("\nSummary Statistics:")
    print("-" * 50)
    print(f"Clean Lambda: {stats['mean_lambda_clean']:.4f} ± {stats['std_lambda_clean']:.4f}")
    print(f"Adversarial Lambda: {stats['mean_lambda_adv']:.4f} ± {stats['std_lambda_adv']:.4f}")
    print(f"Clean Lambda (Robust): {stats['mean_lambda_robust']:.4f}")
    print(f"Clean Lambda (Vulnerable): {stats['mean_lambda_vulnerable']:.4f}")
    print(f"Adversarial Lambda (Correct): {stats['mean_lambda_adv_correct']:.4f}")
    print(f"Adversarial Lambda (Incorrect): {stats['mean_lambda_adv_incorrect']:.4f}")
    print(f"Clean Accuracy: {stats['clean_accuracy']*100:.2f}%")
    print(f"Robust Accuracy: {stats['robust_accuracy']*100:.2f}%")
    print(f"Mean Embedding Similarity: {stats['mean_embedding_similarity']:.4f}")

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--input', type=str, required=True,
                       help='Path to the JSON results file from calc_lam_vals.py')
    parser.add_argument('--output_dir', type=str, default='lambda_analysis',
                       help='Directory to save analysis outputs')
    args = parser.parse_args()
    
    # Create output directory
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Load and analyze results
    results = load_results(args.input)
    
    # Generate plots and analysis
    plot_lambda_distributions(results, output_dir)
    plot_class_statistics(results, output_dir)
    analyze_embedding_similarity(results, output_dir)
    print_summary_statistics(results)
    
    # Save numerical analysis results
    detailed = results['detailed_results']
    lambda_vals_clean = np.array(detailed['lambda_values_clean'])
    lambda_vals_adv = np.array(detailed['lambda_values_adv'])
    is_robust = np.array(detailed['is_robust'])
    
    analysis_results = {
        'lambda_ratios': {
            'clean_robust_vulnerable': float(np.mean(lambda_vals_clean[is_robust]) / 
                                          np.mean(lambda_vals_clean[~is_robust])),
            'adv_robust_vulnerable': float(np.mean(lambda_vals_adv[is_robust]) / 
                                        np.mean(lambda_vals_adv[~is_robust])),
            'adv_clean_ratio': float(np.mean(lambda_vals_adv) / np.mean(lambda_vals_clean))
        },
        'correlations': {
            'clean_lambda_robustness': float(np.corrcoef(lambda_vals_clean, is_robust)[0,1]),
            'adv_lambda_robustness': float(np.corrcoef(lambda_vals_adv, is_robust)[0,1]),
            'clean_adv_lambda': float(np.corrcoef(lambda_vals_clean, lambda_vals_adv)[0,1])
        }
    }
    
    with open(output_dir / 'numerical_analysis.json', 'w') as f:
        json.dump(analysis_results, f, indent=2)

if __name__ == "__main__":
    main()