import csv
import hashlib
import os
import random
import shutil
import sys
import tempfile
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple

import numpy as np
import torch

from op_eval.config import num_perf_trials, seed_num

# On parsing op summary generated with CANN 8.5.0.alpha002, see https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850alpha002/devaids/Profiling/atlasprofiling_16_0067.html

# On op profiling overview with CANN 8.5.0.alpha002, see https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850alpha002/devaids/Profiling/atlasprofiling_16_0057.html

def _bool_env(name: str, default: bool = False) -> bool:
    raw = os.environ.get(name)
    if raw is None:
        return default
    return raw.strip().lower() in ("1", "true", "yes", "on")


def _log(message: str) -> None:
    print(f"[op_eval.perf] {message}", file=sys.stderr)


def perf_seed_for_op(op_name: str) -> int:
    seed_salt = os.environ.get("OP_EVAL_PERF_SEED_SALT", "op_eval_perf_seed")
    seed_base = int(os.environ.get("OP_EVAL_PERF_SEED_BASE", str(seed_num)))
    digest = hashlib.sha256(f"{seed_salt}:{op_name}".encode("utf-8")).digest()
    hashed = int.from_bytes(digest[:8], "little", signed=False)
    return (seed_base + hashed) % (2**31 - 1)


def set_perf_seed(op_name: str | None) -> None:
    if not op_name:
        return
    seed = perf_seed_for_op(op_name)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    npu_mod = getattr(torch, "npu", None)
    if npu_mod is not None and hasattr(npu_mod, "manual_seed_all"):
        try:
            npu_mod.manual_seed_all(seed)
        except Exception:
            pass
    xpu_mod = getattr(torch, "xpu", None)
    if xpu_mod is not None and hasattr(xpu_mod, "manual_seed_all"):
        try:
            xpu_mod.manual_seed_all(seed)
        except Exception:
            pass


def _parse_float(raw: Optional[str]) -> Optional[float]:
    """Parse a string to float, handling edge cases."""
    if raw is None:
        return None
    try:
        s = str(raw).strip()
        if not s or s.lower() in ("nan", "inf", "-inf", "n/a", "na", ""):
            return None
        return float(s)
    except (ValueError, TypeError):
        return None


def _read_csv_rows(path: str) -> List[Dict[str, str]]:
    """Read CSV file and return list of row dicts."""
    try:
        with open(path, "r", newline="", encoding="utf-8") as f:
            reader = csv.DictReader(f)
            return list(reader)
    except Exception:
        return []


def _round3(value: Optional[float]) -> Optional[float]:
    """Round to 3 decimal places, return None if input is None."""
    if value is None:
        return None
    return round(value, 3)


def _us_to_ms(value: Optional[float]) -> Optional[float]:
    """Convert microseconds to milliseconds with 3 decimal precision."""
    if value is None:
        return None
    return round(value / 1000.0, 3)


def _parse_step_trace_time_csv(csv_path: str) -> Optional[float]:
    """
    Parse step_trace_time.csv to extract the Computing field (µs).
    
    The Computing field represents actual device execution time,
    which is more accurate than CPU-side event timing.
    
    Returns:
        Computing time in µs, or None if parsing fails.
    """
    try:
        rows = _read_csv_rows(csv_path)
        if not rows:
            return None
        
        # step_trace_time.csv should have only one data row
        computing_val = rows[0].get("Computing")
        return _parse_float(computing_val)
    except Exception as e:
        _log(f"Failed to parse step_trace_time: {e}")
        return None


def _parse_kernel_details_csv(
    csv_path: str,
    op_name: str,
) -> Optional[Dict[str, Any]]:
    """
    Parse kernel_details.csv from torch_npu.profiler ASCEND_PROFILER_OUTPUT.
    
    Parses all kernel rows and returns metrics keyed by kernel name.
    Duplicate kernel names get numbered suffixes.
    
    Args:
        csv_path: Path to kernel_details.csv
        op_name: Operator name (for reference, not used in parsing)
    
    Returns:
        Dict mapping kernel names to their metrics:
        e.g., {"aclnnMuls_MulAiCore_Mul": {...}, "aclnnTanh_Tanh_Tanh": {...}}
        Duplicates: {"kernel": {...}, "kernel_2": {...}}
    """
    # Columns that are metadata, not metrics (skip these)
    METADATA_COLS = {
        'Device_id', 'Model ID', 'Task ID', 'Stream ID', 'Name', 'Type',
        'OP State', 'Accelerator Core', 'Start Time(us)', 'Wait Time(us)',
        'Mix Block Dim', 'HF32 Eligible', 'Input Shapes', 'Input Data Types',
        'Input Formats', 'Output Shapes', 'Output Data Types', 'Output Formats',
        'Context ID', 'Submission Time(us)'
    }
    
    def _parse_row_metrics(row: Dict[str, str]) -> Dict[str, float]:
        """Parse all numeric columns from a single row."""
        metrics = {}
        for col, val in row.items():
            if col in METADATA_COLS or not val:
                continue
            parsed = _parse_float(val)
            if parsed is not None:
                metrics[col] = parsed
        return metrics
    
    try:
        rows = _read_csv_rows(csv_path)
        if not rows:
            return None
        
        # Parse all rows, keyed by kernel Name
        result = {}
        name_counts = {}  # Track duplicate names
        
        for row in rows:
            name = row.get('Name', 'unknown')
            metrics = _parse_row_metrics(row)
            if not metrics:
                continue
            
            # Handle duplicate names
            if name in name_counts:
                name_counts[name] += 1
                key = f"{name}_{name_counts[name]}"
            else:
                name_counts[name] = 1
                key = name
            
            result[key] = metrics
        
        return result if result else None
        
    except Exception as e:
        _log(f"Failed to parse {csv_path}: {e}")
        return None


def _run_single_pass_profiler(
    run_fn: Callable,
    device,
    prof_dir: str,
    op_name: str,
    aic_metrics_enum,
) -> Optional[Dict[str, Any]]:
    """
    Run a single profiling pass with the given AiCMetrics config.
    
    Args:
        run_fn: Function to profile
        device: Target device
        prof_dir: Directory to store profiling output
        op_name: Operator name
        aic_metrics_enum: AiCMetrics enum value
    
    Returns dict with:
        - 'kernel_details': parsed kernel_details.csv data (all kernels)
        - 'computing_time_us': Computing field from step_trace_time.csv (µs)
    """
    import torch_npu
    from torch_npu.profiler import (
        profile,
        ProfilerActivity,
        tensorboard_trace_handler,
        _ExperimentalConfig,
        ProfilerLevel,
    )
    
    try:
        experimental_config = _ExperimentalConfig(
            profiler_level=ProfilerLevel.Level2,
            aic_metrics=aic_metrics_enum,
        )
        
        # Run profiler
        with torch.no_grad():
            with profile(
                activities=[ProfilerActivity.NPU],
                on_trace_ready=tensorboard_trace_handler(prof_dir),
                experimental_config=experimental_config,
                record_shapes=True,
            ) as prof:
                run_fn()
                torch_npu.npu.synchronize(device=device)
        
        # Parse profiling output
        # torch_npu.profiler creates: *_ascend_pt/ASCEND_PROFILER_OUTPUT/
        result = {}
        prof_path = Path(prof_dir)
        ascend_pt_dirs = list(prof_path.glob("*_ascend_pt"))
        for apt_dir in ascend_pt_dirs:
            ascend_output = apt_dir / "ASCEND_PROFILER_OUTPUT"
            if ascend_output.exists():
                # Parse kernel_details.csv
                kernel_details = ascend_output / "kernel_details.csv"
                if kernel_details.exists():
                    result['kernel_details'] = _parse_kernel_details_csv(
                        str(kernel_details), op_name
                    )
                
                # Parse step_trace_time.csv for accurate Computing time
                step_trace = ascend_output / "step_trace_time.csv"
                if step_trace.exists():
                    result['computing_time_us'] = _parse_step_trace_time_csv(str(step_trace))
                
                return result if result else None
        
        return None
        
    except Exception as e:
        _log(f"Single pass profiling failed: {e}")
        return None


def profile_single_pass(
    run_fn: Callable,
    device,
    op_name: str,
    pass_type: str = 'pipe',
    output_root: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
    """
    Run a single profiling pass (for partial profiling of slow operators).
    
    Args:
        run_fn: Function to profile
        device: Target device
        op_name: Operator name
        pass_type: 'pipe' for PipeUtilization, 'memory' for Memory, 'resource' for ResourceConflictRatio
        output_root: Optional output directory
        
    Returns:
        Dict with 'performance' and 'profiling' keys, or None on failure
    """
    try:
        import torch_npu
        from torch_npu.profiler import AiCMetrics
    except ImportError:
        _log("torch_npu not available for profiling")
        return None
    
    # Map pass_type to AiCMetrics enum
    pass_map = {
        'pipe': (AiCMetrics.PipeUtilization, 'pass_pipe'),
        'memory': (AiCMetrics.MemoryL0, 'pass_memory'),
        'resource': (AiCMetrics.ResourceConflictRatio, 'pass_resource'),
    }
    
    if pass_type not in pass_map:
        _log(f"Unknown pass_type: {pass_type}")
        return None
    
    metrics_enum, pass_name = pass_map[pass_type]
    output_root = output_root or os.environ.get("OP_EVAL_MSPROF_OUTPUT_ROOT")
    output_root = output_root or os.environ.get("ASCEND_OP_ROOT")
    output_root = output_root or tempfile.gettempdir()
    
    pass_dir = str(Path(output_root) / pass_name)
    
    # Run profiling
    pass_result = _run_single_pass_profiler(
        run_fn, device, pass_dir, op_name, metrics_enum
    )
    
    result = {
        'performance': None,
        'profiling': None,
    }
    if not pass_result:
        return None
    
    # Extract performance from step_trace_time
    if pass_result.get('computing_time_us'):
        time_ms = pass_result['computing_time_us'] / 1000.0
        result['performance'] = {
            'mean': _round3(time_ms),
            'max': _round3(time_ms),
            'min': _round3(time_ms),
            'std': 0.0,
            'note': 'Single pass profiling (partial data)',
        }
    
    # Extract profiling metrics
    if pass_result.get('kernel_details'):
        result['profiling'] = pass_result['kernel_details']
    
    return result


def profile_execution(
    run_fn: Callable,
    device,
    op_name: str,
    output_root: Optional[str] = None,
) -> Optional[Dict[str, Any]]:
    """
    Profile operator execution using torch_npu.profiler with 3-pass approach.
    
    Pass 1: PipeUtilization (compute ratios)
    Pass 2: Memory (bandwidth metrics)
    Pass 3: ResourceConflictRatio (pipeline stall metrics)
    
    Args:
        run_fn: Function to profile
        device: Target device
        op_name: Operator name
        output_root: Directory for profiling output
    
    Returns:
        Dict with:
        - 'performance': {mean, max, min, std} in ms from step_trace_time Computing field
        - 'profiling': dict keyed by kernel name with metrics
    """
    
    # Check if torch_npu.profiler is available
    try:
        from torch_npu.profiler import AiCMetrics
    except ImportError:
        _log("torch_npu.profiler.AiCMetrics not available; skipping profiling")
        return None
    
    # Setup output directory
    if not output_root:
        output_root = os.environ.get("OP_EVAL_MSPROF_OUTPUT_ROOT")
    
    if not output_root:
        base = os.environ.get("ASCEND_OP_ROOT") or tempfile.gettempdir()
        output_root = tempfile.mkdtemp(prefix="opprof_", dir=base)
    else:
        os.makedirs(output_root, exist_ok=True)
    
    prof_dir = os.path.join(output_root, f"prof_{op_name}_{int(time.time())}")
    os.makedirs(prof_dir, exist_ok=True)
    
    # Collect data from all passes
    all_pass_data = []
    computing_times_us = []
    passes = [
        ("pass_pipe", AiCMetrics.PipeUtilization, "PipeUtilization"),
        ("pass_memory", AiCMetrics.Memory, "Memory"),
        ("pass_conflict", AiCMetrics.ResourceConflictRatio, "ResourceConflictRatio"),
    ]
    
    try:
        for pass_dir_name, metrics_enum, desc in passes:
            _log(f"Running profiling pass: {desc}")
            pass_dir = os.path.join(prof_dir, pass_dir_name)
            os.makedirs(pass_dir, exist_ok=True)

            pass_result = _run_single_pass_profiler(
                run_fn, device, pass_dir, op_name, metrics_enum,
            )
            
            if pass_result:
                # Collect kernel_details data
                if pass_result.get('kernel_details'):
                    all_pass_data.append(pass_result['kernel_details'])
                
                # Collect step_trace_time Computing values
                ct = pass_result.get('computing_time_us')
                if ct is not None:
                    computing_times_us.append(ct)
        
        if not all_pass_data and not computing_times_us:
            _log("No profiling data collected from any pass")
            return {
                'performance': None,
                'profiling': None,
            }
        
        # Aggregate performance from step_trace_time (µs -> ms)
        performance = None
        if computing_times_us:
            times_ms = [t / 1000.0 for t in computing_times_us]
            performance = {
                'mean': _round3(np.mean(times_ms)),
                'max': _round3(np.max(times_ms)),
                'min': _round3(np.min(times_ms)),
                'std': _round3(np.std(times_ms)),
            }
        
        # Aggregate metrics from kernel_details across passes
        merged_profiling = {}
        
        if all_pass_data:
            # Both modes now return {kernel_name: {metrics}}
            # Group by kernel name, then aggregate metrics across passes
            kernel_metrics: Dict[str, Dict[str, List[float]]] = {}
            
            for pass_data in all_pass_data:
                if not pass_data:
                    continue
                for kernel_name, metrics in pass_data.items():
                    if kernel_name not in kernel_metrics:
                        kernel_metrics[kernel_name] = {}
                    for metric_name, val in metrics.items():
                        if val is not None:
                            if metric_name not in kernel_metrics[kernel_name]:
                                kernel_metrics[kernel_name][metric_name] = []
                            kernel_metrics[kernel_name][metric_name].append(val)
            
            # Average each kernel's metrics across passes
            for kernel_name, metrics_dict in kernel_metrics.items():
                merged_profiling[kernel_name] = {}
                for metric_name, values in metrics_dict.items():
                    final_val = values[0] if len(values) == 1 else np.mean(values)
                    output_name = metric_name
                    # Convert time fields from µs to ms and rename field
                    if '(us)' in metric_name:
                        final_val = final_val / 1000.0
                        output_name = metric_name.replace('(us)', '(ms)')
                    merged_profiling[kernel_name][output_name] = _round3(final_val)
        
        return {
            'performance': performance,
            'profiling': merged_profiling if merged_profiling else None,
        }
        
    except Exception as e:
        _log(f"Profiling failed: {e}")
        return None



def time_execution_event_template(
    context,
    device,
    synchronize,
    event_class,
    eval_target,
) -> Tuple[List[float], Optional[Dict[str, Any]]]:
    """
    Time execution with profiler-based timing.
    
    Performance timing is derived from the profiler's step_trace_time.csv
    Computing field, which provides accurate device-side execution time
    without CPU→NPU dispatch overhead.
    
    Returns:
        (elapsed_times_ms, profiling_result) tuple where:
        - elapsed_times_ms: list of timing samples in ms (from profiler passes)
        - profiling_result: dict with 'performance' and 'profiling' keys
    """
    get_inputs = context['get_inputs']
    get_init_inputs = context['get_init_inputs']
    ModelNew = context[eval_target]
    op_name = context.get("_op_eval_op_name")
    set_perf_seed(op_name)
    
    inputs = get_inputs()
    inputs = [
        x.to(device) if isinstance(x, torch.Tensor) else x
        for x in inputs
    ]
    init_inputs = get_init_inputs()
    init_inputs = [
        x.to(device=device) if isinstance(x, torch.Tensor) else x for x in init_inputs
    ]
    
    profiling_result = None
    elapsed_times_ms = []
    
    with torch.no_grad():
        custom_model = ModelNew(*init_inputs).to(device)
        
        def run_once():
            custom_model(*inputs)
        
        # Run profiling - this now provides both performance timing AND metrics
        # Performance comes from step_trace_time.csv Computing field (3 samples)
        if op_name:
            profiling_result = profile_execution(run_once, device, op_name)
            
            # Extract timing samples from profiler result if available
            if profiling_result and profiling_result.get('performance'):
                perf = profiling_result['performance']
                # The profiler runs 3 passes, each contributing a timing sample
                # We return the mean as the primary elapsed time for compatibility
                if perf.get('mean') is not None:
                    elapsed_times_ms = [perf['mean']]
    
    return elapsed_times_ms, profiling_result
