import json
import pickle
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.optimize import least_squares
import matplotlib.pyplot as plt

# set pd to show all columns
pd.set_option('display.max_columns', None)
import warnings

def load_stats_jsonl(filepath: Path, verbose: bool = True) -> pd.DataFrame:
    """Load a single stats.jsonl file and return as a DataFrame.
    
    Flattens the nested 'stage' dict so its keys become top-level columns.
    Also adds the source file path as a column for traceability.
    """
    records = []
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            record = json.loads(line)
            
            # Extract and flatten the 'stage' dict
            stage_dict = record.pop('stage', {})
            for key, value in stage_dict.items():
                record[key] = value
            
            records.append(record)
    
    df = pd.DataFrame(records)
    df['source_file'] = str(filepath)

    #get aux_labels
    config = json.load(open(filepath.parent / "config.json"))
    aux_labels = tuple(sorted(config['run']['aux_labels']))
    if verbose:
        print(f"aux_labels: {aux_labels}")
    df['aux_labels'] = [aux_labels] * len(df.index)
    df['seed'] = config['run']['seed']

    return df

def load_dir(directory: str, verbose: bool = False) -> pd.DataFrame:
    """Load all stats.jsonl files from a directory and concatenate into a single DataFrame.
    
    Args:
        directory: Path to the root directory to search (e.g., 
            '/workspace/gradient-routing/experiments/ICML-Codebase/src/results/stories/01/combined_2025-11-30_20-05-12')
    
    Returns:
        A concatenated DataFrame with all records from all stats.jsonl files.
        The 'stage' nested dict is flattened so its keys become columns with 'stage_' prefix.
    """
    directory = Path(directory)
    
    if not directory.exists():
        raise ValueError(f"Directory does not exist: {directory}")
    
    stats_files = list(directory.rglob('stats.jsonl'))
    
    if not stats_files:
        raise ValueError(f"No stats.jsonl files found in {directory}")
    
    if verbose:
        print(f"Found {len(stats_files)} stats.jsonl file(s)")
    
    dfs = []
    for filepath in stats_files:
        df = load_stats_jsonl(filepath, verbose=verbose)
        dfs.append(df)
        if verbose:
            print(f"  Loaded {len(df)} records from {filepath}")
    
    df = pd.concat(dfs, ignore_index=True)
    if verbose:
        print(f"Total: {len(df)} records")

    #convert list type columns to tuples
    for x in ['target','expert_labels','aux_labels']:
        if x not in df.columns:
            print(f"WARN: column {x} not found in df")
            continue
        df[x] = df[x].apply(lambda x: tuple(sorted(x)) if isinstance(x, list) or isinstance(x, tuple) else tuple())
    
    return df


# ============================================================
# Metrics computation
# ============================================================

def _subsample_indices(n_points: int, max_points: int = 10_000) -> np.ndarray:
    """
    Return indices for subsampling a curve of length n_points down to
    at most max_points by taking every k-th point (regular stride).
    """
    if n_points <= max_points:
        return np.arange(n_points)

    stride = 10
    return np.arange(0, n_points, stride)

def _load_val_losses(pkl_path: Path, warmup_prc: float = 0.02) -> dict[str, np.ndarray]:
    """Load validation losses from pickle file, dropping warmup portion."""
    with open(pkl_path, "rb") as f:
        d = pickle.load(f)
    out = {}
    for k, v in d.items():
        arr = np.asarray(v, float)
        if arr.size > 1:
            warmup_n = int(arr.size * warmup_prc)
            out[k] = arr[warmup_n:]  # drop first N points (warmup)
    return out


def _get_steps(n_points: int, eval_interval_steps: int = 1) -> np.ndarray:
    """Generate step indices for loss curve."""
    return (np.arange(n_points, dtype=float) + 1) * eval_interval_steps


def _power_floor(x, A, alpha, c, x0):
    """Power-law with floor: c + A * (x + x0)^(-alpha)"""
    with warnings.catch_warnings():
        warnings.filterwarnings("error", category=RuntimeWarning)
        try:
            return c + A * np.power(x + x0, -alpha)
        except RuntimeWarning as e:
            print("RuntimeWarning in _power_floor")
            print(f"x: {x}")
            print(f"A: {A}, alpha: {alpha}, c: {c}, x0: {x0}")
            print(f"x + x0: {x + x0}")
            raise  # rethrow so the fitting doesn't silently continue

def _fit_power_floor(x, y) -> tuple[float, float, float, float]:
    """Fit power-law curve to loss data. Returns (A, alpha, c, x0)."""
    x = np.asarray(x, float)
    y = np.asarray(y, float)
    # stabilize tail: enforce monotone decrease
    y_mon = np.minimum.accumulate(y)

    c0 = 1.0
    A0 = max(y_mon[0] - c0, 1e-3)
    alpha0 = 0.3

    dx = x[1] - x[0] if len(x) > 1 else 1.0
    x_min = float(x.min())
    eps = 1e-6
    lb_x0 = -x_min + eps   # ensures x + x0 > 0
    lb = [1e-10, 1e-3, -10.0, lb_x0]
    ub = [1e8, 5.0, 2.5, x.max() + 10 * dx]
    # x0_0 is within bounds (between lb[3] and ub[3])
    x0_0 = min(float(100), ub[3] * 0.9)  # Use 90% of upper bound to be safe
    p0 = np.array([A0, alpha0, c0, x0_0], float)
    def resid(p):
        return _power_floor(x, *p) - y

    res = least_squares(resid, p0, bounds=(lb, ub),
                        loss='soft_l1', f_scale=0.01, max_nfev=40000)
    return tuple(res.x)  # (A, alpha, c, x0)


def _g_inv_power_floor(ell, A, alpha, c, x0) -> np.ndarray:
    """Inverse of power-law: given loss, return equivalent step."""
    ell = np.asarray(ell, float)
    eps = 1e-12
    denom = np.maximum(ell - c, eps)
    base = np.maximum(A / denom, eps)
    xin = np.power(base, 1.0 / alpha) - x0
    return np.maximum(xin, 1e-3)

def plot_loss_curve(x, y, curve, label):
    plt.plot(x, y, label=f"{label} (raw)")
    plt.plot(x, _power_floor(x, *curve), label=f"{label} (fitted)")
    plt.legend()
    plt.show()
    return

def add_metrics(
    df: pd.DataFrame,
    losses_pkl_path: str,
    split: str = 'test'
) -> pd.DataFrame:
    """Add computed metrics to the DataFrame using baseline losses from a pickle file.
    
    This function:
    1. Loads baseline validation losses from the pickle file
    2. Fits power-law curves to each label's loss trajectory
    3. Computes metrics for each row in the DataFrame:
       - ppl: perplexity (exp of loss)
       - step_equiv: equivalent training step for the observed loss
       - compute_ratio: ratio of step_equiv to baseline step_equiv
       - ppl_ratio: ratio of ppl to baseline ppl
       - loss_ratio: ratio of loss to baseline loss
    
    Args:
        df: DataFrame with columns including 'name', 'data_label', 'loss'
        losses_pkl_path: Path to the losses.pkl file containing baseline losses
    
    Returns:
        DataFrame with additional metric columns added.
    """
    df = df.copy()
    losses_pkl_path = Path(losses_pkl_path)
    
    # Load and fit curves
    losses = _load_val_losses(losses_pkl_path)
    labels = list(losses.keys())
    
    curves = {}
    for lab in labels:
        y_b = losses[lab]
        x_b = _get_steps(len(y_b))
        idx = _subsample_indices(len(y_b))
        x_b = x_b[idx]
        y_b = y_b[idx]
        curves[lab] = _fit_power_floor(x_b, y_b)

    # Compute perplexity
    df['ppl'] = np.exp(df['loss'])
    
    # Compute step equivalent
    def step_equiv(label, loss):
        if pd.isna(label) or pd.isna(loss) or label not in curves:
            return np.nan
        A, alpha, c, x0 = curves[label]
        return _g_inv_power_floor(loss, A, alpha, c, x0)
    
    df["step_equiv"] = df.apply(lambda r: step_equiv(r["data_label"], r["loss"]), axis=1)
    
    # Calculate baseline reference values (mean across seeds)
    baseline_select = df['name'] == 'baseline'
    if 'split' in df.columns:
        split_matches = (df['split'] == split) | df['split'].isna()
        baseline_select = baseline_select & split_matches
    baselines = df[baseline_select][["data_label", "loss", "step_equiv", "ppl"]].copy()
    
    baselines = (
        baselines
        .groupby("data_label", dropna=False, as_index=False)[["loss", "step_equiv", "ppl"]]
        .mean()
    )
    
    baseline_lookup = baselines.set_index("data_label").to_dict('index')
    
    # Compute ratios
    def calc_compute_ratio(label, steps):
        if pd.isna(label) or pd.isna(steps) or label not in baseline_lookup:
            return np.nan
        return steps / baseline_lookup[label]["step_equiv"]
    
    def calc_ppl_ratio(label, ppl):
        if pd.isna(label) or pd.isna(ppl) or label not in baseline_lookup:
            return np.nan
        return ppl / baseline_lookup[label]["ppl"]
    
    def calc_loss_ratio(label, loss):
        if pd.isna(label) or pd.isna(loss) or label not in baseline_lookup:
            return np.nan
        return loss / baseline_lookup[label]["loss"]
    
    df["compute_ratio"] = df.apply(lambda r: calc_compute_ratio(r["data_label"], r["step_equiv"]), axis=1)
    df["log_compute_ratio"] = np.log(df["compute_ratio"])
    df["ppl_ratio"] = df.apply(lambda r: calc_ppl_ratio(r["data_label"], r["ppl"]), axis=1)
    df["loss_ratio"] = df.apply(lambda r: calc_loss_ratio(r["data_label"], r["loss"]), axis=1)
    
    return df