"""
Utility functions for loading and processing experimental results from W&B runs.
"""

import pickle
import pandas as pd
import wandb
from typing import List, Dict, Any, Tuple


def load_runs_from_wandb(
    run_names: List[str],
    project_path: str = "llm_uq/llm_uq"
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, Dict[str, Any]]]:
    """
    Load results from multiple W&B runs and organize them by model and dataset.
    
    Parameters
    ----------
    run_names : List[str]
        List of W&B run names or IDs to load
    project_path : str, optional
        W&B project path in format "entity/project", by default "llm_uq/llm_uq"
    
    Returns
    -------
    all_dfs : Dict[str, pd.DataFrame]
        Dictionary mapping "{model_name}_{dataset_name}" keys to DataFrames
        containing the results for each run
    run_metadata : Dict[str, Dict[str, Any]]
        Dictionary mapping keys to metadata dictionaries containing:
        - run_name: W&B run name
        - model_name: Model name from config
        - dataset_name: Dataset name from config
        - experiment: Experiment number
        - config: Full W&B config
        - file_path: Path to the results pickle file
    
    Examples
    --------
    >>> run_names = ["playful-disco-576", "eeba82mf"]
    >>> dfs, metadata = load_runs_from_wandb(run_names)
    >>> print(list(dfs.keys()))
    ['gemma-3-27b_TRIVIAQA', 'llama-3-8b_NQ']
    """
    api = wandb.Api()
    
    # Extract metadata and file paths for each run
    filenames = {}
    run_metadata = {}
    
    for run_name in run_names:
        run = api.run(f"{project_path}/{run_name}")
        config = run.config
        
        local_run_name = run.name
        model_name = config["model"]["model_name"]
        dataset_name = config["dataset"]["dataset_name"].split("/")[0]
        experiment = config["experiment"]
        out_path = (
            config["output"]["path"] + 
            f"experiment_{experiment}/{local_run_name}/results.pkl"
        )
        
        # Create a readable key for this run
        key = f"{model_name}_{dataset_name}"
        
        # Store file path
        filenames[key] = out_path
        
        # Store metadata for reference
        run_metadata[key] = {
            "run_name": local_run_name,
            "wandb_id": run_name,
            "model_name": model_name,
            "dataset_name": dataset_name,
            "experiment": experiment,
            "config": config,
            "file_path": out_path
        }
        
        print(f"Registered: {key} -> {out_path}")
    
    # Load all results into a dictionary
    all_results = {}
    all_dfs = {}
    
    for key, filename in filenames.items():
        try:
            with open(filename, "rb") as f:
                all_results[key] = pickle.load(f)
            all_dfs[key] = pd.DataFrame(all_results[key])
            print(f"Loaded {key}: {len(all_dfs[key])} samples")
        except FileNotFoundError:
            print(f"Warning: Could not find file for {key}: {filename}")
        except Exception as e:
            print(f"Error loading {key}: {e}")
    
    return all_dfs, run_metadata


def load_runs_from_paths(
    file_paths: Dict[str, str]
) -> Tuple[Dict[str, pd.DataFrame], Dict[str, Dict[str, Any]]]:
    """
    Load results from pickle files given explicit file paths.
    
    Parameters
    ----------
    file_paths : Dict[str, str]
        Dictionary mapping descriptive keys to file paths
    
    Returns
    -------
    all_dfs : Dict[str, pd.DataFrame]
        Dictionary mapping keys to DataFrames containing the results
    metadata : Dict[str, Dict[str, Any]]
        Dictionary mapping keys to metadata dictionaries containing file_path
    
    Examples
    --------
    >>> paths = {
    ...     "model1_dataset1": "/path/to/results1.pkl",
    ...     "model2_dataset2": "/path/to/results2.pkl"
    ... }
    >>> dfs, metadata = load_runs_from_paths(paths)
    """
    all_results = {}
    all_dfs = {}
    metadata = {}
    
    for key, filename in file_paths.items():
        try:
            with open(filename, "rb") as f:
                all_results[key] = pickle.load(f)
            all_dfs[key] = pd.DataFrame(all_results[key])
            metadata[key] = {"file_path": filename}
            print(f"Loaded {key}: {len(all_dfs[key])} samples from {filename}")
        except FileNotFoundError:
            print(f"Warning: Could not find file for {key}: {filename}")
        except Exception as e:
            print(f"Error loading {key}: {e}")
    
    return all_dfs, metadata
