"""
Metrics tracking and experiment logging for ACEAS.
"""

import json
import time
from dataclasses import dataclass, field, asdict
from typing import Dict, List, Any, Optional
from pathlib import Path
import numpy as np


@dataclass
class TrainingMetrics:
    """Metrics for a single training step."""
    timestep: int
    update: int
    policy_loss: float
    mean_reward: float
    success_rate: float
    throughput: float
    difficulty_distribution: Dict[int, float] = field(default_factory=dict)
    curriculum_stats: Dict[str, Any] = field(default_factory=dict)


@dataclass
class EvalMetrics:
    """Metrics for evaluation."""
    timestep: int
    pass_at_1: float
    pass_at_10: float = 0.0
    avg_reward: float = 0.0
    num_tasks: int = 0


class MetricsTracker:
    """
    Tracks and aggregates training metrics.
    """

    def __init__(self, window_size: int = 100):
        self.window_size = window_size

        # Training metrics
        self.rewards: List[float] = []
        self.successes: List[bool] = []
        self.losses: List[float] = []
        self.throughputs: List[float] = []

        # Per-difficulty metrics
        self.difficulty_rewards: Dict[int, List[float]] = {d: [] for d in range(1, 6)}
        self.difficulty_successes: Dict[int, List[bool]] = {d: [] for d in range(1, 6)}

        # Timing metrics
        self.collection_times: List[float] = []
        self.train_times: List[float] = []
        self.update_times: List[float] = []

    def add_training_step(
        self,
        rewards: List[float],
        successes: List[bool],
        difficulties: List[int],
        loss: float,
        throughput: float,
        collection_time: float = 0.0,
        train_time: float = 0.0,
        update_time: float = 0.0,
    ):
        """Record metrics from a training step."""
        self.rewards.extend(rewards)
        self.successes.extend(successes)
        self.losses.append(loss)
        self.throughputs.append(throughput)

        # Per-difficulty
        for reward, success, diff in zip(rewards, successes, difficulties):
            self.difficulty_rewards[diff].append(reward)
            self.difficulty_successes[diff].append(success)

        # Timing
        self.collection_times.append(collection_time)
        self.train_times.append(train_time)
        self.update_times.append(update_time)

    def get_recent_stats(self) -> Dict[str, Any]:
        """Get statistics over recent window."""
        window = self.window_size

        recent_rewards = self.rewards[-window:] if self.rewards else []
        recent_successes = self.successes[-window:] if self.successes else []
        recent_losses = self.losses[-window:] if self.losses else []
        recent_throughputs = self.throughputs[-window:] if self.throughputs else []

        stats = {
            "mean_reward": np.mean(recent_rewards) if recent_rewards else 0.0,
            "std_reward": np.std(recent_rewards) if recent_rewards else 0.0,
            "success_rate": np.mean(recent_successes) if recent_successes else 0.0,
            "mean_loss": np.mean(recent_losses) if recent_losses else 0.0,
            "mean_throughput": np.mean(recent_throughputs) if recent_throughputs else 0.0,
            "total_samples": len(self.rewards),
        }

        # Per-difficulty stats
        for d in range(1, 6):
            recent_d = self.difficulty_rewards[d][-window:]
            if recent_d:
                stats[f"difficulty_{d}_reward"] = np.mean(recent_d)
                stats[f"difficulty_{d}_success"] = np.mean(self.difficulty_successes[d][-window:])

        # Timing stats
        if self.collection_times:
            stats["mean_collection_time"] = np.mean(self.collection_times[-window:])
            stats["mean_train_time"] = np.mean(self.train_times[-window:])
            stats["mean_update_time"] = np.mean(self.update_times[-window:])

        return stats

    def get_difficulty_distribution(self) -> Dict[int, float]:
        """Get empirical distribution of difficulties."""
        total = sum(len(self.difficulty_rewards[d]) for d in range(1, 6))
        if total == 0:
            return {d: 0.2 for d in range(1, 6)}

        return {
            d: len(self.difficulty_rewards[d]) / total
            for d in range(1, 6)
        }


class ExperimentLogger:
    """
    Logs experiment results to files.
    """

    def __init__(self, output_dir: str, experiment_name: str = "aceas"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        self.experiment_name = experiment_name

        self.log_file = self.output_dir / f"{experiment_name}_log.jsonl"
        self.start_time = time.time()

    def log(self, data: Dict[str, Any], step: Optional[int] = None):
        """Log a data point."""
        entry = {
            "timestamp": time.time() - self.start_time,
            "step": step,
            **data,
        }

        with open(self.log_file, "a") as f:
            f.write(json.dumps(entry, default=str) + "\n")

    def log_config(self, config: Dict[str, Any]):
        """Log experiment configuration."""
        config_file = self.output_dir / f"{self.experiment_name}_config.json"
        with open(config_file, "w") as f:
            json.dump(config, f, indent=2, default=str)

    def log_final_results(self, results: Dict[str, Any]):
        """Log final experiment results."""
        results_file = self.output_dir / f"{self.experiment_name}_results.json"
        with open(results_file, "w") as f:
            json.dump(results, f, indent=2, default=str)

    def load_logs(self) -> List[Dict[str, Any]]:
        """Load logged data."""
        logs = []
        if self.log_file.exists():
            with open(self.log_file) as f:
                for line in f:
                    logs.append(json.loads(line))
        return logs


def compute_pass_at_k(
    results: List[Dict[str, Any]],
    k: int = 1,
) -> float:
    """
    Compute pass@k metric.

    Args:
        results: List of dicts with 'task_id' and 'passed' keys
        k: Number of attempts to consider

    Returns:
        pass@k rate
    """
    from collections import defaultdict

    # Group by task
    task_results = defaultdict(list)
    for r in results:
        task_results[r["task_id"]].append(r["passed"])

    # Compute pass@k for each task
    passes = 0
    for task_id, attempts in task_results.items():
        n = len(attempts)
        c = sum(attempts)

        if n >= k:
            # Probability of at least one success in k tries
            # 1 - C(n-c, k) / C(n, k)
            from math import comb
            if c >= k:
                pass_k = 1.0
            elif n - c >= k:
                pass_k = 1.0 - comb(n - c, k) / comb(n, k)
            else:
                pass_k = 1.0
        else:
            # Not enough attempts, use empirical
            pass_k = 1.0 if c > 0 else 0.0

        passes += pass_k

    return passes / len(task_results) if task_results else 0.0


if __name__ == "__main__":
    print("Testing metrics modules...")

    # Test MetricsTracker
    tracker = MetricsTracker()

    for i in range(100):
        rewards = list(np.random.random(8))
        successes = list(np.random.random(8) > 0.5)
        difficulties = list(np.random.randint(1, 6, 8))

        tracker.add_training_step(
            rewards=rewards,
            successes=successes,
            difficulties=difficulties,
            loss=np.random.random(),
            throughput=50 + np.random.random() * 10,
            collection_time=0.5,
            train_time=0.1,
            update_time=0.6,
        )

    stats = tracker.get_recent_stats()
    print(f"Recent stats: {stats}")

    dist = tracker.get_difficulty_distribution()
    print(f"Difficulty distribution: {dist}")

    # Test ExperimentLogger
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdir:
        logger = ExperimentLogger(tmpdir, "test")
        logger.log_config({"lr": 1e-5, "batch_size": 32})
        logger.log({"loss": 0.5, "reward": 0.8}, step=1)
        logger.log({"loss": 0.4, "reward": 0.9}, step=2)
        logger.log_final_results({"final_pass_at_1": 0.75})

        logs = logger.load_logs()
        print(f"Logged {len(logs)} entries")

    # Test pass@k
    results = [
        {"task_id": "t1", "passed": True},
        {"task_id": "t1", "passed": False},
        {"task_id": "t2", "passed": False},
        {"task_id": "t2", "passed": False},
        {"task_id": "t3", "passed": True},
    ]
    pass_1 = compute_pass_at_k(results, k=1)
    print(f"pass@1: {pass_1:.3f}")

    print("\nAll metrics tests passed!")
