import argparse
import torch

def analyze_fisher(file_path):
    """Analyze saved Fisher information results"""
    # Load saved results
    data = torch.load(file_path, map_location="cpu")
    fisher_dict = data["fisher_information"]
    args = data["args"]

    print("="*60)
    print("📂 Analyzing Fisher Information File")
    print(f"Model:           {args['model']}")
    print(f"Dataset:         {args['dataset']} (subset: {args['subset']})")
    print(f"Samples:         {args['num_samples']}")
    print(f"Max length:      {args['max_len']}")
    print(f"Max new tokens:  {args['max_new_tokens']}")
    print("="*60)

    print(f"Loaded Fisher information for {len(fisher_dict)} parameters.\n")

    # Global accumulation
    total_sum = 0.0
    total_sq_sum = 0.0
    total_nonzero = 0
    total_elems = 0
    global_max = float("-inf")
    global_min = float("inf")

    for name, fisher in fisher_dict.items():
        fisher = fisher.cpu()
        elems = fisher.numel()

        total_sum += fisher.sum().item()
        total_sq_sum += (fisher ** 2).sum().item()
        total_nonzero += (fisher != 0).sum().item()
        total_elems += elems

        fmax = fisher.max().item()
        fmin = fisher.min().item()
        if fmax > global_max:
            global_max = fmax
        if fmin < global_min:
            global_min = fmin

        mean_val = fisher.mean().item()
        print(f"[{name}] Shape: {tuple(fisher.shape)} | Layer mean={mean_val:.6e}")

    # Global statistics
    global_mean = total_sum / total_elems
    global_var = total_sq_sum / total_elems - global_mean ** 2
    global_std = global_var ** 0.5

    print("\n" + "="*60)
    print("📊 Global Fisher Information Statistics")
    print(f"Total parameters:     {total_elems:,}")
    print(f"Non-zero entries:     {total_nonzero:,} ({100*total_nonzero/total_elems:.2f}%)")
    print(f"Global mean:          {global_mean:.6e}")
    print(f"Global std:           {global_std:.6e}")
    print(f"Global min:           {global_min:.6e}")
    print(f"Global max:           {global_max:.6e}")
    print("="*60)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Analyze Fisher information results file")
    parser.add_argument("--file", type=str, required=True,
                        help="Path to saved Fisher information .pt file")
    args = parser.parse_args()

    analyze_fisher(args.file)