# -*- coding: utf-8 -*-
"""
Created on Wed May  8 20:54:09 2024

@author: William
"""

import torch

path = r'C:\Users\William\Documents\Programming\PhD\Datasets\Robust_RHO_Project\MNLIDataModule\irreducible_models\irred_losses_and_checks_1.pt'


data = torch.load(path)




#%% test the lambda scheduler:
    
import numpy as np
import matplotlib.pyplot as plt
    
steps = np.arange(1e3)
warmup_steps = 150  

def func(step):
    return min(1.0, (step + 1) / (warmup_steps + 1))
                                
plt.plot([func(x) for x in steps])