# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Script to migrate TensorBoard event files to Weights & Biases.
This script reads all scalar metrics from a TensorBoard event file
and uploads them to a WandB project, preserving the original data structure.
"""

import argparse
import os
import sys
from typing import Dict, List, Tuple

import wandb
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator


def read_tensorboard_events(
    event_file_path: str,
) -> Dict[str, List[Tuple[int, float, float]]]:
    """
    Read scalar events from a TensorBoard event file.

    Args:
        event_file_path: Path to the TensorBoard event file

    Returns:
        Dictionary mapping metric names to lists of (step, timestamp, value) tuples
    """
    print(f"Reading TensorBoard event file: {event_file_path}")

    # Create EventAccumulator and load the event file
    event_acc = EventAccumulator(event_file_path)
    event_acc.Reload()

    # Get all available scalar tags
    scalar_tags = event_acc.Tags()["scalars"]
    print(f"Found {len(scalar_tags)} scalar metrics: {scalar_tags}")

    # Extract scalar data
    scalar_data = {}
    for tag in scalar_tags:
        scalar_events = event_acc.Scalars(tag)
        scalar_data[tag] = [
            (event.step, event.wall_time, event.value) for event in scalar_events
        ]
        print(f"  {tag}: {len(scalar_events)} data points")

    return scalar_data


def upload_to_wandb(
    scalar_data: Dict[str, List[Tuple[int, float, float]]],
    entity: str,
    project: str,
    run_name: str = None,
    run_id: str = None,
    tags: List[str] = None,
) -> None:
    """
    Upload scalar data to Weights & Biases.

    Args:
        scalar_data: Dictionary of metric data from TensorBoard
        entity: WandB entity (username/organization)
        project: WandB project name
        run_name: Optional run name (will be auto-generated if not provided)
        run_id: Optional run ID for resuming existing runs
        tags: Optional list of tags for the run
    """
    # Initialize WandB run
    run_config = {
        "entity": entity,
        "project": project,
    }

    if run_name:
        run_config["name"] = run_name
    if run_id:
        run_config["id"] = run_id
        run_config["resume"] = "allow"
    if tags:
        run_config["tags"] = tags

    print(f"Initializing WandB run in {entity}/{project}")
    if run_name:
        print(f"  Run name: {run_name}")
    if run_id:
        print(f"  Run ID: {run_id}")

    # Initialize WandB
    run = wandb.init(**run_config)

    try:
        # Get all unique steps across all metrics
        all_steps = set()
        for metric_data in scalar_data.values():
            all_steps.update(step for step, _, _ in metric_data)

        all_steps = sorted(all_steps)
        print(f"Uploading data for {len(all_steps)} steps")

        # Create a mapping of step -> metrics for efficient uploading
        step_data = {}
        for step in all_steps:
            step_data[step] = {}

        # Populate step_data with metric values
        for metric_name, metric_data in scalar_data.items():
            # Create a mapping of step -> value for this metric
            step_to_value = {step: value for step, _, value in metric_data}

            # Add to step_data
            for step in all_steps:
                if step in step_to_value:
                    step_data[step][metric_name] = step_to_value[step]

        # Upload data step by step
        for step in all_steps:
            if step_data[step]:  # Only log if we have data for this step
                wandb.log(step_data[step], step=step)

        print(f"Successfully uploaded {len(all_steps)} data points to WandB")
        print(f"Run URL: {run.url}")

    finally:
        # Finish the WandB run
        wandb.finish()


def main():
    """Main function to orchestrate the TensorBoard to WandB migration."""
    parser = argparse.ArgumentParser(
        description="Migrate TensorBoard event files to Weights & Biases"
    )

    parser.add_argument("event_file", help="Path to the TensorBoard event file")

    parser.add_argument(
        "--entity",
        default="ajanthan-pluralis-research",
        help="WandB entity (username/organization)",
    )

    parser.add_argument("--project", default="torchtitan", help="WandB project name")

    parser.add_argument(
        "--run-name",
        help="WandB run name (optional, will be auto-generated if not provided)",
    )

    parser.add_argument(
        "--run-id", help="WandB run ID for resuming existing runs (optional)"
    )

    parser.add_argument(
        "--tags", nargs="*", help="Tags to add to the WandB run (optional)"
    )

    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Read the TensorBoard file but don't upload to WandB",
    )

    args = parser.parse_args()

    # Validate event file exists
    if not os.path.exists(args.event_file):
        print(f"Error: Event file '{args.event_file}' not found")
        sys.exit(1)

    print("TensorBoard to WandB Migration")
    print("=============================")
    print(f"Event file: {args.event_file}")
    print(f"Target: {args.entity}/{args.project}")
    if args.run_name:
        print(f"Run name: {args.run_name}")
    if args.dry_run:
        print("DRY RUN MODE - Will not upload to WandB")
    print()

    try:
        # Read TensorBoard data
        scalar_data = read_tensorboard_events(args.event_file)

        if not scalar_data:
            print("No scalar data found in the event file")
            return

        print(f"\nFound data for {len(scalar_data)} metrics")
        total_points = sum(len(data) for data in scalar_data.values())
        print(f"Total data points: {total_points}")

        if args.dry_run:
            print("\nDry run complete - metrics that would be uploaded:")
            for metric_name, data in scalar_data.items():
                print(f"  {metric_name}: {len(data)} points")
            return

        # Upload to WandB
        print("\nUploading to WandB...")
        upload_to_wandb(
            scalar_data=scalar_data,
            entity=args.entity,
            project=args.project,
            run_name=args.run_name,
            run_id=args.run_id,
            tags=args.tags,
        )

        print("\nMigration completed successfully!")

    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)


if __name__ == "__main__":
    main()
