import sys
import argparse
import pandas as pd
import numpy as np

from pathlib import Path

from umfavi.experiments.file_queue import FileTaskQueue, ExperimentStatus

from umfavi.experiments.utils import (
    get_feedback_config_key,
    derive_feedback_combination,
    compute_aggregated_metric,
)
from umfavi.types import FeedbackType

def _generate_all_feedback_combinations() -> list[tuple[str, dict[FeedbackType, bool]]]:
    """
    Generate all possible non-empty feedback type combinations.
    
    Returns list of (column_name, {FeedbackType: is_active}) tuples.
    """
    from itertools import combinations
    
    all_types = list(FeedbackType)
    result = []
    
    # Single feedback types (e.g., "Pref Only")
    for fb_type in all_types:
        col_name = f"{fb_type.value.capitalize()} Only"
        active = {t: (t == fb_type) for t in all_types}
        combo_key = f"{fb_type.value}_only"
        result.append((col_name, active, combo_key))
    
    # Combinations of 2+ feedback types
    for r in range(2, len(all_types) + 1):
        for combo in combinations(all_types, r):
            col_name = "+".join(t.value.capitalize() for t in combo)
            active = {t: (t in combo) for t in all_types}
            combo_key = "+".join(sorted(t.value for t in combo))
            result.append((col_name, active, combo_key))
    
    return result

def build_concise_table_df(
    df: pd.DataFrame,
    metric: str,
    aggregate: str = "min",
    precision: int = 4,
) -> pd.DataFrame:
    """
    Build concise summary table with best results per feedback combination.
    
    The table has columns for each feedback type combination (based on FeedbackType enum),
    with rows for feedback indicators and per-environment best metrics.
    All combinations from the FeedbackType enum are shown, with empty cells for missing data.
    
    Args:
        df: DataFrame with experiment data from load_experiment_data()
        metric: Metric name (e.g., "regret", "epic_distance")
        aggregate: Aggregation method over epochs ("min", "max", "mean", etc.)
        precision: Decimal places for metric values
        
    Returns:
        DataFrame with concise summary table
    """
    if df.empty:
        return pd.DataFrame()
    
    # Compute aggregated metric per experiment
    df = df.copy()
    metric_col = f"{aggregate}_{metric}"
    df[metric_col] = compute_aggregated_metric(df, metric, aggregate)
    
    # Re-derive feedback_type to capture all feedback types
    df["feedback_combo"] = df.apply(derive_feedback_combination, axis=1)
    
    # Filter to rows with valid metric data
    df_valid = df.dropna(subset=[metric_col])
    
    # Get environments
    env_col = "config.env_id"
    if env_col not in df.columns:
        print("Warning: No env_id found in config")
        return pd.DataFrame()
    
    environments = sorted(df[env_col].unique())
    
    # Generate all feedback combinations from FeedbackType enum
    all_combinations = _generate_all_feedback_combinations()
    
    # Identify which combinations actually have data
    observed_combos = set(df_valid["feedback_combo"].unique()) if not df_valid.empty else set()
    
    # Total number of feedback types (for identifying "all combined")
    n_fb_types = len(FeedbackType)
    
    # Filter to only include: singles, pairs, and "all combined"
    def should_include_combo(combo_key: str) -> bool:
        parts = combo_key.replace("_only", "").split("+")
        n_parts = len(parts)
        # Include singles, pairs, and the "all" combination
        if n_parts == 1 or n_parts == 2 or n_parts == n_fb_types:
            # Only include if data exists for this combo
            return combo_key in observed_combos
        return False
    
    filtered_combinations = [
        (col_name, active, combo_key)
        for col_name, active, combo_key in all_combinations
        if should_include_combo(combo_key)
    ]
    
    # Build results: for each (env, feedback_combo), find the best config
    results = {}
    for env in environments:
        results[env] = {}
        env_df = df_valid[df_valid[env_col] == env] if not df_valid.empty else pd.DataFrame()
        
        for col_name, _, combo_key in filtered_combinations:
            fb_df = env_df[env_df["feedback_combo"] == combo_key] if not env_df.empty else pd.DataFrame()
            if fb_df.empty:
                results[env][col_name] = np.nan
                continue
            
            # Group by config_hash to average over seeds for each unique configuration
            # This matches the logic in select-best
            if "config_hash" in fb_df.columns:
                config_means = fb_df.groupby("config_hash")[metric_col].mean()
                # Find best config (for regret-like metrics, lower is better)
                best_value = config_means.min()
            else:
                # No config_hash available, just take mean
                best_value = fb_df[metric_col].mean()
            
            results[env][col_name] = best_value
    
    rows = []
    
    # Feedback indicator rows - one per FeedbackType
    for fb_type in FeedbackType:
        row = {"Row": fb_type.value.capitalize()}
        for col_name, active, _ in filtered_combinations:
            row[col_name] = "✓" if active[fb_type] else ""
        rows.append(row)
    
    # Environment rows with metrics
    for env in environments:
        row = {"Row": env}
        for col_name, _, _ in filtered_combinations:
            val = results[env][col_name]
            if pd.isna(val):
                row[col_name] = ""
            else:
                row[col_name] = f"{val:.{precision}f}"
        rows.append(row)
    
    result_df = pd.DataFrame(rows)
    result_df = result_df.set_index("Row")
    
    return result_df


def build_exhaustive_table_df(
    df: pd.DataFrame,
    metric: str,
    aggregate: str = "min",
    precision: int = 4,
) -> pd.DataFrame:
    """
    Build exhaustive table with all configurations averaged over seeds.
    
    One row per (environment, feedback_combo, sample counts) combination,
    showing mean and std of the metric across seeds. Uses all feedback types
    from FeedbackType enum.
    
    Args:
        df: DataFrame with experiment data from load_experiment_data()
        metric: Metric name (e.g., "regret", "epic_distance")
        aggregate: Aggregation method over epochs ("min", "max", "mean", etc.)
        precision: Decimal places for metric values
        
    Returns:
        DataFrame with exhaustive results table
    """
    if df.empty:
        return pd.DataFrame()
    
    # Compute aggregated metric per experiment
    df = df.copy()
    metric_col = f"{aggregate}_{metric}"
    df[metric_col] = compute_aggregated_metric(df, metric, aggregate)
    
    # Re-derive feedback combination to capture all feedback types
    df["feedback_combo"] = df.apply(derive_feedback_combination, axis=1)
    
    # Filter to rows with valid metric data
    df = df.dropna(subset=[metric_col])
    if df.empty:
        return pd.DataFrame()
    
    # Build grouping columns
    group_cols = []
    
    env_col = "config.env_id"
    if env_col in df.columns:
        group_cols.append(env_col)
    
    group_cols.append("feedback_combo")
    
    # Add sample count columns for all feedback types
    sample_count_cols = []
    for fb_type in FeedbackType:
        config_key = get_feedback_config_key(fb_type)
        if config_key in df.columns:
            group_cols.append(config_key)
            sample_count_cols.append(config_key)
    
    # Group and compute statistics
    grouped = df.groupby(group_cols)[metric_col].agg(["mean", "std", "count"])
    grouped = grouped.reset_index()
    
    # Rename columns for clarity
    rename_map = {
        "config.env_id": "Environment",
        "feedback_combo": "Feedback",
        "config.n_pref_samples": "n_pref",
        "config.n_demo_samples": "n_demo",
        "config.n_rating_samples": "n_rating",
        "config.n_corr_samples": "n_corr",
        "config.n_stop_samples": "n_stop",
        "mean": f"mean_{metric}",
        "std": f"std_{metric}",
        "count": "n_seeds",
    }
    grouped = grouped.rename(columns=rename_map)
    
    # Sort by environment, feedback type, then sample counts
    sort_cols = []
    if "Environment" in grouped.columns:
        sort_cols.append("Environment")
    sort_cols.append("Feedback")
    for col in ["n_pref", "n_demo", "n_rating", "n_corr", "n_stop"]:
        if col in grouped.columns:
            sort_cols.append(col)
    
    grouped = grouped.sort_values(sort_cols).reset_index(drop=True)
    
    return grouped


# =============================================================================
# LaTeX Export Functions
# =============================================================================

def _parse_feedback_combo(combo: str) -> set[str]:
    """
    Parse a feedback_combo string into a set of feedback type names.
    
    Examples:
        "pref_only" -> {"pref"}
        "demo+pref" -> {"demo", "pref"}
        "demo+pref+rating" -> {"demo", "pref", "rating"}
    """
    if combo.endswith("_only"):
        return {combo.replace("_only", "")}
    return set(combo.split("+"))


def _combo_to_latex_icon(combo: str) -> str:
    """
    Map a feedback_combo key to its corresponding LaTeX icon command.
    
    Command names use alphabetical order (e.g., \\fbDemoPref, not \\fbPrefDemo).
    
    Examples:
        "pref_only" -> r"\\fbPref"
        "demo+pref" -> r"\\fbDemoPref"
        "demo+pref+rating+stop" -> r"\\fbPDRS"
    """
    # Parse the combo into list of feedback types
    if combo.endswith("_only"):
        fb_types = [combo.replace("_only", "")]
    else:
        fb_types = combo.split("+")
    
    # Single feedback type: \fbPref, \fbDemo, \fbRating, \fbStop
    if len(fb_types) == 1:
        return f"\\fb{fb_types[0].capitalize()}"
    
    # All four combined: \fbPDRS
    if len(fb_types) == 4:
        return r"\fbPDRS"
    
    # Pairs/triples: alphabetically sorted, capitalized (e.g., \fbDemoPref, \fbDemoRating)
    fb_types_sorted = sorted(fb_types)
    capitalized = [t.capitalize() for t in fb_types_sorted]
    return f"\\fb{''.join(capitalized)}"


def _value_to_color(
    val: float,
    min_val: float,
    max_val: float,
    lower_is_better: bool = True,
    alpha: float = 0.35,
) -> str:
    """
    Convert a value to LaTeX cellcolor using a red-yellow-green scale.
    
    Uses fully saturated colors blended with white to simulate alpha transparency.
    This gives vibrant, clean colors rather than muddy desaturated ones.
    
    Args:
        val: The value to colorize
        min_val: Minimum value in the range
        max_val: Maximum value in the range
        lower_is_better: If True, low values are green (good), high values are red (bad)
        alpha: Simulated transparency 0-1 (lower = more transparent/lighter)
        
    Returns:
        LaTeX cellcolor command string
    """
    # Define pure RGB colors for red, yellow, green
    # Red: (1, 0, 0), Yellow: (1, 1, 0), Green: (0, 0.8, 0)
    colors = {
        "red": (1.0, 0.0, 0.0),
        "yellow": (1.0, 1.0, 0.0),
        "green": (0.0, 0.8, 0.0),
    }
    
    if max_val == min_val:
        # All values are the same, use yellow (middle)
        r, g, b = colors["yellow"]
    else:
        # Normalize to 0-1 range
        t = (val - min_val) / (max_val - min_val)
        
        # For lower_is_better: low t (good) -> green, high t (bad) -> red
        if lower_is_better:
            t = 1.0 - t  # flip so 0 = red, 1 = green
        
        # Interpolate: 0 -> red, 0.5 -> yellow, 1 -> green
        if t < 0.5:
            # Red to yellow
            factor = t * 2  # 0 to 1
            r = colors["red"][0] + factor * (colors["yellow"][0] - colors["red"][0])
            g = colors["red"][1] + factor * (colors["yellow"][1] - colors["red"][1])
            b = colors["red"][2] + factor * (colors["yellow"][2] - colors["red"][2])
        else:
            # Yellow to green
            factor = (t - 0.5) * 2  # 0 to 1
            r = colors["yellow"][0] + factor * (colors["green"][0] - colors["yellow"][0])
            g = colors["yellow"][1] + factor * (colors["green"][1] - colors["yellow"][1])
            b = colors["yellow"][2] + factor * (colors["green"][2] - colors["yellow"][2])
    
    # Blend with white to simulate alpha (color over white background)
    # result = alpha * color + (1 - alpha) * white
    r = alpha * r + (1 - alpha) * 1.0
    g = alpha * g + (1 - alpha) * 1.0
    b = alpha * b + (1 - alpha) * 1.0
    
    return f"\\cellcolor[rgb]{{{r:.3f},{g:.3f},{b:.3f}}}"


def df_to_latex_concise(
    df: pd.DataFrame,
    metric: str = "regret",
    aggregate: str = "min",
    precision: int = 1,
    heatmap: bool = True,
    lower_is_better: bool = True,
    alpha: float = 0.35,
) -> str:
    """
    Convert aggregated DataFrame to LaTeX format with booktabs style.
    
    Accepts a DataFrame with MultiIndex (env_id, feedback_combo) and a metric column,
    as produced by:
        df.groupby(["config_hash"]).agg({...}).groupby(["config.env_id", "feedback_combo"]).agg({...})
    
    The table shows a header row with feedback combination icons (e.g., \\fbPref, \\fbDemoPref)
    followed by one row per environment with metric values. Best result per row is bolded.
    
    Uses a red-yellow-green color scale for the heatmap.
    
    Args:
        df: DataFrame with MultiIndex (env_id, feedback_combo) or columns for these,
            and a single metric column (e.g., "min_regret")
        metric: Metric name for caption
        aggregate: Aggregation method for caption
        precision: Decimal places for metric values
        heatmap: If True, color cells based on value (requires xcolor and colortbl packages)
        lower_is_better: If True, low values are good (green), high are bad (red)
        alpha: Color intensity 0-1 (lower = more transparent/lighter colors)
        
    Returns:
        LaTeX table string
    """
    if df.empty:
        return "% Empty table - no data available\n"
    
    # Unstack MultiIndex to get environments as rows, feedback_combo as columns
    pivot_df = df.iloc[:, 0].unstack(level=-1)
    
    # Sort environments: grids first, then Box2D, alphabetically within each group
    def env_sort_key(env: str) -> tuple[int, str]:
        is_grid = "grid" in env.lower()
        return (0 if is_grid else 1, env)
    
    environments = sorted(pivot_df.index.tolist(), key=env_sort_key)
    
    # Filter combos: only singles, pairs, and "all combined"
    n_fb_types = len(FeedbackType)
    
    def should_include_combo(combo: str) -> bool:
        n_parts = len(_parse_feedback_combo(combo))
        return n_parts == 1 or n_parts == 2 or n_parts == 4
    
    # Canonical order: pref, demo, rating, stop
    CANONICAL_ORDER = ["pref", "demo", "rating", "stop"]
    
    def combo_sort_key(combo: str) -> tuple[int, tuple[int, ...]]:
        """Sort by number of feedback types, then by canonical order of types."""
        fb_types = _parse_feedback_combo(combo)
        # Get canonical indices for each type, sorted
        indices = tuple(sorted(CANONICAL_ORDER.index(t) for t in fb_types))
        return (len(fb_types), indices)
    
    feedback_combos = sorted(
        [c for c in pivot_df.columns.tolist() if should_include_combo(c)],
        key=combo_sort_key
    )
    
    # Find indices where group changes (singles -> pairs -> all)
    combo_sizes = [len(_parse_feedback_combo(c)) for c in feedback_combos]
    n_singles = sum(1 for s in combo_sizes if s == 1)
    n_pairs = sum(1 for s in combo_sizes if s == 2)
    
    # Build column spec with vertical separators after singles and pairs
    n_cols = len(feedback_combos)
    col_spec_parts = ["l"]  # first column for row labels
    for i in range(n_cols):
        col_spec_parts.append("r")
        # Add separator after singles and after pairs
        if i == n_singles - 1 and n_singles > 0:
            col_spec_parts.append("|")
        elif i == n_singles + n_pairs - 1 and n_pairs > 0:
            col_spec_parts.append("|")
    col_spec = "".join(col_spec_parts)
    
    lines = []
    lines.append(r"% Requires: \usepackage[table]{xcolor}" if heatmap else "")
    lines.append(r"\begin{table*}[htbp]")
    lines.append(r"\centering")
    lines.append(f"\\label{{tab:{metric}_{aggregate}_concise}}")
    lines.append(f"\\begin{{tabular}}{{{col_spec}}}")
    lines.append(r"\toprule")
    
    # Header row with feedback combination icons
    icon_values = [_combo_to_latex_icon(combo) for combo in feedback_combos]
    header_row = " & " + " & ".join(icon_values) + r" \\"
    lines.append(header_row)
    
    lines.append(r"\midrule")
    
    # Separate environments into grids and continuous control
    grid_envs = [e for e in environments if "grid" in e.lower()]
    continuous_envs = [e for e in environments if "grid" not in e.lower()]
    
    # Helper function to add environment rows
    def add_env_rows(env_list: list[str]) -> None:
        for env in env_list:
            row_values = pivot_df.loc[env]
            
            # Find min/max for heatmap normalization (per row) - only for displayed combos
            numeric_values = [row_values.get(combo) for combo in feedback_combos 
                             if pd.notna(row_values.get(combo))]
            min_val = min(numeric_values) if numeric_values else 0
            max_val = max(numeric_values) if numeric_values else 1
            
            # Find best and second-best values - only for displayed combos
            sorted_values = sorted(set(numeric_values), reverse=not lower_is_better) if numeric_values else []
            best_value = sorted_values[0] if len(sorted_values) >= 1 else None
            second_best_value = sorted_values[1] if len(sorted_values) >= 2 else None
            
            # Format values with coloring, bolding (best), and underlining (second-best)
            values = []
            for combo in feedback_combos:
                val = row_values.get(combo)
                if pd.isna(val):
                    values.append("")
                else:
                    formatted = f"{val:.{precision}f}"
                    
                    # Determine formatting: bold for best, underline for second-best
                    is_best = best_value is not None and abs(val - best_value) < 1e-9
                    is_second_best = second_best_value is not None and abs(val - second_best_value) < 1e-9
                    
                    if is_best:
                        formatted = f"\\textbf{{{formatted}}}"
                    elif is_second_best:
                        formatted = f"\\underline{{{formatted}}}"
                    
                    # Add heatmap coloring
                    if heatmap:
                        cellcolor_cmd = _value_to_color(val, min_val, max_val, lower_is_better, alpha)
                        values.append(f"{cellcolor_cmd}{formatted}")
                    else:
                        values.append(formatted)
            
            # Escape underscores in environment name for LaTeX
            env_escaped = env.replace("_", r"\_")
            row_str = f"{env_escaped} & " + " & ".join(values) + r" \\"
            lines.append(row_str)
    
    # Add grid environments
    if grid_envs:
        add_env_rows(grid_envs)
    
    # Add continuous control environments (with separator if both groups exist)
    if continuous_envs:
        if grid_envs:
            lines.append(r"\midrule")
        add_env_rows(continuous_envs)
    
    lines.append(r"\bottomrule")
    lines.append(r"\end{tabular}")
    lines.append(r"\end{table*}")
    
    # Filter out empty lines
    return "\n".join(line for line in lines if line)


def df_to_latex_exhaustive(df: pd.DataFrame, metric: str, aggregate: str) -> str:
    """
    Convert exhaustive table DataFrame to LaTeX format with booktabs style.
    
    Args:
        df: Exhaustive table DataFrame from build_exhaustive_table_df()
        metric: Metric name for caption
        aggregate: Aggregation method for caption
        
    Returns:
        LaTeX table string
    """
    if df.empty:
        return "% Empty table - no data available\n"
    
    n_cols = len(df.columns)
    col_spec = "l" * n_cols
    
    lines = []
    lines.append(r"\begin{table}[htbp]")
    lines.append(r"\centering")
    lines.append(r"\small")
    lines.append(f"\\caption{{All {aggregate} {metric} results by configuration}}")
    lines.append(f"\\label{{tab:{metric}_{aggregate}_exhaustive}}")
    lines.append(f"\\begin{{tabular}}{{{col_spec}}}")
    lines.append(r"\toprule")
    
    # Header row
    header = " & ".join(df.columns) + r" \\"
    lines.append(header)
    lines.append(r"\midrule")
    
    # Data rows
    prev_env = None
    for _, row in df.iterrows():
        # Add midrule between environments for readability
        if "Environment" in df.columns:
            curr_env = row.get("Environment", "")
            if prev_env is not None and curr_env != prev_env:
                lines.append(r"\midrule")
            prev_env = curr_env
        
        # Format values
        values = []
        for col, val in row.items():
            if pd.isna(val):
                values.append("-")
            elif isinstance(val, float):
                values.append(f"{val:.4f}")
            else:
                values.append(str(val))
        
        row_str = " & ".join(values) + r" \\"
        lines.append(row_str)
    
    lines.append(r"\bottomrule")
    lines.append(r"\end{tabular}")
    lines.append(r"\end{table}")
    
    return "\n".join(lines)


def export_table(
    df: pd.DataFrame,
    output_path: Path,
    table_type: str,
    metric: str,
    aggregate: str,
) -> None:
    """
    Export table to CSV or LaTeX format.
    
    Args:
        df: Table DataFrame to export
        output_path: Output file path (.csv or .tex)
        table_type: "concise" or "exhaustive"
        metric: Metric name
        aggregate: Aggregation method
    """
    if df.empty:
        print("No data to export")
        return
    
    suffix = output_path.suffix.lower()
    
    if suffix == ".csv":
        # For concise table, reset index to include row names
        if table_type == "concise":
            df.to_csv(output_path)
        else:
            df.to_csv(output_path, index=False)
        print(f"Exported {len(df)} rows to {output_path}")
        
    elif suffix == ".tex":
        if table_type == "concise":
            latex_str = df_to_latex_concise(df, metric, aggregate)
        else:
            latex_str = df_to_latex_exhaustive(df, metric, aggregate)
        
        with open(output_path, "w") as f:
            f.write(latex_str)
        print(f"Exported LaTeX table to {output_path}")
        
    else:
        print(f"Error: Unsupported output format '{suffix}'")
        print("Supported formats: .csv, .tex")
        sys.exit(1)


def cmd_status(args: argparse.Namespace) -> None:
    """Show the current status of the experiment queue."""
    queue = FileTaskQueue(args.queue_dir)
    summary = queue.get_status_summary()
    
    print(f"\nExperiment Queue Status ({args.queue_dir})")
    print("=" * 40)
    print(f"  Pending:   {summary['pending']:>6}")
    print(f"  Running:   {summary['running']:>6}")
    print(f"  Completed: {summary['completed']:>6}")
    print(f"  Failed:    {summary['failed']:>6}")
    print("-" * 40)
    print(f"  Total:     {summary['total']:>6}")
    print()
    
    # Show progress bar
    if summary['total'] > 0:
        done = summary['completed'] + summary['failed']
        pct = 100 * done / summary['total']
        bar_len = 30
        filled = int(bar_len * done / summary['total'])
        bar = "█" * filled + "░" * (bar_len - filled)
        print(f"  Progress: [{bar}] {pct:.1f}%")
        print()
    
    # If there are failed experiments, show some info
    if summary['failed'] > 0 and args.verbose:
        print("Failed experiments:")
        failed = queue.get_all_experiments(ExperimentStatus.FAILED)
        for exp in failed[:5]:  # Show first 5
            error_preview = exp.error_message[:80] if exp.error_message else "No error message"
            print(f"  ID {exp.id}: {error_preview}...")
        if len(failed) > 5:
            print(f"  ... and {len(failed) - 5} more")
        print()