import os
import glob
import json
import argparse
import subprocess
import random
import datetime
import re
import math
from collections import defaultdict
from typing import Dict, Any, Optional

import pandas as pd

SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)

DEFAULT_S3_BUCKET = os.environ.get("SWEEP_S3_BUCKET", "s3://scottviteri")
METADATA_PATTERN = re.compile(r"eval_metadata(?:_stride(\d+))?\.json$")
_S3_WARNING_PRINTED = False
DEFAULT_WIKI_NUM_SAMPLES = 1024
SUPPORTED_TASKS = ["gsm8k", "mmlu", "arc", "svamp", "aqua", "mathqa", "arithmetic"]
WIKI_TASKS = ["wiki_continuation"]
# Ground-truth test/validation set sizes pulled from the scripted dataset loaders.
# These reflect the exact number of evaluation examples consumed by src/evaluation.py
# when --num_samples is omitted, as of 2025-11-22.
TASK_TEST_SET_SIZES = {
    "gsm8k": 1319,       # openai/gsm8k (test split)
    "mmlu": 1531,        # cais/mmlu (validation split)
    "arc": 294,          # ai2_arc ARC-Challenge (validation split, filtered to A-D choices)
    "svamp": 300,        # SVAMP test split
    "aqua": 254,         # AQuA-RAT test split
    "mathqa": 2985,      # MathQA test split
    "arithmetic": 200,   # synthetic evaluation set generated in evaluation.py (chunk_size default)
}


def _binomial_ci(p: float, n: Optional[int], z: float = 1.96) -> Optional[tuple[float, float]]:
    """
    Compute a normal-approximation binomial confidence interval for accuracy.
    
    Returns (lower, upper) or None if n is missing/invalid.
    """
    if n is None or n <= 0:
        return None
    # Clamp p into [0,1] to avoid numerical issues from bad metadata
    p_clamped = max(0.0, min(1.0, float(p)))
    se = math.sqrt(p_clamped * (1.0 - p_clamped) / n)
    lo = max(0.0, p_clamped - z * se)
    hi = min(1.0, p_clamped + z * se)
    return lo, hi


def safe_relpath(path, base_dir):
    if not base_dir:
        return path
    try:
        base_abs = os.path.abspath(base_dir)
        path_abs = os.path.abspath(path)
        common = os.path.commonpath([path_abs, base_abs])
        if common == base_abs:
            return os.path.relpath(path_abs, base_abs)
    except ValueError:
        pass
    return path


def _s3_uri_for_path(path, project_root, bucket):
    rel_path = safe_relpath(path, project_root).replace("\\", "/")
    uri = f"{bucket.rstrip('/')}/{rel_path}"
    if uri.startswith("s3:/") and not uri.startswith("s3://"):
        uri = uri.replace("s3:/", "s3://", 1)
    return uri


def upload_adapter_metadata(adapter_dir, project_root, bucket=None):
    """Upload metadata files from an adapter directory to S3."""
    bucket = bucket or DEFAULT_S3_BUCKET
    if not bucket: return
    s3_dest = _s3_uri_for_path(adapter_dir, project_root, bucket)
    
    include_args = [
        "--exclude", "*",
        "--include", "eval_metadata*.json",
        "--include", "eval_results*.jsonl"
    ]
    
    print(f"Uploading metadata for {os.path.basename(adapter_dir)} to S3...")
    try:
        subprocess.run(
            ["aws", "s3", "sync", adapter_dir, s3_dest, *include_args],
            check=True
        )
    except Exception as e:
        print(f"Warning: failed to upload metadata to {s3_dest}: {e}")


def list_s3_runs(dataset, s3_results_prefix, method_filter=None):
    """
    Discover runs on S3.
    Returns list of (dataset, method, s3_run_path) tuples.
    """
    if not s3_results_prefix:
        return []

    prefix = s3_results_prefix.rstrip("/")
    s3_path = f"{prefix}/{dataset}/"
    try:
        result = subprocess.run(
            ["aws", "s3", "ls", s3_path],
            capture_output=True,
            text=True,
            check=True,
        )
    except subprocess.CalledProcessError:
        return []

    discovered = []
    method_filter_lower = method_filter.lower() if method_filter else None

    for line in result.stdout.splitlines():
        line = line.strip()
        if not line.startswith("PRE "):
            continue
        run_name = line[4:].strip("/")
        if not run_name:
            continue
            
        # Parse method from run name {dataset}_{method}_{timestamp}
        # Simple heuristic: matches parse_run_dir logic
        if run_name == "baseline" or run_name.startswith("baseline_"):
             discovered.append((dataset, "baseline", run_name))
             continue

        if not run_name.startswith(dataset + "_"):
            continue
            
        parts = run_name.split('_')
        if len(parts) < 3:
            continue
            
        # Method is middle part(s)
        method_part = run_name[len(dataset)+1:].split('_')[:-2]
        method = "_".join(method_part)
        
        if method_filter_lower and method_filter_lower not in method.lower():
            continue
            
        discovered.append((dataset, method, run_name))
        
    return discovered


def list_s3_adapters(dataset, run_name, s3_results_prefix):
    """
    List adapter directories for a run on S3.
    Returns list of adapter names (e.g. 'adapter_50').
    """
    prefix = s3_results_prefix.rstrip("/")
    s3_path = f"{prefix}/{dataset}/{run_name}/"
    try:
        result = subprocess.run(
            ["aws", "s3", "ls", s3_path],
            capture_output=True,
            text=True,
            check=True,
        )
    except subprocess.CalledProcessError:
        return []

    adapters = []
    for line in result.stdout.splitlines():
        line = line.strip()
        if not line.startswith("PRE adapter_"):
            continue
        adapter_name = line[4:].strip("/")
        adapters.append(adapter_name)
    return sorted(adapters, key=lambda x: int(x.split('_')[-1]) if x.split('_')[-1].isdigit() else x)


def check_s3_metadata(dataset, run_name, adapter_name, s3_results_prefix):
    """
    Check if evaluation metadata already exists on S3 for this adapter.
    Returns (has_metadata, metadata_content_if_available)
    """
    prefix = s3_results_prefix.rstrip("/")
    if adapter_name:
        s3_adapter_path = f"{prefix}/{dataset}/{run_name}/{adapter_name}/"
    else:
        s3_adapter_path = f"{prefix}/{dataset}/{run_name}/"
    
    # List files in adapter dir
    try:
        result = subprocess.run(
            ["aws", "s3", "ls", s3_adapter_path],
            capture_output=True,
            text=True,
            check=True,
        )
    except subprocess.CalledProcessError:
        return False, None

    has_metadata = False
    for line in result.stdout.splitlines():
        if "eval_metadata" in line and line.endswith(".json"):
            has_metadata = True
            break
            
    return has_metadata, None


def download_adapter_metadata(dataset, run_name, adapter_name, project_root, s3_results_prefix):
    """Download just metadata files for an adapter."""
    prefix = s3_results_prefix.rstrip("/")
    if adapter_name:
        s3_path = f"{prefix}/{dataset}/{run_name}/{adapter_name}/"
    else:
        s3_path = f"{prefix}/{dataset}/{run_name}/"
    local_path = os.path.join(project_root, "results", dataset, run_name, adapter_name)
    os.makedirs(local_path, exist_ok=True)
    
    include_args = [
        "--exclude", "*",
        "--include", "eval_metadata*.json",
        "--include", "eval_results*.jsonl"
    ]
    
    try:
        subprocess.run(
            ["aws", "s3", "sync", s3_path, local_path, *include_args],
            check=True,
            capture_output=True
        )
    except subprocess.CalledProcessError as e:
        print(f"Warning: failed to download metadata from {s3_path}: {e}")


def download_adapter_weights(dataset, run_name, adapter_name, project_root, s3_results_prefix):
    """Download weights for a specific adapter."""
    prefix = s3_results_prefix.rstrip("/")
    if adapter_name:
        s3_path = f"{prefix}/{dataset}/{run_name}/{adapter_name}/"
    else:
        s3_path = f"{prefix}/{dataset}/{run_name}/"
    local_path = os.path.join(project_root, "results", dataset, run_name, adapter_name)
    
    print(f"Syncing weights for {adapter_name} from S3...")
    try:
        subprocess.run(
            ["aws", "s3", "sync", s3_path, local_path],
            check=True
        )
        return True
    except subprocess.CalledProcessError as e:
        print(f"Error downloading weights from {s3_path}: {e}")
        return False


def load_local_metadata(adapter_dir):
    """Load best metadata from a local adapter directory."""
    pattern = os.path.join(adapter_dir, "eval_metadata*.json")
    files = glob.glob(pattern)
    if not files:
        return None
        
    # Pick best stride/most samples
    best_meta = None
    best_score = (-1, -1, -float('inf')) # (stride, num_samples, accuracy)
    
    for fpath in files:
        try:
            with open(fpath, 'r') as f:
                data = json.load(f)
            
            stride = data.get("evaluation", {}).get("stride", 1)
            num_samples = data.get("evaluation", {}).get("num_samples", 0)
            acc = data.get("accuracy", 0)
            
            score = (stride, num_samples, acc)
            # Prefer lower stride (more detailed), then higher samples, then higher accuracy
            # Actually for score comparison: 
            # We want minimal stride (1 is best), max samples.
            # Let's use simple heuristic: just take the one with most samples
            
            current_score = (
                -stride, 
                num_samples if num_samples is not None else 0,
                acc if isinstance(acc, (int, float)) else 0
            )
            
            if best_meta is None or current_score > best_score:
                best_score = current_score
                best_meta = data
                best_meta["_metadata_path"] = fpath
        except:
            continue
            
    return best_meta


def _parse_metadata_num_samples(entry: Dict[str, Any]) -> Optional[int]:
    evaluation_block = entry.get("evaluation") or {}
    value = evaluation_block.get("num_samples")
    
    # Fallback to top-level keys for older metadata or different formats
    if value is None:
        value = entry.get("num_samples")
    if value is None:
        value = entry.get("num_examples")

    if value is None:
        return None
    try:
        return int(value)
    except (TypeError, ValueError):
        try:
            return int(float(value))
        except (TypeError, ValueError):
            return None


def get_run_hyperparameters(dataset, run_name, s3_results_prefix, project_root):
    """Fetch log.jsonl and return full hyperparameter dict."""
    prefix = s3_results_prefix.rstrip("/")
    s3_log = f"{prefix}/{dataset}/{run_name}/log.jsonl"
    local_run_dir = os.path.join(project_root, "results", dataset, run_name)
    local_log = os.path.join(local_run_dir, "log.jsonl")
    
    os.makedirs(local_run_dir, exist_ok=True)
    try:
        subprocess.run(["aws", "s3", "cp", s3_log, local_log], check=True, capture_output=True)
        with open(local_log, 'r') as f:
            line = f.readline()
            data = json.loads(line)
            return data
    except Exception:
        return {}


def get_model_type_from_s3(dataset, run_name, s3_results_prefix, project_root):
    """Always fetch log.jsonl to determine model type. S3 is ground truth."""
    data = get_run_hyperparameters(dataset, run_name, s3_results_prefix, project_root)
    if "model_type" in data:
        return data["model_type"]
    raise FileNotFoundError(f"Could not find log.jsonl or model_type for run {run_name}")


def evaluate_adapter(dataset, run_name, adapter_name, project_root, args, s3_results_prefix, model_type, extra_args=None):
    """
    Evaluate a specific adapter.
    1. Download weights
    2. Run eval
    3. Upload metadata
    """
    local_adapter_dir = os.path.join(project_root, "results", dataset, run_name, adapter_name)
    
    # 1. Download weights
    if not download_adapter_weights(dataset, run_name, adapter_name, project_root, s3_results_prefix):
        return None

    # 2. Run eval
    print(f"Evaluating {dataset}/{run_name}/{adapter_name}...")
    eval_script = os.path.join(project_root, "src", "evaluation.py")
    
    cmd = [
        "python", eval_script,
        "--task_type", dataset,
        "--model_path", local_adapter_dir,
        "--model_type", model_type,
    ]
    
    if args.force_eval:
        cmd.append("--force_eval")
    
    if args.num_samples:
        cmd.extend(["--num_samples", str(args.num_samples)])
    if args.stride:
        cmd.extend(["--stride", str(args.stride)])
    if args.batch_size:
        cmd.extend(["--batch_size", str(args.batch_size)])
    
    if extra_args:
        cmd.extend(extra_args)
        
    try:
        subprocess.run(cmd, check=True, cwd=project_root)
    except subprocess.CalledProcessError as e:
        print(f"Evaluation failed: {e}")
        return None
        
    # 3. Sync metadata back
    upload_adapter_metadata(local_adapter_dir, project_root, args.s3_bucket)
    
    return load_local_metadata(local_adapter_dir)


def upload_best_adapter_file(run_dir, project_root, bucket=None):
    """Upload best_adapter.json for a run directory to S3."""
    bucket = bucket or DEFAULT_S3_BUCKET
    if not bucket:
        return

    best_path = os.path.join(run_dir, "best_adapter.json")
    if not os.path.exists(best_path):
        return

    s3_dest = _s3_uri_for_path(run_dir, project_root, bucket)
    include_args = [
        "--exclude", "*",
        "--include", "best_adapter.json",
    ]

    print(f"    Uploading best_adapter.json for {os.path.basename(run_dir)} to {s3_dest}...")
    try:
        subprocess.run(
            ["aws", "s3", "sync", run_dir, s3_dest, *include_args],
            check=True,
        )
    except Exception as e:
        print(f"    Warning: failed to upload best_adapter.json to {s3_dest}: {e}")


def generate_best_adapter_file(run_dir, best_meta, best_adapter_name):
    """Generate best_adapter.json file in the run directory."""
    if not best_meta or not run_dir:
        return False

    metadata_path = best_meta.get("_metadata_path")
    if metadata_path:
        # Make path relative to run directory
        try:
            rel_metadata_path = os.path.relpath(metadata_path, run_dir)
        except ValueError:
            rel_metadata_path = os.path.basename(metadata_path)
    else:
        rel_metadata_path = None

    # Extract useful fields
    stride = best_meta.get("evaluation", {}).get("stride", 1)
    num_examples = _parse_metadata_num_samples(best_meta)
    batch_index = best_meta.get("batch_index")
    
    # If batch_index missing in top-level, try to parse from adapter name
    if batch_index is None and best_adapter_name and "adapter_" in best_adapter_name:
        try:
            batch_index = int(best_adapter_name.split("_")[-1])
        except ValueError:
            pass

    best_adapter_data = {
        "adapter": best_adapter_name,
        "accuracy": best_meta.get("accuracy"),
        "model_path": best_meta.get("model_path"),
        "model_type": best_meta.get("model_type"),
        "task_type": best_meta.get("task_type"),
        "stride": stride,
        "num_examples": num_examples,
        "batch_index": batch_index,
        "metadata_file": rel_metadata_path,
        "metadata": best_meta,
        "generated_at": datetime.datetime.now().isoformat()
    }

    output_path = os.path.join(run_dir, "best_adapter.json")
    try:
        with open(output_path, "w") as f:
            json.dump(best_adapter_data, f, indent=2)
        print(f"    Generated best_adapter.json for {best_adapter_name} (acc: {best_meta.get('accuracy', 0):.2%})")
        return True
    except Exception as e:
        print(f"    Error writing best_adapter.json: {e}")
    return False


def main():
    parser = argparse.ArgumentParser(description="Compile sweep results from S3")
    parser.add_argument("--task_type", type=str, help="Limit to specific task")
    parser.add_argument("--method", type=str, help="Limit to specific method")
    parser.add_argument("--column", type=str, help="Alias for method")
    parser.add_argument("--s3_bucket", type=str, default=None, help="S3 bucket")
    parser.add_argument("--dry_run", action="store_true", help="Don't actually run evals")
    parser.add_argument("--force_eval", action="store_true", help="Re-run existing evals")
    parser.add_argument("--num_samples", type=int, help="Num samples for eval")
    parser.add_argument("--stride", type=int, default=1)
    parser.add_argument("--batch_size", type=int)
    parser.add_argument("--generate_best_adapter", action="store_true", help="Generate best_adapter.json for each run")
    
    args = parser.parse_args()
    
    method_filter = args.column or args.method
    bucket = args.s3_bucket or DEFAULT_S3_BUCKET
    s3_results_prefix = f"{bucket.rstrip('/')}/results"
    project_root = PROJECT_ROOT
    
    # Determine tasks
    if args.task_type:
        tasks = [args.task_type]
    else:
        tasks = sorted(set(SUPPORTED_TASKS + WIKI_TASKS))
        
    # Shuffle tasks to reduce contention between workers
    random.shuffle(tasks)
        
    # Store per-dataset, per-method summaries as:
    #   results_table[dataset][method] = {
    #       "accuracy": best_accuracy,
    #       "num_samples": effective_sample_count or None,
    #   }
    results_table: Dict[str, Dict[str, Dict[str, Any]]] = defaultdict(dict)

    def update_results_table(
        dataset: str,
        key: str,
        acc: float,
        meta: Optional[Dict[str, Any]],
        is_wiki_task: bool,
    ):
        """
        Update results_table[dataset][key] if this accuracy is better.
        
        For QA tasks we also retain num_samples so we can print binomial CIs.
        For wiki tasks we keep only the point estimate (log-probability units),
        since accuracy is a mean log-prob rather than a Bernoulli rate.
        """
        if meta is None:
            num_samples = None
        else:
            num_samples = None if is_wiki_task else _parse_metadata_num_samples(meta)

        existing = results_table[dataset].get(key)
        if existing is None or acc > existing.get("accuracy", -float("inf")):
            results_table[dataset][key] = {
                "accuracy": float(acc),
                "num_samples": num_samples,
            }
    
    for task in tasks:
        print(f"\nScanning {task}...")
        is_wiki_task = task in WIKI_TASKS
        desired_task_samples = args.num_samples
        if desired_task_samples is None:
            if is_wiki_task:
                desired_task_samples = DEFAULT_WIKI_NUM_SAMPLES
            else:
                desired_task_samples = TASK_TEST_SET_SIZES.get(task)
        
        runs = list_s3_runs(task, s3_results_prefix, method_filter)
        
        # Also shuffle runs
        random.shuffle(runs)
        
        for dataset, method, run_name in runs:
            print(f"  Checking run: {run_name} ({method})")
            
            if method == "baseline":
                # For baseline, the run_name IS the adapter (kind of)
                # Check if baseline ITSELF has metadata
                has_meta, _ = check_s3_metadata(dataset, run_name, "", s3_results_prefix)
                
                meta = None
                if has_meta and not args.force_eval:
                     download_adapter_metadata(dataset, run_name, "", project_root, s3_results_prefix)
                     local_path = os.path.join(project_root, "results", dataset, run_name)
                     meta = load_local_metadata(local_path)

                     samples_msg = ""
                     if meta:
                         ns = _parse_metadata_num_samples(meta)
                         if ns is not None:
                             samples_msg = f", samples={ns}"
                     print(f"    Found metadata in {run_name} root{samples_msg}")
                elif not args.dry_run:
                    print(f"    Evaluating baseline (metadata missing)")
                    
                    # Find sibling run to get hyperparameters
                    sibling_params = {}
                    for _, other_method, other_run in runs:
                        if other_method != "baseline":
                            sibling_params = get_run_hyperparameters(dataset, other_run, s3_results_prefix, project_root)
                            if sibling_params:
                                print(f"    Using hyperparameters from sibling: {other_run}")
                                break
                    
                    # Determine model type
                    if "qwen" in run_name.lower():
                        model_type = "qwen3-14b" # Default qwen variant
                    elif "llama" in run_name.lower():
                        model_type = "llama"
                    else:
                        model_type = sibling_params.get("model_type", "llama")

                    # Build extra args
                    extra_args = ["--use_base_model"]
                    if "cot_length" in sibling_params:
                        extra_args.extend(["--cot_length", str(sibling_params["cot_length"])])
                    if "temperature" in sibling_params:
                        extra_args.extend(["--temperature", str(sibling_params["temperature"])])
                    
                    # Ensure we write metadata to the baseline directory
                    local_baseline_dir = os.path.join(project_root, "results", dataset, run_name)
                    extra_args.extend(["--adapter_metadata_dir", local_baseline_dir])

                    meta = evaluate_adapter(dataset, run_name, "", project_root, args, s3_results_prefix, model_type, extra_args=extra_args)

                if meta:
                    acc = meta.get("accuracy", 0)
                    update_results_table(
                        dataset=dataset,
                        key=run_name,
                        acc=acc,
                        meta=meta,
                        is_wiki_task=is_wiki_task,
                    )
                    
                    if args.generate_best_adapter:
                        local_run_path = os.path.join(project_root, "results", dataset, run_name)
                        generated = generate_best_adapter_file(local_run_path, meta, "baseline")
                        if generated and not args.dry_run:
                            upload_best_adapter_file(local_run_path, project_root, args.s3_bucket)
                continue

            adapters = list_s3_adapters(dataset, run_name, s3_results_prefix)
            random.shuffle(adapters)
            if not adapters:
                continue
                
            try:
                model_type = get_model_type_from_s3(dataset, run_name, s3_results_prefix, project_root)
            except FileNotFoundError:
                print(f"    Skipping {run_name} (missing log.jsonl)")
                continue
            
            best_acc = -float('inf')
            best_meta = None
            best_adapter_name = None
            
            for adapter in adapters:
                # Check if done
                has_meta, _ = check_s3_metadata(dataset, run_name, adapter, s3_results_prefix)
                
                meta = None
                # Check for existing metadata
                if has_meta and not args.force_eval:
                    # Download metadata only
                    download_adapter_metadata(dataset, run_name, adapter, project_root, s3_results_prefix)
                    local_path = os.path.join(project_root, "results", dataset, run_name, adapter)
                    meta = load_local_metadata(local_path)

                    samples_msg = ""
                    if meta:
                        ns = _parse_metadata_num_samples(meta)
                        if ns is not None:
                            samples_msg = f", samples={ns}"
                    print(f"    Skipping {adapter} (metadata found on S3{samples_msg})")
                    
                    # If metadata has low sample count and we want more, force re-eval
                    if meta and desired_task_samples:
                        current_samples = _parse_metadata_num_samples(meta)
                        if current_samples is not None and current_samples < desired_task_samples:
                            print(f"    Re-evaluating {adapter} (low sample count: {current_samples} < {desired_task_samples})")
                            # Force eval by setting meta to None so we fall through to the eval block
                            meta = None 
                            
                if meta is None and not args.dry_run:
                    # Needs eval
                    print(f"    Evaluating {adapter} (metadata missing or insufficient samples)")
                    meta = evaluate_adapter(dataset, run_name, adapter, project_root, args, s3_results_prefix, model_type)
                
                if meta:
                    acc = meta.get("accuracy", 0)
                    num_samples = meta.get("evaluation", {}).get("num_samples", 0)
                    if num_samples and num_samples < 100:
                         print(f"    Warning: Low sample count ({num_samples}) for {run_name}/{adapter}")

                    if acc > best_acc:
                        best_acc = acc
                        best_meta = meta
                        best_adapter_name = adapter
                        
            # Record score for this dataset/method
            if best_meta is not None and best_acc > -float("inf"):
                update_results_table(
                    dataset=dataset,
                    key=method,
                    acc=best_acc,
                    meta=best_meta,
                    is_wiki_task=is_wiki_task,
                )
            
            # Generate best adapter file if requested
            if args.generate_best_adapter and best_meta:
                local_run_path = os.path.join(project_root, "results", dataset, run_name)
                generated = generate_best_adapter_file(local_run_path, best_meta, best_adapter_name)
                if generated and not args.dry_run:
                    upload_best_adapter_file(local_run_path, project_root, args.s3_bucket)

    # Print tables
    print("\n" + "="*50)
    print("Results Table")
    print("="*50)
    # Build a human-readable table with confidence intervals where possible.
    pretty_table: Dict[str, Dict[str, str]] = {}
    for dataset, methods in results_table.items():
        pretty_table[dataset] = {}
        for method, stats in methods.items():
            acc = stats.get("accuracy", float("nan"))
            n = stats.get("num_samples")
            ci = _binomial_ci(acc, n)
            if ci is not None:
                lo, hi = ci
                pretty_table[dataset][method] = f"{acc:.3f} [{lo:.3f}, {hi:.3f}]"
            else:
                pretty_table[dataset][method] = f"{acc:.3f}"

    df = pd.DataFrame.from_dict(pretty_table, orient='index')
    if not df.empty:
        print(df.to_markdown())
        df.to_csv("sweep_results_table.csv")

if __name__ == "__main__":
    main()
