import os
import sys
import json
import argparse
import time
from datetime import datetime
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.multiprocessing as mp
from typing import List, Dict, Tuple, Optional

sys.path.insert(0, str(Path(__file__).parent.parent))

from exp7.config import ExperimentConfig, SweepConfig
from exp7.train import run_single_experiment
from exp7.evaluate import evaluate_model, generate_summary_statistics


def run_experiment_worker(
    config: ExperimentConfig,
    gpu_id: int,
    verbose: bool = False,
) -> Dict:
    if torch.cuda.is_available():
        config.device = f"cuda:{gpu_id}"
        torch.cuda.set_device(gpu_id)
    else:
        config.device = "cpu"

    try:
        model, logs = run_single_experiment(config, verbose=verbose)

        metrics = evaluate_model(model, config)

        result = {
            "omega_deg": config.omega_deg,
            "bc_type": config.bc_type,
            "method": config.method,
            "seed": config.seed,
            "true_mu": config.true_mu,
            **metrics,
            "success": metrics["rel_err_mu"] < 5.0,
            "status": "completed",
        }

        return result

    except Exception as e:
        return {
            "omega_deg": config.omega_deg,
            "bc_type": config.bc_type,
            "method": config.method,
            "seed": config.seed,
            "true_mu": config.true_mu,
            "rel_err_mu": float("nan"),
            "dominant_mu": float("nan"),
            "constraint_violation": float("nan"),
            "solution_l2_rel": float("nan"),
            "success": False,
            "status": f"error: {str(e)}",
        }


def run_experiment_wrapper(args):
    config, gpu_id, verbose = args
    return run_experiment_worker(config, gpu_id, verbose)


def run_sweep_sequential(
    sweep_config: SweepConfig,
    gpu_id: int = 0,
    verbose: bool = True,
) -> List[Dict]:
    configs = sweep_config.get_all_configs()
    results = []

    print(f"Running {len(configs)} experiments sequentially on GPU {gpu_id}")

    for i, config in enumerate(configs):
        if verbose:
            print(
                f"\n[{i+1}/{len(configs)}] ω={config.omega_deg:.1f}° "
                f"BC={config.bc_type} method={config.method} seed={config.seed}"
            )

        result = run_experiment_worker(config, gpu_id, verbose=False)
        results.append(result)

        if verbose:
            status = "✓" if result["success"] else "✗"
            print(
                f"  {status} rel_err={result['rel_err_mu']:.2f}% "
                f"μ_pred={result.get('dominant_mu', 'nan'):.4f} "
                f"μ_true={result['true_mu']:.4f}"
            )

    return results


def run_sweep_parallel(
    sweep_config: SweepConfig,
    n_gpus: int = 4,
    verbose: bool = True,
) -> List[Dict]:
    configs = sweep_config.get_all_configs()

    if not torch.cuda.is_available():
        print("CUDA not available, falling back to sequential execution on CPU")
        return run_sweep_sequential(sweep_config, gpu_id=0, verbose=verbose)

    available_gpus = torch.cuda.device_count()
    n_gpus = min(n_gpus, available_gpus)
    print(f"Running {len(configs)} experiments in parallel across {n_gpus} GPUs")

    args_list = [(config, i % n_gpus, False) for i, config in enumerate(configs)]

    mp.set_start_method("spawn", force=True)

    results = []
    with mp.Pool(n_gpus) as pool:
        for i, result in enumerate(pool.imap(run_experiment_wrapper, args_list)):
            results.append(result)
            if verbose and (i + 1) % 10 == 0:
                print(f"  Completed {i+1}/{len(configs)} experiments")

    return results


def run_sweep_distributed(
    sweep_config: SweepConfig,
    worker_id: int,
    n_workers: int,
    gpu_id: int,
    verbose: bool = True,
) -> List[Dict]:
    configs = sweep_config.get_all_configs()

    worker_configs = [c for i, c in enumerate(configs) if i % n_workers == worker_id]

    print(
        f"Worker {worker_id}/{n_workers}: Running {len(worker_configs)} experiments on GPU {gpu_id}"
    )

    results = []
    for i, config in enumerate(worker_configs):
        if verbose:
            print(
                f"  [{i+1}/{len(worker_configs)}] ω={config.omega_deg:.1f}° "
                f"BC={config.bc_type} method={config.method}"
            )

        result = run_experiment_worker(config, gpu_id, verbose=False)
        results.append(result)

    return results


def save_results(
    results: List[Dict],
    output_dir: str,
    csv_name: str = "exp7.csv",
    json_name: str = "exp7_stats.json",
) -> Tuple[str, str]:
    os.makedirs(output_dir, exist_ok=True)

    csv_path = os.path.join(output_dir, csv_name)
    df = pd.DataFrame(results)

    csv_columns = [
        "omega_deg",
        "bc_type",
        "method",
        "seed",
        "true_mu",
        "dominant_mu",
        "rel_err_mu",
        "abs_err_mu",
        "constraint_violation",
        "solution_l2_rel",
        "bc_dirichlet_err",
        "bc_neumann_err",
        "mode_error",
        "success",
        "status",
    ]
    csv_columns = [c for c in csv_columns if c in df.columns]
    df[csv_columns].to_csv(csv_path, index=False)
    print(f"Saved results to {csv_path}")

    summary = generate_summary_statistics(results)
    json_path = os.path.join(output_dir, json_name)
    with open(json_path, "w") as f:
        json.dump(summary, f, indent=2)
    print(f"Saved statistics to {json_path}")

    return csv_path, json_path


def print_summary(results: List[Dict]):
    summary = generate_summary_statistics(results)

    print("\n" + "=" * 70)
    print("EXPERIMENT 7 SUMMARY")
    print("=" * 70)

    for category in ["all", "naive", "constraint", "convex", "reentrant"]:
        if category not in summary:
            continue
        stats = summary[category]
        print(f"\n{category.upper()} ({stats['n_experiments']} experiments):")
        print(f"  Success rate (@5%): {stats['success_rate']:.1f}%")
        print(
            f"  RelErr μ: mean={stats['rel_err_mu']['mean']:.2f}% "
            f"median={stats['rel_err_mu']['median']:.3f}% "
            f"p90={stats['rel_err_mu']['p90']:.2f}%"
        )
        print(f"  Constraint viol: mean={stats['constraint_violation']['mean']:.2e}")

    print("\n" + "-" * 70)
    print("PER BC TYPE (Constraint-aware method only):")
    constraint_results = [r for r in results if r.get("method") == "constraint"]

    for bc in ["DD", "NN", "DN", "ND"]:
        bc_results = [r for r in constraint_results if r.get("bc_type") == bc]
        if not bc_results:
            continue

        success = sum(1 for r in bc_results if r["success"]) / len(bc_results) * 100
        rel_errs = [r["rel_err_mu"] for r in bc_results]
        print(
            f"  {bc}: n={len(bc_results)} success={success:.1f}% "
            f"median_err={np.median(rel_errs):.3f}%"
        )


def main():
    parser = argparse.ArgumentParser(description="Run Experiment 7 sweep")
    parser.add_argument("--n-omega", type=int, default=30, help="Number of angles")
    parser.add_argument(
        "--omega-min", type=float, default=90, help="Min angle (degrees)"
    )
    parser.add_argument(
        "--omega-max", type=float, default=330, help="Max angle (degrees)"
    )
    parser.add_argument("--seeds", type=int, nargs="+", default=[0, 1, 2], help="Seeds")
    parser.add_argument("--n-gpus", type=int, default=1, help="Number of GPUs")
    parser.add_argument("--output-dir", type=str, default=".", help="Output directory")
    parser.add_argument("--parallel", action="store_true", help="Run in parallel")
    parser.add_argument(
        "--worker-id", type=int, default=None, help="Worker ID for distributed"
    )
    parser.add_argument(
        "--n-workers", type=int, default=None, help="Total workers for distributed"
    )
    parser.add_argument("--gpu-id", type=int, default=0, help="GPU ID")
    parser.add_argument("--verbose", action="store_true", help="Verbose output")
    args = parser.parse_args()

    sweep_config = SweepConfig(
        n_omega=args.n_omega,
        omega_min_deg=args.omega_min,
        omega_max_deg=args.omega_max,
        seeds=args.seeds,
        output_dir=args.output_dir,
    )

    print(f"Experiment 7: Large-scale wedge/corner sweep")
    print(f"Total experiments: {sweep_config.total_experiments}")
    print(f"Angles: {args.n_omega} from {args.omega_min}° to {args.omega_max}°")
    print(f"BC types: {sweep_config.bc_types}")
    print(f"Seeds: {args.seeds}")

    start_time = time.time()

    if args.worker_id is not None and args.n_workers is not None:
        results = run_sweep_distributed(
            sweep_config, args.worker_id, args.n_workers, args.gpu_id, args.verbose
        )
        csv_name = f"exp7_worker{args.worker_id}.csv"
        json_name = f"exp7_stats_worker{args.worker_id}.json"
    elif args.parallel and args.n_gpus > 1:
        results = run_sweep_parallel(sweep_config, args.n_gpus, args.verbose)
        csv_name = "exp7.csv"
        json_name = "exp7_stats.json"
    else:
        results = run_sweep_sequential(sweep_config, args.gpu_id, args.verbose)
        csv_name = "exp7.csv"
        json_name = "exp7_stats.json"

    elapsed = time.time() - start_time
    print(f"\nCompleted in {elapsed/60:.1f} minutes")

    save_results(results, args.output_dir, csv_name, json_name)

    print_summary(results)


if __name__ == "__main__":
    main()
