import logging
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import torch
from pykeen.losses import MarginRankingLoss
from pykeen.sampling import BernoulliNegativeSampler
from torch.nn import BCELoss, BCEWithLogitsLoss, CrossEntropyLoss, NLLLoss
from torch.utils.data import DataLoader
from tqdm import tqdm

from kge.dataset import OGBDataset, SplitDataset
from kge.eval import KGEvaluator, OGBEvaluator
from kge.loggers import MetricsLogger
from kge.losses import BCELossProtocol, CELossProtocol, MarginLossProtocol
from kge.models import KGModel
from kge.sampling import NegativeSamplerProtocol


@dataclass(frozen=True)
class EngineConfig:
    """Configuration for training."""

    # Training settings
    batch_size: int = 128
    num_epochs: int = 100
    learning_rate: float = 0.01
    eval_every: int = 1
    valid_sample_size: int = 1000
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    early_stopping_patience: int = 5
    checkpoint_dir: Path | None = None
    log_train_likelihood: bool = False

    # Model and loss settings
    loss_type: Literal["margin", "bce", "bce_logits", "ce"] = "ce"
    loss_margin: float = 1.0  # Only for margin loss
    neg_samples: int = 10  # Only for margin loss
    regularization_lambda: float = 1e-3

    # Evaluation settings
    test_matrix_rank_at_the_end: bool = False
    test_on_valid: bool = False
    validation_metric: Literal["mr", "mrr"] = "mr"


class Engine:
    """Training engine for KG embedding models."""

    def __init__(
        self,
        config: EngineConfig,
        dataset: SplitDataset | OGBDataset,
        model: KGModel,
        loggers: list[MetricsLogger],
    ):
        self.config = config
        self.dataset = dataset
        self.evaluator = self._init_evaluator()
        self.model = model
        self.loss_fn = self._init_loss_fn()
        self.negative_sampler = self._init_negative_sampler()
        self.loggers = loggers

    def _log_metrics(self, metrics: dict, step: int, prefix: str = "") -> None:
        for logger in self.loggers:
            logger.log_metrics(metrics, step, prefix)

    def _save_best_model(self, val_metric: float, epoch: int) -> None:
        self.best_val_metric = val_metric
        self.best_model_state = self.model.state_dict()
        self._log_metrics({f"best_valid_{self.config.validation_metric}": val_metric}, epoch)
        if self.config.checkpoint_dir:
            checkpoint_path = self.config.checkpoint_dir / f"checkpoint_epoch_{epoch}.pt"
            self.model.save_checkpoint(checkpoint_path)
            msg = f"Saved checkpoint to {checkpoint_path}"
            logging.info(msg)

    def _load_best_model(self) -> None:
        if self.best_model_state is not None:
            self.model.load_state_dict(self.best_model_state)
            logging.info("Loaded best model for final evaluation")
        else:
            msg = "No best model found"
            logging.error(msg)
            raise ValueError(msg)

    def _init_evaluator(self) -> KGEvaluator | OGBEvaluator:
        if isinstance(self.dataset, OGBDataset):
            evaluator = OGBEvaluator(self.dataset.dataset_name)
            evaluator.initialize_filters([self.dataset.train])
            return evaluator
        evaluator = KGEvaluator()
        evaluator.initialize_filters([self.dataset.train, self.dataset.valid, self.dataset.test])
        return evaluator

    def _init_loss_fn(self) -> MarginLossProtocol | BCELossProtocol | CELossProtocol:
        """Initialize the loss function based on config."""
        if self.config.loss_type in "margin":
            return MarginRankingLoss(margin=self.config.loss_margin)
        if self.config.loss_type == "bce":
            return BCELoss()
        if self.config.loss_type == "bce_logits":
            return BCEWithLogitsLoss()
        if self.config.loss_type == "ce":
            if self.model.return_log_prob:
                return NLLLoss()
            return CrossEntropyLoss()
        msg = f"Unknown loss type: {self.config.loss_type}"
        raise ValueError(msg)

    def _init_negative_sampler(self) -> NegativeSamplerProtocol:
        return BernoulliNegativeSampler(
            num_entities=self.dataset.num_entities,
            mapped_triples=self.dataset.train.triples,
        )

    def train(
        self,
    ):
        """Train the model on the given dataset.

        Args:
            dataset: Dataset containing train/valid/test splits
            checkpoint_path: Path to save model checkpoints

        """
        optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=self.config.learning_rate,
        )
        train_loader = DataLoader(
            self.dataset.train,
            batch_size=self.config.batch_size,
            shuffle=True,
        )
        best_valid_metric = float("inf") if self.config.validation_metric == "mr" else float("-inf")
        patience_counter = 0
        num_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        self._log_metrics({"num_parameters": num_parameters}, 0)
        for epoch in range(self.config.num_epochs):
            self.model.train()
            total_loss = 0.0
            for step, (s, r, o) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}")):
                s = s.to(self.config.device)
                r = r.to(self.config.device)
                o = o.to(self.config.device)
                optimizer.zero_grad()

                # Time loss computation and backprop after warmup
                if epoch == -1 and step == 10:
                    if self.config.device.startswith("cuda"):
                        torch.cuda.reset_peak_memory_stats(self.config.device)
                    else:
                        logging.warning("Memory tracking is only supported for CUDA devices")

                    start_time = time.perf_counter()

                # Compute loss
                if isinstance(self.loss_fn, MarginRankingLoss):
                    pos_scores = self.model.score_sro(s, r, o)
                    neg_o = self.negative_sampler.sample(
                        (s, r, o),
                        num_samples=self.config.neg_samples,
                    )
                    neg_scores = self.model.score_sro(s, r, neg_o)
                    loss = self.loss_fn(pos_scores, neg_scores)
                elif isinstance(self.loss_fn, (CrossEntropyLoss, NLLLoss)):
                    scores = self.model.score_o(s, r)
                    loss = self.loss_fn(scores, o)
                else:
                    raise NotImplementedError
                loss += self.model.regularization_term() * self.config.regularization_lambda

                # Time the backward pass
                if epoch == -1 and step == 10:
                    elapsed_time = (time.perf_counter() - start_time) * 1000  # Convert to ms

                    metrics = {"batch_time_ms": elapsed_time}
                    if self.config.device.startswith("cuda"):
                        peak_memory = (
                            torch.cuda.max_memory_allocated(self.config.device) / 1024**2
                        )  # Convert to MB
                        metrics["peak_memory_mb"] = peak_memory
                        logging.info(
                            f"Batch time: {elapsed_time:.2f}ms, Peak memory: {peak_memory:.2f}MB",
                        )
                    else:
                        logging.info(f"Batch time: {elapsed_time:.2f}ms")

                    self._log_metrics(metrics, epoch)

                loss.backward()
                optimizer.step()
                total_loss += loss.item()
            avg_loss = total_loss / len(train_loader)
            self._log_metrics({"loss": avg_loss}, epoch)

            # Evaluation
            if (epoch + 1) % self.config.eval_every == 0:
                self.model.eval()
                valid_rank_metrics = self.evaluator.evaluate_object_ranks(
                    self.model,
                    self.dataset.valid,
                    batch_size=self.config.batch_size,
                    device=self.config.device,
                    sample_size=self.config.valid_sample_size,
                )

                # Get current validation metric
                current_metric = (
                    valid_rank_metrics.mr
                    if self.config.validation_metric == "mr"
                    else valid_rank_metrics.mrr
                )

                # Update best model check based on metric type
                is_better = (
                    current_metric < best_valid_metric
                    if self.config.validation_metric == "mr"
                    else current_metric > best_valid_metric
                )

                if is_better:
                    best_valid_metric = current_metric
                    self._save_best_model(current_metric, epoch)
                    patience_counter = 0
                else:
                    patience_counter += 1

                self._log_metrics(valid_rank_metrics.to_dict(), epoch, prefix="valid")
                valid_likelihood = self.evaluator.evaluate_object_nll(
                    self.model,
                    self.dataset.valid,
                    batch_size=self.config.batch_size,
                    device=self.config.device,
                    sample_size=self.config.valid_sample_size,
                    filtered=False,
                )
                self._log_metrics({"valid_likelihood": valid_likelihood}, epoch, prefix="valid")
                if self.config.log_train_likelihood:
                    train_likelihood = self.evaluator.evaluate_object_nll(
                        self.model,
                        self.dataset.train,
                        batch_size=self.config.batch_size,
                        device=self.config.device,
                        sample_size=self.config.valid_sample_size,
                    )
                    self._log_metrics({"train_likelihood": train_likelihood}, epoch, prefix="train")
                if patience_counter >= self.config.early_stopping_patience:
                    logging.info("Early stopping triggered")
                    self._log_metrics({"early_stopping": True}, epoch)
                    break
        self._load_best_model()
        test_rank_metrics = self.evaluator.evaluate_object_ranks(
            self.model,
            (self.dataset.valid if self.config.test_on_valid else self.dataset.test),
            batch_size=self.config.batch_size,
            device=self.config.device,
            sample_size=-1,
        )
        self._log_metrics(test_rank_metrics.to_dict(), self.config.num_epochs, prefix="test")
        test_likelihood = self.evaluator.evaluate_object_nll(
            self.model,
            (self.dataset.valid if self.config.test_on_valid else self.dataset.test),
            batch_size=self.config.batch_size,
            device=self.config.device,
            sample_size=-1,
            filtered=True,
        )
        self._log_metrics(
            {"test_likelihood": test_likelihood},
            self.config.num_epochs,
            prefix="test",
        )
        if self.config.test_matrix_rank_at_the_end:
            if isinstance(self.evaluator, OGBEvaluator):
                logging.warning(
                    "OGBEvaluator does not support score matrix rank evaluation. Skipping...",
                )
            else:
                rank = self.evaluator.evaluate_score_matrix_rank(
                    self.model,
                    (self.dataset.valid if self.config.test_on_valid else self.dataset.test),
                    batch_size=self.config.batch_size,
                    device=self.config.device,
                    log_prob=True,
                )
                self._log_metrics({"rank": rank}, self.config.num_epochs, prefix="test")
