#!/usr/bin/env python3
import argparse
import os
import sys
import json
import pandas as pd
import importlib.util
from pathlib import Path
from tqdm import tqdm

# Add src to path
script_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(script_dir))
src_dir = os.path.join(project_root, 'src')

if src_dir not in sys.path:
    sys.path.insert(0, src_dir)


def import_module_from_path(module_name, file_path):
    """Import a module from a specific file path"""
    spec = importlib.util.spec_from_file_location(module_name, file_path)
    if spec is None or spec.loader is None:
        raise ImportError(f"Cannot load spec for {module_name} from {file_path}")
    module = importlib.util.module_from_spec(spec)
    sys.modules[module_name] = module
    spec.loader.exec_module(module)
    return module


# Import src package and submodules
import_module_from_path('src', os.path.join(src_dir, '__init__.py'))
import_module_from_path('src.analysis', os.path.join(src_dir, 'analysis', '__init__.py'))
import_module_from_path('src.fitting', os.path.join(src_dir, 'fitting', '__init__.py'))
import_module_from_path('src.core.utils', os.path.join(src_dir, 'core', 'utils.py'))

from src.analysis import MetricsCalculator, DriftPlotter, RadarPlotter
from src.fitting import batch_fit_cognitive_model
from src.core.utils import (
    load_csv_with_validation,
    prepare_for_fitting,
    format_results_report,
    setup_logger,
)

# Scenario groups (must be consistent with strategy logic)
SCENARIO_GROUPS = [
    'Baseline', 'Optimism', 'Authority', 'Threat',
    'Stimulus', 'Magnitude', 'Punishment', 'Regret'
]


def load_config(config_path: str = None) -> dict:
    if config_path is None:
        config_path = os.path.join(project_root, 'scripts', 'config', 'analysis.json')

    if not os.path.exists(config_path):
        print(f"⚠️  Config file not found: {config_path}, using defaults")
        return {}

    try:
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        print(f"✓ Loaded configuration from: {config_path}")
        return config
    except Exception as e:
        print(f"⚠️  Failed to load config: {e}, using defaults")
        return {}


def find_model_runs(base_folder: str) -> dict:
    """
    Find model runs from various directory structures:
    1. Single model run: logs/jailbreak/{model}/{run_id}/
    2. Batch analysis: logs/analysis/batch_{timestamp}/
    3. Direct model dir: logs/analysis/{model}/
    4. Nested batch structure: logs/analysis/batch_{timestamp}/{model}/{model}/
    """
    models = {}
    base_path = Path(base_folder)

    if not base_path.exists():
        return models

    first_level_dirs = [d for d in base_path.iterdir() if d.is_dir()]

    # Pattern 1: This IS a batch timestamp directory (logs/analysis/batch_20260109_144939/)
    # Check if base_path name itself indicates a batch/timestamp
    is_batch_dir = (
        base_path.name.startswith('batch_') or
        (base_path.name.startswith(('20', '19')) and len(base_path.name) >= 12)
    )

    if is_batch_dir:
        # Look for model directories inside this batch directory
        for model_dir in base_path.iterdir():
            if not model_dir.is_dir():
                continue
            if model_dir.name in ['checkpoint', 'summary']:
                continue

            # Check for nested model directory (batch analysis format)
            nested_model = model_dir / model_dir.name
            if nested_model.exists() and nested_model.is_dir():
                # Found: batch_20260109_144939/{model}/{model}/
                model_name = model_dir.name
                if model_name not in models:
                    models[model_name] = {}
                models[model_name][base_path.name] = str(nested_model)
            elif any((model_dir / g).exists() for g in SCENARIO_GROUPS):
                # Found: batch_20260109_144939/{model}/ with scenario groups
                model_name = model_dir.name
                if model_name not in models:
                    models[model_name] = {}
                models[model_name][base_path.name] = str(model_dir)

        if models:
            return models

    # Pattern 2: Direct model run with scenario groups (logs/jailbreak/{model}/{run_id}/)
    if any(d.name in SCENARIO_GROUPS for d in first_level_dirs):
        model_name = base_path.parent.name
        run_id = base_path.name
        models[model_name] = {run_id: str(base_path)}
        return models

    # Pattern 3: Run ID directories at first level (logs/jailbreak/{model}/)
    has_run_ids = any(d.name.startswith(('20', '19')) for d in first_level_dirs)
    has_scenarios = any(d.name in SCENARIO_GROUPS for d in first_level_dirs)

    if has_run_ids:
        model_name = base_path.name
        runs = {}
        for run_dir in base_path.iterdir():
            if not run_dir.is_dir() or run_dir.name == 'checkpoint':
                continue
            if any((run_dir / g).exists() for g in SCENARIO_GROUPS):
                runs[run_dir.name] = str(run_dir)
        if runs:
            models[model_name] = runs
        return models

    if has_scenarios:
        model_name = base_path.name
        models[model_name] = {model_name: str(base_path)}
        return models

    # Pattern 4: Model directories at first level (logs/analysis/ with model subdirs)
    for model_dir in base_path.iterdir():
        if not model_dir.is_dir():
            continue
        if model_dir.name in ['checkpoint', 'summary']:
            continue

        runs = {}
        subdirs = [d for d in model_dir.iterdir() if d.is_dir()]

        # Check for nested model directory
        nested_model = model_dir / model_dir.name
        if nested_model.exists() and nested_model.is_dir():
            # Use the nested directory as the run path
            runs[model_dir.name] = str(nested_model)
        else:
            # Check for run IDs or scenario groups
            has_run_ids = any(d.name.startswith(('20', '19')) for d in subdirs)
            has_scenarios = any(d.name in SCENARIO_GROUPS for d in subdirs)

            if has_run_ids:
                for run_dir in subdirs:
                    if run_dir.name == 'checkpoint':
                        continue
                    if any((run_dir / g).exists() for g in SCENARIO_GROUPS):
                        runs[run_dir.name] = str(run_dir)
            elif has_scenarios:
                runs[model_dir.name] = str(model_dir)

        if runs:
            models[model_dir.name] = runs

    return models


def analyze_single_run(
    run_path: str,
    run_id: str,
    model_name: str,
    output_base: str,
    extended: bool,
    config: dict,
    logger,
    run_metrics: bool,
    run_fit: bool,
    run_drift: bool,
) -> dict:
    logger.info(f"\n{'=' * 70}")
    logger.info(f"Analyzing: {model_name} / {run_id}")
    logger.info(f"{'=' * 70}")

    results = {
        'model': model_name,
        'run_id': run_id,
        'metrics': None,
        'fitting': None,
        'success': False,
    }

    # 1. Metrics
    # if run_metrics:
    try:
        df_details, df_summary = MetricsCalculator.calculate_from_folder(run_path)
        if not df_details.empty:
            MetricsCalculator.print_report(df_summary, f"{model_name}/{run_id}")

            metrics_dir = os.path.join(output_base, model_name, run_id)
            os.makedirs(metrics_dir, exist_ok=True)

            details_path = os.path.join(metrics_dir, "metrics_details.csv")
            summary_path = os.path.join(metrics_dir, "metrics_summary.csv")

            df_details.to_csv(details_path, index=False)
            df_summary.to_csv(summary_path, index=False)

            results['metrics'] = {
                'details_path': details_path,
                'summary_path': summary_path,
                'summary': df_summary,
            }
            logger.info(f"  ✓ Metrics saved to: {metrics_dir}")
    except Exception as e:
        logger.error(f"  ✗ Metrics calculation failed: {e}")

    # 2. Parameter fitting
    if run_fit:
        try:
            all_df = pd.DataFrame()
            for group in SCENARIO_GROUPS:
                group_path = os.path.join(run_path, group)
                if not os.path.exists(group_path):
                    continue
                for f in os.listdir(group_path):
                    if f.endswith('.csv'):
                        df = load_csv_with_validation(os.path.join(group_path, f))
                        if df is not None:
                            all_df = pd.concat([all_df, df], ignore_index=True)

            if not all_df.empty:
                filters = config.get('parameter_fitting', {}).get('data_filters', {})
                all_df = prepare_for_fitting(all_df, filters=filters)

                # Defensive check - immediately after prepare_for_fitting
                if 'group' not in all_df.columns:
                    raise ValueError("Column 'group' missing after prepare_for_fitting")

                logger.info(
                    f"  Data: {len(all_df)} records, "
                    f"{all_df['group'].nunique()} scenarios"
                )

                results_fit = batch_fit_cognitive_model(
                    all_df,
                    extended_mode=extended,
                    config=config,
                )

                if results_fit:
                    report = format_results_report(results_fit)
                    print("\n" + report)

                    fit_dir = os.path.join(output_base, model_name, run_id)
                    os.makedirs(fit_dir, exist_ok=True)
                    output_path = os.path.join(fit_dir, "cognitive_report.csv")

                    records = []
                    for g, r in results_fit.items():
                        rec = {
                            'Group': g,
                            'nll': r.nll,
                            'bic': r.bic,
                            'H': r.entropy,
                            'count': r.count,
                        }
                        rec.update(r.params)
                        records.append(rec)

                    pd.DataFrame(records).to_csv(output_path, index=False)

                    results['fitting'] = {
                        'path': output_path,
                        'data': records,
                    }
                    logger.info(f"  ✓ Fitting saved to: {output_path}")
        except Exception as e:
            logger.error(f"  ✗ Parameter fitting failed: {e}")

    # 3. Drift plot
    if run_drift and results['metrics']:
        try:
            drift_dir = os.path.join(output_base, model_name, 'visualizations')
            os.makedirs(drift_dir, exist_ok=True)

            plotter = DriftPlotter(drift_dir)
            plotter.run(run_path, f"drift_{run_id}.png")
            logger.info("  ✓ Drift plot saved")
        except Exception as e:
            logger.error(f"  ✗ Drift plot failed: {e}")

    results['success'] = bool(results['metrics'] or results['fitting'])
    return results


def main(args):
    logger = setup_logger("Batch_Analysis")
    config = load_config(args.config)

    # Load configuration from file if --config is provided, otherwise use args
    use_config_file = args.config is not None

    if use_config_file:
        # Read settings from config file
        # For config mode, folder is the input_base_dir (where to read data from)
        # and output_base is where to save results
        folder = config.get('input_output', {}).get('input_base_dir')
        if not folder:
            folder = config.get('input_output', {}).get('output_base_dir', 'logs/analysis')
        output_base = config.get('input_output', {}).get('output_base_dir', 'logs/analysis')

        # Workflow steps from config
        workflow = config.get('workflow', {})
        steps = workflow.get('steps', ['metrics', 'fitting', 'drift', 'radar'])

        # Determine analysis type based on workflow steps
        if workflow.get('auto_chain', False):
            analysis_type = 'all'
        else:
            # Map config steps to analysis type
            if 'radar' in steps and len(steps) == 1:
                analysis_type = 'radar'
            elif 'drift' in steps and len(steps) == 1:
                analysis_type = 'drift'
            elif 'fitting' in steps and len(steps) == 1:
                analysis_type = 'fit'
            elif 'metrics' in steps and len(steps) == 1:
                analysis_type = 'metrics'
            else:
                analysis_type = 'all'

        # Extended mode from config
        extended = config.get('parameter_fitting', {}).get('extended_mode', False)

        logger.info(f"Using configuration from: {args.config}")
        logger.info(f"Analysis type (from config): {analysis_type}")
        logger.info(f"Workflow steps: {steps}")
        logger.info(f"Extended mode: {extended}")
    else:
        # Use command line arguments
        folder = args.folder
        analysis_type = args.type
        extended = args.extended

        # Add timestamp to output if not specified
        if not args.output_dir:
            import datetime
            output_base = os.path.join(
                'logs', 'analysis',
                f"batch_{datetime.datetime.now():%Y%m%d_%H%M%S}"
            )
        else:
            output_base = args.output_dir

        logger.info(f"Using command line arguments")
        logger.info(f"Analysis type: {analysis_type}")
        logger.info(f"Extended mode: {extended}")

    if not os.path.exists(folder):
        logger.error(f"Folder does not exist: {folder}")
        return 1

    os.makedirs(output_base, exist_ok=True)
    logger.info(f"Output directory: {output_base}")

    # Override extended mode from config if needed
    if use_config_file:
        config_extended = config.get('parameter_fitting', {}).get('extended_mode', False)
        if config_extended and not extended:
            extended = config_extended
            logger.info(f"Extended mode overridden by config: {extended}")

    models = find_model_runs(folder)
    if not models:
        logger.error("No valid model runs found")
        return 1

    models_data = {}

    for model_name, runs in tqdm(models.items(), desc="Processing models"):
        models_data[model_name] = {}
        for run_id, run_path in tqdm(runs.items(), desc=model_name, leave=False):
            # Determine which operations to run
            if use_config_file:
                run_metrics = 'metrics' in steps or analysis_type in ('metrics', 'all')
                run_fit = 'fitting' in steps or analysis_type in ('fit', 'all')
                run_drift = 'drift' in steps or analysis_type in ('drift', 'all')
            else:
                run_metrics = analysis_type in ('metrics', 'all')
                run_fit = analysis_type in ('fit', 'all')
                run_drift = analysis_type in ('drift', 'all')

            result = analyze_single_run(
                run_path, run_id, model_name,
                output_base, extended, config, logger,
                run_metrics, run_fit, run_drift,
            )
            models_data[model_name][run_id] = result

    # Summary report
    if ('metrics' in analysis_type or 'fit' in analysis_type or analysis_type == 'all') or \
       (use_config_file and ('metrics' in steps or 'fitting' in steps)):
        logger.info("\n" + "=" * 70)
        logger.info("SUMMARY REPORT")
        logger.info("=" * 70)
        for model_name, runs in models_data.items():
            for run_id, result in runs.items():
                if result['metrics'] is not None and result['metrics'].get('summary') is not None:
                    logger.info(f"\n--- {model_name}/{run_id} ---")
                    MetricsCalculator.print_report(
                        result['metrics']['summary'],
                        f"{model_name}/{run_id}"
                    )
                if result['fitting'] is not None:
                    logger.info(f"\n--- {model_name}/{run_id} (Fitting) ---")
                    logger.info(f"  Cognitive report saved to: {result['fitting']['path']}")

    # Radar plots
    if ('radar' in analysis_type or analysis_type == 'all') or \
       (use_config_file and 'radar' in steps):
        radar_config = config.get('radar_plot', {})
        if radar_config.get('enabled', True):
            # Get radar output directory from config
            # Priority: command-line output_dir > config.output_dir > output_base
            if not use_config_file and args.output_dir:
                # Command-line mode with --output_dir: use it directly
                radar_output_dir = args.output_dir
            else:
                # Config mode or no --output_dir: use config or output_base
                radar_output_dir = radar_config.get('output_dir')
                if not radar_output_dir:
                    radar_output_dir = output_base
                elif not os.path.isabs(radar_output_dir):
                    radar_output_dir = os.path.join(project_root, radar_output_dir)

            # Get exclude_models from config
            exclude_models = radar_config.get('exclude_models', [])
            # Get layer config
            layer_config = radar_config.get('layers', {})
            layered_mode = radar_config.get('layered_mode', True)
            normalization = radar_config.get('normalization', 'minmax')
            comparison = radar_config.get('comparison', {})
            # Get additional plots config
            additional_plots = radar_config.get('additional_plots', {})

            # Build config dict for RadarPlotter
            plotter_config = {
                'layered_mode': layered_mode,
                'layers': layer_config,
                'normalization': normalization,
                'comparison': comparison,
                'additional_plots': additional_plots
            }

            plotter = RadarPlotter(radar_output_dir, exclude_models=exclude_models, config=plotter_config)
            plotter.run(folder)  # Read from input folder

            # Generate additional style plots if enabled
            df = plotter.load_and_prep_data(folder)
            if not df.empty:
                # Check if perception diverging is enabled
                if additional_plots.get('perception_diverging', {}).get('enabled', True):
                    filename = additional_plots.get('perception_diverging', {}).get('filename', '1_perception_bias.png')
                    plotter.plot_perception_diverging(df, filename)

                # Check if alpha subplots is enabled
                if additional_plots.get('alpha_subplots', {}).get('enabled', True):
                    filename = additional_plots.get('alpha_subplots', {}).get('filename', '2_alpha_asymmetry_subplots.png')
                    plotter.plot_alpha_subplots(df, filename)

                # Check if risk preference is enabled
                if additional_plots.get('risk_preference', {}).get('enabled', True):
                    filename = additional_plots.get('risk_preference', {}).get('filename', '3_risk_preference_subplots.png')
                    plotter.plot_risk_preference(df, filename)

            logger.info(f"✓ Radar plots saved to: {radar_output_dir}")
        else:
            logger.info("⚠️  Radar plots disabled in config")

    logger.info("✅ BATCH ANALYSIS COMPLETE")
    logger.info(f"Results saved to: {output_base}")
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Batch analysis for agent cognitive attack experiments. "
                    "Use --config for config-driven analysis or --folder + --type for command-line mode."
    )

    # Config-driven mode (primary)
    parser.add_argument(
        "--config",
        help="Path to analysis.json config file. When provided, uses config settings instead of command-line args"
    )
    parser.add_argument(
        "--folder",
        help="Input folder with experiment results. Required if --config is not used"
    )

    # Command-line mode (backward compatible)
    parser.add_argument(
        "--type",
        choices=['metrics', 'fit', 'drift', 'radar', 'all'],
        help="Analysis type. Required if --config is not used"
    )
    parser.add_argument("--output_dir", help="Output directory (optional)")
    parser.add_argument("--extended", action="store_true", help="Enable extended fitting mode")

    args = parser.parse_args()

    # Validation: either config OR (folder + type) must be provided
    if not args.config:
        if not args.folder or not args.type:
            parser.error("Either --config OR both --folder and --type must be provided")

    sys.exit(main(args))
