from pathlib import Path
import pickle

import numpy as np
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.utils import check_random_state

from utils import sample_action_fast, softmax, sigmoid


def process_kuairec_data(
    num_data: int,
    num_actions: int = None,
    num_def_actions: float = 0.0,
    beta: float = 0.0,  # pi_0 parameter
    tau: float = 1.0,
    sigma: float = 2.0,
    num_clusters: int = 30,
    random_state: int = 12345,
) -> dict:
    random_ = check_random_state(random_state)
    q_x_a_raw = np.load("./data/q_x_a.npy")
    if num_actions is None:
        num_users, num_actions = q_x_a_raw.shape
    else:
        num_users = q_x_a_raw.shape[0]
        q_x_a_raw = q_x_a_raw[:, :num_actions]
    num_users, num_actions = q_x_a_raw.shape
    x_u = np.load("./data/x_u.npy")
    e_a = np.load("./data/e_a.npy")
    phi_a = KMeans(n_clusters=num_clusters, random_state=random_state).fit_predict(e_a)

    unique_user_set = np.arange(num_users)
    u = random_.choice(unique_user_set, size=num_data)
    x = x_u[u]
    q_x_a = q_x_a_raw[u]

    # define the logging policy
    eta = random_.normal(size=q_x_a.shape)
    pi_0_logits = tau * (q_x_a + beta * eta)
    if num_def_actions > 0.0:
        pi_0 = np.zeros_like(pi_0_logits)
        num_supported_actions = np.int32((1. - num_def_actions) * num_actions)
        supported_actions = np.argsort(
            random_.gumbel(size=(num_data, num_actions)), axis=1
        )[:, ::-1][:, :num_supported_actions]
        supported_actions_idx = (
            np.tile(np.arange(num_data), (num_supported_actions, 1)).T,
            np.sort(supported_actions, axis=1),
        )
        pi_0[supported_actions_idx] = softmax(pi_0_logits[supported_actions_idx])
    else:
        pi_0 = softmax(pi_0_logits)
    pi_0_c = np.ones((num_data, num_clusters)) / num_clusters
    # for c_ in range(num_clusters):
    #     pi_0_c[:, c_] = pi_0[:, phi_a == c_].sum(1)

    # sampling data
    a = sample_action_fast(pi_0, random_state=random_state)
    q_x_a_factual = q_x_a[np.arange(num_data), a]
    r = random_.normal(q_x_a_factual, scale=sigma)

    return dict(
        num_data=num_data,
        num_actions=num_actions,
        u=u,
        x=x,
        a=a,
        e_a=e_a,
        phi_a=phi_a,
        r=r,
        pi_0=pi_0,
        pi_0_c=pi_0_c,
        pscore=np.clip(pi_0[np.arange(num_data), a], 1e-10, 1.0),
        pscore_c=pi_0_c[np.arange(num_data), phi_a[a]],
        q_x_a=q_x_a,
    )



def process_kuairec_data2(
    num_data: int,
    num_actions: int = None,
    beta: float = 0.0,  # pi_0 parameter
    tau: float = 3.0, # inverse temperature
    sigma: float = 2.0,
    frac_candidates: float = 0.05,
    is_test: bool = False,
    random_state: int = 12345,
) -> dict:
    random_ = check_random_state(random_state)
    q_x_a_raw = np.load("./data/q_x_a.npy")
    if num_actions is None:
        num_users, num_actions = q_x_a_raw.shape
    else:
        num_users = q_x_a_raw.shape[0]
        q_x_a_raw = q_x_a_raw[:, :num_actions]
    x_u = np.load("./data/x_u.npy")
    e_a = np.load("./data/e_a.npy")

    unique_user_set = np.arange(num_users)
    u = random_.choice(unique_user_set, size=num_data)
    x = x_u[u]
    q_x_a = q_x_a_raw[u]

    # define the logging policy
    pscore_s = sigmoid(q_x_a - q_x_a.mean(1)[:, np.newaxis])
    pscore_s = frac_candidates * pscore_s / pscore_s.mean(1)[:, np.newaxis]
    s = random_.binomial(1, p=pscore_s)
    eta = random_.normal(size=q_x_a.shape)
    pi_0 = (np.exp(tau * (q_x_a + beta * eta)) * s)
    pi_0 /= (np.exp(tau * (q_x_a + beta * eta)) * s).sum(1)[:, np.newaxis]

    if np.isnan(pi_0).sum() > 0:
        pi_0[np.isnan(pi_0)] = 1. / num_actions

    # sampling data
    a = sample_action_fast(pi_0, random_state=random_state)
    q_x_a_factual = q_x_a[np.arange(num_data), a]
    r = random_.normal(q_x_a_factual, scale=sigma)

    # estimate the pscore
    pi_0_hat = np.zeros_like(q_x_a)
    if is_test is False:
        pscore_model = LogisticRegression(C=10, solver="sag", random_state=random_state)
        pscore_model.fit(x, a)
        pi_0_hat[:, np.unique(a)] = pscore_model.predict_proba(x)

    return dict(
        num_data=num_data,
        num_actions=num_actions,
        u=u,
        x=x,
        s=s,
        a=a,
        e_a=e_a,
        r=r,
        pi_0=pi_0,
        pi_0_hat=pi_0_hat,
        pscore_s=pscore_s,
        pscore=pi_0_hat[np.arange(num_data), a],
        q_x_a=q_x_a,
        q_x_a_factual=q_x_a_factual,
    )
