import os
import torch
import random
import numpy as np


def set_random_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)


def plot_loss_curve(results, save_path, rank=0):
    import matplotlib.pyplot as plt
    import os
    os.makedirs(os.path.join(save_path, 'loss'), exist_ok=True)

    loss_list = [x['loss'] for x in results if 'loss' in x]

    plt.plot(loss_list)
    plt.xlabel('Step')
    plt.ylabel('Loss')
    plt.savefig(os.path.join(save_path, f"loss/loss_{rank}"), bbox_inches='tight')
    
