from dataclasses import dataclass

import numpy as np
from sklearn.ensemble import RandomForestRegressor
from sklearn.utils import check_random_state
import torch


def learn_q_model(train_data: dict, test_data: dict, random_state: int = 12345):

    x_tr, x_te, a, e_a, r = train_data["x"], test_data["x"], train_data["a"], train_data["e_a"], train_data["r"]
    model = RandomForestRegressor(
        n_estimators=10, max_depth=10, min_samples_leaf=10, random_state=random_state
    )
    model.fit(np.c_[x_tr, e_a[a]], r, sample_weight=(1. / train_data["pscore"]))

    q_hat_train = np.zeros(train_data["q_x_a"])
    q_hat_test = np.zeros(test_data["q_x_a"])
    n_tr, num_actions = q_hat_train.shape
    n_te = q_hat_test.shape[0]
    for a_ in range(num_actions):
        q_hat_train[:, a_] += model.predict(np.c_[x_tr, e_a[np.ones(n_tr, dtype=int) * a_]])
        q_hat_test[:, a_] += model.predict(np.c_[x_te, e_a[np.ones(n_te, dtype=int) * a_]])

    return q_hat_train, q_hat_test


def sample_action_fast(pi: np.ndarray, random_state: int = 12345) -> np.ndarray:
    random_ = check_random_state(random_state)
    uniform_rvs = random_.uniform(size=pi.shape[0])[:, np.newaxis]
    cum_pi = pi.cumsum(axis=1)
    flg = cum_pi > uniform_rvs
    sampled_actions = flg.argmax(axis=1)
    return sampled_actions


def sigmoid(x: np.ndarray) -> np.ndarray:
    return np.exp(np.minimum(x, 0)) / (1.0 + np.exp(-np.abs(x)))


def softmax(x: np.ndarray) -> np.ndarray:
    b = np.max(x, axis=1)[:, np.newaxis]
    numerator = np.exp(x - b)
    denominator = np.sum(numerator, axis=1)[:, np.newaxis]
    return numerator / denominator


@dataclass
class RegBasedPolicyDataset(torch.utils.data.Dataset):
    context: np.ndarray
    action: np.ndarray
    reward: np.ndarray

    def __post_init__(self):
        """initialize class"""
        assert self.context.shape[0] == self.action.shape[0] == self.reward.shape[0]

    def __getitem__(self, index):
        return (
            self.context[index],
            self.action[index],
            self.reward[index],
        )

    def __len__(self):
        return self.context.shape[0]


@dataclass
class GradientBasedPolicyDataset(torch.utils.data.Dataset):
    context: np.ndarray
    action: np.ndarray
    reward: np.ndarray
    pscore: np.ndarray
    q_hat: np.ndarray
    pi_0: np.ndarray

    def __post_init__(self):
        """initialize class"""
        assert (
            self.context.shape[0]
            == self.action.shape[0]
            == self.reward.shape[0]
            == self.pscore.shape[0]
            == self.q_hat.shape[0]
            == self.pi_0.shape[0]
        )

    def __getitem__(self, index):
        return (
            self.context[index],
            self.action[index],
            self.reward[index],
            self.pscore[index],
            self.q_hat[index],
            self.pi_0[index],
        )

    def __len__(self):
        return self.context.shape[0]
