import numpy as np
import matplotlib.pyplot as plt
import os

def plot_and_save_batch_results(id_history, dataset_name, num_rounds, output_dir='results/exp1'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    client_labels = {"Client_1_Normal": "Normal", "Client_2_FGSM": "FGSM", "Client_3_PGD": "PGD", 
                     "Client_4_BasicIterative": "BasicIterative", "Client_5_DeepFool": "DeepFool"}
    rounds = np.arange(1, num_rounds + 1)
    
    plt.figure(figsize=(10, 6))
    for client_name, ids in id_history.items():
        label = client_labels.get(client_name, client_name)
        plt.plot(rounds, ids, marker='o', linestyle='-', label=label)
    plt.xlabel('Round'); plt.ylabel('Intrinsic Dimension'); plt.title(f'ID of Gradients per Round ({dataset_name.upper()})')
    plt.legend(title="Client Type"); plt.grid(True, linestyle='--', alpha=0.6); plt.tight_layout()
    raw_path = os.path.join(output_dir, f"id_per_client_{dataset_name.lower()}.png")
    plt.savefig(raw_path, dpi=300); plt.close()
    print(f"Raw ID plot saved to {raw_path}")

    normalized_id_history = {}
    all_clients = list(id_history.keys())
    for client in all_clients: normalized_id_history[client] = []
    for i in range(num_rounds):
        round_values = [id_history[client][i] for client in all_clients]
        min_val, max_val = min(round_values), max(round_values)
        for client in all_clients:
            val = id_history[client][i]
            normalized_val = 0.5 if max_val == min_val else (val - min_val) / (max_val - min_val)
            normalized_id_history[client].append(normalized_val)
            
    plt.figure(figsize=(10, 6))
    for client_name, ids in normalized_id_history.items():
        label = client_labels.get(client_name, client_name)
        plt.plot(rounds, ids, marker='s', linestyle='--', label=label)
    plt.xlabel('Round'); plt.ylabel('Normalized Intrinsic Dimension'); plt.title(f'Min-Max Normalized ID per Round ({dataset_name.upper()})')
    plt.legend(title="Client Type"); plt.grid(True, linestyle='--', alpha=0.6); plt.tight_layout()
    norm_path = os.path.join(output_dir, f"id_per_client_{dataset_name.lower()}_min-max.png")
    plt.savefig(norm_path, dpi=300); plt.close()
    print(f"Normalized ID plot saved to {norm_path}")

def plot_individual_comparison(results, output_dir='results/exp2a'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    for attack_name, data in results.items():
        norm_ids, adv_ids = data['norm_ids'], data['adv_ids']
        samples = np.arange(len(norm_ids))
        
        plt.figure(figsize=(12, 5))
        plt.plot(samples, norm_ids, marker='o', linestyle='-', label='Normal Data ID')
        plt.plot(samples, adv_ids, marker='x', linestyle='--', label='Adversarial Data ID')
        plt.xlabel('Sample Index'); plt.ylabel('Intrinsic Dimension (ID)')
        plt.title(f'Incremental ID Comparison - {attack_name} (SVHN)')
        plt.legend(); plt.tight_layout()
        
        output_path = os.path.join(output_dir, f"id_per_sample_svhn_{attack_name}.png")
        plt.savefig(output_path, dpi=300); plt.close()
        print(f"Plot saved to {output_path}")

def plot_evaluation_histogram(results, bounds, output_dir='results/exp2b'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        
    lower_bound, upper_bound = bounds
    
    plt.figure(figsize=(12, 8))
    for group_name, ids in results.items():
        plt.hist(ids, bins=20, alpha=0.6, label=group_name)
    
    plt.axvline(lower_bound, color='red', linestyle='--', label="Threshold Bounds")
    plt.axvline(upper_bound, color='red', linestyle='--')
    plt.xlabel("Incremental Intrinsic Dimension (ID)"); plt.ylabel("Frequency")
    plt.title("Distribution of Incremental IDs for Evaluation Groups on SVHN")
    plt.legend()
    output_path = os.path.join(output_dir, "evaluation_svhn_histogram.png")
    plt.savefig(output_path, dpi=300); plt.close()
    print(f"Histogram plot saved to {output_path}")