"""
Common Experiment Runner

Provides unified training, evaluation, and visualization pipeline
For hyperparameter experiments and ablation studies
"""

import os
import sys
import json
import torch
import random
import numpy as np
from typing import Dict, List, Optional
from datetime import datetime

# Add project root directory
project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.insert(0, project_root)

from src.models import *
from src.training import train_model, TrainingConfig, ModelConfig
from src.evaluation import evaluate_model_on_test_set


def set_seed(seed: int = 42):
    """Set random seed to ensure experiment reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def create_model(model_config: ModelConfig, dataset_type: str) -> torch.nn.Module:
    """
    Create model based on configuration

    Args:
        model_config: Model configuration
        dataset_type: Dataset type
    """
    model_type = model_config.model_type
    config_dict = model_config.to_dict()

    # Determine input channels based on dataset
    if dataset_type == 'convection_diffusion':
        in_channels = 2  # c + u
    elif dataset_type == 'traffic_flow':
        in_channels = 2  # rho + vmax
    elif dataset_type == 'shallow_water':
        in_channels = 3  # h + mx + my
    elif dataset_type == 'spinodal_decomposition':
        in_channels = 1  # phi
    else:
        raise ValueError(f"Unknown dataset_type: {dataset_type}")

    # Create model - supports new and old naming conventions
    # 1D FluxNet models
    if model_type == 'FluxNet_N_1D':
        return FluxNet_N_1D(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_P_1D':
        return FluxNet_P_1D(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_L_1D':
        return FluxNet_L_1D(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_D_1D':
        return FluxNet_D_1D(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_U_1D':
        from src.models import FluxNet_U_1D
        return FluxNet_U_1D(in_channels=in_channels, **config_dict)

    # 2D FluxNet models
    elif model_type in ['FluxNet_N', 'FluxNet_U']:
        return FluxNet_N(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_P':
        return FluxNet_P(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_L':
        return FluxNet_L(in_channels=in_channels, **config_dict)

    elif model_type == 'FluxNet_D':
        return FluxNet_D(in_channels=in_channels, **config_dict)

    # Shallow water models
    elif model_type == 'FluxNet_SW_2D':
        return FluxNet_SW_2D(**config_dict)

    elif model_type == 'FluxNet_SW_Baseline':
        return FluxNet_SW_Baseline(**config_dict)

    elif model_type == 'FNO_SW':
        if FNO_SW is None:
            raise ImportError("FNO_SW not available")
        return FNO_SW(**config_dict)

    # Shallow water baselines with projection
    elif model_type == 'FNO_SW_Proj':
        from src.models import FNO_SW_Proj
        if FNO_SW_Proj is None:
            raise ImportError("FNO_SW_Proj not available")
        return FNO_SW_Proj(**config_dict)

    elif model_type == 'CNN_SW_Proj':
        from src.models import CNN_SW_Proj
        if CNN_SW_Proj is None:
            raise ImportError("CNN_SW_Proj not available")
        return CNN_SW_Proj(**config_dict)

    # 1D FNO models
    elif model_type == 'FNO_1D':
        from src.models import FNO_1D
        if FNO_1D is None:
            raise ImportError("FNO_1D not available")
        return FNO_1D(in_channels=in_channels, out_channels=1, **config_dict)

    elif model_type == 'FNO_FluxD_1D':
        from src.models import FNO_FluxD_1D
        if FNO_FluxD_1D is None:
            raise ImportError("FNO_FluxD_1D not available")
        return FNO_FluxD_1D(in_channels=in_channels, **config_dict)

    # FNO with FluxLAP head (shallow water ablation)
    elif model_type == 'FNO_FluxLAP':
        from src.models import FNO_FluxLAP
        if FNO_FluxLAP is None:
            raise ImportError("FNO_FluxLAP not available")
        return FNO_FluxLAP(**config_dict)

    # CNN baselines
    elif model_type == 'CNN_Baseline_1D':
        return CNN_Baseline_1D(in_channels=in_channels, out_channels=1, **config_dict)

    elif model_type == 'CNN_Baseline_2D':
        return CNN_Baseline_2D(in_channels=in_channels, out_channels=1, **config_dict)

    else:
        raise ValueError(f"Unknown model_type: {model_type}")


def get_experiment_name(model_config: ModelConfig, training_config: TrainingConfig) -> str:
    """Generate experiment name"""
    parts = [model_config.model_type]
    parts.append(f"c{model_config.base_channels}")
    parts.append(f"b{model_config.num_blocks}")
    parts.append(f"k{model_config.kernel_size}")

    if 'FluxNet' in model_config.model_type and 'Baseline' not in model_config.model_type:
        parts.append(f"n{model_config.neighborhood_size}")

    parts.append(f"ndt{training_config.ndt}")

    if training_config.use_pushforward:
        parts.append("pf")

    if training_config.soft_conservation_weight > 0:
        parts.append("soft")

    return "_".join(parts)


def get_dataset_bounds(dataset_type: str) -> Dict:
    """
    Get intrinsic bounds of the dataset

    These are physical bounds of the dataset, independent of the model
    Used to calculate out-of-bounds rates
    """
    bounds = {
        'convection_diffusion': {'lower_bound': 0.0, 'upper_bound': None},  # c >= 0
        'traffic_flow': {'lower_bound': 0.0, 'upper_bound': 1.0},  # 0 <= rho <= 1
        'shallow_water': {'lower_bound': 0.0, 'upper_bound': None},  # h >= 0 (h field only)
        'spinodal_decomposition': {'lower_bound': 0.0, 'upper_bound': 1.0},  # 0 <= phi <= 1
    }
    return bounds.get(dataset_type, {'lower_bound': None, 'upper_bound': None})


def run_single_experiment(
    model_config: ModelConfig,
    training_config: TrainingConfig,
    dataset_type: str,
    train_folder: str,
    val_folder: str,
    test_folder: str,
    save_path: str,
    experiment_name: Optional[str] = None,
    gpu_id: int = 0,
    run_training: bool = True,
    run_evaluation: bool = True,
    seed: int = 42,
    evaluate_mode: str = 'both',
    visualize_trajectories: Optional[str] = None
) -> Dict:
    """
    Run a single experiment

    Args:
        model_config: Model configuration
        training_config: Training configuration
        dataset_type: Dataset type
        train_folder: Training data directory (absolute path)
        val_folder: Validation data directory (absolute path)
        test_folder: Test data directory (absolute path)
        save_path: Root directory for saving results
        experiment_name: Experiment name (None for auto-generation)
        gpu_id: GPU ID
        run_training: Whether to run training
        run_evaluation: Whether to run evaluation
        seed: Random seed
        evaluate_mode: Evaluation mode ('onestep', 'rollout', 'both')
        visualize_trajectories: Visualize trajectories ('all', None, or specific list)

    Returns:
        Dictionary of experiment results
    """
    # Set random seed
    set_seed(seed)

    # Generate experiment name
    if experiment_name is None:
        experiment_name = get_experiment_name(model_config, training_config)

    print("=" * 80)
    print(f"Experiment: {experiment_name}")
    print(f"Dataset: {dataset_type}")
    print(f"Model: {model_config.model_type}")
    print("=" * 80)

    # Set device
    device = torch.device(f"cuda:{gpu_id}" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create save directory
    result_dir = os.path.join(save_path, experiment_name)
    os.makedirs(result_dir, exist_ok=True)
    print(f"Results saved to: {result_dir}")

    # Create model
    model = create_model(model_config, dataset_type)
    param_count = sum(p.numel() for p in model.parameters())
    print(f"Model parameter count: {param_count:,}")

    # Get dataset intrinsic bounds
    dataset_bounds = get_dataset_bounds(dataset_type)

    # Save configuration
    config_data = {
        'experiment_name': experiment_name,
        'dataset_type': dataset_type,
        'model_config': model_config.__dict__,
        'training_config': training_config.to_dict(),
        'seed': seed,
        'param_count': param_count,
        'dataset_bounds': dataset_bounds,
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    with open(os.path.join(result_dir, 'config.json'), 'w') as f:
        json.dump(config_data, f, indent=2, default=str)

    results = {'experiment_name': experiment_name}

    # ========== Training ==========
    if run_training:
        print("\n" + "=" * 60)
        print("Starting training")
        print("=" * 60)

        train_results = train_model(
            model=model,
            dataset_type=dataset_type,
            train_folder=train_folder,
            val_folder=val_folder,
            result_dir=result_dir,
            config=training_config,
            device=device,
            num_workers=training_config.num_workers
        )

        results['training'] = {
            'best_loss': train_results['best_loss'],
            'total_time': train_results['total_time']
        }

    # ========== Evaluation ==========
    if run_evaluation:
        print("\n" + "=" * 60)
        print("Starting evaluation")
        print("=" * 60)

        # Load best model
        best_model_path = os.path.join(result_dir, 'best_model.pt')
        if os.path.exists(best_model_path):
            model.load_state_dict(torch.load(best_model_path, map_location=device))
            print(f"Loaded best model: {best_model_path}")
        else:
            print("Warning: Best model not found, using current model state")

        model.eval()

        # Evaluate using dataset intrinsic bounds
        eval_output_dir = os.path.join(result_dir, 'evaluation')

        eval_results = evaluate_model_on_test_set(
            model=model,
            test_folder=test_folder,
            dataset_type=dataset_type,
            output_dir=eval_output_dir,
            ndt=training_config.ndt,
            lower_bound=dataset_bounds['lower_bound'],
            upper_bound=dataset_bounds['upper_bound'],
            device=device,
            mode=evaluate_mode,
            visualize_trajectories=visualize_trajectories
        )

        results['evaluation'] = {}
        if 'onestep' in eval_results:
            results['evaluation']['onestep'] = {
                'mae': eval_results['onestep']['mae_overall_mean'],
                'mae_std': eval_results['onestep']['mae_overall_std'],
                'rmse': eval_results['onestep']['rmse_overall_mean'],
                'cons_drift_mean': eval_results['onestep']['cons_drift_mean'],
                'cons_drift_max': eval_results['onestep']['cons_drift_max'],
                'viol_lower': eval_results['onestep']['viol_lower_mean'],
                'viol_upper': eval_results['onestep']['viol_upper_mean'],
            }
        # Changed: report last frame rollout error instead of global average
        if 'rollout' in eval_results:
            results['evaluation']['rollout'] = {
                'mae': eval_results['rollout']['mae_overall_mean'],
                'mae_std': eval_results['rollout']['mae_overall_std'],
                'rmse': eval_results['rollout']['rmse_overall_mean'],
                'cons_drift_mean': eval_results['rollout']['cons_drift_mean'],
                'cons_drift_max': eval_results['rollout']['cons_drift_max'],
                'viol_lower': eval_results['rollout']['viol_lower_mean'],
                'viol_upper': eval_results['rollout']['viol_upper_mean'],
            }

        # Save results only during evaluation - DO NOT CHANGE!!!
        with open(os.path.join(result_dir, 'results.json'), 'w') as f:
            json.dump(results, f, indent=2, default=str)

        # Save compact results for summary
        import joblib
        joblib.dump(results, os.path.join(result_dir, 'results.pkl'))

    print("\n" + "=" * 60)
    print(f"Experiment completed: {experiment_name}")
    print("=" * 60)

    return results


def generate_summary_table(save_path: str, experiments: List[Dict], output_file: str = "summary.md"):
    """
    Generate experiment summary table (Markdown format)

    Includes:
    - Rollout table: Error at T=1.0 time point (final rollout timestep, not global average)
    - Conservation performance and out-of-bounds statistics (with variance)
    - Conditional mean OOB magnitude
    - Separate statistics for three shallow water fields
    """
    import joblib

    results = []
    for exp in experiments:
        exp_name = exp.get('name', get_experiment_name(exp['model_config'], exp['training_config']))
        result_file = os.path.join(save_path, exp_name, 'results.pkl')

        if os.path.exists(result_file):
            result = joblib.load(result_file)
            result['config'] = exp
            results.append(result)

            # Try to read detailed evaluation results
            eval_dir = os.path.join(save_path, exp_name, 'evaluation')
            for mode in ['onestep', 'rollout']:
                summary_file = os.path.join(eval_dir, mode, f'test_set_summary_{mode}.json')
                if os.path.exists(summary_file):
                    with open(summary_file, 'r') as f:
                        result[f'{mode}_detailed'] = json.load(f)

    if not results:
        print("No experiment results found")
        return

    valid_results = [r for r in results if 'evaluation' in r]
    if not valid_results:
        print("No valid evaluation results found")
        return

    # Find best results
    rollout_results = [r for r in valid_results if 'rollout' in r['evaluation']]
    best_rollout_mae = min(r.get('rollout_detailed', {}).get('mae_at_T1.0', r['evaluation']['rollout'].get('mae', float('inf')))
                          for r in rollout_results) if rollout_results else None

    # ========== Generate Markdown content ==========
    md_content = f"""# Ablation Experiment Summary

Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

---

## 1. Rollout Evaluation Results (@T=1.0, Final Timestep)

**Note**: All rollout metrics are values at T=1.0 timestep, not global evolution averages.

| Model | MAE@T=1.0 (mean±std) | RMSE@T=1.0 | Conservation Drift@T=1.0 (mean±std) | Max Conservation Drift |
|-------|---------------------|------------|--------------------------------------|------------------------|
"""

    for r in valid_results:
        exp_name = r['experiment_name']
        rollout = r['evaluation'].get('rollout', {})
        detailed = r.get('rollout_detailed', {})

        if rollout:
            # Use values at @T=1.0
            mae_100 = detailed.get('mae_at_T1.0', rollout.get('mae', 0))
            mae_100_std = detailed.get('mae_at_T1.0_std', rollout.get('mae_std', 0))
            rmse_100 = detailed.get('rmse_at_T1.0', rollout.get('rmse', 0))
            cons_100 = detailed.get('cons_drift_at_T1.0', rollout.get('cons_drift_mean', 0))
            cons_100_std = detailed.get('cons_drift_at_T1.0_std', rollout.get('cons_drift_std', 0))
            cons_max = rollout.get('cons_drift_max', 0)

            mae_str = f"{mae_100:.2e}±{mae_100_std:.2e}"
            if best_rollout_mae and abs(mae_100 - best_rollout_mae) < 1e-10:
                mae_str = f"**{mae_str}**"

            md_content += f"| {exp_name} | {mae_str} | {rmse_100:.2e} | {cons_100:.2e}±{cons_100_std:.2e} | {cons_max:.2e} |\n"

    md_content += f"""
---

## 2. Out-of-Bounds Statistics (Rollout@T=1.0)

| Model | Lower Bound Violation Rate(mean±std) | Upper Bound Violation Rate(mean±std) | Cond. OOB Mag.(Lower) | Cond. OOB Mag.(Upper) | Prediction Range |
|-------|--------------------------------------|--------------------------------------|------------------------|------------------------|------------------|
"""

    for r in valid_results:
        exp_name = r['experiment_name']
        rollout = r['evaluation'].get('rollout', {})
        detailed = r.get('rollout_detailed', {})

        if rollout:
            # Out-of-bounds rates at @T=1.0 timestep
            viol_l = detailed.get('viol_lower_at_T1.0', rollout.get('viol_lower', 0))
            viol_l_std = detailed.get('viol_lower_at_T1.0_std', rollout.get('viol_lower_std', 0))
            viol_u = detailed.get('viol_upper_at_T1.0', rollout.get('viol_upper', 0))
            viol_u_std = detailed.get('viol_upper_at_T1.0_std', rollout.get('viol_upper_std', 0))

            viol_l_str = f"{viol_l:.2f}%±{viol_l_std:.2f}%"
            viol_u_str = f"{viol_u:.2f}%±{viol_u_std:.2f}%"

            # Conditional out-of-bounds magnitude (if available)
            cond_mag_l = detailed.get('cond_magnitude_lower_mean', 0)
            cond_mag_l_std = detailed.get('cond_magnitude_lower_std', 0)
            cond_mag_u = detailed.get('cond_magnitude_upper_mean', 0)
            cond_mag_u_std = detailed.get('cond_magnitude_upper_std', 0)
            cond_l_str = f"{cond_mag_l:.2e}±{cond_mag_l_std:.2e}" if cond_mag_l > 0 else "N/A"
            cond_u_str = f"{cond_mag_u:.2e}±{cond_mag_u_std:.2e}" if cond_mag_u > 0 else "N/A"

            min_val = detailed.get('min_value_overall', 0)
            max_val = detailed.get('max_value_overall', 1)
            range_str = f"[{min_val:.4f}, {max_val:.4f}]"

            md_content += f"| {exp_name} | {viol_l_str} | {viol_u_str} | {cond_l_str} | {cond_u_str} | {range_str} |\n"

    # Check if shallow water three-field statistics are available
    has_sw_stats = any(r.get('rollout_detailed', {}).get('sw_mae_h_mean') for r in valid_results)
    if has_sw_stats:
        md_content += f"""
---

## 3. Shallow Water Three-Field Statistics (Rollout)

| Model | h Field MAE (mean±std) | mx Field MAE (mean±std) | my Field MAE (mean±std) |
|-------|------------------------|-------------------------|-------------------------|
"""
        for r in valid_results:
            exp_name = r['experiment_name']
            detailed = r.get('rollout_detailed', {})
            if detailed.get('sw_mae_h_mean'):
                h_mae = f"{detailed.get('sw_mae_h_mean', 0):.2e}±{detailed.get('sw_mae_h_std', 0):.2e}"
                mx_mae = f"{detailed.get('sw_mae_mx_mean', 0):.2e}±{detailed.get('sw_mae_mx_std', 0):.2e}"
                my_mae = f"{detailed.get('sw_mae_my_mean', 0):.2e}±{detailed.get('sw_mae_my_std', 0):.2e}"
                md_content += f"| {exp_name} | {h_mae} | {mx_mae} | {my_mae} |\n"

        md_content += f"""
### Shallow Water Conservation Drift

| Model | h Conservation Drift (mean±std) | mx Conservation Drift (mean±std) | my Conservation Drift (mean±std) |
|-------|---------------------------------|----------------------------------|----------------------------------|
"""
        for r in valid_results:
            exp_name = r['experiment_name']
            detailed = r.get('rollout_detailed', {})
            if detailed.get('sw_cons_drift_h_mean'):
                h_cons = f"{detailed.get('sw_cons_drift_h_mean', 0):.2e}±{detailed.get('sw_cons_drift_h_std', 0):.2e}"
                mx_cons = f"{detailed.get('sw_cons_drift_mx_mean', 0):.2e}±{detailed.get('sw_cons_drift_mx_std', 0):.2e}"
                my_cons = f"{detailed.get('sw_cons_drift_my_mean', 0):.2e}±{detailed.get('sw_cons_drift_my_std', 0):.2e}"
                md_content += f"| {exp_name} | {h_cons} | {mx_cons} | {my_cons} |\n"

        md_content += f"""
### Shallow Water h Field Out-of-Bounds Statistics

| Model | h Lower Bound Violation Rate (mean±std) | h Conditional OOB Magnitude (mean±std) |
|-------|----------------------------------------|----------------------------------------|
"""
        for r in valid_results:
            exp_name = r['experiment_name']
            detailed = r.get('rollout_detailed', {})
            if detailed.get('sw_h_viol_rate_mean') is not None:
                h_viol = f"{detailed.get('sw_h_viol_rate_mean', 0):.2f}%±{detailed.get('sw_h_viol_rate_std', 0):.2f}%"
                h_cond = detailed.get('sw_h_cond_mag_mean', 0)
                h_cond_std = detailed.get('sw_h_cond_mag_std', 0)
                h_cond_str = f"{h_cond:.2e}±{h_cond_std:.2e}" if h_cond > 0 else "N/A"
                md_content += f"| {exp_name} | {h_viol} | {h_cond_str} |\n"

    md_content += f"""
---

## Notes

- **Bold** indicates the optimal value for that metric
- All rollout metrics are values at **@T=1.0 timestep** (final timestep), not global evolution averages
- Error format: mean±std (across test trajectories)
- Conservation drift: Relative conservation error
- Conditional Mean OOB Magnitude: Average out-of-bounds magnitude calculated only for points that violate bounds, measuring "how far problematic points deviate"

## Best Configuration

- Best Rollout MAE@T=1.0: {f'{best_rollout_mae:.4e}' if best_rollout_mae else 'N/A'}
"""

    # Save
    with open(os.path.join(save_path, output_file), 'w', encoding='utf-8') as f:
        f.write(md_content)

    print(f"Summary table saved to: {os.path.join(save_path, output_file)}")


if __name__ == "__main__":
    print("This is the common experiment runner module")
    print("Please create specific experiment scripts to use this module")
