# This file tests tasks/classification_head.py.
# It simply trains using Adam and makes a plot of the result.

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import time

import context
from tasks import classification_head

# Task setup
task = classification_head.ClassificationHeadTask()

opt = tf.keras.optimizers.Adam(lr=0.0003, beta_1=0.99, beta_2=0.9998)

x = task.get_initialization()
x = tf.Variable(x)

# Trajectory recording and training loop breaking logic
time_limit_seconds = 120
start_time = time.time()
def break_condition():
    return time.time() - start_time > time_limit_seconds
loss_curve = []

task.regularization = 0.008**2
# Training loop
while True:
    loss, grad = task.loss_and_grad_fn(x)
    opt.apply_gradients(zip([grad], [x]))

    loss_curve.append(float(loss))
    print(float(loss))

    if break_condition():
        break

# Trajectory postprocessing
loss_curve = np.array(loss_curve)

# Validation loss evaluation testing
print("Validation loss:", float(task.evaluate_validation_loss(x)))
print("Validation accuracy:", float(task.evaluate_validation_accuracy(x)))
print("Test loss:", float(task.evaluate_test_loss(x)))
print("Test accuracy:", float(task.evaluate_test_accuracy(x)))
