from typing import Any, Sequence, Protocol, Literal, Type
from dataclasses import dataclass

from tqdm import tqdm  # type: ignore
from icecream import ic  # type: ignore
import numpy as np  # type: ignore
import pandas as pd  # type: ignore
import matplotlib.pyplot as plt  # type: ignore
from sklearn.dummy import DummyRegressor  # type: ignore
from sklearn.linear_model import LinearRegression, LogisticRegression, PoissonRegressor, QuantileRegressor  # type: ignore
from sklearn.preprocessing import LabelEncoder  # type: ignore
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier, GradientBoostingRegressor, GradientBoostingClassifier  # type: ignore
from sklearn.neural_network import MLPClassifier  # type: ignore
from ucimlrepo import fetch_ucirepo  # type: ignore
from sklearn.metrics import roc_auc_score  # type: ignore
from scipy.optimize import root_scalar  # type: ignore


#
class Regressor(Protocol):
    def fit(self, data_x, data_y) -> None:
        pass

    def predict(self, data_x):
        pass

    def predict_uncertainty(self, data_x):
        pass


class LinearRegressor(Regressor):
    def fit(self, data_x, data_y) -> None:
        self.model = LinearRegression().fit(data_x, data_y)
        squared_residuals = (self.model.predict(data_x) - data_y) ** 2
        self.residual_model = LinearRegression().fit(data_x, squared_residuals)

        predicted_squared_residuals = self.residual_model.predict(data_x)
        pi = np.sqrt(
            np.where(predicted_squared_residuals >= 0, predicted_squared_residuals, 0)
        )
        self.normalizing_constant = np.mean(pi) / 0.2

    def predict(self, data_x):
        return self.model.predict(data_x)

    def predict_uncertainty(self, data_x):
        predicted_squared_residuals = self.residual_model.predict(data_x)
        pi = np.sqrt(
            np.where(predicted_squared_residuals >= 0, predicted_squared_residuals, 0)
        )
        if self.normalizing_constant == 0:
            return pi
        else:
            return pi / self.normalizing_constant


class BatchLinearRegressor(Regressor):
    def fit(self, data_x, data_y) -> None:
        data_x = data_x.reshape((-1, data_x.shape[-1]))
        data_y = data_y.reshape((-1, data_y.shape[-1]))
        self.model = LinearRegression().fit(data_x, data_y)

    def predict(self, data_x):
        return self.model.predict(data_x.reshape((-1, data_x.shape[-1]))).reshape(
            data_x.shape[:-1] + (-1,)
        )

    def predict_uncertainty(self, data_x):
        return NotImplemented


class BatchRandomForest(Regressor):
    def fit(self, data_x, data_y) -> None:
        data_x = data_x.reshape((-1, data_x.shape[-1]))
        data_y = data_y.reshape((-1, data_y.shape[-1]))
        self.model = RandomForestRegressor().fit(data_x, data_y)

    def predict(self, data_x):
        return self.model.predict(data_x.reshape((-1, data_x.shape[-1]))).reshape(
            data_x.shape[:-1] + (-1,)
        )

    def predict_uncertainty(self, data_x):
        return NotImplemented


class RandomForest(Regressor):
    def fit(self, data_x, data_y) -> None:
        self.model = RandomForestRegressor(random_state=0).fit(data_x, data_y)

    def predict(self, data_x):
        return self.model.predict(data_x)

    def predict_uncertainty(self, data_x):
        return NotImplemented


class Classifier(Protocol):
    def fit(self, data_x, data_y) -> None:
        pass

    def predict_proba(self, data_x):
        pass


class LogisticRegressor(Classifier):
    def fit(self, data_x, data_y) -> None:
        self.model = LogisticRegression(random_state=0).fit(data_x, data_y)

    def predict_proba(self, data_x):
        return self.model.predict_proba(data_x)


class RandomForestC(Classifier):
    def fit(self, data_x, data_y) -> None:
        self.model = RandomForestClassifier(random_state=0).fit(data_x, data_y)

    def predict_proba(self, data_x):
        return self.model.predict_proba(data_x)


class GradientBoostingC(Classifier):
    def fit(self, data_x, data_y) -> None:
        self.model = GradientBoostingClassifier(random_state=0).fit(data_x, data_y)

    def predict_proba(self, data_x):
        return self.model.predict_proba(data_x)


class MLP_C(Classifier):
    def fit(self, data_x, data_y) -> None:
        self.model = MLPClassifier(random_state=0).fit(data_x, data_y)

    def predict_proba(self, data_x):
        return self.model.predict_proba(data_x)


class SingleEValues(Protocol):
    def __call__(self, sample: Any) -> float:  # should be immutable
        pass

    def step(self, sample: Any) -> None:
        pass

    def bounds(self) -> tuple[float, float]:  # should be immutable
        pass


class MeanBettor(Protocol):
    def step(self, sample: Any) -> None:
        pass

    def __call__(self, candidate_mean: float) -> float:
        pass


@dataclass
class TrimToPi:
    pi_lower_bound: float


class StandardAGrapaBettor(MeanBettor):
    def __init__(self):
        self.rolling_sum = 0
        self.rolling_sum_sq = 0
        self.n = 0

    def step(self, sample: Any) -> None:
        assert 0 <= sample <= 1

        self.n += 1
        self.rolling_sum += sample
        self.rolling_sum_sq += sample**2

    def __call__(self, candidate_mean: float) -> float:
        empirical_mean_plusfake = (0.5 + self.rolling_sum) / (self.n + 1)
        assert 0 <= empirical_mean_plusfake <= 1
        # sum_(i=1)^n (x_i - mu)^2 = sum_(i=1)^n (x_i^2 - 2 x_i mu + mu^2)
        # = sum_(i=1)^n x_i^2 - 2 mu sum_(i=1)^n x_i + sum_(i=1)^n mu^2
        # = sum_(i=1)^n x_i^2 - 2 mu sum_(i=1)^n x_i + n mu^2
        empirical_var_plusfake = (
            0.25
            + self.rolling_sum_sq
            - 2 * empirical_mean_plusfake * self.rolling_sum
            + self.n * empirical_mean_plusfake**2
        ) / (self.n + 1)

        return (empirical_mean_plusfake - candidate_mean) / (
            empirical_var_plusfake + (empirical_mean_plusfake - candidate_mean) ** 2
        )


class StandardEwmaAGrapaBettor(MeanBettor):
    def __init__(self, momentum: float):
        self.rolling_mean = 0.5
        self.rolling_var = 0.25
        assert 0 < momentum < 1
        self.momentum = momentum
        self.dilation = 0

    def step(self, info: Any) -> None:
        sample, dilation = info

        assert 0 <= sample <= 1

        new_rolling_mean = (
            self.momentum * sample + (1 - self.momentum) * self.rolling_mean
        )
        self.rolling_var = (
            self.momentum * (sample - self.rolling_mean) * (sample - new_rolling_mean)
            + (1 - self.momentum) * self.rolling_var
        )
        self.rolling_mean = new_rolling_mean

        self.dilation = dilation

    def __call__(self, candidate_mean: float) -> float:
        assert 0 <= self.rolling_mean <= 1

        return (self.rolling_mean - candidate_mean - self.dilation) / (
            self.rolling_var + (self.rolling_mean - candidate_mean - self.dilation) ** 2
        )


class MeanSingleEvalues(SingleEValues):
    def __init__(
        self,
        *,
        candidate_mean: float,
        bettor: MeanBettor,
        c: float = 1.0,
        trimmer: TrimToPi | None = None,
    ):
        self.candidate_mean = candidate_mean
        assert 0 <= self.candidate_mean <= 1

        self.bettor = bettor

        if trimmer is None:
            assert 0 < c <= 1
            self.c = c
        else:
            assert 0 < trimmer.pi_lower_bound <= 1
            # we want $c.pi_lower_bound >= 1 - b/a$, where
            #   a = 1 - c
            #   b = 1 + c * max(self.candidate_mean/(1 - self.candidate_mean), (1 - self.candidate_mean)/self.candidate_mean)
            #
            # Let foo = max((1-self.candidate_mean)/self.candidate_mean, 0).
            # Then:
            #   1 - b/a = 1 - (1 - c)/(1 + c * foo) = pi_lower_bound
            #   <=> (1 - c)/(1 + c * foo) = 1 - pi_lower_bound
            #   <=> 1 - c = (1 - pi_lower_bound)*(1 + c * foo)
            #   <=> 1 - c = (1 - pi_lower_bound) + c * (1 - pi_lower_bound) * foo
            #   <=> 1 = (1 - pi_lower_bound) + c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> 1 - (1 - pi_lower_bound) = c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> 1 - 1 + pi_lower_bound = c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> pi_lower_bound = c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> c = pi_lower_bound/((1 - pi_lower_bound) * foo + 1)
            foo = max(
                (1 - self.candidate_mean) / self.candidate_mean,
                self.candidate_mean / (1 - self.candidate_mean),
            )
            trim_c = trimmer.pi_lower_bound / ((1 - trimmer.pi_lower_bound) * foo + 1)
            self.c = trim_c * c

    def __call__(self, sample: Any) -> float:
        assert 0 <= sample <= 1

        bet = np.clip(
            self.bettor(candidate_mean=self.candidate_mean),
            -self.c / (1 - self.candidate_mean),
            self.c / self.candidate_mean,
        )
        assert -1 / (1 - self.candidate_mean) <= bet <= 1 / self.candidate_mean

        e = 1 + bet * (sample - self.candidate_mean)
        assert e >= 0

        return e

    def step(self, sample: Any) -> None:
        self.bettor.step(sample)

    def bounds(self) -> tuple[float, float]:
        lambda_inf = -self.c / (1 - self.candidate_mean)
        lambda_sup = self.c / self.candidate_mean
        return (
            1
            + min(
                -lambda_sup * self.candidate_mean,
                lambda_inf * (1 - self.candidate_mean),
            ),
            # = 1 + min(-self.c, -self.c) = 1 - self.c
            1
            + max(
                lambda_sup * (1 - self.candidate_mean),
                -lambda_inf * self.candidate_mean,
            ),
            # = 1 + max(self.c/self.candidate_mean * (1 - self.candidate_mean), self.c/(1 - self.candidate_mean) * self.candidate_mean),
            # = 1 + max(self.c * (1 - self.candidate_mean)/self.candidate_mean, self.c * self.candidate_mean/(1 - self.candidate_mean)),
            # = 1 + self.c * max((1 - self.candidate_mean)/self.candidate_mean, self.candidate_mean/(1 - self.candidate_mean)),
        )


class OneSidedMeanSingleEvalues(SingleEValues):
    def __init__(
        self,
        *,
        candidate_mean: float,
        bettor: MeanBettor,
        c: float = 0.5,
        trimmer: TrimToPi | None = None,
    ):
        self.candidate_mean = candidate_mean
        assert 0 <= self.candidate_mean <= 1

        self.bettor = bettor

        if trimmer is None:
            assert 0 < c <= 1
            self.c = c
        else:
            assert 0 < trimmer.pi_lower_bound <= 1
            # we want $c.pi_lower_bound >= 1 - b/a$, where
            #   a = 1 - c
            #   b = 1 + c * max(self.candidate_mean/(1 - self.candidate_mean), (1 - self.candidate_mean)/self.candidate_mean)
            #
            # Let foo = max((1-self.candidate_mean)/self.candidate_mean, 0).
            # Then:
            #   1 - b/a = 1 - (1 - c)/(1 + c * foo) = pi_lower_bound
            #   <=> (1 - c)/(1 + c * foo) = 1 - pi_lower_bound
            #   <=> 1 - c = (1 - pi_lower_bound)*(1 + c * foo)
            #   <=> 1 - c = (1 - pi_lower_bound) + c * (1 - pi_lower_bound) * foo
            #   <=> 1 = (1 - pi_lower_bound) + c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> 1 - (1 - pi_lower_bound) = c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> 1 - 1 + pi_lower_bound = c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> pi_lower_bound = c * ((1 - pi_lower_bound) * foo + 1)
            #   <=> c = pi_lower_bound/((1 - pi_lower_bound) * foo + 1)
            foo = max((1 - self.candidate_mean) / self.candidate_mean, 0)
            trim_c = trimmer.pi_lower_bound / ((1 - trimmer.pi_lower_bound) * foo + 1)
            self.c = trim_c * c

    def __call__(self, sample: Any) -> float:
        assert 0 <= sample <= 1

        bet = np.clip(
            self.bettor(candidate_mean=self.candidate_mean),
            0,
            self.c / self.candidate_mean,  # XXX should be at most one maybe?
        )
        assert 0 <= bet <= 1 / self.candidate_mean

        # Null: E[sample] <= self.candidate_mean
        e = 1 + bet * (sample - self.candidate_mean)
        assert e >= 0

        return e

    def dilation_scaler(self) -> float:
        bet = np.clip(
            self.bettor(candidate_mean=self.candidate_mean),
            0,
            self.c / self.candidate_mean,
        )
        assert 0 <= bet <= 1 / self.candidate_mean
        return bet

    def step(self, info: Any) -> None:
        self.bettor.step(info)

    def bounds(self) -> tuple[float, float]:
        lambda_inf = 0
        lambda_sup = self.c / self.candidate_mean
        return (
            1
            + min(
                -lambda_sup * self.candidate_mean,
                lambda_inf * (1 - self.candidate_mean),
            ),
            # = 1 + min(-self.c/self.candidate_mean * self.candidate_mean, 0)
            # = 1 + min(-self.c, 0) = 1 - self.c
            1
            + max(
                lambda_sup * (1 - self.candidate_mean),
                -lambda_inf * self.candidate_mean,
            ),
            # = 1 + max(self.c/self.candidate_mean * (1 - self.candidate_mean), -0 * self.candidate_mean),
            # = 1 + max(self.c * (1 - self.candidate_mean)/self.candidate_mean, 0),
            # = 1 + self.c * max((1 - self.candidate_mean)/self.candidate_mean, 0),
        )


class Rescale(SingleEValues):
    def __init__(
        self,
        inner: SingleEValues,
        *,
        factor: float | TrimToPi,
        up_until: int | None = None,
    ):
        self.inner = inner
        self.desired_factor = factor
        self.up_until = up_until

        self.n = 0

        self._solve_for_factor()

    def _solve_for_factor(self) -> None:
        if self.up_until is not None and self.n >= self.up_until:
            self.factor = 1.0
            return

        if isinstance(self.desired_factor, float) or isinstance(
            self.desired_factor, int
        ):
            self.factor = self.desired_factor
        elif isinstance(self.desired_factor, TrimToPi):
            # inf e = inf (factor * (e - 1) + 1)
            # = factor * (inf e - 1) + 1
            # = factor * (a - 1) + 1
            #
            # sup e = sup (factor * (e - 1) + 1)
            # = factor * (sup e - 1) + 1
            # = factor * (b - 1) + 1
            #
            # So
            #     1 - (inf e)/(sup e) = 1 - (factor * (a - 1) + 1)/(factor * (b - 1) + 1) = pi
            # <=> (factor * (a - 1) + 1)/(factor * (b - 1) + 1) = 1 - pi
            # <=> factor * (a - 1) + 1 = (1 - pi) * (factor * (b - 1) + 1)
            # <=> factor * (a - 1) + 1 = (factor * (b - 1) + 1) - pi * (factor * (b - 1) + 1)
            # <=> factor * (a - 1) + 1 = factor * (b - 1) + 1 - pi * factor * (b - 1) - pi
            # <=> factor * (a - 1) = factor * (b - 1) - pi * factor * (b - 1) - pi
            # <=> factor * (a - 1) - factor * (b - 1) + pi * factor * (b - 1) = - pi
            # <=> factor * ((a - 1) - (b - 1) + pi * (b - 1)) = - pi
            # <=> factor * (a - 1 - b + 1 + pi * (b - 1)) = - pi
            # <=> factor * (a - b + pi * (b - 1)) = - pi
            # <=> factor = - pi / (a - b + pi * (b - 1))
            # <=> factor = pi / (b - a - pi * (b - 1))
            #   and for better numerical precision:
            # <=> factor = pi / (b - a - pi * b + pi)
            # <=> factor = pi / ((1 - pi) * b - a + pi)

            a, b = self.inner.bounds()
            self.factor = self.desired_factor.pi_lower_bound / (
                (1 - self.desired_factor.pi_lower_bound) * b
                - a
                + self.desired_factor.pi_lower_bound
            )
            assert self.factor >= 0
            a_, b_ = self.bounds()
        else:
            raise TypeError()

    def _rescale(self, e: float) -> float:
        assert e >= 0, f"{e = }"
        return self.factor * (e - 1) + 1

    def __call__(self, sample: Any) -> float:  # should be immutable
        return self._rescale(self.inner(sample))

    def step(self, sample: Any) -> None:
        self.inner.step(sample)
        self.n += 1

        self._solve_for_factor()

    def bounds(self) -> tuple[float, float]:  # should be immutable
        a, b = self.inner.bounds()
        return self._rescale(a), self._rescale(b)


def rollout_eprocess(evalue: SingleEValues, data: Sequence[Any]) -> np.ndarray:
    BIG_M = 1e64

    capital = 1.0
    out = np.ones(len(data) + 1)
    for i, sample in enumerate(data):
        capital *= evalue(sample)
        evalue.step(sample)
        # if capital >= BIG_M:
        #     capital = BIG_M
        out[i + 1] = capital
    return out


def rollout_eprocess_cppi(
    evalue: SingleEValues,
    data_y: np.ndarray,
    data_x: np.ndarray,
    regressor: Regressor,
    model,
    conformal_err: float,
    labelled_samples: np.ndarray,
) -> np.ndarray:
    BIG_M = 1e64

    assert len(data_y) == data_x.shape[0]
    n = len(data_y)

    threshold = 0.5

    # debug
    lower_samples = []
    set_sizes = []

    predicted_prob1s = model.predict_proba(data_x)[:, 1]

    n_updates = 0

    capital = 1.0
    out = np.ones(n + 1)
    rng = np.random.default_rng(0)
    for i in tqdm(range(n)):
        inf, sup = evalue.bounds()
        M = sup - inf

        # step the process
        if threshold is not None:
            # s(x, y) = 1 - p(y | x)
            proba1 = predicted_prob1s[i]
            contains0 = 1 - (1 - proba1) <= threshold  # s(x, 0) <= threshold
            contains1 = 1 - proba1 <= threshold  # s(x, 1) <= threshold
            empty = (~contains0) & (~contains1)
            lower_sample = np.where(
                empty, np.around(proba1), np.where(contains0, 0.0, 1.0)
            )
            # lower_sample = np.around(proba1)
            # lower_sample = data_y[i]
            e_fake = evalue(lower_sample)  # inf e(C(X))
            assert e_fake > 0

            set_sizes.append(np.where(contains0 & contains1, 2, 1))
            lower_samples.append(lower_sample)

            this_evalue = e_fake - evalue.dilation_scaler() * M * conformal_err
            assert this_evalue > 0.0
            capital *= this_evalue
            evalue.step((lower_sample, M * conformal_err))

        # if capital >= BIG_M:
        #     capital = BIG_M
        out[i + 1] = capital

        # every now and then, then update the conformal calibration
        if i in labelled_samples:
            # s(x, y) = 1 - p(y | x)
            probas1 = predicted_prob1s[i]
            contains0 = 1 - (1 - proba1) <= threshold  # s(x, 0) <= threshold
            contains1 = 1 - proba1 <= threshold  # s(x, 1) <= threshold
            y_contained = np.where(data_y[i] == 1, contains1, contains0)

            # update threshold
            epsilon = 0.3
            step_size = 1.0 * (i + 1) ** (-0.5 - epsilon)
            threshold = threshold + step_size * ((1 - y_contained) - conformal_err)

            ic(threshold)

            n_updates += 1

    print(f"N UPDATES: {n_updates}")
    ic(labelled_samples.min())

    return out
