import os
import wandb
import pandas as pd
from tqdm import tqdm
import logging
from src.utils import metric_from_dataset_name

logger = logging.getLogger("wandb")
logger.setLevel(logging.ERROR)


def pull_wandb_runs(
    entity: str,
    project_name: str,
    dataset_name: str,
    model_name: str,
    eigen_subset_name: str,
    output_dir: str = "results",
    **kwargs,
) -> None:
    """
    Pull all finished runs from a W&B project and save them as a CSV file.

    Args:
        entity (str): W&B entity/username
        project_name (str): Name of the W&B project
        output_dir (str): Directory to save the CSV file
    """
    # Initialize W&B API
    api = wandb.Api()
    if "config.eigen_subset_args_window_size" in kwargs:
        kwargs["config.eigen_subset_args_window_size"] = int(kwargs["config.eigen_subset_args_window_size"])
    # Get all finished runs
    runs = api.runs(
        f"{entity}/{project_name}",
        per_page=10,
        filters={
            "state": "finished",
            "config.dataset_shortname": dataset_name,
            "config.model_name": model_name,
            "config.eigen_subset_name": eigen_subset_name,
            **kwargs,
        },
    )

    print(f"Found {len(runs)} finished runs")

    # Collect run data
    data = []
    for run in tqdm(runs, desc="Pulling runs"):
        try:
            # Get config and summary
            config = dict(run.config)
            summary = dict(run.summary)

            if len(config) > 0 and len(summary) > 0:
                # Remove wandb-specific summary data
                summary.pop("_wandb", None)

                # Combine config and summary
                run_data = config.copy()
                run_data.update(summary)

                data.append(run_data)
        except Exception as e:
            print(f"Error processing run: {e}")

    if not data:
        print("No valid runs found")
        return

    # Convert to DataFrame
    df = pd.DataFrame(data)

    if output_dir is not None:
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)

        # Save to CSV
        output_path = os.path.join(output_dir, f"{project_name}.csv")
        df.to_csv(output_path, index=False)
        print(f"Saved {len(df)} runs to {output_path}")

    return df


def report_best(df):
    metric = metric_from_dataset_name(df["dataset_name"][0])
    grouped = df.groupby("group_id")[[f"Best/Val/{metric}", f"Best/Test/{metric}"]].agg(["mean", "std", "count"])
    mask = grouped[f"Best/Val/{metric}"]["count"] == (
        20
        if df["dataset_shortname"][0] == "wikics"
        else (
            1
            if df["dataset_shortname"][0] == "arxiv"
            else (1 if df["dataset_shortname"][0] in ["paris", "shanghai"] else 10)
        )
    )
    grouped = grouped[mask]
    grouped
    # Find the row with the best validation accuracy
    best_val_row = grouped.loc[grouped[f"Best/Val/{metric}"]["mean"].idxmax()]
    print("\nBest validation performance:")
    print(f"Group ID: {best_val_row.name}")
    # Get all runs from the best performing group
    best_group_runs = df[df["group_id"] == best_val_row.name]

    # Sort by model_split_index and create a formatted table
    sorted_runs = best_group_runs.sort_values("model_split_index")
    print("\nResults by split:")
    print(f"{'Split':^6} | {'Val Acc':^8} | {'Test Acc':^8}")
    print("-" * 28)
    for _, run in sorted_runs.iterrows():
        print(
            f"{int(run['model_split_index']):^6} | "
            f"{run[f'Best/Val/{metric}']:^8.4f} | "
            f"{run[f'Best/Test/{metric}']:^8.4f}"
        )
    print()

    print(
        f"Val {metric}: "
        f"{best_val_row[f'Best/Val/{metric}']['mean']:.4f} ± "
        f"{best_val_row[f'Best/Val/{metric}']['std']:.4f}"
    )
    print(
        f"Test {metric}: "
        f"{best_val_row[f'Best/Test/{metric}']['mean']:.4f} ± "
        f"{best_val_row[f'Best/Test/{metric}']['std']:.4f}"
    )


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Pull W&B runs and save to CSV")
    parser.add_argument("--project", type=str, default="BTS", help="W&B project name")
    parser.add_argument("--dataset-name", type=str, default="chameleon", help="Dataset name")
    parser.add_argument("--model-name", type=str, default="gt-bts", help="Model name")
    parser.add_argument("--eigen-subset-name", type=str, default="bts2", help="Eigen subset")
    parser.add_argument("--output-dir", type=str, default="results", help="Output directory for CSV file")
    parser.add_argument("--entity", type=str, default="BTS", help="W&B entity/username")
    parser.add_argument(
        "--kwargs", type=str, nargs="*", default=[], help="Additional arguments in the format key=value"
    )
    args = parser.parse_args()

    kwargs = {}
    for arg in args.kwargs:
        key, value = arg.split("=")
        kwargs[key] = value

    df = pull_wandb_runs(
        args.entity, args.project, args.dataset_name, args.model_name, args.eigen_subset_name, args.output_dir, **kwargs
    )

    print(args.project)
    report_best(df)
