import pandas as pd
import json
from pathlib import Path
from rich.console import Console
from rich.table import Table

def analyze_adaptability(results_dir: str, output_dir: str, subsets: list):
    """
    Parses all JSON results to generate the 'Adaptability Gap Analysis' table.
    """
    console = Console()
    output_path = Path(output_dir)
    output_path.mkdir(exist_ok=True)

    # 1. Load data
    all_records = []
    for filepath in Path(results_dir).rglob('*_results.json'):
        try:
            with open(filepath, 'r') as f:
                data = json.load(f)
            params = data['params']
            subset = str(params.get('subset', 'Full'))
            if subsets and subset not in subsets:
                continue

            all_records.append({
                'dataset': params['dataset'],
                'model': params['model'],
                'strategy': params['strategy'],
                'track': params['track'],
                'subset': subset,
                'seed': params['seed'],
                'final_error': data['final_error_count']
            })
        except (KeyError, json.JSONDecodeError):
            console.print(f"[yellow]Warning: Skipping malformed file {filepath.name}[/yellow]")

    if not all_records:
        console.print("[red]Error: No valid records found. Check paths and filters.[/red]")
        return

    df = pd.DataFrame(all_records)

    # 2. Process data
    summary = df.groupby(['dataset', 'model', 'track', 'subset', 'strategy']).agg(
        mean_error=('final_error', 'mean'),
        std_error=('final_error', 'std')
    ).reset_index()

    # Find best strategy (non-random) for each group
    active_strategies = summary[summary['strategy'] != 'random']
    best_strategy_idx = active_strategies.loc[active_strategies.groupby(['dataset', 'model', 'track', 'subset'])['mean_error'].idxmin()]

    # Get random baseline
    random_baseline = summary[summary['strategy'] == 'random'].set_index(['dataset', 'model', 'track', 'subset'])

    # 3. Join and calculate improvement
    report_df = best_strategy_idx.set_index(['dataset', 'model', 'track', 'subset'])
    report_df = report_df.join(random_baseline[['mean_error', 'std_error']], rsuffix='_random')
    report_df.rename(columns={'mean_error': 'best_mean_error', 'std_error': 'best_std_error'}, inplace=True)

    # Handle cases where random baseline might be missing for a specific group
    report_df['mean_error_random'].fillna(float('nan'), inplace=True)

    report_df['improvement_pct'] = ((report_df['mean_error_random'] - report_df['best_mean_error']) / report_df['mean_error_random']) * 100

    # 4. Generate Console Table
    console_table = Table(title=f"Adaptability Gap Analysis (Subsets: {', '.join(subsets)})")
    for col in ["Dataset", "Model", "Track", "Best Strategy", "Best Error", "Random Error", "% Improv."]:
        console_table.add_column(col)

    for index, row in report_df.iterrows():
        dataset, model, track, subset = index
        console_table.add_row(
            dataset,
            model,
            track,
            row['strategy'],
            f"{row['best_mean_error']:.1f} ± {row['best_std_error']:.1f}",
            f"{row['mean_error_random']:.1f} ± {row['std_error_random']:.1f}" if pd.notna(row['mean_error_random']) else "N/A",
            f"{row['improvement_pct']:.1f}%" if pd.notna(row['improvement_pct']) else "N/A",
        )
    console.print(console_table)

    # 5. Generate LaTeX Table
    latex_df = report_df.reset_index()[['dataset', 'model', 'track', 'strategy', 'best_mean_error', 'best_std_error', 'mean_error_random', 'improvement_pct']]
    latex_df.columns = ["Dataset", "Model", "Track", "Best Strategy", "Mean Error (Best)", "Std Error (Best)", "Mean Error (Random)", r"\% Improv."]

    latex_df['Mean Error (Best)'] = latex_df.apply(lambda r: f"{r['Mean Error (Best)']:.1f} $\\pm$ {r['Std Error (Best)']:.1f}", axis=1)
    latex_df.drop(columns=['Std Error (Best)'], inplace=True)

    latex_string = latex_df.to_latex(index=False, escape=False, float_format="%.1f", na_rep="-", column_format='llccrrr')

    latex_path = output_path / f"adaptability_gap_table_{'_'.join(subsets)}.tex"
    with open(latex_path, 'w') as f:
        f.write("% Auto-generated by analyze_adaptability.py\n")
        f.write(latex_string)

    console.print(f"\n[green]LaTeX table saved to {latex_path}[/green]")
    console.print(latex_string)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("results_dir", help="Directory with JSON results.")
    parser.add_argument("--output_dir", default="analysis_output_new", help="Directory for output.")
    parser.add_argument("--subsets", nargs="+", default=["300", "Full"], help="Subsets to analyze.")
    args = parser.parse_args()
    analyze_adaptability(args.results_dir, args.output_dir, args.subsets)