from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Iterable, List, Sequence, Optional, Union

import numpy as np
import sympy as sp


class IndependentArmsMixin:

    indep_arm_indices: Optional[Sequence[int]]
    indep_arms: Optional[Sequence[np.ndarray]]

    def compute_independent_arms(
        self, arms: Union[Sequence[Sequence[float]], np.ndarray]
    ) -> Sequence[np.ndarray]:
        matrix = np.asarray(arms, dtype=float)
        if matrix.ndim != 2:
            raise ValueError("Expected a 2-D array of arms.")
        _, indices = sp.Matrix(matrix).T.rref()
        self.indep_arm_indices = tuple(indices)
        self.indep_arms = tuple(matrix[i] for i in indices)
        return self.indep_arms

    def is_same_arm_set(
        self, arms: Union[Sequence[Sequence[float]], np.ndarray], rtol: float = 1e-4
    ) -> bool:
        if self.indep_arms is None:
            return False
        matrix = np.asarray(arms, dtype=float)
        if matrix.shape[0] == 0:
            return False
        return np.allclose(self.indep_arms[0], matrix[0], rtol=rtol)


class BanditAlgorithm(IndependentArmsMixin, ABC):

    def __init__(self, num_actions: int, horizon: int):
        self.num_actions = int(num_actions)
        self.T = int(horizon)

        self.t = 0
        self.is_reset = False

        self.all_arms: List[np.ndarray] = []
        self.rewards: List[float] = []
        self.ChangePoints: List[int] = []
        self.indep_arm_indices: Optional[Sequence[int]] = None
        self.indep_arms: Optional[Sequence[np.ndarray]] = None
        self.N_e: Optional[int] = None

    @abstractmethod
    def select_arm(self, arms, *args, **kwargs):
        pass
    @abstractmethod
    def update_statistics(self, arm: int, reward: float, *args, **kwargs):
        pass

    def update(self, arm: int, reward: float, *args, **kwargs) -> None:
        self.update_statistics(arm, reward, *args, **kwargs)
        if self.is_reset:
            self.is_reset = False
        else:
            self.t += 1

    def get_indep_arms(self, arms: Optional[Union[Sequence[Sequence[float]], np.ndarray]] = None):
        if arms is None:
            if len(self.all_arms) == 0:
                raise ValueError("No arms available; supply them explicitly.")
            arms_array = self.all_arms
        else:
            arms_array = np.asarray(arms, dtype=float)
            self.all_arms = arms_array
        indep = self.compute_independent_arms(arms_array)
        self.N_e = len(indep)
        return indep

    def re_init(self) -> None:
        self.all_arms = []
        self.rewards = []
        self.t = 0
        self.is_reset = True
        self.indep_arm_indices = None
        self.indep_arms = None
        self.N_e = None
