import torch
import numpy as np
from src.q20game import Q20Env
from src.jester_loader import get_z


def get_param(game_mode):
    IT_THRESH = 20000
    N_it = 500000
    N_budget = 20
    if game_mode == "1d":
        N_dim = 25
        Z = torch.from_numpy(np.triu(np.ones((N_dim, N_dim + 1)), 1)).float().cuda()
        LAMBDA_ENTROPY = .2
        LAMBDA_ENTROPY_BINARY = .3
        LAMBDA_SAMPLE_ENTROPY = .05
        M = 1000
        L = 10
        N_GEN = 1000
        R_lim = 7.5
        N_it_overall = 300000
        N_pulls = 1
        eval_range = [3, 3.5, 4, 4.5, 5, 5.5, 6, 6.5, 7][::-1]
    elif game_mode == "1d100":
        N_dim = 25
        Z = torch.from_numpy(np.triu(np.ones((N_dim, N_dim + 1)), 1)).float().cuda()
        LAMBDA_ENTROPY = 0
        LAMBDA_ENTROPY_BINARY = 0.2
        LAMBDA_SAMPLE_ENTROPY = 0.05
        M = 1000
        L = 10
        N_GEN = 1000
        R_lim = 7.5
        N_it_overall = 100000
        N_pulls = 100
    elif game_mode == "q20":
        N_dim = 100
        QEnv = Q20Env()
        Z = torch.from_numpy(QEnv.get_probs()[1].T).float().cuda()  # N_dim * N_arms
        LAMBDA_ENTROPY = .8
        LAMBDA_ENTROPY_BINARY = .8
        LAMBDA_SAMPLE_ENTROPY = .1
        M = 500
        L = 30
        N_GEN = 300
        R_lim = 30
        IT_THRESH = 50000
        N_it_overall = 200000
        N_pulls = 1
        eval_range = [3.5, 4, 4.5, 5]
    elif game_mode == "jester":
        N_dim = 100
        Z = get_z().t()  # N_dim * N_arms
        LAMBDA_ENTROPY = .8
        LAMBDA_ENTROPY_BINARY = .8
        LAMBDA_SAMPLE_ENTROPY = .05
        M = 500
        L = 30
        N_GEN = 2000
        R_lim = 30
        N_it_overall = 200000
        N_pulls = 1
        eval_range = [3, 4, 5, 6, 7]
    return IT_THRESH, N_it, N_dim, N_budget, Z, LAMBDA_ENTROPY, LAMBDA_ENTROPY_BINARY, LAMBDA_SAMPLE_ENTROPY, M, L, \
           N_GEN, R_lim, N_it_overall, N_pulls, eval_range
