#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
COMPLETE PowerSGD Inspector - Shows EVERYTHING in detail!
"""

import os

import torch
import torch.distributed as dist
from torch.distributed.algorithms.ddp_comm_hooks.powerSGD_hook import PowerSGDState


def setup_minimal_distributed():
    """Set up minimal distributed environment."""
    os.environ.setdefault("MASTER_ADDR", "localhost")
    os.environ.setdefault("MASTER_PORT", "12355")
    os.environ.setdefault("RANK", "0")
    os.environ.setdefault("WORLD_SIZE", "1")

    try:
        dist.init_process_group(
            backend="gloo", init_method="env://", world_size=1, rank=0
        )
        return True
    except Exception as e:
        print(f"Failed to initialize distributed: {e}")
        return False


def find_best_checkpoint():
    """Find the best checkpoint to examine."""
    outputs_dir = "/home/ubuntu/projects/torchtitan/outputs"
    checkpoints = []

    for exp_dir in os.listdir(outputs_dir):
        exp_path = os.path.join(outputs_dir, exp_dir)
        if os.path.isdir(exp_path):
            checkpoint_dir = os.path.join(exp_path, "checkpoint")
            if os.path.exists(checkpoint_dir):
                for step_dir in os.listdir(checkpoint_dir):
                    step_path = os.path.join(checkpoint_dir, step_dir)
                    powersgd_path = os.path.join(step_path, "powersgd.pt")
                    if os.path.exists(powersgd_path):
                        size = os.path.getsize(powersgd_path)
                        checkpoints.append((powersgd_path, size))

    # Sort by size and pick a medium-large one for interesting data
    checkpoints.sort(key=lambda x: x[1])
    if len(checkpoints) > 1:
        return checkpoints[-2][0]  # Second largest
    return checkpoints[0][0] if checkpoints else None


def complete_powersgd_inspection(output_file="complete_powersgd_inspection.txt"):
    """Complete inspection with all details saved to file."""

    with open(output_file, "w") as f:

        def print_and_save(*args, **kwargs):
            """Print to console and save to file."""
            text = " ".join(str(arg) for arg in args)
            print(text, **kwargs)
            f.write(text + "\n")
            f.flush()  # Ensure immediate write

        print_and_save("🚀 COMPLETE POWERSGD CHECKPOINT INSPECTION")
        print_and_save("=" * 80)

        # Setup distributed
        print_and_save("\n🔧 Setting up minimal distributed environment...")
        if not setup_minimal_distributed():
            print_and_save("❌ Failed to setup distributed")
            return None
        print_and_save("✅ Distributed setup successful")

        # Find checkpoint
        checkpoint_path = find_best_checkpoint()
        if not checkpoint_path:
            print_and_save("❌ No checkpoint found")
            return None

        print_and_save(f"\n📁 CHECKPOINT: {checkpoint_path}")
        print_and_save(
            f"📏 SIZE: {os.path.getsize(checkpoint_path) / (1024*1024):.1f} MB"
        )

        # Load checkpoint
        print_and_save("\n🔄 Loading checkpoint...")
        try:
            torch.serialization.add_safe_globals([PowerSGDState])
            state = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
            print_and_save("✅ Successfully loaded PowerSGD state!")
        except Exception as e:
            print_and_save(f"❌ Failed to load: {e}")
            return None

        # Basic info
        print_and_save(f"\n🏷️  TYPE: {type(state)}")

        # Get all attributes
        all_attrs = [attr for attr in dir(state) if not attr.startswith("_")]
        print_and_save(f"\n📋 ALL ATTRIBUTES ({len(all_attrs)}):")
        print_and_save(f"    {', '.join(all_attrs)}")

        # Examine each attribute in detail
        print_and_save("\n" + "=" * 80)
        print_and_save("📊 DETAILED ATTRIBUTE EXAMINATION")
        print_and_save("=" * 80)

        for attr_name in all_attrs:
            try:
                attr_value = getattr(state, attr_name)
                print_and_save(f"\n🔍 {attr_name.upper()}")
                print_and_save("-" * 50)
                print_and_save(f"Type: {type(attr_value)}")

                if isinstance(attr_value, dict):
                    print_and_save(f"Dictionary with {len(attr_value)} entries")
                    print_and_save(f"Keys: {list(attr_value.keys())}")

                    # Show details for EVERY key
                    for key in sorted(attr_value.keys()):
                        value = attr_value[key]
                        print_and_save(f"\n  📌 Key {key}:")

                        if hasattr(value, "shape"):
                            print_and_save(f"    🧮 Tensor: {value.shape} {value.dtype}")
                            print_and_save(f"    📈 Elements: {value.numel():,}")
                            print_and_save(
                                f"    📊 Stats: min={value.min():.6f}, max={value.max():.6f}"
                            )
                            print_and_save(
                                f"    📊 Mean: {value.mean():.6f}, Std: {value.std():.6f}"
                            )

                            # Show some actual values for small tensors
                            if value.numel() <= 50:
                                print_and_save(
                                    f"    📋 All values: {value.flatten().tolist()}"
                                )
                            else:
                                print_and_save(
                                    f"    📋 First 20 values: {value.flatten()[:20].tolist()}"
                                )
                                print_and_save(
                                    f"    📋 Last 20 values: {value.flatten()[-20:].tolist()}"
                                )

                            # Try to infer what this tensor might represent
                            shape_analysis = analyze_tensor_shape(
                                value.shape, attr_name, key
                            )
                            if shape_analysis:
                                print_and_save(
                                    f"    🎯 Likely represents: {shape_analysis}"
                                )

                        else:
                            print_and_save(f"    📄 Value: {value}")

                elif hasattr(attr_value, "shape"):
                    print_and_save(f"🧮 Tensor: {attr_value.shape} {attr_value.dtype}")
                    print_and_save(f"📈 Elements: {attr_value.numel():,}")
                    if attr_value.numel() > 0:
                        print_and_save(
                            f"📊 Stats: min={attr_value.min():.6f}, max={attr_value.max():.6f}"
                        )
                        print_and_save(
                            f"📊 Mean: {attr_value.mean():.6f}, Std: {attr_value.std():.6f}"
                        )

                elif callable(attr_value):
                    print_and_save("🔧 Method/Function")

                else:
                    print_and_save(f"📄 Value: {attr_value}")

            except Exception as e:
                print_and_save(f"❌ Error accessing {attr_name}: {e}")

        # Summary statistics
        print_and_save("\n" + "=" * 80)
        print_and_save("📈 SUMMARY STATISTICS")
        print_and_save("=" * 80)

        try:
            print_and_save(
                f"🔢 Matrix Approximation Rank: {state.matrix_approximation_rank}"
            )
            print_and_save(f"🔄 Current Iteration: {state.iter}")
            print_and_save(f"🚀 Start PowerSGD Iteration: {state.start_powerSGD_iter}")
            print_and_save(f"📉 Min Compression Rate: {state.min_compression_rate}")
            print_and_save(f"🔄 Use Error Feedback: {state.use_error_feedback}")
            print_and_save(f"🌡️  Warm Start: {state.warm_start}")

            if hasattr(state, "total_numel_before_compression") and hasattr(
                state, "total_numel_after_compression"
            ):
                before = state.total_numel_before_compression
                after = state.total_numel_after_compression
                ratio = before / after if after > 0 else 0
                print_and_save("\n📊 COMPRESSION STATISTICS:")
                print_and_save(f"  Before: {before:,} parameters")
                print_and_save(f"  After:  {after:,} parameters")
                print_and_save(f"  Ratio:  {ratio:.2f}:1 compression")
                print_and_save(
                    f"  Saved:  {((before - after) / before * 100):.1f}% reduction"
                )

            # Analyze the memory dictionaries
            for dict_name in ["p_memory_dict", "q_memory_dict", "error_dict"]:
                if hasattr(state, dict_name):
                    memory_dict = getattr(state, dict_name)
                    if memory_dict:
                        total_params = sum(
                            v.numel()
                            for v in memory_dict.values()
                            if hasattr(v, "numel")
                        )
                        print_and_save(f"\n📚 {dict_name.upper()}:")
                        print_and_save(f"  Layers: {len(memory_dict)}")
                        print_and_save(f"  Total parameters: {total_params:,}")

                        # Show size of each layer
                        for key in sorted(memory_dict.keys()):
                            if hasattr(memory_dict[key], "numel"):
                                print_and_save(
                                    f"  Layer {key}: {memory_dict[key].numel():,} params"
                                )

        except Exception as e:
            print_and_save(f"❌ Error in summary: {e}")

        print_and_save(f"\n✅ Complete inspection saved to: {output_file}")

        # Cleanup
        try:
            if dist.is_initialized():
                dist.destroy_process_group()
        except Exception as e:
            print_and_save(f"❌ Error in cleanup: {e}")
            pass

        return state


def analyze_tensor_shape(shape, attr_name, key):
    """Analyze what a tensor shape might represent."""
    if len(shape) == 1:
        size = shape[0]

        # Common patterns for different types of tensors
        if attr_name == "p_memory_dict":
            # P matrices are typically larger, representing left SVD components
            return f"Left SVD matrix (P) for layer {key}, likely flattened from original gradient shape"
        elif attr_name == "q_memory_dict":
            # Q matrices are typically smaller, representing right SVD components
            return f"Right SVD matrix (Q) for layer {key}, compressed representation"
        elif attr_name == "error_dict":
            # Error tensors store compression errors
            return (
                f"Error feedback tensor for layer {key}, stores compression residuals"
            )

        # Try to guess based on size patterns
        if size > 50_000_000:  # Very large
            return "Large gradient tensor, likely from major model layer"
        elif size > 1_000_000:  # Medium-large
            return "Medium gradient tensor, likely from transformer layer"
        elif size > 100_000:  # Medium
            return "Compressed representation or smaller layer"
        else:
            return "Small tensor, likely bias or embedding"

    return None


if __name__ == "__main__":
    print("🚀 Starting complete PowerSGD inspection...")

    # Run the complete inspection
    result = complete_powersgd_inspection("complete_powersgd_inspection.txt")

    if result:
        print(
            "\n🎉 SUCCESS! Complete inspection saved to 'complete_powersgd_inspection.txt'"
        )
        print("📁 File contains ALL keys, ALL layers, and detailed analysis!")
    else:
        print("\n❌ Inspection failed")
