import os
import sys
from args import fusion_args
from train import run_train
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

#BEST: epoch_51
def loss_curve(args, losses, loss_type = 'train'):
    x_vals = [i for i in range(len(losses))]
    plt.plot(x_vals, losses)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title(f'{loss_type} losses over epochs for fusion model')
    plt.savefig(os.path.join(args.save_path, args.exp_name, f'{loss_type}_loss.png'))

def main(args):
    if not os.path.exists(os.path.join(args.save_path, args.exp_name)):
        os.makedirs(os.path.join(args.save_path, args.exp_name))
    model, train_losses, val_losses, test_acc = run_train(args)
    loss_curve(args, train_losses)
    loss_curve(args, val_losses, loss_type = 'val')
    print(test_acc)

if __name__ == '__main__':
    args = fusion_args()
    main(args)