from model import *
import numpy as np
import torch.optim as optim
import torch
import sklearn.metrics as metrics

cuda = 1
if cuda > 0:
    device = 'cuda:' + str(cuda)
else:
    device = 'cpu'

def to_onehot(data, maxval=11):
    onehot = torch.eye(maxval)[data.long()]
    return onehot

def get_data(set_size, query_sizes, elem_dim=2, onehot=True):
    #data = np.random.randn(set_size, 2)
    data = np.random.choice(np.arange(10), (set_size, elem_dim), replace=True)
    labels = []
    queries = []
    for query_size in query_sizes:
        if np.random.rand() < 0.5:
            labels.append(1.0)
            idxs = np.random.choice(range(set_size), query_size, replace=False)
            queries.append(data[idxs,:])
        else:
            labels.append(0.0)
            #num_matches = np.random.choice(query_size - 1)
            #idxs = np.random.choice(set_size, num_matches)
            #fill = np.random.randn(query_size - num_matches, 2)
            #matches = data[idxs, :]
            #queries.append(np.concatenate((matches, fill), axis=0))

            # continuous
            #queries.append(np.random.randn(query_size, 2))
            # discrete
            queries.append(np.random.choice(np.arange(query_size), (query_size, elem_dim), replace=True))

    queries = np.stack(queries)
    data = torch.from_numpy(data.astype(np.float32))
    queries = torch.from_numpy(queries.astype(np.float32))
    labels = torch.tensor(labels)

    data = data.unsqueeze(0)
    if onehot:
        queries = torch.flatten(to_onehot(queries), start_dim=2)
        data = torch.flatten(to_onehot(data), start_dim=2)
    return data.to(device), queries.to(device), labels.to(device)

def train():
    set_size = 20
    elem_dim = 2
    onehot = True
    maxval = 11

    if onehot:
        input_dim = elem_dim * maxval
    else:
        input_dim = elem_dim

    encode_model = EncodingModel(input_dim=input_dim).to(device)
    compare_model = ComparisonModel().to(device)
    optimizer = optim.Adam(list(encode_model.parameters()) + list(compare_model.parameters()), lr=1e-3)
    num_queries = 10

    epoch_size = 10000
    for epoch in range(epoch_size):
        preds = []
        labels = []
        for it in range(epoch_size):
            encode_model.zero_grad()
            compare_model.zero_grad()
            if it % 20 == 0:
                data, queries, label = get_data(set_size, [10]*num_queries, elem_dim)
            labels.extend(label.cpu().numpy())
            data_embed = encode_model(data)
            query_embed = encode_model(queries)

            pred, loss = compare_model(data_embed, query_embed, label)
            preds.extend(np.round(pred.cpu().detach().numpy()))
            total_loss = loss
            total_loss.backward()
            optimizer.step()

        #acc = metrics.accuracy_score(label.cpu().detach().numpy(),
        #        np.round(pred.cpu().detach().numpy()))
        acc = metrics.accuracy_score(labels, preds)
        print("Epoch %d, accuracy %f" % (epoch, acc))

if __name__ == "__main__":
    train()
