
import numpy as np
import matplotlib
matplotlib.use('TkAgg')
import matplotlib.pyplot as plt
import torch
import sys

import argparse

from scipy.stats import norm


'''
Plot training loss of fbNeq
'''

from results.fbNeq3_batch1e5_mse.trainLoss	import *
from results.fbNeq3_batch1e5_mse.validLoss	import *
from results.fbNeq3_batch1e4_mse.trainLoss	import *
from results.fbNeq3_batch1e4_mse.validLoss	import *
from results.fbNeq3_batch1e5_ce.trainLoss	import *
from results.fbNeq3_batch1e5_ce.validLoss	import *
from results.fbNeq3_batch1e4_ce.trainLoss	import *
from results.fbNeq3_batch1e4_ce.validLoss	import *

plt.figure(1)
plt.title("Valid BER comparison")
plt.plot(validLossBatch1e5Mse, '-', label='batch1e5Mse')
plt.plot(validLossBatch1e4Mse, '-', label='batch1e4Mse')
plt.plot(validLossBatch1e5Ce, '-', label='batch1e5Ce')
plt.plot(validLossBatch1e4Ce, '-', label='batch1e4Ce')
plt.legend(loc='best')
plt.yscale('log')
plt.grid(True)
plt.xlabel('EPOCH/10')
plt.ylabel('Valid Ber')

plt.show()

