import logging
from dataclasses import dataclass
from math import ceil
from typing import Iterator, NamedTuple, Optional, Tuple, override

import numpy as np
from tqdm import tqdm, trange

from ..data import Dataset
from ..data.datasets.synthetic import SyntheticDataset
from ..predictors import Predictor
from ..utils import (
    get_experiment_config,
    get_logger,
    get_rng,
    xover,
)
from ..utils.grid_utils import create_grid, create_grid_sampled
from . import Calibrator, CalibratorConfig, register_calibrator

logger = get_logger(__name__)


class DeltaWeight(NamedTuple):
    weight: np.ndarray
    sign: int
    source: str  # "check1" or "check2"


@dataclass
class GridBoostCalibratorConfig(CalibratorConfig):
    output_dim: int
    # Optimization parameters
    lr: float | None
    tol: float
    # Grid parameters
    grid_iter_size: int
    grid_resolution: float
    grid_size: int | None
    # Training parameters
    batch_size: int
    max_iter: int
    early_stop: bool
    use_fresh_samples: bool


class GridBoostCalibrator(Calibrator):
    def __init__(
        self,
        predictor: Predictor,
        config: GridBoostCalibratorConfig,
        dataset,
    ) -> None:
        super().__init__(predictor, config, dataset)
        self.predictor = predictor
        self.lr = config.lr if config.lr is not None else config.tol / 2
        self.tol = config.tol
        self.output_dim = config.output_dim
        self.grid_resolution = config.grid_resolution
        self.grid_iter_size = config.grid_iter_size  # Gb
        self.batch_size = config.batch_size
        self.max_iter = config.max_iter
        self.early_stop = config.early_stop
        self.use_fresh_samples = config.use_fresh_samples

        top_k = getattr(self.dataset, "top_k", None)
        top_k = top_k if top_k == 1 else None
        experiment_config = get_experiment_config()
        seed = experiment_config.seed if experiment_config is not None else None
        if config.grid_size is not None:
            self.grid = create_grid_sampled(
                self.output_dim,
                self.grid_resolution,
                config.grid_size,
                top_k,
                seed=seed,
            )
        else:
            self.grid = create_grid(self.output_dim, self.grid_resolution, top_k)

        self._init_state()

    def _init_state(self):
        self.t = 0
        self.delta_weights = []

    def _grid_iter(self, y_pred: np.ndarray) -> Iterator[Tuple[int, int, np.ndarray]]:
        for g in range(0, self.grid.shape[0], self.grid_iter_size):
            weights = self.grid[g : g + self.grid_iter_size]  # (Gb, D)
            g_end = g + weights.shape[0]

            # (Gb, 1, D) * (1, B, D) -> (Gb, B, D)
            wy_pred = weights[:, np.newaxis, :] * y_pred[np.newaxis, :, :]
            y_ohe = self.dataset.decision_function(wy_pred)
            yield g, g_end, y_ohe

    def _step(
        self, y_pred: np.ndarray, y_base: np.ndarray, delta_weight: DeltaWeight
    ) -> np.ndarray:
        # check1 uses y_base for weighting, check2 uses y_pred
        y_for_weight = y_base if delta_weight.source == "check1" else y_pred
        # (B, D) * (1, D) -> (B, D)
        wy_pred = y_for_weight * delta_weight.weight[np.newaxis, :]
        delta = self.dataset.decision_function(wy_pred) * delta_weight.sign
        return np.clip(y_pred + self.lr * delta, 0.0, 1.0)

    @override
    def predict(self, X: np.ndarray, y_base: Optional[np.ndarray] = None) -> np.ndarray:
        if y_base is None:
            y_base = self.predictor.predict(X)

        y_pred = y_base.copy()
        for delta_weight in self.delta_weights:
            y_pred = self._step(y_pred, y_base, delta_weight)

        return y_pred

    def _check1(self, y: np.ndarray, y_base: np.ndarray, y_pred: np.ndarray) -> DeltaWeight | None:
        diff = y - y_pred  # (b, d)

        for g_start, g_end, y_ohe in self._grid_iter(y_base):
            # y_ohe: (g, b, d)
            errs = (diff[np.newaxis, :] * y_ohe).sum(axis=-1).mean(axis=-1) / self.dataset.scale()
            crossing_idx = xover(np.abs(errs), self.tol)

            if crossing_idx == -1:
                continue

            err = errs[crossing_idx]
            sign = int(np.sign(err))
            g_star = g_start + crossing_idx  # The violating weight index

            logger.debug(f"{self.t}: Weight {g_star} error: {err:.4f}")
            return DeltaWeight(weight=self.grid[g_star], sign=sign, source="check1")

        return None

    def _check2(self, y: np.ndarray, y_pred: np.ndarray) -> DeltaWeight | None:
        diff = y - y_pred
        ohe = self.dataset.decision_function(y_pred)

        err = (diff * ohe).mean()
        sign = int(np.sign(err))

        if np.abs(err) < self.tol:
            return None
        logger.debug(f"{self.t}: Equal Weight error: {err:.4f}")

        return DeltaWeight(weight=np.ones(y.shape[1]), sign=sign, source="check2")

    def _test_pre(self) -> None:
        config = get_experiment_config()
        if config is None:
            return

        X, y_true = self.dataset.load_test()
        y_base = self.predictor.predict(X)

        # Pre-calibration MSE
        mse = float(np.mean((y_true - y_base) ** 2))
        logger.log_metric("mse", mse, t=self.t)

        max_mvs = (y_true * self.dataset.decision_function(y_true)).sum(axis=1).mean()
        self._max_mvs = max_mvs  # Store for post-calibration logging
        logger.log_metric("mvs", max_mvs, type="oracle")

        best_mvs, best_g = -1, -1
        use_tqdm = not logger.isEnabledFor(logging.DEBUG)

        grid_iter = self._grid_iter(y_base)
        if use_tqdm:
            grid_iter = tqdm(
                grid_iter,
                desc=" " * 17 + "Evaluating grid",
                total=ceil(len(self.grid) / self.grid_iter_size),
            )

        for g_start, g_end, y_ohe in grid_iter:
            # y_ohe: (Gb, B, D)
            for g in range(g_start, g_end):
                # Compute MVS
                mvs = float((y_true * y_ohe[g - g_start]).sum() / y_true.shape[0])
                logger.log_metric("mvs", mvs, type="grid", g=g)

                if mvs >= best_mvs:
                    best_mvs, best_g = mvs, g

        logger.info(
            f"Best g={best_g}: {best_mvs:.3f} -> {best_mvs / max_mvs:.3%} of oracle ({max_mvs:.3f})"
        )

        # Evaluate EW
        y_ohe = self.dataset.decision_function(y_base)
        mvs = float((y_true * y_ohe).sum() / y_true.shape[0])
        logger.log_metric("mvs", mvs, type="ew", t=self.t)
        logger.info(f"Pre-calibration EW: {mvs:.3f} -> {mvs / max_mvs:.3%} of oracle")

    @override
    def fit(self, dataset: Dataset) -> None:
        """Fit GridBoostCalibrator to dataset.

        If use_fresh_samples=True, generates new synthetic samples at each iteration.
        Otherwise, samples from calibration dataset.
        """

        config = get_experiment_config()
        rng = get_rng()
        self._test_pre()

        self._init_state()
        tab = " " * 17
        n_updates = 0

        use_tqdm = not logger.isEnabledFor(logging.DEBUG)
        if use_tqdm:
            pbar = trange(1, self.max_iter + 1, desc=tab + "Training GridBoostCalibrator")
        else:
            pbar = range(1, self.max_iter + 1)

        X_test, y_test = self.dataset.load_test()
        y_test_base = self.predictor.predict(X_test)
        y_test_pred = self.predict(X_test, y_test_base)

        for self.t in pbar:
            if self.use_fresh_samples:
                assert isinstance(self.dataset, SyntheticDataset), (
                    "use_fresh_samples=True requires SyntheticDataset with synth() method"
                )
                X_batch, y_batch = self.dataset.synth(self.batch_size)
            else:
                X_cal, y_cal = self.dataset.load_calibrator()
                sample_idx = rng.choice(X_cal.shape[0], size=self.batch_size)
                X_batch, y_batch = X_cal[sample_idx], y_cal[sample_idx]

            y_base = self.predictor.predict(X_batch)
            y_pred = self.predict(X_batch, y_base)

            dw = self._check2(y_batch, y_pred)
            dw = self._check1(y_batch, y_base, y_pred) if dw is None else dw

            if dw is None:
                if self.early_stop:
                    break
            else:
                self.delta_weights.append(dw)
                n_updates += 1

            if config is not None:
                # Compute and log metrics on test data
                if dw is not None:
                    y_test_pred = self._step(y_test_pred, y_test_base, dw)
                elif self.t != self.max_iter:
                    # Skip test data evaluation if not the last iteration
                    continue

                y_test_ohe = self.dataset.decision_function(y_test_pred)

                # Compute MVS
                mvs = float((y_test * y_test_ohe).sum() / y_test.shape[0])
                logger.log_metric("mvs", mvs, type="ew", t=self.t)
                mvs_pct = mvs / self._max_mvs * 100

                if self.t == self.max_iter:
                    logger.info(
                        f"Post-calibration EW: {mvs:.3f} -> {mvs / self._max_mvs:.3%} of oracle"
                    )

                # Compute MSE
                mse = float(np.mean((y_test - y_test_pred) ** 2))
                logger.log_metric("mse", mse, t=self.t)
                if use_tqdm:
                    pbar.set_postfix(updates=n_updates, mse=f"{mse:.3f}", mvs=f"{mvs_pct:.1f}%")


register_calibrator("grid_boost", GridBoostCalibratorConfig, GridBoostCalibrator)
