import numpy as np
import numpy.typing as npt
from functools import partial
import scipy.linalg as linalg
from typing import Callable
import utils
from frequent_directions import DyadicBlockSketching
from line_profiler import profile


class DBSL:
    def __init__(
        self,
        d: int,
        sketch_size: int,
        eps: float,
        beta: float,
        lmd: float,
        robust: bool = True,
    ) -> None:
        self.time = 0
        self.d = d
        self.beta = beta
        self.lmd = lmd
        self.hat_theta = np.zeros((d, 1))
        # self.V_inv = (1 / lmd) * np.eye(d)
        self.X = np.zeros((0, d))
        self.xy = np.zeros((1, d))
        self.rewards = np.zeros((0))

        self.sketch_size = sketch_size
        self.robust = robust
        self.sketch = DyadicBlockSketching(sketch_size, d, eps, lmd, robust)

        self.deltas = np.array([self.lmd])
        self.full_rank = False

    @profile
    def fit(
        self, decision_set: npt.NDArray, observe: Callable[[int, npt.ArrayLike], float]
    ) -> float:
        """
        decision_set: (num_actions, d)
        """
        self.time += 1

        if decision_set.ndim != 2:
            raise ValueError("Array dimension must be 2")

        # print(decision_set.shape, self.hat_theta.shape)
        if not self.full_rank:
            S, H = self.sketch.get()
        expected_rewards = (decision_set @ self.hat_theta).reshape(
            -1, 1
        ) + self.beta * (
            np.apply_along_axis(
                partial(utils.matrix_induced_norm, A=self.V_inv), 1, decision_set
            )
            if self.full_rank
            else np.apply_along_axis(
                lambda x: np.sqrt(
                    np.sum(np.square(x)) - (x @ S.T).T @ H @ (x @ S.T)
                    if len(S) != 0
                    else 0
                )
                / np.sqrt(self.lmd),
                1,
                decision_set,
            )
        ).reshape(
            -1, 1
        )
        # (1/self.lmd) * ()np.apply_along_axis(
        #     partial(utils.matrix_induced_norm, A=self.V_inv), 1, decision_set
        # )

        # print(expected_rewards.shape)
        ind = np.argmax(expected_rewards)
        play = decision_set[ind]

        # observe reward
        reward = observe(ind, play)

        ## SOFUL: compute S_t, H_t using Alg. 4 given S_t-1, x_t
        if not self.full_rank:
            self.sketch.fit(play)
            S, H = self.sketch.get()
            # self.V_inv = (np.eye(self.d) - S.T @ H @ S) / self.lmd
            self.hat_theta = self.hat_theta
            # print(S.shape)
            if len(S) != 0:
                self.hat_theta -= S.T @ (H @ (S @ self.xy.T))

            if len(S) >= self.d:
                self.full_rank = True
                self.V_inv = (np.eye(self.d) - S.T @ H @ S) / self.lmd
        # self.deltas.append(delta)
        # H = 1 / (sigma_squared + self.lmd)
        else:
            self.V_inv = utils.woodbury(self.V_inv, play, play)
            self.hat_theta = self.V_inv @ (self.rewards @ self.X)

        # compute the V_inv and theta
        self.X = np.row_stack([self.X, play])
        # print(self.rewards.shape, reward)
        self.rewards = np.append(self.rewards, reward)
        self.xy += reward * play

        return reward
