import signal
import time
from typing import Optional, Tuple

import torch

from op_eval.config import (
    FIRST_RUN_TIMEOUT_S,
    num_correct_trials,
    seed_num,
)


def set_seed(seed: int):
    torch.manual_seed(seed)
    # NOTE: this only sets on current cuda device
    torch.cuda.manual_seed(seed)


class CorrectnessTimeoutError(Exception):
    """Raised when correctness run times out."""
    pass


def _run_with_timeout(fn, timeout_s: int):
    """
    Run fn() with a timeout using SIGALRM.
    
    Args:
        fn: Callable to execute
        timeout_s: Timeout in seconds
        
    Returns:
        Result of fn()
        
    Raises:
        CorrectnessTimeoutError: If the operation times out
    """
    def handler(signum, frame):
        raise CorrectnessTimeoutError(f"Operation timed out after {timeout_s}s")
    
    old_handler = signal.signal(signal.SIGALRM, handler)
    signal.alarm(timeout_s)
    try:
        return fn()
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)

    
def execute_template(
    synchronize,
    device,
    context,
    op_name: str = None,
) -> Tuple[bool, str, Optional[float]]:
    """
    Execute correctness check with fail-fast and VRAM optimization.
    
    Strategy (2-phase, custom-first for fail-fast):
    - Phase 1: Run custom model for all trials, store outputs (timeout on first run)
    - Phase 2: Run reference model for all trials, compare with stored outputs
    
    This avoids wasting time on reference if custom fails/times out.
    Only ONE model is in device memory at a time.
    """
    import sys
    print(f"[DEBUG] execute_template: Starting for op {op_name}", file=sys.stderr, flush=True)
    get_inputs = context['get_inputs']
    get_init_inputs = context['get_init_inputs']
    Model = context['Model']
    ModelNew = context['ModelNew']
    
    # Helper to move tensors to CPU
    def to_cpu(tensor_or_list):
        if isinstance(tensor_or_list, torch.Tensor):
            return tensor_or_list.detach().cpu()
        elif isinstance(tensor_or_list, (list, tuple)):
            return type(tensor_or_list)(to_cpu(x) for x in tensor_or_list)
        return tensor_or_list

    # Helper to move tensors to device (mirrors to_cpu).
    def to_device(tensor_or_list):
        if isinstance(tensor_or_list, torch.Tensor):
            return tensor_or_list.to(device)
        elif isinstance(tensor_or_list, (list, tuple)):
            return type(tensor_or_list)(to_device(x) for x in tensor_or_list)
        return tensor_or_list

    def _fmt_number(value):
        if isinstance(value, complex):
            return f"{value.real:.6g}{'+' if value.imag >= 0 else ''}{value.imag:.6g}j"
        try:
            return f"{float(value):.6g}"
        except (TypeError, ValueError):
            return str(value)

    def _tensor_mismatch_detail(ref_output, custom_output, path, *, atol, rtol):
        if ref_output.shape != custom_output.shape:
            return f"{path}.shape mismatch: expected {tuple(ref_output.shape)}, got {tuple(custom_output.shape)}"

        total = int(ref_output.numel())
        if total == 0:
            return None

        ref_is_float = torch.is_floating_point(ref_output) or torch.is_complex(ref_output)
        custom_is_float = torch.is_floating_point(custom_output) or torch.is_complex(custom_output)
        if ref_is_float or custom_is_float:
            is_close = torch.isclose(ref_output, custom_output, atol=atol, rtol=rtol)
        else:
            is_close = ref_output == custom_output

        mismatch_mask = ~is_close
        mismatch_count = int(mismatch_mask.sum().item())
        if mismatch_count == 0:
            return None

        if ref_is_float or custom_is_float:
            ref_float = ref_output
            custom_float = custom_output
        else:
            ref_float = ref_output.to(torch.float32)
            custom_float = custom_output.to(torch.float32)

        abs_diff = (ref_float - custom_float).abs()
        if torch.is_complex(abs_diff):
            abs_diff = abs_diff.abs()
        rel_diff = abs_diff / (ref_float.abs() + 1e-12)

        mismatch_abs = abs_diff[mismatch_mask]
        mismatch_rel = rel_diff[mismatch_mask]
        max_abs = mismatch_abs.max().item()
        max_rel = mismatch_rel.max().item()

        bbox = None
        if ref_output.dim() == 0:
            bbox = f"Bounding box of error elements: {path}[()]"
        else:
            mismatch_indices = torch.nonzero(mismatch_mask, as_tuple=False)
            if mismatch_indices.numel() > 0:
                mins = mismatch_indices.min(dim=0).values
                maxs = mismatch_indices.max(dim=0).values
                slices = [
                    f"{int(mins[i].item())}:{int(maxs[i].item()) + 1}"
                    for i in range(mismatch_indices.size(1))
                ]
                bbox = f"Bounding box of error elements: {path}[" + ", ".join(slices) + "]"

        pct = (mismatch_count / total) * 100.0 if total else 0.0
        detail = (
            f"{path}: {mismatch_count}/{total} elements mismatched ({pct:.2f}%), "
            f"max_abs={_fmt_number(max_abs)}, max_rel={_fmt_number(max_rel)}"
        )
        if bbox:
            detail += f", {bbox}"
        return detail

    def _compare_outputs(ref_output, custom_output, path, *, atol, rtol, meta):
        if isinstance(ref_output, torch.Tensor) and isinstance(custom_output, torch.Tensor):
            if meta is not None and meta.get("shape") is None:
                meta["shape"] = tuple(ref_output.shape)
                meta["numel"] = int(ref_output.numel())
                meta["dtype"] = str(ref_output.dtype)
            detail = _tensor_mismatch_detail(
                ref_output,
                custom_output,
                path,
                atol=atol,
                rtol=rtol,
            )
            if detail:
                return False, [detail]
            return True, []
        if isinstance(ref_output, (list, tuple)) and isinstance(custom_output, (list, tuple)):
            if len(ref_output) != len(custom_output):
                return False, [f"len({path}) mismatch: expected {len(ref_output)}, got {len(custom_output)}"]
            all_details = []
            ok = True
            for idx, (ref_item, custom_item) in enumerate(zip(ref_output, custom_output)):
                item_ok, item_details = _compare_outputs(
                    ref_item,
                    custom_item,
                    f"{path}[{idx}]",
                    atol=atol,
                    rtol=rtol,
                    meta=meta,
                )
                if not item_ok:
                    ok = False
                    all_details.extend(item_details)
            return ok, all_details

        if type(ref_output) != type(custom_output):
            return False, [f"type({path}) mismatch: expected {type(ref_output).__name__}, got {type(custom_output).__name__}"]
        if ref_output == custom_output:
            return True, []
        return False, [f"{path} value mismatch: expected {ref_output}, got {custom_output}"]
    
    try:
        init_inputs = get_init_inputs()
        init_inputs = [
            x.to(device=device) if isinstance(x, torch.Tensor) else x for x in init_inputs
        ]
        first_run_time_s = None
        
        # ========== PHASE 1: Custom Model First (Fail-Fast) ==========
        # Run custom model for all trials, store inputs and outputs
        custom_data = []  # List of (inputs_cpu, custom_output_cpu)
        
        print(f"[DEBUG] execute_template: Phase 1 - Running custom model...", file=sys.stderr, flush=True)
        
        with torch.no_grad():
            set_seed(seed_num)
            custom_model = ModelNew(*init_inputs).to(device)
            synchronize(device=device)
            
            for trial in range(num_correct_trials):
                inputs = get_inputs()
                inputs = [x.to(device) if isinstance(x, torch.Tensor) else x for x in inputs]
                synchronize(device=device)
                
                if trial == 0:
                    # First run: apply timeout and measure time
                    start_time = time.monotonic()
                    first_run_timeout_s = FIRST_RUN_TIMEOUT_S
                    try:
                        def run_custom():
                            print(f"[DEBUG] execute_template: Running custom_model (Trial 0)...", file=sys.stderr, flush=True)
                            result = custom_model(*inputs)
                            synchronize(device=device)
                            print(f"[DEBUG] execute_template: custom_model done.", file=sys.stderr, flush=True)
                            return result
                        
                        custom_output = _run_with_timeout(run_custom, first_run_timeout_s)
                    except CorrectnessTimeoutError:
                        del custom_model
                        return (
                            False,
                            (
                                f"[FAIL] First correctness run timed out after {first_run_timeout_s}s"
                            ),
                            None,
                        )
                    
                    first_run_time_s = time.monotonic() - start_time

                    inputs_cpu = to_cpu(inputs)
                    custom_output_cpu = to_cpu(custom_output)
                    custom_data.append((inputs_cpu, custom_output_cpu))
                    del inputs, custom_output
                    
                else:
                    # Subsequent runs: no per-run timeout
                    custom_output = custom_model(*inputs)
                    synchronize(device=device)
                    custom_data.append((to_cpu(inputs), to_cpu(custom_output)))
                    del inputs, custom_output
            
            del custom_model
        
        print(f"[DEBUG] execute_template: Phase 1 complete, {len(custom_data)} custom outputs stored", file=sys.stderr, flush=True)
        
        # ========== PHASE 2: Reference Model and Compare ==========
        correctness = True
        correctness_information = ''
        failed_trials = []
        detail_lines = []
        atol = 1e-02
        rtol = 1e-02
        max_failed_trial_details = 5
        output_meta = {"shape": None, "numel": None, "dtype": None}
        
        print(f"[DEBUG] execute_template: Phase 2 - Running reference model...", file=sys.stderr, flush=True)
        
        with torch.no_grad():
            set_seed(seed_num)
            original_model = Model(*init_inputs).to(device)
            synchronize(device=device)
            
            for trial in range(num_correct_trials):
                inputs_cpu, custom_output_cpu = custom_data[trial]
                
                # Run reference with same inputs
                inputs = to_device(inputs_cpu)
                ref_output = original_model(*inputs)
                synchronize(device=device)
                
                # Compare
                custom_output = to_device(custom_output_cpu)
                
                trial_ok, trial_details = _compare_outputs(
                    ref_output,
                    custom_output,
                    "output",
                    atol=atol,
                    rtol=rtol,
                    meta=output_meta,
                )
                
                del inputs, ref_output, custom_output
                
                if not trial_ok:
                    correctness = False
                    failed_trials.append(trial)
                    if len(detail_lines) < max_failed_trial_details:
                        joined_details = "; ".join(trial_details) if trial_details else "output mismatch"
                        detail_lines.append(f"Trial {trial}: {joined_details}")
            
            del original_model
        
        if not correctness:
            num_failed = len(failed_trials)
            num_passed = num_correct_trials - num_failed
            summary = (
                f"[FAIL] Output mismatch: {num_passed}/{num_correct_trials} trials passed, "
                f"{num_failed} failed. "
            )
            if failed_trials:
                summary += f"Failing trials: {', '.join(str(t) for t in failed_trials)}. "
            summary += f"Tolerance atol={atol:g}, rtol={rtol:g}."
            if len(failed_trials) > max_failed_trial_details:
                summary += f" Showing first {max_failed_trial_details} failing trials."
            shape = output_meta.get("shape")
            numel = output_meta.get("numel")
            dtype = output_meta.get("dtype")
            shape_text = shape if shape is not None else "n/a"
            numel_text = numel if numel is not None else "n/a"
            dtype_text = dtype if dtype is not None else "n/a"
            correctness_information = (
                summary
                + f"\nOutput: output.shape={shape_text}, output.dtype={dtype_text}, matches reference ground truth!"
                + f"\nTotal Elements: {numel_text}, matches reference ground truth!"
            )
            if detail_lines:
                correctness_information += "\n" + "\n".join(detail_lines[:max_failed_trial_details])

        print(f"[DEBUG] execute_template: Phase 2 complete. Correct={correctness}", file=sys.stderr, flush=True)
        return correctness, correctness_information, first_run_time_s
                    
    except CorrectnessTimeoutError:
        raise
    except Exception as e:
        prefix = f"[{op_name}]" if op_name else "[FAIL]"
        print(f'{prefix} runtime error when evaluating correctness: {e}', file=sys.stderr)
        return False, f"[FAIL] {str(e)}", None
