import torch
import argparse
from torch.utils.data import DataLoader
import numpy as np
import random
import time
import os
seed = 30
from itertools import cycle
from models.ns import NSRestrictedNet, NSRestrictedPotentialNet
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
gen = torch.Generator()
gen.manual_seed(seed)

restriction_x = 0.5
t_max = 5.
x_max = 2.2 - restriction_x
x_min = 0.
y_max = 0.41
y_min = 0.

from tuner_results import ns_best_params

parser = argparse.ArgumentParser()
parser.add_argument('--init_weight', default=1., type=float, help='Weight for the init loss')
parser.add_argument('--der_weight', default=1., type=float, help='Weight for the derivative loss')
parser.add_argument('--out_weight', default=1., type=float, help='Weight for the output loss')
parser.add_argument('--mom_weight', default=1., type=float, help='Weight for the momentum pde loss')
parser.add_argument('--div_weight', default=1., type=float, help='Weight for the divergence pde loss')
parser.add_argument('--bc_weight', default=1., type=float, help='Weight for the boundary condition loss')
parser.add_argument('--lr_init', default=1e-4, type=float, help='Starting learning rate')
parser.add_argument('--device', default='cuda:2', type=str, help='Device to use')
parser.add_argument('--name', default='base', type=str, help='Experiment name')
parser.add_argument('--interp', default='cubic', type=str, help='Experiment name')
parser.add_argument('--train_steps', default=100000, type=int, help='Number of training steps')
parser.add_argument('--epochs', default=500, type=int, help='Number of epochs')
parser.add_argument('--mode', default='Derivative', type=str, help='Mode: -1 for PINN learning, 0 for derivative learning, 1 for output learning')
parser.add_argument('--use_hessian', default=False, type=bool, help='Whether the hessian is used', action=argparse.BooleanOptionalAction)
parser.add_argument('--use_empirical', default=False, type=bool, help='Whether the hessian is used', action=argparse.BooleanOptionalAction)
parser.add_argument('--batch_size', default=128, type=int, help='Number of samples per step')
parser.add_argument('--layers', default=8, type=int, help='Number of layers in the network')
parser.add_argument('--units', default=128, type=int, help='Number of units per layer in the network')

args = parser.parse_args()
init_weight = args.init_weight
device = args.device
name = args.name
train_steps = args.train_steps
epochs = args.epochs
batch_size = args.batch_size
layers = args.layers
units = args.units
lr_init = args.lr_init
mode = args.mode
mom_weight = args.mom_weight
div_weight = args.div_weight
der_weight = args.der_weight
out_weight = args.out_weight
bc_weight = args.bc_weight
interp = args.interp
use_hessian = args.use_hessian
use_empirical = args.use_empirical


print('Loading the data...') 
# Folder name
if use_empirical:
    EXP_PATH = f'NS_empirical_restricted'
    pde_dataset = torch.load(os.path.join(EXP_PATH, f'pde_data_{interp}.pth'))
else:
    EXP_PATH = f'NS_true_restricted'
    pde_dataset = torch.load(os.path.join(EXP_PATH, f'pde_data.pth'))

if not os.path.exists(f'{EXP_PATH}/{name}'):
    os.mkdir(f'{EXP_PATH}/{name}')
    
batch_size = 512
init_weight = ns_best_params[str(mode)]['init_weight']
bc_weight = ns_best_params[str(mode)]['bc_weight']
lr_init = ns_best_params[str(mode)]['lr_init']

if mode == 'Derivative':
    sys_weight = ns_best_params[str(mode)]['sys_weight']
    mom_weight = 0.
    div_weight = 0.
elif mode == 'Output':
    sys_weight = ns_best_params[str(mode)]['sys_weight']
    div_weight = 0.
    mom_weight = 0.
elif mode == 'PINN':
    div_weight = ns_best_params[str(mode)]['div_weight']
    mom_weight = ns_best_params[str(mode)]['mom_weight']
    sys_weight = 0.
elif mode == 'Sobolev':
    sys_weight = ns_best_params[str(mode)]['sys_weight']
    div_weight = 0.
    mom_weight = 0.
else:
    raise ValueError('Mode is not valid')

title_mode = mode

# Load the init data
init_dataset = torch.load(os.path.join(EXP_PATH, f'init_data.pth'))
bc_dataset = torch.load(os.path.join(EXP_PATH, f'bc_data.pth'))

# Generate the dataloaders
bc_dataloader = DataLoader(bc_dataset, batch_size, generator=gen, shuffle=True, num_workers=24)
pde_dataloader = DataLoader(pde_dataset, batch_size, generator=gen, shuffle=True, num_workers=24)
init_dataloader = DataLoader(init_dataset, batch_size, generator=gen, shuffle=True, num_workers=24)
test_dataloader = DataLoader(pde_dataset, batch_size, generator=gen, shuffle=True, num_workers=24)

print('Data loaded!')

# Last model definitions
hidden_units=[units for _ in range(layers)]
activation = torch.nn.Tanh()

if name == 'base':
    model = NSRestrictedNet(
        init_weight=init_weight,
        mom_weight=mom_weight,
        sys_weight=sys_weight,
        div_weight=div_weight,
        bc_weight=bc_weight,
        hidden_units=hidden_units,
        lr_init=lr_init,
        activation=activation,
        device=device,
        last_activation=False,
    ).to(device)
else:
    model = NSRestrictedPotentialNet(
        init_weight=init_weight,
        sys_weight=sys_weight,
        mom_weight=mom_weight,
        div_weight=div_weight,
        bc_weight=bc_weight,
        hidden_units=hidden_units,
        lr_init=lr_init,
        activation=activation,
        device=device,
        last_activation=False,
    ).to(device)
    

step_list = []
mom_losses = []
div_losses = []
out_losses = [] 
der_losses = []
init_losses = []
bc_losses = []
tot_losses = []

step_list_test = []
mom_losses_test = []
div_losses_test = []
out_losses_test = []
der_losses_test = []
init_losses_test = []
bc_losses_test = []
time_test = []
tot_losses_test = []



# Training loop
def train_loop(epochs:int,
        pde_dataloader:DataLoader,
        init_dataloader:DataLoader,
        bc_dataloader:DataLoader,
        print_every:int=10):
    
    # Training mode for the network
    model.train()
    
    for epoch in range(epochs):
        start_time = time.time()
        step_prefix = epoch*min(len(pde_dataloader),train_steps)
        print(f'Epoch: {epoch}, step_prefix: {step_prefix}')
        for step, (pde_data, init_data, bc_data) in enumerate(zip(pde_dataloader, cycle(init_dataloader), cycle(bc_dataloader))):
            if step > train_steps:
                break
            # Load batches from dataloaders
            x_pde = pde_data[0].to(device).float().requires_grad_(True)
            y_pde = pde_data[1].to(device).float()
            D_pde = pde_data[2].reshape((-1,3,3)).to(device).float()
                     
            x_init = init_data[0].to(device).float()
            y_init = init_data[1].to(device).float()
            
            x_bc = bc_data[0].to(device).float()
            y_bc = bc_data[1].to(device).float()
            
            # Call zero grad on optimizer
            model.opt.zero_grad()
            
            loss = model.loss_fn(mode=mode,
                x_pde=x_pde, y_pde=y_pde, D_pde=D_pde,
                x_init=x_init, y_init=y_init, x_bc=x_bc, y_bc=y_bc
            )
            # Backward the loss, calculate gradients
            loss.backward()
            # Optimizer step
            model.opt.step()
            # Update the learning rate scheduling
            #pinn_student.lr_scheduler.step()
            #model.lr_scheduler.step()
            
            # Printing
            if (step_prefix+step) % print_every == 0 and step>0:
                with torch.no_grad():
                    step_val, mom_loss_val, div_loss_val, out_loss_val, der_loss_val, init_loss_val, bc_loss_val, tot_loss_val = model.print_losses(
                        step=step_prefix+step, mode=mode,
                        x_pde=x_pde, y_pde=y_pde, D_pde=D_pde,
                        x_init=x_init, y_init=y_init, x_bc=x_bc, y_bc=y_bc
                    )
                    step_list.append(step_val)
                    mom_losses.append(mom_loss_val)
                    div_losses.append(div_loss_val)
                    out_losses.append(out_loss_val)
                    der_losses.append(der_loss_val)
                    bc_losses.append(bc_loss_val)
                    init_losses.append(init_loss_val)
                    tot_losses.append(tot_loss_val)
        end_time = time.time()
        
        epoch_time = end_time - start_time
        print('\n')
        print(f'Epoch: {epoch}, time: {epoch_time}')
        time_test.append(epoch_time)
        mom_loss_test = 0.0
        div_loss_test = 0.0
        out_loss_test = 0.0
        der_loss_test = 0.0
        init_loss_test = 0.0
        bc_loss_test = 0.0
        tot_loss_test = 0.0
        with torch.no_grad():
            for (pde_data, init_data, bc_data) in zip(test_dataloader, cycle(init_dataloader), cycle(bc_dataloader)):
                x_pde = pde_data[0].to(device).float().requires_grad_(True)
                y_pde = pde_data[1].to(device).float()
                D_pde = pde_data[2].reshape((-1,3,3)).to(device).float()
                         
                x_init = init_data[0].to(device).float()
                y_init = init_data[1].to(device).float()
                
                x_bc = bc_data[0].to(device).float()
                y_bc = bc_data[1].to(device).float()
                
                step_val, mom_loss_val, div_loss_val, out_loss_val, der_loss_val, init_loss_val, bc_loss_val, tot_loss_val = model.print_losses(
                    step=step_prefix+step, mode=mode,
                    x_pde=x_pde, y_pde=y_pde, D_pde=D_pde,
                    x_init=x_init, y_init=y_init, x_bc=x_bc, y_bc=y_bc, print_to_screen=False
                )
                
                mom_loss_test += mom_loss_val.item()
                div_loss_test += div_loss_val.item()
                out_loss_test += out_loss_val.item()
                der_loss_test += der_loss_val.item()
                init_loss_test += init_loss_val.item()
                bc_loss_test += bc_loss_val.item()
                tot_loss_test += tot_loss_val.item()
                
            mom_loss_test /= len(test_dataloader)
            div_loss_test /= len(test_dataloader)
            out_loss_test /= len(test_dataloader)
            der_loss_test /= len(test_dataloader)
            init_loss_test /= len(test_dataloader)
            bc_loss_test /= len(test_dataloader)
            tot_loss_test /= len(test_dataloader)
            
            
            step_list_test.append(step_prefix+step)
            mom_losses_test.append(mom_loss_test)
            div_losses_test.append(div_loss_test)
            out_losses_test.append(out_loss_test)
            der_losses_test.append(der_loss_test)
            init_losses_test.append(init_loss_test)
            bc_losses_test.append(bc_loss_test)
            tot_losses_test.append(tot_loss_test)
            
            print(f'Epoch: {epoch}, step: {step}, mom_loss: {mom_loss_test}, div_loss: {div_loss_test}, out_loss: {out_loss_test}, der_loss: {der_loss_test}, init_loss: {init_loss_test}, bc_loss: {bc_loss_test}, tot_loss: {tot_loss_test}')
            

        

# Run the loop
train_loop(epochs=epochs, pde_dataloader=pde_dataloader, init_dataloader=init_dataloader, bc_dataloader=bc_dataloader, print_every=100)


torch.cuda.empty_cache()
model.eval()


import os
if not os.path.exists(f'{EXP_PATH}/{name}/saved_models'):
    os.mkdir(f'{EXP_PATH}/{name}/saved_models')
torch.save(model.state_dict(), f'{EXP_PATH}/{name}/saved_models/NSnet_{title_mode}')
# %%
model.load_state_dict(torch.load(f'{EXP_PATH}/{name}/saved_models/NSnet_{title_mode}'))

with open(f'NS_empirical/data_original.npy', 'rb') as f:
    pde_true = np.load(f)

if not os.path.exists(f'{EXP_PATH}/{name}/plots{title_mode}/'):
    os.mkdir(f'{EXP_PATH}/{name}/plots{title_mode}/')

dt = 0.01
dx = 0.01
x_max_interp = 1.5
x_min_interp = -1.5
t_vec = np.arange(0.,10.+dt,dt)
steps_print = 10
t_max = 10.

from matplotlib import pyplot as plt

# For each time step, plot the results
for t_ind in range(steps_print):
    
    t = t_max*t_ind/steps_print
    print(f'Plotting {t}')
    N = 500
    out_indexes = np.argwhere(np.float32(pde_true[:,0]) == t).reshape((-1))
    curr_pde = pde_true[out_indexes]
    
    points_x = np.unique(curr_pde[:,1])
    points_y = np.unique(curr_pde[:,2])
    len_x = len(points_x)
    len_y = len(points_y)
    #X = np.tile(points_x,(len(points_y),1))
    #Y = np.tile(points_y,(len(points_y),1)).T

    #X,Y = np.meshgrid(points_x,points_y)
    X = curr_pde[:,1].reshape((len_x,len_y)).T
    Y = curr_pde[:,2].reshape((len_x,len_y)).T
    T = t*np.ones_like(X.reshape((-1)))
    pts = np.vstack([T,X.reshape(-1),Y.reshape(-1)]).T
    #pts = pde_true[out_indexes,:3]
    
    pred_plot = model.forward(torch.tensor(pts).to(device).float()).detach().cpu().numpy()
    xvel_pred_plot = pred_plot[:,0].reshape(X.shape)
    yvel_pred_plot = pred_plot[:,1].reshape(X.shape)
    pres_pred_plot = pred_plot[:,2].reshape(X.shape)
    
    
    xvel_true_plot = curr_pde[:,3].reshape(X.T.shape).T
    yvel_true_plot = curr_pde[:,4].reshape(X.T.shape).T
    pres_true_plot = curr_pde[:,5].reshape(X.T.shape).T
    #plots the streamplot for the velocity field
    fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(11,10))
    
    
    vmin, vmax = (np.min(np.nan_to_num(pres_true_plot)), np.max(np.nan_to_num(pres_true_plot)))
    ax[0].set_xlim((x_min-0.01,x_max + 0.01))
    ax[0].set_ylim((y_min-0.01,y_max + 0.01))
    ax[0].set_title('True pressure')
    contour = ax[0].contourf(X,Y,pres_true_plot,100,cmap='jet', vmin=vmin, vmax=vmax)
    stream = ax[0].streamplot(X,Y,xvel_true_plot,yvel_true_plot,density=1,linewidth=0.2, color=xvel_true_plot**2+yvel_true_plot**2)
    
    
    ax[1].set_xlim((x_min-0.01,x_max + 0.01))
    ax[1].set_ylim((y_min-0.01,y_max + 0.01))
    ax[1].set_title('Predicted pressure')
    ax[1].contourf(X,Y,pres_pred_plot,100,cmap='jet', vmin=vmin, vmax=vmax)
    ax[1].streamplot(X,Y,xvel_pred_plot,yvel_pred_plot,density=1,linewidth=0.2, color=xvel_pred_plot**2+yvel_pred_plot**2, norm=stream.lines.norm)
    
    ax[2].set_xlim((x_min-0.01,x_max + 0.01))
    ax[2].set_ylim((y_min-0.01,y_max + 0.01))
    ax[2].set_title('Error pressure')
    contour_error = ax[2].contourf(X,Y,np.abs(pres_pred_plot-pres_true_plot),100,cmap='jet')
    ax[2].streamplot(X,Y,xvel_true_plot,yvel_true_plot,density=1,linewidth=0.2, color=xvel_true_plot**2+yvel_true_plot**2, norm=stream.lines.norm)
    
    
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.4, 0.03, 0.45])
    cbar_ax1 = fig.add_axes([0.85, 0.125, 0.03, 0.2])
    fig.colorbar(contour, cax=cbar_ax)
    fig.colorbar(contour_error, cax=cbar_ax1)
    
    plt.savefig(f'{EXP_PATH}/{name}/plots{title_mode}/pressure_results{t}.png', dpi=300)
    plt.close()
    
    
    
    
    fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(11,10))
    
    vmin, vmax = (np.min(np.nan_to_num(xvel_true_plot)), np.max(np.nan_to_num(xvel_true_plot)))
    ax[0].set_xlim((x_min-0.01,x_max + 0.01))
    ax[0].set_ylim((y_min-0.01,y_max + 0.01))
    ax[0].set_title('True x-Velocity')
    contour = ax[0].contourf(X,Y,xvel_true_plot,100,cmap='jet', vmin=vmin, vmax=vmax)
    stream = ax[0].streamplot(X,Y,xvel_true_plot,yvel_true_plot,density=1,linewidth=0.2, color=xvel_true_plot**2+yvel_true_plot**2)
    
    
    ax[1].set_xlim((x_min-0.01,x_max + 0.01))
    ax[1].set_ylim((y_min-0.01,y_max + 0.01))
    ax[1].set_title('Predicted x-Velocity')
    ax[1].contourf(X,Y,xvel_pred_plot,100,cmap='jet', vmin=vmin, vmax=vmax)
    ax[1].streamplot(X,Y,xvel_pred_plot,yvel_pred_plot,density=1,linewidth=0.2, color=xvel_pred_plot**2+yvel_pred_plot**2, norm=stream.lines.norm)
    
    ax[2].set_xlim((x_min-0.01,x_max + 0.01))
    ax[2].set_ylim((y_min-0.01,y_max + 0.01))
    ax[2].set_title('Error x-Velocity')
    contour_error = ax[2].contourf(X,Y,np.abs(xvel_pred_plot-xvel_true_plot),100,cmap='jet')
    ax[2].streamplot(X,Y,xvel_true_plot,yvel_true_plot,density=1,linewidth=0.2, color=xvel_true_plot**2+yvel_true_plot**2, norm=stream.lines.norm)
    
    
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.4, 0.03, 0.45])
    cbar_ax1 = fig.add_axes([0.85, 0.125, 0.03, 0.2])
    fig.colorbar(contour, cax=cbar_ax)
    fig.colorbar(contour_error, cax=cbar_ax1)
    plt.savefig(f'{EXP_PATH}/{name}/plots{title_mode}/xvel_results{t}.png', dpi=300)
    plt.close()
    
    
    
    
    fig, ax = plt.subplots(nrows=3, ncols=1, figsize=(11,10))
    
    vmin, vmax = (np.min(np.nan_to_num(yvel_true_plot)), np.max(np.nan_to_num(yvel_true_plot)))
    print(vmin)
    print(vmax)
    ax[0].set_xlim((x_min-0.01,x_max + 0.01))
    ax[0].set_ylim((y_min-0.01,y_max + 0.01))
    ax[0].set_title('True y-Velocity')
    contour = ax[0].contourf(X,Y,yvel_true_plot,100,cmap='jet', vmin=vmin, vmax=vmax)
    stream = ax[0].streamplot(X,Y,xvel_true_plot,yvel_true_plot,density=1,linewidth=0.2, color=xvel_true_plot**2+yvel_true_plot**2)
    
    
    ax[1].set_xlim((x_min-0.01,x_max + 0.01))
    ax[1].set_ylim((y_min-0.01,y_max + 0.01))
    ax[1].set_title('Predicted y-Velocity')
    ax[1].contourf(X,Y,yvel_pred_plot,100,cmap='jet', vmin=vmin, vmax=vmax)
    ax[1].streamplot(X,Y,xvel_pred_plot,yvel_pred_plot,density=1,linewidth=0.2, color=xvel_pred_plot**2+yvel_pred_plot**2, norm=stream.lines.norm)
    
    ax[2].set_xlim((x_min-0.01,x_max + 0.01))
    ax[2].set_ylim((y_min-0.01,y_max + 0.01))
    ax[2].set_title('Error y-Velocity')
    contour_error = ax[2].contourf(X,Y,np.abs(yvel_pred_plot-yvel_true_plot),100,cmap='jet')
    ax[2].streamplot(X,Y,xvel_true_plot,yvel_true_plot,density=1,linewidth=0.2, color=xvel_true_plot**2+yvel_true_plot**2, norm=stream.lines.norm)
    
    
    fig.subplots_adjust(right=0.8)
    cbar_ax = fig.add_axes([0.85, 0.4, 0.03, 0.45])
    cbar_ax1 = fig.add_axes([0.85, 0.125, 0.03, 0.2])
    fig.colorbar(contour, cax=cbar_ax)
    fig.colorbar(contour_error, cax=cbar_ax1)
    plt.savefig(f'{EXP_PATH}/{name}/plots{title_mode}/yvel_results{t}.png', dpi=300)
    plt.close()


# Convert the losses arrays
epoch_list = torch.tensor(step_list).cpu().numpy()
mom_losses = torch.tensor(mom_losses).cpu().numpy()
div_losses = torch.tensor(div_losses).cpu().numpy()
out_losses = torch.tensor(out_losses).cpu().numpy()
der_losses = torch.tensor(der_losses).cpu().numpy()
init_losses = torch.tensor(init_losses).cpu().numpy()
bc_losses = torch.tensor(bc_losses).cpu().numpy()
tot_losses = torch.tensor(tot_losses).cpu().numpy()
    
loss_combination = np.column_stack([epoch_list, mom_losses, div_losses, out_losses, der_losses, init_losses, bc_losses, tot_losses])
with open(f'{EXP_PATH}/{name}/plots{title_mode}/traindata.npy', 'wb') as f:
    np.save(f, loss_combination)
N = 100
l = len(np.convolve(mom_losses, np.ones(N)/N, mode='valid'))
plt.figure()
plt.plot(epoch_list[:l], np.convolve(mom_losses, np.ones(N)/N, mode='valid'), label='mom_loss', color='red')
plt.plot(epoch_list[:l], np.convolve(div_losses, np.ones(N)/N, mode='valid'), label='div_loss', color='orange')
plt.plot(epoch_list[:l], np.convolve(out_losses, np.ones(N)/N, mode='valid'), label='out_loss', color='green')
plt.plot(epoch_list[:l], np.convolve(der_losses, np.ones(N)/N, mode='valid'), label='der_loss', color='blue')
plt.plot(epoch_list[:l], np.convolve(init_losses, np.ones(N)/N, mode='valid'), label='init_loss', color='purple')
plt.plot(epoch_list[:l], np.convolve(bc_losses, np.ones(N)/N, mode='valid'), label='bc_loss', color='pink')
plt.plot(epoch_list[:l], np.convolve(tot_losses, np.ones(N)/N, mode='valid'), label='tot_loss', color='black')
plt.legend()
plt.yscale('log')
plt.title('Losses of the student model')
plt.xlabel('Training steps')
plt.ylabel('Loss')
plt.savefig(f'{EXP_PATH}/{name}/plots{title_mode}/losses.png')
plt.close()


epoch_list = torch.tensor(step_list_test).cpu().numpy()
mom_losses = torch.tensor(mom_losses_test).cpu().numpy()
div_losses = torch.tensor(div_losses_test).cpu().numpy()
out_losses = torch.tensor(out_losses_test).cpu().numpy()
der_losses = torch.tensor(der_losses_test).cpu().numpy()
init_losses = torch.tensor(init_losses_test).cpu().numpy()
bc_losses = torch.tensor(bc_losses_test).cpu().numpy()
tot_losses = torch.tensor(tot_losses_test).cpu().numpy()
time_test = torch.tensor(time_test).cpu().numpy()

loss_combination = np.column_stack([epoch_list, mom_losses, div_losses, out_losses, der_losses, init_losses, bc_losses, tot_losses, time_test])
with open(f'{EXP_PATH}/{name}/plots{title_mode}/testdata.npy', 'wb') as f:
    np.save(f, loss_combination)
N = 10
l = len(np.convolve(mom_losses, np.ones(N)/N, mode='valid'))
plt.figure()
plt.plot(epoch_list[:l], np.convolve(mom_losses, np.ones(N)/N, mode='valid'), label='mom_loss', color='red')
plt.plot(epoch_list[:l], np.convolve(div_losses, np.ones(N)/N, mode='valid'), label='div_loss', color='orange')
plt.plot(epoch_list[:l], np.convolve(out_losses, np.ones(N)/N, mode='valid'), label='out_loss', color='green')
plt.plot(epoch_list[:l], np.convolve(der_losses, np.ones(N)/N, mode='valid'), label='der_loss', color='blue')
plt.plot(epoch_list[:l], np.convolve(init_losses, np.ones(N)/N, mode='valid'), label='init_loss', color='purple')
plt.plot(epoch_list[:l], np.convolve(bc_losses, np.ones(N)/N, mode='valid'), label='bc_loss', color='pink')
plt.plot(epoch_list[:l], np.convolve(tot_losses, np.ones(N)/N, mode='valid'), label='tot_loss', color='black')
plt.legend()
plt.yscale('log')
plt.title('Losses of the student model')
plt.xlabel('Training steps')
plt.ylabel('Loss')
plt.savefig(f'{EXP_PATH}/{name}/plots{title_mode}/losses_test.png')
