

import os
import json
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch

plt.rcParams['font.size'] = 14
default_font_size = mpl.rcParams['font.size']
print(f"Font size: {default_font_size}")

DPI=400


def load_results(json_file):
    with open(json_file, 'r') as f:
        return json.load(f)


def create_boxplot(results, output_dir="vis_ab_output"):
    os.makedirs(output_dir, exist_ok=True)
    
    sample_ratios = results["sample_ratios"]
    num_trials = results["num_trials"]
    
    random_data = []
    stratified_data = []
    
    for ratio in sample_ratios:
        ratio_str = str(ratio)
        random_data.append(results["results"][ratio_str]["random"]["losses"])
        stratified_data.append(results["results"][ratio_str]["stratified"]["losses"])
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    x_positions = np.arange(len(sample_ratios))
    width = 0.35
    
    random_boxes = ax.boxplot(random_data, positions=x_positions - width/2, 
                             patch_artist=True, widths=width, 
                             boxprops=dict(facecolor='lightblue', alpha=0.7),
                             medianprops=dict(color='blue', linewidth=2),
                             flierprops=dict(marker='o', markerfacecolor='blue', markersize=4))
    
    stratified_boxes = ax.boxplot(stratified_data, positions=x_positions + width/2, 
                                 patch_artist=True, widths=width,
                                 boxprops=dict(facecolor='lightcoral', alpha=0.7),
                                 medianprops=dict(color='red', linewidth=2),
                                 flierprops=dict(marker='o', markerfacecolor='red', markersize=4))
    
    ax.set_xlabel("Sample Ratio")
    ax.set_ylabel("Average Loss")
    ax.set_title(f"Loss Variance Analysis: {results['dataset']} - {results['model'].split('/')[-1]}")
    
    ax.set_xticks(x_positions)
    ax.set_xticklabels([f"{ratio:.1f}" for ratio in sample_ratios])
    
    ax.legend([random_boxes["boxes"][0], stratified_boxes["boxes"][0]], 
              ['Random', 'Stratified'], loc='upper right')
    
    ax.grid(True, linestyle="--", alpha=0.3)
    
    plt.tight_layout()
    
    output_file = os.path.join(output_dir, f"loss_variance_boxplot.png")
    plt.savefig(output_file, dpi=DPI, bbox_inches="tight")
    print(f"Saved boxplot to: {output_file}")
    plt.close()
    
    return output_file


def main():
    if torch.distributed.is_initialized() and not torch.distributed.get_rank() == 0:
        return
    
    print("Starting Loss Variance Visualization...")
    
    json_file = "ablation/stratified_loss_results.json"
    
    if not os.path.exists(json_file):
        print(f"Error: {json_file} not found!")
        print("Please generate the results file before visualization.")
        return
    
    results = load_results(json_file)
    print(f"Loaded results for {results['dataset']} - {results['model']}")
    print(f"Sample ratios: {results['sample_ratios']}")
    print(f"Number of trials: {results['num_trials']}")
    
    print("\nCreating boxplot...")
    boxplot_file = create_boxplot(results)
    
    print(f"\nBoxplot saved to: {boxplot_file}")
    
    print(f"\nBrief Statistics:")
    for ratio in results["sample_ratios"]:
        ratio_str = str(ratio)
        random_data = results["results"][ratio_str]["random"]
        stratified_data = results["results"][ratio_str]["stratified"]
        
        print(f"  Ratio {ratio:.1f}:")
        print(f"    Random - Mean: {random_data['mean']:.4f}, Std: {random_data['std']:.4f}")
        print(f"    Stratified - Mean: {stratified_data['mean']:.4f}, Std: {stratified_data['std']:.4f}")
        print(f"    Variance ratio (Random/Stratified): {random_data['std']**2/stratified_data['std']**2:.2f}")


if __name__ == "__main__":
    main()
