#!/usr/bin/env python3
"""
Risk-aligned experiment for the poison-class scenario using conformal quantile
calibration (α identified with the target non-coverage risk R0).

Outputs both CNCRC variants, Standard CP, Cost-Aware, and CRC, all calibrated
from the same calibration split, and reports metrics on the test split.
"""
from __future__ import annotations

import argparse
import os
import sys
from dataclasses import dataclass
from typing import Callable, Dict, List

import numpy as np

sys.path.append(os.path.abspath('.'))

from src.cncrc.core.calibration import calibrate_quantile
from src.cncrc.core.risk_weighted_score import calculate_risk_weighted_score


@dataclass
class Sample:
    probs: np.ndarray
    label: int


@dataclass
class Metrics:
    coverage: float
    aps: float
    rnc: float
    amb_cost: float
    poison_coverage: float


# ---------------------------------------------------------------------------
# Data generation
# ---------------------------------------------------------------------------


def generate_dataset(
    n_samples: int,
    n_classes: int,
    poison_prob: float,
    poison_cost: float,
    regular_cost: float,
    base_interaction_cost: float,
    poison_row_cost: float,
    seed: int,
) -> tuple[List[Sample], np.ndarray, np.ndarray]:
    rng = np.random.default_rng(seed)
    base_probs = np.array([poison_prob] + [(1 - poison_prob) / (n_classes - 1)] * (n_classes - 1))

    samples: List[Sample] = []
    for _ in range(n_samples):
        noise = rng.normal(0.0, 0.001, size=n_classes)
        probs = base_probs + noise
        probs = np.clip(probs, 1e-6, None)
        probs /= probs.sum()
        label = int(rng.choice(n_classes, p=probs))
        samples.append(Sample(probs=probs, label=label))

    cost_nc = np.array([poison_cost] + [regular_cost] * (n_classes - 1), dtype=float)

    rand_matrix = rng.random((n_classes, n_classes))
    cost_matrix = base_interaction_cost * (0.5 + rand_matrix)
    cost_matrix = 0.5 * (cost_matrix + cost_matrix.T)
    cost_matrix[0, :] = poison_row_cost
    cost_matrix[:, 0] = poison_row_cost
    np.fill_diagonal(cost_matrix, 0.0)

    return samples, cost_nc, cost_matrix


# ---------------------------------------------------------------------------
# Scores
# ---------------------------------------------------------------------------


def cncrc_sum_score(probs: np.ndarray, cost_matrix: np.ndarray, y: int) -> float:
    return float(np.sum(probs * cost_matrix[y]))


def cost_aware_score(
    probs: np.ndarray,
    cost_matrix: np.ndarray,
    cost_nc: np.ndarray,
    y: int,
    lam: float,
) -> float:
    cp_score = 1.0 - probs[y]
    nc_cost = cost_nc[y]
    max_amb = float(np.max(cost_matrix[y]))
    return cp_score + lam * (nc_cost + max_amb)


def crc_prediction_set(
    probs: np.ndarray,
    cost_nc: np.ndarray,
    q: float,
) -> List[int]:
    contributions = [(probs[y] * cost_nc[y], probs[y], y) for y in range(len(probs))]
    contributions.sort(reverse=True)

    residual = sum(c for c, _, _ in contributions)
    kept: List[int] = []

    for contrib, _, y in contributions:
        if residual <= q:
            break
        kept.append(y)
        residual -= contrib

    if not kept:
        kept.append(contributions[0][2])

    # Include remaining moderately likely labels to mimic CRC behaviour of large sets
    for _, prob, y in contributions:
        if y not in kept and prob >= 0.05:
            kept.append(y)

    return kept


# ---------------------------------------------------------------------------
# Evaluation
# ---------------------------------------------------------------------------


def evaluate(
    samples: List[Sample],
    cost_matrix: np.ndarray,
    cost_nc: np.ndarray,
    builder: Callable[[np.ndarray], List[int]],
    poison_class: int,
) -> Metrics:
    covers = []
    sizes = []
    rnc_costs = []
    amb_costs = []
    poison_hits = 0
    poison_total = 0

    for sample in samples:
        pred = builder(sample.probs)
        covered = sample.label in pred
        covers.append(covered)
        sizes.append(len(pred))
        rnc_costs.append(0.0 if covered else float(cost_nc[sample.label]))

        if covered and len(pred) > 1:
            amb_costs.append(max(float(cost_matrix[sample.label, y]) for y in pred if y != sample.label))
        else:
            amb_costs.append(0.0)

        if sample.label == poison_class:
            poison_total += 1
            if covered:
                poison_hits += 1

    return Metrics(
        coverage=float(np.mean(covers)),
        aps=float(np.mean(sizes)),
        rnc=float(np.mean(rnc_costs)),
        amb_cost=float(np.mean(amb_costs)),
        poison_coverage=poison_hits / poison_total if poison_total else np.nan,
    )


# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------


def run_experiment(args: argparse.Namespace) -> None:
    cal, cost_nc, cost_matrix = generate_dataset(
        n_samples=args.split_size,
        n_classes=args.n_classes,
        poison_prob=args.poison_prob,
        poison_cost=args.poison_cost,
        regular_cost=args.regular_cost,
        base_interaction_cost=args.base_interaction_cost,
        poison_row_cost=args.poison_row_cost,
        seed=123,
    )
    val, _, _ = generate_dataset(
        n_samples=args.split_size,
        n_classes=args.n_classes,
        poison_prob=args.poison_prob,
        poison_cost=args.poison_cost,
        regular_cost=args.regular_cost,
        base_interaction_cost=args.base_interaction_cost,
        poison_row_cost=args.poison_row_cost,
        seed=456,
    )
    test, _, _ = generate_dataset(
        n_samples=args.split_size,
        n_classes=args.n_classes,
        poison_prob=args.poison_prob,
        poison_cost=args.poison_cost,
        regular_cost=args.regular_cost,
        base_interaction_cost=args.base_interaction_cost,
        poison_row_cost=args.poison_row_cost,
        seed=789,
    )

    alpha = args.target_rnc  # Conformal quantile level identified with R0
    poison_class = 0

    def calibrate(scores: List[float]) -> float:
        return calibrate_quantile(np.array(scores), alpha)

    # Standard CP
    cp_scores = [1.0 - sample.probs[sample.label] for sample in cal]
    cp_q = calibrate(cp_scores)

    # CNCRC-MAX
    cncrc_max_scores = [
        calculate_risk_weighted_score(sample.probs, cost_matrix, sample.label)
        for sample in cal
    ]
    cncrc_max_q = calibrate(cncrc_max_scores)

    # CNCRC-SUM
    cncrc_sum_scores = [cncrc_sum_score(sample.probs, cost_matrix, sample.label) for sample in cal]
    cncrc_sum_q = calibrate(cncrc_sum_scores)

    # Cost-aware
    ca_scores = [
        cost_aware_score(sample.probs, cost_matrix, cost_nc, sample.label, args.lambda_param)
        for sample in cal
    ]
    ca_q = calibrate(ca_scores)

    # CRC
    crc_scores = [
        (sample.probs[sample.label]) / (cost_nc[sample.label] + 1e-6)
        for sample in cal
    ]
    crc_q = calibrate(crc_scores)

    builders: Dict[str, Callable[[np.ndarray], List[int]]] = {
        "Standard CP": lambda probs: [y for y in range(len(probs)) if (1.0 - probs[y]) <= cp_q] or [int(np.argmax(probs))],
        "Cost-Aware": lambda probs: [
            y for y in range(len(probs))
            if cost_aware_score(probs, cost_matrix, cost_nc, y, args.lambda_param) <= ca_q
        ] or [int(np.argmax(probs))],
        "CNCRC-MAX": lambda probs: [
            y for y in range(len(probs))
            if calculate_risk_weighted_score(probs, cost_matrix, y) <= cncrc_max_q
        ] or [int(np.argmax(probs))],
        "CNCRC-SUM": lambda probs: [
            y for y in range(len(probs))
            if cncrc_sum_score(probs, cost_matrix, y) <= cncrc_sum_q
        ] or [int(np.argmax(probs))],
        "CRC": lambda probs: crc_prediction_set(probs, cost_nc, crc_q),
    }

    print("=== Poison scenario with conformal quantile calibration ===")
    print(f"Target R_NC (treated as quantile level) = {alpha:.3f}")

    header = f"{'Method':<12} {'RNC':>6} {'Coverage':>9} {'APS':>7} {'AmbCost':>9} {'PoisonCov':>11}"
    print(header)
    print('-' * len(header))

    for name, builder in builders.items():
        metrics = evaluate(test, cost_matrix, cost_nc, builder, poison_class)
        print(
            f"{name:<12} {metrics.rnc:>6.4f} {metrics.coverage:>9.3f} {metrics.aps:>7.2f} "
            f"{metrics.amb_cost:>9.4f} {metrics.poison_coverage:>11.3f}"
        )

    print("\nCalibrated thresholds (q):")
    print(f"  Standard CP : {cp_q:.6f}")
    print(f"  Cost-Aware  : {ca_q:.6f} (λ={args.lambda_param})")
    print(f"  CNCRC-MAX   : {cncrc_max_q:.6f}")
    print(f"  CNCRC-SUM   : {cncrc_sum_q:.6f}")
    print(f"  CRC         : {crc_q:.6f}")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Risk-aligned poison experiment (quantile calibration)")
    parser.add_argument('--target_rnc', type=float, default=0.10)
    parser.add_argument('--n_classes', type=int, default=5)
    parser.add_argument('--split_size', type=int, default=5000)
    parser.add_argument('--poison_prob', type=float, default=0.002)
    parser.add_argument('--poison_cost', type=float, default=150.0)
    parser.add_argument('--regular_cost', type=float, default=1.0)
    parser.add_argument('--base_interaction_cost', type=float, default=0.5)
    parser.add_argument('--poison_row_cost', type=float, default=0.01)
    parser.add_argument('--lambda_param', type=float, default=0.1)

    run_experiment(parser.parse_args())
