from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np


def plot_loss(losses,):
    plt.figure(figsize=(6, 4))
    plt.plot(losses)
    
    
def load_loss(log_path: Path, grad_acc: int = 4, smooth_window: int = 100):
    '''
    [0] [2024-08-06 19:45:57] it 80214 | loss 3.2205 | lr 5.000e-04 | grad_norm 0.000e+00 | it_time 2.923 | fw_time 0.123 | bw_time 2.053 | mem_used 13440MB | s_mean 1.932e-03 |
    '''
    loss = []
    with open(log_path, 'r') as f:
        for line in f:
            if ' | loss ' in line:
                loss.append(float(line.split(' | ')[1].split(' ')[1]))
    
    # Take the average of the gradient accumulation
    loss = [np.mean(loss[i:i+grad_acc]) for i in range(0, len(loss), grad_acc)]
    
    loss = np.array(loss)
    loss = np.convolve(loss, np.ones(smooth_window) / smooth_window, mode='valid')
    return loss


def main():
    exp_paths = {
        '128': 'slurm-390.out',
        '256': 'slurm-391.out',
        '512': 'slurm-393.out',
        '2048': 'slurm-395.out',
        '24-1024': 'slurm-511.out',
    }
    for exp_name, log_path in exp_paths.items():
        print(f"Plotting {exp_name}")
        
        losses = load_loss(log_path)
        plt.plot(losses, label=exp_name)

    plt.ylim(2, 4)
    plt.legend()
    plt.xlabel('Steps')
    plt.ylabel('Loss')
    dst_path = 'loss.pdf'
    plt.grid()
    print(f"Saving to {dst_path}")
    plt.savefig(dst_path, dpi=300, bbox_inches='tight')
    

if __name__ == '__main__':
    main()
