import numpy as np
from utils import range_without, softmax_func  # type: ignore


def synthetic_IIA(N_participants, N_questions, std_v, max_options=4):
    rng = np.random.default_rng()
    v = np.random.normal(loc=0, scale=std_v, size=(N_questions, max_options))
    p_full = softmax_func(v)
    res_full = rng.multinomial(N_participants, pvals=p_full)
    res_rems = []
    for i in range(max_options):
        p_rem = softmax_func(v[:, range_without(max_options, i)])
        res_rems.append(rng.multinomial(N_participants, pvals=p_rem))
    return res_full, res_rems


def synthetic_additive_context(
    N_participants, N_questions, std_v, std_c, max_options=4
):
    rng = np.random.default_rng()
    v = np.random.normal(loc=0, scale=std_v, size=(N_questions, max_options))
    p_full = softmax_func(v)
    res_full = rng.multinomial(N_participants, pvals=p_full)
    res_rems = []
    for i in range(max_options):
        p_rem = softmax_func(
            v[:, range_without(max_options, i)]
            + np.random.normal(
                loc=0, scale=std_c, size=(N_questions, max_options - 1)
            )
        )
        res_rems.append(rng.multinomial(N_participants, pvals=p_rem))
    return res_full, res_rems


def synthetic_simple_context(
    N_participants, N_questions, std_v, std_c, max_options=4
):
    rng = np.random.default_rng()
    v = np.random.normal(loc=0, scale=std_v, size=(N_questions, 4))
    p_full = softmax_func(v)
    res_full = rng.multinomial(N_participants, pvals=p_full)
    res_rems = []
    for i in range(max_options):
        p_rem = softmax_func(
            v[:, range_without(max_options, i)]
            * np.random.normal(loc=1, scale=std_c, size=(N_questions, 1))
        )
        res_rems.append(rng.multinomial(N_participants, pvals=p_rem))
    return res_full, res_rems


def synthetic_noise_context(
    N_participants, N_questions, std_v, alpha, max_options=4
):
    rng = np.random.default_rng()
    v = np.random.normal(loc=0, scale=std_v, size=(N_questions, 4))
    p_full = softmax_func(v)
    res_full = rng.multinomial(N_participants, pvals=p_full)
    res_rems = []
    p_uni = np.ones(max_options - 1) / (max_options - 1)
    for i in range(max_options):
        N_partial = np.random.binomial(N_participants, p=alpha)
        p_rem = softmax_func(v[:, range_without(max_options, i)])
        res = rng.multinomial(N_partial, pvals=p_rem) + rng.multinomial(
            N_participants - N_partial, pvals=p_uni
        )
        res_rems.append(res)
    return res_full, res_rems
