import numpy as np
from ChainMDP import ChainMDP
from sklearn.linear_model import LogisticRegression

num_chains = 4
length = 32
feature_dim = 16
datasize = 100

# datasizes = [10, 20, 50, 100, 200, 500, 1000, 2000, 5000, 10000]
# feature_dims = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20]
num_chainss = [2, 4, 8, 12, 16, 20, 24, 28, 32]

num_experiments = 100
# results = np.zeros((len(datasizes), num_experiments, 3))
# results = np.zeros((len(feature_dims), num_experiments, 3))
results = np.zeros((len(num_chainss), num_experiments, 3))


def random_query(mdp, size=20):
    dataset = []
    for n in range(size):
        i, j = np.random.randint(0, num_chains), np.random.randint(0, num_chains)
        while i == j:
            j = np.random.randint(0, num_chains)
        phii, phij, preference = mdp.query(i, j)

        dataset.append((phii, phij, preference))
    return dataset


def random_weighted_query(mdp, size=20):
    dataset = []
    for n in range(size):
        i, j = np.random.randint(0, num_chains), np.random.randint(0, num_chains)
        wi = np.clip(np.random.randn(length), -1, 1)
        wj = np.clip(np.random.randn(length), -1, 1)
        while i == j:
            j = np.random.randint(0, num_chains)
        phii, phij, preference = mdp.weighted_query(i, j, wi, wj)

        dataset.append((phii, phij, preference))
    return dataset


def regression(dataset, lamda=0.1):
    X = [phii - phij for phii, phij, p in dataset]
    y = np.array([p for phii, phij, p in dataset]).ravel()
    clf = LogisticRegression(random_state=0, C=1 / lamda)
    clf.fit(X, y)
    param = clf.coef_[0]
    # print(param)
    return param


for n in range(num_experiments):
    # for i, datasize in enumerate(datasizes):
    # for i, feature_dim in enumerate(feature_dims):
    for i, num_chains in enumerate(num_chainss):
        mdp = ChainMDP(length, num_chains, feature_dim)
        dataset = random_query(mdp, size=datasize)
        param = regression(dataset)
        dataset2 = random_weighted_query(mdp, size=datasize)
        param2 = regression(dataset2)
        results[i, n, 0] = mdp.evaluate(param)
        results[i, n, 1] = mdp.evaluate(param2)
        results[i, n, 2] = mdp.evaluate()
        # print(mdp.evaluate(param))
        # print(mdp.evaluate(param))
        # print(mdp.evaluate(param2))
        print(n, datasize)

# np.save("./results.npy", results)
# np.save("./results_num_chain.npy", results)
np.save("./results_feature_dim.npy", results)
