from dataclasses import dataclass
from typing import Tuple

import numpy as np

from wmcal.utils.functions import top_k_ohe

from ....utils import get_rng
from .. import register_dataset
from . import SyntheticDataset, SyntheticDatasetConfig


@dataclass
class SimpleSyntheticDatasetConfig(SyntheticDatasetConfig):
    input_dim: int
    poly_degree: int
    output_dim: int
    top_k: int = 1


class SimpleSyntheticDataset(SyntheticDataset):
    def __init__(self, config: SimpleSyntheticDatasetConfig):
        super().__init__(config)
        self.config = config
        self.test_size = config.test_size
        self.predictor_size = config.predictor_size
        self.calibrator_size = config.calibrator_size
        self.input_dim = config.input_dim
        self.poly_degree = config.poly_degree
        self.output_dim = config.output_dim
        self.top_k = config.top_k

        self.params = self._sample_params()
        self._cached_test_data = None
        self._cached_predictor_data = None
        self._cached_calibrator_data = None

    def scale(self) -> float:
        return float(self.top_k)

    def _sample_params(self) -> dict:
        rng = get_rng()
        W0 = rng.standard_normal((self.input_dim, self.output_dim))
        b0 = rng.standard_normal(self.output_dim)
        poly_W = [
            rng.standard_normal((self.input_dim, self.output_dim))
            for _ in range(max(0, self.poly_degree - 1))
        ]
        return {"W0": W0, "b0": b0, "poly_W": poly_W}

    def _generate_polynomial_samples(
        self, n_samples: int
    ) -> Tuple[np.ndarray, np.ndarray]:
        rng = get_rng()
        X = rng.standard_normal((n_samples, self.input_dim))

        W = self.params["W0"]
        b = self.params["b0"]
        e = rng.standard_normal((n_samples, self.output_dim)) * 0.1
        y = X @ W + b + e

        for idx, degree in enumerate(range(2, self.poly_degree + 1)):
            PW = self.params["poly_W"][idx]
            y += ((X**degree) @ PW) / degree

        return X, y

    def _generate_samples(self, n_samples: int) -> Tuple[np.ndarray, np.ndarray]:
        X, y = self._generate_polynomial_samples(n_samples)

        if not hasattr(self, "_cached_y_mean"):
            # Monte Carlo to compute mean and std
            _, mc_y = self._generate_polynomial_samples(1000)
            self._cached_y_mean = mc_y.mean()
            self._cached_y_std = mc_y.std()

        y = (y - self._cached_y_mean) / self._cached_y_std * 4
        y = 1 / (1 + np.exp(-y))

        return X, y

    def synth(self, n_samples: int) -> Tuple[np.ndarray, np.ndarray]:
        """Synthesize new samples on-the-fly."""
        return self._generate_samples(n_samples)

    def load_test(self) -> Tuple[np.ndarray, np.ndarray]:
        if self._cached_test_data is None:
            assert self.test_size is not None
            self._cached_test_data = self._generate_samples(self.test_size)
        return self._cached_test_data

    def load_predictor(self) -> Tuple[np.ndarray, np.ndarray]:
        if self._cached_predictor_data is None:
            assert self.predictor_size is not None
            self._cached_predictor_data = self._generate_samples(self.predictor_size)
        return self._cached_predictor_data

    def load_calibrator(self) -> Tuple[np.ndarray, np.ndarray]:
        if self._cached_calibrator_data is None:
            assert self.calibrator_size is not None
            self._cached_calibrator_data = self._generate_samples(self.calibrator_size)
        return self._cached_calibrator_data

    def decision_function(self, y_pred: np.ndarray) -> np.ndarray:
        return top_k_ohe(y_pred, self.top_k)


register_dataset("synthetic", SimpleSyntheticDatasetConfig, SimpleSyntheticDataset)
