import os, json, sys, io, types
import coverage
import csv
import gc
import psutil
from datetime import datetime
from evaluation_test_case_pass_k import (
    canonical_solution_load_for_reward,
    safe_eval,
)
from reward_contract import collect_assert_ids_from_source,build_assert_predicates,_run_case_trace,_metrics,compute_score
from inference_parsing import extract_simplified_testcases
from .data_load import load_section

from pathlib import Path
import inspect, tempfile, os
import signal, inspect
import pdb
import argparse

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from TG_CG_main import _build_contract_str

USE_SECTIONS = set(load_section())        # 기능성에 쓸 5개 시나리오

_error_prefixes = (
    "AssertionError:",
    "Exception:",
    "TypeError:",
    "Timeout",
    "InvalidInput",
)

def _is_error(res):
    return isinstance(res, str) and res.startswith(_error_prefixes)

class TimeoutError(Exception): pass


def summarize_metrics(metrics_list, weights=None):
    """
    metrics_list : [{'line':0.8,'branch':0.7,'correct':1}, …]  same dict collection
    weights      : {'line':0.5,'branch':0.5,'correct':0.7}     (if not, all 1)

    return example :
      {
        'avg_line'            : 0.76,
        'avg_branch'          : 0.68,
        'avg_correct'         : 0.92,
        'avg_weighted_total'  : 0.73
      }
    """
    if not metrics_list:
        return {}

    weights = weights or {}
    keys = set().union(*metrics_list)            # all metric keys
    n = len(metrics_list)

    # -- (1) each metric average --
    avg = {k: sum(d.get(k, 0) for d in metrics_list) / n for k in keys}
    stats = {f"avg_{k}": v for k, v in avg.items()}

    # -- (2) weighted sum average --
    if weights:
        total_w = sum(weights.get(k, 1) for k in keys)
        weighted_total = sum(avg[k] * weights.get(k, 1) for k in keys)
        stats["avg_weighted_total"] = weighted_total / (total_w or 1)

    return stats

# ------------------------------------------------------------------ #
# 📊 Function to aggregate multiple task metrics and calculate the average
# ------------------------------------------------------------------ #
def evaluate_and_report(vr, model_out, mode, weights, verbose=False):
    results, metrics_list = {}, []
    task_metrics = {}  # add
    for tid, raw in model_out.items():
        try:
            #pdb.set_trace()
            reward, detail = vr(tid, raw, mode=mode, return_metrics=True)
            results[tid] = reward
            metrics_list.append(detail)             # save detail whole
            task_metrics[tid] = detail  # add
            if verbose:
                metrics_txt = " ".join(f"{k}={v:.3f}" for k, v in detail.items())
                #print(f"{tid:30s}  reward={reward:.4f}  {metrics_txt}")
        except Exception as e:
            print(f"Error evaluating {tid}: {e}")
            pdb.set_trace()
            # set default value when error occurs
            if mode == "functionality_specification":
                default_detail = {'line': 0, 'branch': 0, 'correct': 0}
            else:  # contracts
                default_detail = {'AVC': 0, 'TS': 0, 'CE': 0}
            
            results[tid] = 0
            metrics_list.append(default_detail)
            task_metrics[tid] = default_detail
            
            if verbose:
                print(f"{tid:30s}  reward=0.0000  (error occurred)")
        

    # --- average · weighted average calculation ---
    stats = summarize_metrics(metrics_list, weights)
    avg_reward = sum(results.values()) / (len(results) or 1)

    print("\n=== SUMMARY ===")
    print(f"Tasks evaluated : {len(results)}")
    print(f"Average reward  : {avg_reward:.4f}")
    for k, v in sorted(stats.items()):
        print(f"{k:24s}: {v:.4f}")

    return results, stats, task_metrics


def save_results_to_csv(results, stats, task_metrics, output_dir, dataset, mode, timestamp=None):
    """
    Function to save results to CSV file
    
    Args:
        results: individual task results {task_id: reward}
        stats: overall statistics {metric: value}
        task_metrics: each task's detailed metrics {task_id: {metric: value}}
        output_dir: directory to save
        dataset: dataset name (humaneval/mbpp)
        mode: mode (functionality_specification/assert_specification)
        timestamp: timestamp (if not, use current time)
    """
    if timestamp is None:
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    
    # create directory if not exists
    os.makedirs(output_dir, exist_ok=True)
    
    # 1. save individual task results (include metrics)
    # extract metrics key (use the most frequent key among all tasks)
    all_keys = set()
    for m in task_metrics.values():
        all_keys.update(m.keys())
    metric_keys = sorted(all_keys)
    
    task_results_file = os.path.join(output_dir, f"{dataset}_{mode}_task_results_{timestamp}.csv")
    stats_file = os.path.join(output_dir, f"{dataset}_{mode}_overall_stats_{timestamp}.csv")    
    
    # 1. save individual task results (include metrics)
    with open(task_results_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['task_id', 'reward'] + metric_keys)
        for task_id, reward in results.items():
            metrics = task_metrics.get(task_id, {})
            row = [task_id, f"{reward:.6f}"] + [f"{metrics.get(k, 0):.6f}" for k in metric_keys]
            writer.writerow(row)
    
    # 2. save overall statistics
    if mode == "assert_specification":
        ordered_stats = ['Total_task_id', 'avg_AVC', 'avg_TS', 'avg_CE', 'avg_TQS', 'avg_weighted_total']
    else:
        ordered_stats = ['Total_task_id', 'avg_line', 'avg_branch', 'avg_correct', 'avg_weighted_total']
    with open(stats_file, 'w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(['metric', 'value'])
        for metric in ordered_stats:
            if metric == 'Total_task_id':
                value = len(results)
            else:
                value = stats.get(metric, 0)
            
            writer.writerow([metric, f"{value:.6f}"])
    
    print(f"\n=== CSV file saved ===")
    print(f"individual results: {task_results_file}")
    print(f"overall statistics: {stats_file}")
    
    return task_results_file, stats_file


def _call_with_timeout(func, args, timeout=5):
    """execute in the same process + N seconds timeout"""
    def handler(signum, frame):
        raise TimeoutError("Timeout")
    old_handler = signal.signal(signal.SIGALRM, handler)
    signal.alarm(timeout)
    try:
        return func(*args)
    finally:
        signal.alarm(0)
        signal.signal(signal.SIGALRM, old_handler)

def to_temp_file_and_exec(src_str, entry_point):
    tmp = tempfile.NamedTemporaryFile("w", suffix=".py", delete=False)
    tmp.write(src_str)
    tmp.flush(); tmp.close()

    # exec again to get function object (file name is fixed as tmp.name)
    local_env = {}
    with open(tmp.name, 'r', encoding='utf-8') as f:
        #code_obj = compile(f.read(), tmp.name, "exec")
        code_obj = compile(src_str, tmp.name, 'exec')
        exec(code_obj, local_env)
    
    return tmp.name, local_env[entry_point]



# ------------------------------------------------------------------ #
# Functionality Rewarder
# ------------------------------------------------------------------ #
class FunctionalityRewarder:
    def __init__(self, metric_weights=None, corr_weight=0.2, verbose=False):
        self.metric_weights = metric_weights or {'line': 1.0}
        self.corr_weight    = corr_weight
        self._need_branch   = 'branch' in self.metric_weights
        self.verbose        = verbose     # ← whether to print
    # -------------------------------------------------------------- #
    # internal: return detailed metrics after executing single test code ------------------- #
    # -------------------------------------------------------------- #
    def _coverage_reward(self, code_target, test_code): 
        cov = coverage.Coverage(branch=self._need_branch, data_file=None)
        target_path = None
        cov.start()
        # 1) execute
        func_obj, src_str, entry_point = code_target
        target_path, tmp_func = to_temp_file_and_exec(src_str, entry_point)            
        if isinstance(test_code, str):
            parsed_args = safe_eval(test_code)
        else:
            parsed_args = test_code  
            
        if parsed_args is None:                     # safe_eval failed
            cov.stop()
            return {"line": 0, "branch": 0, "correct": 0, "reward": 0}
        param_count = len(inspect.signature(tmp_func).parameters)
        if param_count == 1:
            if isinstance(parsed_args, (list, tuple)) and len(parsed_args) == 1:
                args_to_pass = [parsed_args[0]]     # remove wrapper
            else:
                args_to_pass = [parsed_args]
        else:
            if isinstance(parsed_args, (list, tuple)):
                args_to_pass = parsed_args
            else:
                args_to_pass = [parsed_args]

        try:
            result = _call_with_timeout(tmp_func, args_to_pass, timeout=5)
        except TimeoutError:
            result = "Timeout"
        except AssertionError as e:
            result = f"Exception: {e}"
        except Exception as e:
            result = f"Exception: {e}"
        cov.stop()
        
        if any(prefix in str(result) for prefix in _error_prefixes):
            line_ratio = 0
            branch_ratio = 0
            correct = 0
            reward = 0
            return {"line": line_ratio, "branch": branch_ratio,"correct": correct, "reward": reward}
        
        # 2) line·branch coverage (use original logic)
        _, stmts, _, _, missing_br = cov.analysis2(target_path)
        data = cov.get_data()
        hit_lines = data.lines(target_path) or []
        line_ratio = len(hit_lines) / len(stmts) if stmts else 0.0

        total_b = len(missing_br) + len(data.arcs(target_path) or [])
        hit_b   = len(data.arcs(target_path) or [])
        branch_ratio = hit_b / total_b if total_b else 1.0

        correct = 0 if _is_error(result) else 1
        reward = (self.metric_weights.get('line', 0) * line_ratio +
                self.metric_weights.get('branch', 0) * branch_ratio +
                self.corr_weight * correct)
        
        return {"line": line_ratio, "branch": branch_ratio,"correct": correct, "reward": reward}

    # -------------------------------------------------------------- #
    # external: return average reward only ---------------------- #
    # -------------------------------------------------------------- #
    def compute(self, code_target, test_code, return_detail=False):
        detail = self._coverage_reward(code_target, test_code)
        return detail if return_detail else detail["reward"]


# ------------------------------------------------------------------ #
# Contracts Reward
# ------------------------------------------------------------------ #

class ContractRewarder:
    def __init__(self,
                 AVC_weight: float = 0.40,
                 TS_weight : float = 0.45,
                 CE_weight : float = 0.15,
                 verbose: bool = False):
        self.W = dict(AVC=AVC_weight, TS=TS_weight, CE=CE_weight)
        self.verbose = verbose

    # ---------- public API -----------------------------------------
    def compcute(self,
                reference_src: str,
                entry_point  : str,
                parsed_assert_blocks: dict
               ) -> float:

        #preds = build_assert_predicates(reference_src, entry_point)
        ns: dict[str, Any] = {}
        exec(compile(reference_src, "<contract_src>", "exec"), ns)
        #exec(reference_src, ns)
        fn: types.FunctionType = ns[entry_point]
        param_cnt = len(inspect.signature(fn).parameters)

        #collect assert-line map
        line2id = collect_assert_ids_from_source(reference_src)
        all_ids = set(line2id.values())
        


        fired: Dict[str, Set[str]] = {}
        for aid, cases in parsed_assert_blocks.items():
            #print(f"Processing {aid} with {len(cases)} cases")
            for idx, case in enumerate(cases):
                raw = case["input"]
                args_obj = safe_eval(raw) if isinstance(raw, str) else raw

                if param_cnt == 1:
                    if isinstance(args_obj, (list, tuple)) and len(args_obj) == 1:
                        args_to_pass = [args_obj[0]]
                    else:
                        args_to_pass = [args_obj]
                else:
                    args_to_pass = list(args_obj) if isinstance(args_obj,(list,tuple)) else [args_obj]
                ids = _run_case_trace(fn, args_to_pass, line2id)
                fired[f"{aid}#{idx}"] = ids

        reward, detail = _metrics(fired, all_ids, len(fired), self.W)

        #if self.verbose:
        #    print("[Contract metrics]", detail)

        return reward, detail

# ------------------------------------------------------------------ #
# Multi-mode ValidityRewarder
# ---------------------------------rm--------------------------------- #
class ValidityRewarder:
    """
    canon_jsonl   : canonical_solution·entry_point file
    output_jsonl  : model output(jsonl) — here contracts(assert list) are extracted
    """
    def __init__(self,
                 output_jsonl,
                 canon_jsonl,          # ← original DATASET_JSONL
                 cov_weights=None, corr_weight=0.2,
                 AVC_weight=0.4, TS_weight=0.45, CE_weight=0.15,
                 load_variant="no_contracts"
                 ):

        # 1) load canonical function
        if 'humaneval' in canon_jsonl.lower():
            self.task_funcs = canonical_solution_load_for_reward('humaneval', load_variant)
        elif 'mbpp' in canon_jsonl.lower():
            self.task_funcs = canonical_solution_load_for_reward('mbpp', load_variant)

        # 2-A) entry_point is read from canonical file
        self.entry_map = {}
        self.contracts_map = {}

        with open(canon_jsonl, encoding="utf-8") as f:
            for line in f:
                rec = json.loads(line)
                self.entry_map[rec["task_id"]] = rec["entry_point"]
                tid = rec["task_id"]
                cdict = {}
                contract_str = _build_contract_str(rec.get("contract"))
                for line in contract_str.splitlines():
                    if ":" in line:
                        key, val = line.split(":", 1)
                        cdict[key.strip()] = val.strip()
                self.contracts_map[tid] = cdict
                
        # 3) prepare rewarder
        self.func_rwd     = FunctionalityRewarder(cov_weights, corr_weight, verbose=False)  # verbose to False
        self.contract_rwd = ContractRewarder(AVC_weight, TS_weight, CE_weight, verbose=False)  # verbose to False

    # ---------------------------------------------------------- #
    def __call__(self, task_id, raw_output, mode="functionality_specification", return_metrics=False):
        contracts = self.contracts_map[task_id]
        entry_fn  = self.entry_map[task_id]          # function name
        try:
            func_obj, src_str, entry_point = self.task_funcs[task_id] # (callable, src)
        except Exception as e:
            print(f"Error loading task {task_id}: {e}")
            return 0, {'line': 0, 'branch': 0, 'correct': 0} if mode == "functionality_specification" else {'AVC': 0, 'TS': 0, 'CE': 0}

        # ---------- Functionality ----------
        func_reward = 0
        details = []
        all_metrics = []
        if mode in ("functionality_specification"):
            if isinstance(raw_output, str):
                parsed_sec = extract_simplified_testcases(task_id, raw_output,"section", contracts)
            else:
                parsed_sec = raw_output
            scores = []
            
            for sec, cases in parsed_sec.items():
                if sec not in USE_SECTIONS:
                    continue
                for case in cases:
                    try:
                        arg_obj = safe_eval(case["input"])              # ← string → object
                        tc = f"{arg_obj!r}"
                        
                        d = self.func_rwd.compute((func_obj, src_str, entry_point), tc,return_detail=True)
                        
                        w_line    = self.func_rwd.metric_weights.get("line",   0)
                        w_branch  = self.func_rwd.metric_weights.get("branch", 0)
                        w_corr    = self.func_rwd.corr_weight
                        max_score = w_line + w_branch + w_corr        # ← max score
                        
                        raw  = (w_line*d["line"] + w_branch*d["branch"] + w_corr*d["correct"])
                        norm = raw / max_score                        # 0 ≤ norm ≤ 1
                        scores.append(norm)                           # ← normalized score
                        details.append(d)
                        all_metrics.append(d)
                    except Exception as e:
                        print(f"Error processing test case in {task_id}: {e}")
                        continue
                        
            avg_score = sum(scores) / (len(scores) or 1)
            threshold   = 0.8                                # threshold
            func_reward = 1 if avg_score >= threshold else 0
            
            #if self.func_rwd.verbose:
                #for d in details:
                #    print(d)
                #print(f"avg_score: {avg_score}")
        
        # ---------- Contracts ----------
        contract_reward = 0
        contract_metrics = {}
        if mode in ("assert_specification"):
            try:
                if isinstance(raw_output, str):
                    parsed_assert = extract_simplified_testcases(task_id, raw_output, "assert_specification", contracts)
                else:
                    parsed_assert = raw_output
                contract_reward, contract_metrics = self.contract_rwd.compcute(reference_src = src_str, entry_point = entry_point, parsed_assert_blocks = parsed_assert)
            except Exception as e:
                print(f"Error processing contracts for {task_id}: {e}")
                contract_reward = 0
                contract_metrics = {}
        
        # ---------- results ----------
        if return_metrics:
            if mode == "functionality_specification":
                # functionality mode returns average metrics of all test cases
                if all_metrics:
                    avg_metrics = {}
                    for key in ['line', 'branch', 'correct']:
                        avg_metrics[key] = sum(m.get(key, 0) for m in all_metrics) / len(all_metrics)
                    return func_reward, avg_metrics
                else:
                    return func_reward, {'line': 0, 'branch': 0, 'correct': 0}
            elif mode == "assert_specification":
                return contract_reward, contract_metrics
        else:
            if mode == "functionality_specification":
                return func_reward
            elif mode == "assert_specification":
                return contract_reward

# --------------------------------------------------------------------------- #
def signal_handler(signum, frame):
    """Ctrl+C signal handling"""
    print("\n\n⚠️  Process interrupted by user. Cleaning up and exiting...")
    sys.exit(0)

if __name__ == "__main__":
    # register Ctrl+C signal handler
    signal.signal(signal.SIGINT, signal_handler)
    
    parser = argparse.ArgumentParser(description='Reward function calculation script')
    parser.add_argument('--data', type=str, default='mbpp', choices=['humaneval', 'mbpp'], help='Dataset to use (humaneval or mbpp)')
    parser.add_argument('--canon_jsonl', type=str, help='Path to canonical JSONL file')
    parser.add_argument('--output_jsonl', type=str, help='Path to model output JSONL file')
    parser.add_argument('--mode', type=str, default='contracts', help='Reward calculation mode')
    parser.add_argument('--line-weight', type=float, default=0.5, help='Line coverage weight')
    parser.add_argument('--branch-weight', type=float, default=0.5, help='Branch coverage weight')
    parser.add_argument('--corr-weight', type=float, default=0.7, help='Correctness weight')
    parser.add_argument('--avc-weight', type=float, default=0.40, help='AVC weight')
    parser.add_argument('--ts-weight', type=float, default=0.45, help='TS weight')
    parser.add_argument('--ce-weight', type=float, default=0.15, help='CE weight')
    parser.add_argument('--load-variant', type=str, default='in_contracts', choices=['no_contracts', 'in_contracts'], help='Load variant')
    parser.add_argument('--verbose', type=bool, default=False, help='Enable verbose output')  # Default is False
    parser.add_argument('--output_dir', type=str, default='../../code/metrics_output', help='Directory to save results')
    parser.add_argument('--batch_size', type=int, default=0, help='Batch size (0 means process all)')
    parser.add_argument('--start_idx', type=int, default=0, help='Start index')
    args = parser.parse_args()
    
    CANON_JSONL = args.canon_jsonl
    OUTPUT_JSONL = args.output_jsonl

    print(f"Dataset: {args.data}")
    print(f"Canonical JSONL: {CANON_JSONL}")
    print(f"Output JSONL: {OUTPUT_JSONL}")
    print(f"Mode: {args.mode}")
    
    # ------------------------------------------------------------------ #
    # A. create model output jsonl → {task_id: raw_output(str)} dictionary
    # ------------------------------------------------------------------ #
    model_out = {}
    
    if 'HumanEvalPlus.jsonl' in OUTPUT_JSONL or 'MbppPlus.jsonl' in OUTPUT_JSONL:
        with open(OUTPUT_JSONL, encoding="utf-8") as f:
            for line in f:
                rec = json.loads(line)
                tid  = rec.get("task_id")
                raw  = (rec["plus_input"])
                if raw:
                    model_out[tid] = raw
    
    elif 'jsonl' in OUTPUT_JSONL:
        with open(OUTPUT_JSONL, encoding="utf-8") as f:
            for line in f:
                if not line.strip():
                    continue
                rec = json.loads(line)
                tid  = rec.get("task_id") or rec["name"]          # either one
                raw  = (rec["grammar"][0]["production"][0])
                if raw:
                    model_out[tid] = raw
    
    elif Path(OUTPUT_JSONL).suffix.lower() == ".json":
        with Path(OUTPUT_JSONL).open(encoding="utf-8") as f:
            data = json.load(f)        
            model_name = list(data.keys())[0]
            for tid in data[model_name]:
                model_out[tid] = data[model_name][tid]

                    
    print(f"✔ model output {len(model_out)} loaded")

    # ------------------------------------------------------------------ #
    # B. initialize rewarder (contracts are automatically extracted from OUTPUT_JSONL)
    # ------------------------------------------------------------------ #
    vr = ValidityRewarder(
        output_jsonl = OUTPUT_JSONL,
        canon_jsonl  = CANON_JSONL,      # canonical solution·entry_point
        cov_weights  = {'line': args.line_weight, 'branch': args.branch_weight},
        corr_weight  = args.corr_weight,               # 커버리지 전용
        AVC_weight   = args.avc_weight,
        TS_weight    = args.ts_weight,
        CE_weight    = args.ce_weight,
        load_variant = args.load_variant
    )
    
    vr.func_rwd.verbose = args.verbose 
    vr.contract_rwd.verbose = args.verbose
    
    # ------------------------------------------------------------------ #
    # C. calculate reward and average metrics
    # ------------------------------------------------------------------ #
    
    # batch processing logic
    if args.batch_size > 0:
        model_out_items = list(model_out.items())
        start_idx = args.start_idx
        end_idx = min(start_idx + args.batch_size, len(model_out_items))
        
        print(f"Batch processing: {start_idx} ~ {end_idx-1} / total {len(model_out_items)}")
        
        # select items corresponding to the batch
        batch_model_out = dict(model_out_items[start_idx:end_idx])
    else:
        batch_model_out = model_out
    
    if args.mode == "functionality_specification":
        w = {'line': args.line_weight,
            'branch': args.branch_weight,
            'correct': args.corr_weight}
    else:  # contracts
        w = {'AVC': args.avc_weight,
            'TS' : args.ts_weight,
            'CE' : args.ce_weight}
    
    results, stats, task_metrics = evaluate_and_report(vr, batch_model_out, mode=args.mode, weights=w, verbose=args.verbose)
    
    # print("\n" + "="*50)
    # print("=== overall metric average ===")
    # print("="*50)
    # for k, v in sorted(stats.items()):
    #     print(f"{k:24s}: {v:.4f}")
    # print("="*50)
    
    # # individual task results (only in verbose mode)
    # if args.verbose:
    #     print("\n=== individual task results ===")
    #     for task_id, reward in results.items():
    #         print(f"{task_id:30s}: {reward:.4f}")
    
    # save results to CSV file
    output_dir = args.output_dir
    timestamp = datetime.now().strftime("%Y%m%d")
    task_file, stats_file = save_results_to_csv(
        results, stats, task_metrics, output_dir, args.data, args.mode, timestamp
    )
        

