"""
Universal experiment manager for contextual bandit training.

This module provides a flexible experiment manager that handles parameter
configuration, training tracking, and regret computation using final arm
parameters as ground truth.
"""

import json
import math
import os
import random
from dataclasses import asdict, dataclass, field
from pprint import pprint
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
from tqdm import tqdm

from .dynamic_contextual_bandit import DynamicArmContextualBandit
from .llm_judge import JudgeLLM
from .multihop_rag_bandit import MultiHopRAGBandit
from .reranker import LLMReranker
from .reranker_contextual_bandit import RerankingContextualBandit
from .universal_bandit_optimizer import (
    NormConstraintType,
    UniversalBanditOptimizer,
    create_unit_sphere_optimizer,
)
from .universal_contextual_bandit import UniversalContextualBandit
from .universal_data_loader import UniversalBanditDataLoader


@dataclass
class UniversalBanditConfig:
    """Configuration for universal bandit experiments."""

    # Data parameters
    embedding_model: str = "large"  # 'ada', 'small', 'large'
    true_embedding_model: str = "large"  # Model to use for true embeddings
    max_queries: Optional[int] = None  # None for all queries
    add_noise: bool = False
    noise_std: float = 0.1

    # Training parameters
    n_epochs: int = 1  # Number of full passes through all queries
    batch_size: int = 10
    learning_rate: float = 1e-5
    temperature: float = 0.3
    epsilon: float = 0.0  # Epsilon-greedy exploration
    lambda_reg: float = 0.0  # L2 regularization
    clip_value: float = 10.0  # IPS weight clipping value

    # Optimizer parameters
    norm_constraint: NormConstraintType = NormConstraintType.UNIT_SPHERE
    beta1: float = 0.9
    beta2: float = 0.999
    optimizer_epsilon: float = 1e-8

    # Scheduler parameters (for future extension)
    scheduler_type: str = "exp"  # "constant", "exp", "warmup_cosine"
    warmup_steps: int = 0
    warmup_ratio: float = 0.0
    total_steps: int = 0  # Auto-computed if 0
    min_lr_scale: float = 0.0

    # Evaluation parameters
    eval_interval: int = 100  # Evaluate every N batches
    recall_k: int = 10

    # Parameter tracking for regret computation
    track_interval: int = 20  # Track parameters every N steps

    # Reproducibility
    seed: int = 42

    # Logging / output
    metrics_output_path: Optional[str] = "runs/universal_bandit_metrics.json"

    # Additional configuration
    extra_config: Dict[str, Any] = field(default_factory=dict)


class ParameterTracker:
    """Tracks parameter evolution during training for regret computation."""

    def __init__(self, track_interval: int = 20):
        """
        Initialize parameter tracker.

        Args:
            track_interval: Interval (in steps) for tracking parameters
        """
        self.track_interval = track_interval
        self.tracked_steps: List[int] = []
        self.tracked_parameters: List[torch.Tensor] = []
        # Store references to all queries and correct arms for regret computation
        self.all_query_embeddings: Optional[torch.Tensor] = None
        self.all_query_correct_arms: Optional[List[List[int]]] = None

    def should_track(self, step: int) -> bool:
        """Check if parameters should be tracked at this step."""
        return step % self.track_interval == 0

    def set_all_queries(
        self,
        query_embeddings: torch.Tensor,
        query_correct_arms: List[List[int]],
    ):
        """
        Set references to all queries for regret computation.

        Args:
            query_embeddings: All query embeddings [num_queries, embedding_dim]
            query_correct_arms: Correct arms for all queries
        """
        self.all_query_embeddings = query_embeddings
        self.all_query_correct_arms = query_correct_arms

    def track(
        self,
        step: int,
        parameters: torch.Tensor,
    ):
        """
        Track parameters at a given step.

        Args:
            step: Current training step
            parameters: Current arm parameters [num_arms, embedding_dim]
        """
        self.tracked_steps.append(step)
        self.tracked_parameters.append(parameters.clone().cpu())

    def compute_regrets(
        self,
        final_parameters: torch.Tensor,
    ) -> List[Tuple[int, float]]:
        """
        Compute cumulative regrets using final parameters as ground truth.

        Uses cross-entropy loss as the regret metric and evaluates each tracked
        parameter against ALL queries (not just the batch seen at that step).

        Formula: regret_t = Σ_{i=1}^t Σ_q (L(θ_t, q) - L(θ_final, q))
        where L is cross-entropy loss.

        Args:
            final_parameters: Final learned parameters [num_arms, embedding_dim]

        Returns:
            List of (step, cumulative_regret) tuples
        """
        if (
            self.all_query_embeddings is None
            or self.all_query_correct_arms is None
        ):
            raise ValueError(
                "Must call set_all_queries() before computing regrets"
            )

        regrets = []
        final_params_cpu = final_parameters.cpu()
        all_queries_cpu = self.all_query_embeddings.cpu()

        # Compute final cross-entropy losses for all queries (ground truth)
        final_scores = torch.mm(
            all_queries_cpu, final_params_cpu.t()
        )  # [num_queries, num_arms]
        final_probs = torch.softmax(
            final_scores, dim=1
        )  # [num_queries, num_arms]

        # Compute final cross-entropy loss for each query
        final_ce_losses = torch.zeros(len(self.all_query_correct_arms))
        for i, correct_arms in enumerate(self.all_query_correct_arms):
            if correct_arms:  # Only compute if there are correct arms
                # For each correct arm, compute -log(p_correct)
                correct_probs = final_probs[i, correct_arms]
                # Use the maximum probability among correct arms (best case)
                best_correct_prob = torch.max(correct_probs)
                final_ce_losses[i] = -torch.log(best_correct_prob + 1e-9)

        cumulative_regret = 0.0

        for step, params in zip(self.tracked_steps, self.tracked_parameters):
            # Compute current cross-entropy losses for all queries
            current_scores = torch.mm(
                all_queries_cpu, params.t()
            )  # [num_queries, num_arms]
            current_probs = torch.softmax(
                current_scores, dim=1
            )  # [num_queries, num_arms]

            # Compute current cross-entropy loss for each query
            current_ce_losses = torch.zeros(len(self.all_query_correct_arms))
            for i, correct_arms in enumerate(self.all_query_correct_arms):
                if correct_arms:  # Only compute if there are correct arms
                    # For each correct arm, compute -log(p_correct)
                    correct_probs = current_probs[i, correct_arms]
                    # Use the maximum probability among correct arms (best case)
                    best_correct_prob = torch.max(correct_probs)
                    current_ce_losses[i] = -torch.log(best_correct_prob + 1e-9)

            # Compute regret for this step: sum over all queries
            step_regret = torch.sum(current_ce_losses - final_ce_losses).item()

            # Add to cumulative regret
            cumulative_regret += step_regret

            regrets.append((step, cumulative_regret))

        return regrets

    def clear(self):
        """Clear tracked data."""
        self.tracked_steps.clear()
        self.tracked_parameters.clear()
        self.all_query_embeddings = None
        self.all_query_correct_arms = None


class TrainingTracker:
    """Tracks training progress and metrics."""

    def __init__(self):
        """Initialize training tracker."""
        self.rewards_history = []
        self.regrets_history = []
        self.recall_history = {}
        self.hit_history = {}
        self.ndcg_history = {}
        self.grad_stats_history = []
        self.ips_stats_history = []
        self.lr_history = []

    def log_batch(
        self,
        step: int,
        rewards: torch.Tensor,
        regrets: torch.Tensor,
        grad_stats: Dict[str, float],
        ips_stats: Dict[str, float],
        learning_rate: float,
    ):
        """Log metrics for a training batch."""
        self.rewards_history.append((step, torch.mean(rewards).item()))
        self.regrets_history.append((step, torch.mean(regrets).item()))
        self.grad_stats_history.append((step, grad_stats))
        self.ips_stats_history.append((step, ips_stats))
        self.lr_history.append((step, learning_rate))

    def log_evaluation(self, step: int, rslt_dict: dict):
        """Log evaluation metrics."""
        for k, v in rslt_dict.items():
            if k.startswith("recall_at_"):
                if not k in self.recall_history:
                    self.recall_history[k] = []
                self.recall_history[k].append((step, v))
            elif k.startswith("ndcg_at_"):
                if not k in self.ndcg_history:
                    self.ndcg_history[k] = []
                self.ndcg_history[k].append((step, v))
            elif k.startswith("hit_at_"):
                if not k in self.hit_history:
                    self.hit_history[k] = []
                self.hit_history[k].append((step, v))

    def get_final_metrics(self) -> Dict[str, Any]:
        """Get final evaluation metrics (full-dataset only)."""
        metrics: Dict[str, Any] = {}

        # Only keep final recall/ndcg metrics; omit rewards, latent regret, IPS, grads, LR
        if self.recall_history:
            for k, v in self.recall_history.items():
                metrics[f"final_{k}"] = v[-1][1]

        if self.ndcg_history:
            for k, v in self.ndcg_history.items():
                metrics[f"final_{k}"] = v[-1][1]

        if self.hit_history:
            for k, v in self.hit_history.items():
                metrics[f"final_{k}"] = v[-1][1]

        return metrics

    def clear(self):
        """Clear all tracked data."""
        self.rewards_history.clear()

    def to_serializable(self) -> Dict[str, Any]:
        """Convert tracked histories into JSON-serializable structure.

        Only include full-dataset evaluation metrics (recall@k, ndcg@k).
        """

        def tuples_to_dict_list(pairs, key):
            return [{"step": int(s), key: v} for s, v in pairs]

        rslt: Dict[str, Any] = {}

        if self.recall_history:
            for k, v in self.recall_history.items():
                rslt[k] = tuples_to_dict_list(v, k)

        if self.ndcg_history:
            for k, v in self.ndcg_history.items():
                rslt[k] = tuples_to_dict_list(v, k)

        if self.hit_history:
            for k, v in self.hit_history.items():
                rslt[k] = tuples_to_dict_list(v, k)

        return rslt

    def save_json(
        self,
        output_path: str,
        *,
        final_metrics: Optional[Dict[str, Any]] = None,
        parameter_regrets: Optional[List[Tuple[int, float]]] = None,
        config_dict: Optional[Dict[str, Any]] = None,
    ) -> str:
        """Save tracked histories and summaries to a JSON file.

        Returns the path written.
        """
        if not output_path:
            output_path = "runs/universal_bandit_metrics.json"
        os.makedirs(os.path.dirname(output_path), exist_ok=True)

        payload: Dict[str, Any] = {
            "histories": self.to_serializable(),
        }
        if final_metrics is not None:
            payload["final_metrics"] = final_metrics
        if parameter_regrets is not None:
            payload["parameter_regrets"] = [
                {"step": int(s), "regret": float(r)}
                for s, r in parameter_regrets
            ]
        if config_dict is not None:
            payload["config"] = {k: str(v) for k, v in config_dict.items()}

        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2)

        self.regrets_history.clear()
        self.rewards_history.clear()
        self.recall_history.clear()
        self.hit_history.clear()
        self.ndcg_history.clear()
        self.grad_stats_history.clear()
        self.ips_stats_history.clear()
        self.lr_history.clear()

        return output_path


class UniversalBanditExperiment:
    """
    Universal contextual bandit experiment manager.

    Handles experiment setup, training loop, evaluation, and result collection
    with parameter tracking for regret computation.
    """

    def __init__(self, config: UniversalBanditConfig):
        """Initialize the experiment with configuration."""
        self.config = config
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )

        # Data containers
        self.query_embeddings = None
        self.initial_embeddings = None
        self.true_embeddings = None
        self.query_correct_arms = None
        self.current_embeddings = None

        # Training components
        self.bandit = None
        self.optimizer = None

        # Tracking components
        self.tracker = TrainingTracker()
        self.parameter_tracker = ParameterTracker(config.track_interval)

        if config.extra_config.get("exp_type") == "rerank":
            self.exp_type = "rerank"
        elif config.extra_config.get("exp_type") == "multihop":
            self.exp_type = "multihop"
        elif config.extra_config.get("exp_type") == "dynamic":
            self.exp_type = "dynamic"
        else:
            self.exp_type = "normal"

        print("EXPERIMENT_TYPE:", self.exp_type)

    def setup_experiment(self, data_loader: UniversalBanditDataLoader):
        """Set up the experiment by loading data and initializing components."""
        print("=== Setting up Universal Bandit Experiment ===")

        # Set random seeds
        self._set_random_seeds()

        # Load data
        print("Loading data...")
        (
            self.query_embeddings,
            self.initial_embeddings,
            self.true_embeddings,
            self.query_correct_arms,
        ) = data_loader.load_data(
            embedding_model=self.config.embedding_model,
            true_embedding_model=self.config.true_embedding_model,
            add_noise=self.config.add_noise,
            noise_std=self.config.noise_std,
        )

        # Initialize current embeddings
        self.current_embeddings = self.initial_embeddings.clone()

        # Set all queries in parameter tracker for regret computation
        self.parameter_tracker.set_all_queries(
            self.query_embeddings, self.query_correct_arms
        )

        # Get dimensions
        num_arms, embedding_dim = self.current_embeddings.shape

        print(f"Data loaded:")
        print(f"  Queries: {len(self.query_embeddings)}")
        print(f"  Arms: {num_arms}")
        print(f"  Embedding dimension: {embedding_dim}")

        # Create optimizer
        self.optimizer = UniversalBanditOptimizer(
            num_arms=num_arms,
            embedding_dim=embedding_dim,
            device=self.device,
            learning_rate=self.config.learning_rate,
            beta1=self.config.beta1,
            beta2=self.config.beta2,
            epsilon=self.config.optimizer_epsilon,
            norm_constraint=self.config.norm_constraint,
        )

        # Create bandit algorithm
        if self.exp_type == "normal":
            self.bandit = UniversalContextualBandit(
                num_arms=num_arms,
                embedding_dim=embedding_dim,
                device=self.device,
                optimizer=self.optimizer,
                temperature=self.config.temperature,
                epsilon=self.config.epsilon,
                lambda_reg=self.config.lambda_reg,
                clip_value=self.config.clip_value,
            )
        elif self.exp_type == "rerank":
            self.bandit = RerankingContextualBandit(
                num_arms=num_arms,
                embedding_dim=embedding_dim,
                device=self.device,
                optimizer=self.optimizer,
                temperature=self.config.temperature,
                epsilon=self.config.epsilon,
                lambda_reg=self.config.lambda_reg,
                clip_value=self.config.clip_value,
                reranker=LLMReranker(
                    data_loader=data_loader, device=self.device
                ),
                retrieval_k=10,
            )
        elif self.exp_type == "dynamic":
            self.bandit = DynamicArmContextualBandit(
                num_arms=num_arms,
                embedding_dim=embedding_dim,
                device=self.device,
                optimizer=self.optimizer,
                temperature=self.config.temperature,
                epsilon=self.config.epsilon,
                lambda_reg=self.config.lambda_reg,
                clip_value=self.config.clip_value,
                masked_arm_indices=np.random.choice(
                    num_arms,
                    int(
                        num_arms
                        * self.config.extra_config.get("masked_arm_ratio", 0.5)
                    ),
                    replace=False,
                ),
                unmask_progress_point=self.config.extra_config.get(
                    "unmask_progress_point", 0.5
                ),
            )
        elif self.exp_type == "multihop":
            self.bandit = MultiHopRAGBandit(
                num_arms=num_arms,
                embedding_dim=embedding_dim,
                device=self.device,
                optimizer=self.optimizer,
                temperature=self.config.temperature,
                epsilon=self.config.epsilon,
                lambda_reg=self.config.lambda_reg,
                clip_value=self.config.clip_value,
                judge_llm=JudgeLLM(data_loader=data_loader),
                retrieval_k=5,
                data_loader=data_loader,
            )

        print("Experiment setup complete!")

    def _set_random_seeds(self):
        """Set random seeds for reproducibility."""
        random.seed(self.config.seed)
        np.random.seed(self.config.seed)
        torch.manual_seed(self.config.seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed(self.config.seed)
            torch.cuda.manual_seed_all(self.config.seed)

    def _get_current_learning_rate(self, step: int, total_steps: int) -> float:
        """Get current learning rate based on scheduler.

        Supported:
        - constant: fixed learning rate
        - exp: inverse-sqrt decay (stabilizes early steps)
        - warmup_cosine: linear warmup then cosine annealing to min_lr_scale * base_lr
        """
        base_lr = self.config.learning_rate
        sched = (self.config.scheduler_type or "constant").lower()

        if sched == "constant":
            return base_lr

        if sched == "exp":
            # Inverse-sqrt decay (common for bandit-like online updates)
            return base_lr / math.sqrt(max(step + 1, 1))

        if sched == "warmup_cosine":
            T = max(total_steps, 1)
            warm = self.config.warmup_steps
            if warm == 0 and self.config.warmup_ratio > 0:
                warm = int(self.config.warmup_ratio * T)
            warm = min(warm, T - 1) if T > 1 else 0
            min_lr = base_lr * max(self.config.min_lr_scale, 0.0)

            if step <= warm and warm > 0:
                # Linear warmup from 0 -> base_lr
                return base_lr * float(step) / float(max(warm, 1))

            # Cosine decay for the remaining steps
            progress = (step - warm) / float(max(T - warm, 1))
            cosine = 0.5 * (
                1.0 + math.cos(math.pi * min(max(progress, 0.0), 1.0))
            )
            return min_lr + (base_lr - min_lr) * cosine

        # Fallback
        return base_lr

    def run_training(self) -> Tuple[List, List, List, Dict[str, Any]]:
        """
        Run the training loop.

        Returns:
            Tuple of (rewards_history, regrets_history, recall_history, final_metrics)
        """
        print("=== Starting Training ===")

        n_queries = len(self.query_embeddings)
        total_batches = (
            n_queries * self.config.n_epochs + self.config.batch_size - 1
        ) // self.config.batch_size

        if self.config.total_steps == 0:
            self.config.total_steps = total_batches

        print(f"Training configuration:")
        print(f"  Epochs: {self.config.n_epochs}")
        print(f"  Batch size: {self.config.batch_size}")
        print(f"  Total batches: {total_batches}")
        print(f"  Learning rate: {self.config.learning_rate}")
        print(f"  Temperature: {self.config.temperature}")
        print(f"  Norm constraint: {self.config.norm_constraint}")

        # debug: save the initial embeddings to disk
        # torch.save(self.current_embeddings, "initial_embeddings.pt")

        step_count = 0

        for epoch in range(self.config.n_epochs):
            print(f"\n--- Epoch {epoch + 1}/{self.config.n_epochs} ---")

            # Shuffle query indices for each epoch
            query_indices = torch.randperm(n_queries)

            # Process queries in batches
            for batch_start in tqdm(
                range(0, n_queries, self.config.batch_size),
                desc=f"Epoch {epoch + 1}",
                leave=False,
            ):
                batch_end = min(batch_start + self.config.batch_size, n_queries)
                batch_indices = query_indices[batch_start:batch_end]

                # Get batch data
                batch_contexts = self.query_embeddings[batch_indices]
                batch_correct_arms = [
                    self.query_correct_arms[i] for i in batch_indices
                ]

                # Get current learning rate
                current_lr = self._get_current_learning_rate(
                    step_count, total_batches
                )

                # Track parameters if needed
                if self.parameter_tracker.should_track(step_count):
                    self.parameter_tracker.track(
                        step_count,
                        self.current_embeddings,
                    )

                # Train on batch
                (
                    self.current_embeddings,
                    rewards,
                    regrets,
                    _,
                    grad_stats,
                    ips_stats,
                ) = self.bandit.train_batch(
                    contexts=batch_contexts,
                    correct_arms_batch=batch_correct_arms,
                    arm_embeddings=self.current_embeddings,
                    true_embeddings=self.true_embeddings,
                    initial_embeddings=self.initial_embeddings,
                    step_count=step_count,
                    current_lr=current_lr,
                    batch_query_indices=batch_indices,
                    total_steps=total_batches,
                )

                # Log batch metrics
                self.tracker.log_batch(
                    step_count,
                    rewards,
                    regrets,
                    grad_stats,
                    ips_stats,
                    current_lr,
                )

                # Evaluate if needed
                if step_count % self.config.eval_interval == 0:
                    eval_metrics = self.bandit.evaluate_policy(
                        self.query_embeddings,
                        self.current_embeddings,
                        self.query_correct_arms,
                        # k=self.config.recall_k,
                    )

                    self.tracker.log_evaluation(
                        step_count,
                        eval_metrics,
                    )

                    print(f"Step {step_count}: ")
                    pprint(eval_metrics)

                step_count += 1

        # Final evaluation
        final_eval = self.bandit.evaluate_policy(
            self.query_embeddings,
            self.current_embeddings,
            self.query_correct_arms,
            # k=self.config.recall_k,
        )

        # debug: save the final embeddings to disk
        # torch.save(self.current_embeddings, "final_embeddings.pt")

        self.tracker.log_evaluation(
            step_count,
            final_eval,
        )

        # Compute regrets using final parameters
        parameter_regrets = self.parameter_tracker.compute_regrets(
            self.current_embeddings
        )

        # Get final metrics
        final_metrics = self.tracker.get_final_metrics()
        final_metrics.update(
            {
                "total_queries_processed": step_count * self.config.batch_size,
                **final_eval,
            }
        )

        # Save histories + metrics to JSON for later plotting
        try:
            cfg_dict = asdict(self.config)
        except Exception:
            cfg_dict = None
        output_path = (
            self.config.metrics_output_path
            or "runs/universal_bandit_metrics.json"
        )
        written = self.tracker.save_json(
            output_path,
            final_metrics=final_metrics,
            parameter_regrets=parameter_regrets,
            config_dict=cfg_dict,
        )
        print(f"Saved training metrics to: {written}")

        print(f"\n=== Training Complete ===")
        print("Final metrics:")
        pprint(final_metrics)

        return (
            self.tracker.rewards_history,
            self.tracker.regrets_history,
            self.tracker.recall_history,
            final_metrics,
        )


# Convenience function for running experiments
def run_universal_bandit_experiment(
    config: UniversalBanditConfig,
    data_loader: UniversalBanditDataLoader,
) -> Tuple[List, List, List, Dict[str, Any]]:
    """
    Run a complete universal bandit experiment.

    Args:
        config: Experiment configuration
        data_loader: Data loader for the experiment

    Returns:
        Tuple of (rewards_history, regrets_history, recall_history, final_metrics)
    """
    experiment = UniversalBanditExperiment(config)
    experiment.setup_experiment(data_loader)
    return experiment.run_training()
