import argparse
import pickle
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from glob import glob

def parse_args():
    parser = argparse.ArgumentParser(description="Analyze Qwen Gradient Evaluation Results")
    parser.add_argument("--input_dir", type=str, required=True, help="Directory containing .pkl result files")
    parser.add_argument("--output_file", type=str, default="analysis_results", help="Base filename for output (pdf/txt)")
    return parser.parse_args()

def main():
    args = parse_args()
    
    files = glob(os.path.join(args.input_dir, "*_grads.pkl"))
    if not files:
        print(f"No result files found in {args.input_dir}")
        return
        
    all_sims = []
    rel_stats = {}
    
    print(f"Found {len(files)} result files.")
    
    for fpath in files:
        rel_name = os.path.basename(fpath).replace("_grads.pkl", "")
        with open(fpath, 'rb') as f:
            data = pickle.load(f)
            
        sims = [d['cosine_sim'] for d in data]
        if not sims:
            continue
            
        all_sims.extend(sims)
        rel_stats[rel_name] = {
            "mean": np.mean(sims),
            "std": np.std(sims),
            "count": len(sims)
        }
        
    if not all_sims:
        print("No valid data found.")
        return
        
    # Global Stats
    global_mean = np.mean(all_sims)
    global_std = np.std(all_sims)
    
    # Save Text Report
    txt_path = args.output_file + ".txt"
    with open(txt_path, 'w') as f:
        f.write("Gradient Cosine Similarity Analysis\n")
        f.write("===================================\n")
        f.write(f"Total Samples: {len(all_sims)}\n")
        f.write(f"Global Mean Cosine Sim: {global_mean:.4f}\n")
        f.write(f"Global Std Dev: {global_std:.4f}\n\n")
        
        f.write("Per Relation Stats:\n")
        for rel, stats in rel_stats.items():
            f.write(f"{rel}: Mean={stats['mean']:.4f}, Std={stats['std']:.4f}, N={stats['count']}\n")
            
    print(f"Saved text report to {txt_path}")
    
    # Generate Plot
    # Histogram of Cosine Similarities
    plt.figure(figsize=(6, 1))
    # Bins of width 0.05 from -1 to 1
    bins = np.arange(-1.0, 1.05, 0.05)
    sns.histplot(all_sims, bins=bins, kde=False, color='skyblue', stat='count')
    plt.axvline(global_mean, color='r', linestyle='--', label=f'Mean: {global_mean:.2f}')
    plt.axvline(-1.0, color='g', linestyle=':', label='Hypothesis: -1.0')
    
    # plt.title("Distribution of Gradient Cosine Similarities")
    plt.xlabel("Cosine Similarity")
    plt.ylabel("Count")
    plt.legend(loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    pdf_path = args.output_file + ".pdf"
    plt.savefig(pdf_path)
    print(f"Saved plot to {pdf_path}")
    
    # Additional Scatter Plot?
    # Maybe Scatter of Sim vs Loss Difference?
    # Or just the histogram is sufficient as "Plot".
    # User asked: "generate a scatter plot... average cosine similarity with stdef".
    # Maybe a bar plot of per-relation means?
    # Let's add a per-relation plot.
    
    plt.figure(figsize=(8, 5))
    rels = list(rel_stats.keys())
    means = [rel_stats[r]['mean'] for r in rels]
    stds = [rel_stats[r]['std'] for r in rels]
    
    # Sort by mean
    zipped = sorted(zip(rels, means, stds), key=lambda x: x[1])
    rels, means, stds = zip(*zipped)
    
    plt.errorbar(rels, means, yerr=stds, fmt='o', capsize=5)
    plt.axhline(-1.0, color='g', linestyle=':', label='Hypothesis: -1.0')
    plt.axhline(global_mean, color='r', linestyle='--', label='Global Mean')
    plt.xticks(rotation=90)
    # plt.title("Mean Gradient Cosine Similarity per Relation")
    plt.tight_layout()
    plt.savefig(args.output_file + "_per_relation.pdf")

if __name__ == "__main__":
    main()