from math import sqrt
from typing import List, Optional, Tuple

import numpy as np
from scipy.sparse import csc_matrix, csr_matrix, identity
from scipy.sparse.linalg import spsolve

from llm_mcts.mcts_algo.linucb.data_types import NodeIdentifier
from llm_mcts.mcts_algo.linucb.node_indices import LinUCBNodeIndices


class LinUCBState:
    def __init__(self, model_names: List[str], alpha: float = 0.5):
        # Matrices and vectors initialization for LinUCB
        init_num_feats = len(model_names)
        self.A = identity(init_num_feats, dtype="int8").tocsr()
        self.b = csc_matrix((init_num_feats, 1))

        # Node indices for each LLM
        # indices should always be sorted in ascending order; This is an invariant of any method in this class
        self.node_indices = LinUCBNodeIndices(model_names)
        self.next_idx: int = 0

        # Hyperparameters
        self.alpha = alpha

    def tell_reward(
        self,
        reward: float,
        node_identifier: NodeIdentifier,
        node_indices: Optional[LinUCBNodeIndices] = None,
    ) -> None:
        """We reflect the reward information to A and b, and update the node indices"""
        if node_indices is not None:
            self.node_indices = node_indices

        feature_vector = self._get_onehot_feature_vector(node_identifier)

        self._expand_A_and_b()

        self.A += feature_vector @ feature_vector.transpose()
        self.b += reward * feature_vector

    def ask_next_idx(self) -> Tuple[LinUCBNodeIndices, NodeIdentifier]:
        """
        Main part of the algorithm; calculate the probability and retrieve the next idx.
        The creation of the new node and adding that to node indices is the responsibility of the caller of this function.
        """
        self._expand_A_and_b()

        # NOTE: spsolve returns A^{-1} b
        beta = spsolve(self.A, self.b)
        probs = []
        for node_idx in self.node_indices._get_all_idxs():
            z = self._get_onehot_feature_vector(node_idx)

            p = (z.transpose() @ beta)[0].item() + self.alpha * sqrt(
                (z.transpose() @ spsolve(self.A, z))[0].item()
            )
            probs.append(p)

        # Get argmax with random tie-breaking
        np_probs = np.array(probs)
        idx = np.random.choice(np.flatnonzero(np.isclose(np_probs, np_probs.max())))

        # in case gen node is chosen, we will create a new node
        if idx < self.node_indices.num_gen_nodes:
            identifier = self.node_indices.get_model_name(idx)
        else:
            identifier = int(idx - self.node_indices.num_gen_nodes)

        return self.node_indices, identifier

    def _get_onehot_feature_vector(self, node_identifier: NodeIdentifier) -> csr_matrix:
        """
        Get one-hot feature vector for a GEN node or an already generated node.
        """
        z = csr_matrix((self.node_indices.total_dim, 1))
        for idx in self.node_indices.get_one_hot_idxs(node_identifier):
            z[idx, 0] = 1
        return z

    def _expand_A_and_b(self) -> None:
        total_dim = self.node_indices.total_dim
        orig_size = self.A.shape[0]
        if total_dim > orig_size:
            self.A.resize((total_dim, total_dim))
            for idx in range(orig_size, total_dim):
                self.A[idx, idx] = 1

            self.b.resize((total_dim, 1))
