#!/usr/bin/env python3

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Script to fetch TensorBoard data from a remote instance and upload to Weights & Biases.
This script connects to a remote TensorBoard server, extracts scalar metrics for a specific run,
and uploads them to a WandB project.
"""

import argparse
import sys
from typing import Dict, List, Tuple
from urllib.parse import quote, urljoin

import requests
import wandb


def test_tensorboard_connection(tensorboard_url: str) -> bool:
    """
    Test if TensorBoard instance is accessible.

    Args:
        tensorboard_url: Base URL of the TensorBoard instance

    Returns:
        True if accessible, False otherwise
    """
    try:
        response = requests.get(tensorboard_url, timeout=10)
        return response.status_code == 200
    except requests.RequestException:
        return False


def fetch_tensorboard_runs(tensorboard_url: str) -> List[str]:
    """
    Fetch available runs from TensorBoard instance.

    Args:
        tensorboard_url: Base URL of the TensorBoard instance

    Returns:
        List of available run names
    """
    # Try different possible endpoints for runs
    endpoints_to_try = ["/data/runs", "/data/plugin/scalars/runs", "/api/runs"]

    for endpoint in endpoints_to_try:
        runs_url = urljoin(tensorboard_url, endpoint)
        print(f"Trying endpoint: {runs_url}")

        try:
            response = requests.get(runs_url, timeout=30)
            response.raise_for_status()
            runs_data = response.json()

            # Handle different response formats
            if isinstance(runs_data, list):
                return runs_data
            elif isinstance(runs_data, dict):
                return list(runs_data.keys())
            else:
                print(f"Unexpected response format from {endpoint}: {type(runs_data)}")
                continue

        except requests.RequestException as e:
            print(f"Error with endpoint {endpoint}: {e}")
            continue

    print("No working runs endpoint found")
    return []


def fetch_tensorboard_tags(tensorboard_url: str, run_name: str) -> List[str]:
    """
    Fetch available scalar tags for a specific run.

    Args:
        tensorboard_url: Base URL of the TensorBoard instance
        run_name: Name of the run to fetch tags for

    Returns:
        List of available scalar tag names
    """
    # Try different possible endpoints for tags
    encoded_run = quote(run_name)
    endpoints_to_try = [
        f"/data/plugin/scalars/tags?run={encoded_run}",
        f"/data/tags?run={encoded_run}",
        f"/api/tags?run={encoded_run}",
    ]

    for endpoint in endpoints_to_try:
        tags_url = urljoin(tensorboard_url, endpoint)
        print(f"Trying tags endpoint: {tags_url}")

        try:
            response = requests.get(tags_url, timeout=30)
            response.raise_for_status()
            tags_data = response.json()

            # Handle different response formats
            if isinstance(tags_data, list):
                return tags_data
            elif isinstance(tags_data, dict):
                return list(tags_data.keys())
            else:
                print(f"Unexpected tags response format: {type(tags_data)}")
                continue

        except requests.RequestException as e:
            print(f"Error with tags endpoint {endpoint}: {e}")
            continue

    print(f"No working tags endpoint found for run {run_name}")
    return []


def fetch_tensorboard_scalars(
    tensorboard_url: str, run_name: str, tag: str
) -> List[Tuple[int, float, float]]:
    """
    Fetch scalar data for a specific run and tag.

    Args:
        tensorboard_url: Base URL of the TensorBoard instance
        run_name: Name of the run
        tag: Name of the scalar tag

    Returns:
        List of (step, timestamp, value) tuples
    """
    encoded_run = quote(run_name)
    encoded_tag = quote(tag)

    # Try different possible endpoints for scalar data
    endpoints_to_try = [
        f"/data/plugin/scalars/scalars?run={encoded_run}&tag={encoded_tag}",
        f"/data/scalars?run={encoded_run}&tag={encoded_tag}",
        f"/api/scalars?run={encoded_run}&tag={encoded_tag}",
    ]

    for endpoint in endpoints_to_try:
        scalars_url = urljoin(tensorboard_url, endpoint)
        print(f"Trying scalars endpoint: {scalars_url}")

        try:
            response = requests.get(scalars_url, timeout=30)
            response.raise_for_status()
            scalars_data = response.json()

            # Handle different response formats
            if isinstance(scalars_data, list):
                # Assume format is [[timestamp, step, value], ...]
                if scalars_data and len(scalars_data[0]) >= 3:
                    return [(point[1], point[0], point[2]) for point in scalars_data]
                # Or format might be [step, timestamp, value]
                elif scalars_data and len(scalars_data[0]) >= 3:
                    return [(point[0], point[1], point[2]) for point in scalars_data]
            elif isinstance(scalars_data, dict) and "values" in scalars_data:
                # Handle nested format
                values = scalars_data["values"]
                return [(point[1], point[0], point[2]) for point in values]
            else:
                print(f"Unexpected scalars response format: {type(scalars_data)}")
                print(
                    f"Sample data: {scalars_data[:2] if isinstance(scalars_data, list) else scalars_data}"
                )
                continue

        except requests.RequestException as e:
            print(f"Error with scalars endpoint {endpoint}: {e}")
            continue

    print(f"No working scalars endpoint found for {run_name}/{tag}")
    return []


def discover_tensorboard_structure(tensorboard_url: str) -> None:
    """
    Discover the structure of the TensorBoard instance by trying common endpoints.
    """
    print("Discovering TensorBoard API structure...")

    common_endpoints = [
        "/",
        "/data/logdir",
        "/data/runs",
        "/data/plugin/scalars/runs",
        "/api/runs",
        "/data/environment",
        "/data/plugins",
    ]

    for endpoint in common_endpoints:
        url = urljoin(tensorboard_url, endpoint)
        try:
            response = requests.get(url, timeout=10)
            print(f"{endpoint}: {response.status_code}")
            if response.status_code == 200:
                try:
                    data = response.json()
                    if isinstance(data, dict) and len(str(data)) < 200:
                        print(f"  Content: {data}")
                    elif isinstance(data, list) and len(data) < 10:
                        print(f"  Content: {data}")
                    else:
                        print(
                            f"  Content type: {type(data)}, length: {len(data) if hasattr(data, '__len__') else 'unknown'}"
                        )
                except Exception:
                    content_preview = (
                        response.text[:100] + "..."
                        if len(response.text) > 100
                        else response.text
                    )
                    print(f"  Content (text): {content_preview}")
        except requests.RequestException as e:
            print(f"{endpoint}: ERROR - {e}")

    print()


def fetch_run_data(
    tensorboard_url: str, run_name: str, specific_tags: List[str] = None
) -> Dict[str, List[Tuple[int, float, float]]]:
    """
    Fetch all scalar data for a specific run.

    Args:
        tensorboard_url: Base URL of the TensorBoard instance
        run_name: Name of the run to fetch data for
        specific_tags: Optional list of specific tags to fetch. If None, fetches all tags.

    Returns:
        Dictionary mapping tag names to lists of (step, timestamp, value) tuples
    """
    print(f"Fetching data for run: {run_name}")

    # Get available tags for this run
    if specific_tags is None:
        tags = fetch_tensorboard_tags(tensorboard_url, run_name)
        if not tags:
            print(f"No tags found for run {run_name}")
            return {}
    else:
        tags = specific_tags

    print(f"Found {len(tags)} tags: {tags}")

    # Fetch data for each tag
    scalar_data = {}
    for tag in tags:
        print(f"  Fetching data for tag: {tag}")
        data = fetch_tensorboard_scalars(tensorboard_url, run_name, tag)
        if data:
            scalar_data[tag] = data
            print(f"    Retrieved {len(data)} data points")
        else:
            print(f"    No data found for tag: {tag}")

    return scalar_data


def upload_to_wandb(
    scalar_data: Dict[str, List[Tuple[int, float, float]]],
    entity: str,
    project: str,
    run_name: str = None,
    run_id: str = None,
    tags: List[str] = None,
    metric_mapping: Dict[str, str] = None,
) -> None:
    """
    Upload scalar data to Weights & Biases.

    Args:
        scalar_data: Dictionary of metric data from TensorBoard
        entity: WandB entity (username/organization)
        project: WandB project name
        run_name: Optional run name (will be auto-generated if not provided)
        run_id: Optional run ID for resuming existing runs
        tags: Optional list of tags for the run
        metric_mapping: Optional dictionary to map TensorBoard metric names to WandB metric names
    """
    # Initialize WandB run
    run_config = {
        "entity": entity,
        "project": project,
    }

    if run_name:
        run_config["name"] = run_name
    if run_id:
        run_config["id"] = run_id
        run_config["resume"] = "allow"
    if tags:
        run_config["tags"] = tags

    print(f"Initializing WandB run in {entity}/{project}")
    if run_name:
        print(f"  Run name: {run_name}")
    if run_id:
        print(f"  Run ID: {run_id}")

    # Initialize WandB
    run = wandb.init(**run_config)

    try:
        # Get all unique steps across all metrics
        all_steps = set()
        for metric_data in scalar_data.values():
            all_steps.update(step for step, _, _ in metric_data)

        all_steps = sorted(all_steps)
        print(f"Uploading data for {len(all_steps)} steps")

        # Create a mapping of step -> metrics for efficient uploading
        step_data = {}
        for step in all_steps:
            step_data[step] = {}

        # Populate step_data with metric values
        for metric_name, metric_data in scalar_data.items():
            # Apply metric name mapping if provided
            wandb_metric_name = metric_name
            if metric_mapping and metric_name in metric_mapping:
                wandb_metric_name = metric_mapping[metric_name]
                print(f"Mapping metric '{metric_name}' -> '{wandb_metric_name}'")

            # Create a mapping of step -> value for this metric
            step_to_value = {step: value for step, _, value in metric_data}

            # Add to step_data
            for step in all_steps:
                if step in step_to_value:
                    step_data[step][wandb_metric_name] = step_to_value[step]

        # Upload data step by step
        for step in all_steps:
            if step_data[step]:  # Only log if we have data for this step
                wandb.log(step_data[step], step=step)

        print(f"Successfully uploaded {len(all_steps)} data points to WandB")
        print(f"Run URL: {run.url}")

    finally:
        # Finish the WandB run
        wandb.finish()


def main():
    """Main function to orchestrate the remote TensorBoard to WandB migration."""
    parser = argparse.ArgumentParser(
        description="Fetch TensorBoard data from remote instance and upload to Weights & Biases"
    )

    parser.add_argument(
        "--tensorboard-url",
        default="http://209.38.22.230:6006",
        help="URL of the TensorBoard instance",
    )

    parser.add_argument(
        "--run-name-filter",
        default="olmo_1b_allenai",
        help="Run name to filter for (can be partial match)",
    )

    parser.add_argument(
        "--tags",
        nargs="*",
        default=["train/loss"],
        help="Specific tags to fetch (default: train/loss)",
    )

    parser.add_argument(
        "--entity",
        default="ajanthan-pluralis-research",
        help="WandB entity (username/organization)",
    )

    parser.add_argument("--project", default="torchtitan", help="WandB project name")

    parser.add_argument(
        "--wandb-run-name",
        help="WandB run name (optional, will be auto-generated if not provided)",
    )

    parser.add_argument(
        "--wandb-run-id", help="WandB run ID for resuming existing runs (optional)"
    )

    parser.add_argument(
        "--wandb-tags", nargs="*", help="Tags to add to the WandB run (optional)"
    )

    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Fetch the TensorBoard data but don't upload to WandB",
    )

    parser.add_argument(
        "--list-runs", action="store_true", help="List available runs and exit"
    )

    parser.add_argument(
        "--discover",
        action="store_true",
        help="Discover TensorBoard API structure and exit",
    )

    parser.add_argument(
        "--metric-mapping",
        nargs="*",
        help="Custom metric mappings in format 'tensorboard_name:wandb_name' (optional)",
    )

    args = parser.parse_args()

    # Ensure tensorboard_url has proper format
    tensorboard_url = args.tensorboard_url.rstrip("/")

    print("Remote TensorBoard to WandB Migration")
    print("=====================================")
    print(f"TensorBoard URL: {tensorboard_url}")
    print(f"Target: {args.entity}/{args.project}")
    if args.dry_run:
        print("DRY RUN MODE - Will not upload to WandB")
    print()

    # Test connection first
    print("Testing TensorBoard connection...")
    if not test_tensorboard_connection(tensorboard_url):
        print("ERROR: Cannot connect to TensorBoard instance")
        sys.exit(1)
    print("Connection successful!")
    print()

    try:
        # Discover API structure if requested
        if args.discover:
            discover_tensorboard_structure(tensorboard_url)
            return

        # List available runs if requested
        if args.list_runs:
            print("Fetching available runs...")
            runs = fetch_tensorboard_runs(tensorboard_url)
            if runs:
                print(f"Available runs ({len(runs)}):")
                for run in runs:
                    print(f"  - {run}")
            else:
                print("No runs found or unable to access runs endpoint")
                print("Try using --discover to see available API endpoints")
            return

        # Find matching runs
        print("Fetching available runs...")
        all_runs = fetch_tensorboard_runs(tensorboard_url)
        if not all_runs:
            print("No runs found or unable to connect to TensorBoard instance")
            print("Try using --discover to see available API endpoints")
            sys.exit(1)

        # Filter runs based on the filter pattern
        matching_runs = [run for run in all_runs if args.run_name_filter in run]

        if not matching_runs:
            print(f"No runs found matching filter '{args.run_name_filter}'")
            print(f"Available runs: {all_runs}")
            sys.exit(1)

        print(f"Found {len(matching_runs)} matching runs:")
        for run in matching_runs:
            print(f"  - {run}")

        # For now, use the first matching run
        target_run = matching_runs[0]
        print(f"\nUsing run: {target_run}")

        # Fetch data for the target run
        scalar_data = fetch_run_data(tensorboard_url, target_run, args.tags)

        if not scalar_data:
            print("No scalar data found for the specified run and tags")
            return

        print(f"\nFound data for {len(scalar_data)} metrics:")
        total_points = sum(len(data) for data in scalar_data.values())
        for metric_name, data in scalar_data.items():
            print(f"  {metric_name}: {len(data)} points")
        print(f"Total data points: {total_points}")

        if args.dry_run:
            print("\nDry run complete - would upload the above metrics to WandB")
            return

        # Upload to WandB
        print("\nUploading to WandB...")
        wandb_run_name = args.wandb_run_name or f"{target_run}_loss_curve"

        # Define metric mapping from TensorBoard to WandB
        metric_mapping = {"train/loss": "loss_metrics/global_avg_loss"}

        # Parse custom metric mappings if provided
        if args.metric_mapping:
            for mapping in args.metric_mapping:
                if ":" in mapping:
                    tb_name, wandb_name = mapping.split(":", 1)
                    metric_mapping[tb_name.strip()] = wandb_name.strip()
                    print(
                        f"Added custom mapping: {tb_name.strip()} -> {wandb_name.strip()}"
                    )
                else:
                    print(
                        f"Warning: Invalid mapping format '{mapping}', expected 'tensorboard_name:wandb_name'"
                    )

        upload_to_wandb(
            scalar_data=scalar_data,
            entity=args.entity,
            project=args.project,
            run_name=wandb_run_name,
            run_id=args.wandb_run_id,
            tags=args.wandb_tags,
            metric_mapping=metric_mapping,
        )

        print("\nMigration completed successfully!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback

        traceback.print_exc()
        sys.exit(1)


if __name__ == "__main__":
    main()
