import matplotlib.pyplot as plt
import json
from matplotlib import rcParams
rcParams.update({'figure.autolayout': True})
import numpy as np
import pickle
import sys
import os

with open('./WanD_output.log', 'rb') as f:
    lines=f.readline() 
    item= [i for i in lines.decode().strip().split(",")]

    p=item[0]
    kT=item[1]
    ts=item[2]
    fc=item[3]
    ei=item[4]
    bias=item[5]    
 
    next(f)
    epoch=[]
    x_loss=[]
    y_loss=[]
    x_acc=[]
    y_acc=[]
    weight_norm=[]
    while True:
        lines=f.readline()  
        if not lines:
            break
        item= [i for i in lines.decode().strip().split(",")]
        if item[1]=="-inf":
            item[1]='-50'  
        step=json.loads(item[0])
        data0=json.loads(item[1])
        data1=json.loads(item[2])
        data2=json.loads(item[3])
        data3=json.loads(item[4])
        data4=json.loads(item[5])
        epoch.append(step)
        x_loss.append(data0)
        y_loss.append(data1)
        x_acc.append(data2)
        y_acc.append(data3)
        weight_norm.append(data4)

#plt.scatter(x_WL,y_WL)
plt.figure(figsize=(8,4))

scale='linear'
#scale='log'

 
plt.suptitle(f"1L Transformer on Modular Addition (p={p}),\n WanD, kT={kT}, time step={ts}, frictionCoefficient={fc}, \n forbiddenRegionEntropyIncrease={ei}, maxTrainLossToStudy={bias}")
  
plt.subplot(1,2,1)
plt.plot(epoch,x_acc,color='red',label='train')
plt.plot(epoch,y_acc,color='blue',label='test')
plt.legend()
plt.xlabel("Optimization Steps")
plt.ylabel("Accuracy")
#plt.title("x_acc&y_acc")
plt.xscale(scale)


plt.subplot(1,2,2)
plt.plot(epoch,weight_norm)
plt.xlabel("Optimization Steps")
plt.ylabel("Weight Norm")
plt.yscale('linear')
#plt.title("weight norm")
plt.xscale(scale)


filename="result.png"
plt.savefig(filename)
