import numpy as np
import paddle
from paddle_quantum.dataset import *

from QNN_model import train_model


if __name__ == '__main__':
    block = 1
    depth = 1
    num_qubit = 4
    encoding = 'angle_encoding'
 
    # encoding Iris dataset to quantum state
    iris = Iris(encoding=encoding, num_qubits=num_qubit, classes=[0, 1], return_state=True)
    
    data_X, data_y = iris.feature, iris.target.astype("float64")

    data_size = len(data_y)
    index = list(range(data_size))
    np.random.seed(0)
    np.random.shuffle(index)

    training_size = int(data_size*0.8)
    train_index = index[:training_size]
    test_index = index[training_size:]

    train_X = data_X[train_index]
    train_y = data_y[train_index]
    train_X = paddle.to_tensor(train_X)
    train_y = paddle.to_tensor(train_y)

    test_X = data_X[test_index]
    test_y = data_y[test_index]
    test_X = paddle.to_tensor(test_X)
    test_y = paddle.to_tensor(test_y)

    print(train_X.shape)
    print(train_y.shape)
    print(test_X.shape)
    print(test_y.shape)

    num_samples = 10
    for seed in range(num_samples):
        train_loss, test_accuracy = train_model(train_X, train_y, test_X, test_y, seed=seed, N=num_qubit, DEPTH=depth, BLOCK=block, EPOCH=10, BATCH_SIZE=40, LR=0.1)
        np.save('extension/classification/output/iris/train_loss_%s_%s_%s_%s.npy'%(num_qubit, block, depth, seed), train_loss)
        np.save('extension/classification/output/iris/test_accuracy_%s_%s_%s_%s.npy'%(num_qubit, block, depth, seed), test_accuracy)
        print('----------------------------------------------------------------------------------sample: %s/%s  finished'%(seed+1, num_samples))