import json
def log2wandb(log_file="vanilla_trans.227.log"):

    with open(log_file,'r') as f:
        log1 = f.readlines()
        wandb_model_config = log1[1]
        wandb_model_config = eval(wandb_model_config)
        log1 = [e for e in log1 if e.startswith('iter ')]
        train_iter_loss = [[int(e.split()[1].replace(":","")), float(e.split()[-4].replace(",",""))] for e in log1]
        valid_iter_loss = [[int(e.split()[1].replace(":","")), float(e.split()[-1])] for e in log1]
    print(wandb_model_config)
    print(log1[:5])
    print(train_iter_loss[:5])
    print(valid_iter_loss[:5])
    model1_name = log_file.split('.')[0].split('/')[-1]
    print(model1_name)
    

    data_list = [
        {"iter": iter, "train_loss": train_loss, "val_loss": val_loss} for iter, train_loss, val_loss in zip(
            [e[0] for e in train_iter_loss],
            [e[1] for e in train_iter_loss],
            [e[1] for e in valid_iter_loss]
        )
    ]
    print(data_list[:5])
    
    
    
    
    

    import wandb
    wandb.init(
        project="stack_llm",
        name=model1_name,
        config=dict(wandb_model_config)
    )
    for data_point in data_list:
        wandb.log({
            "train_loss": data_point["train_loss"],
            "val_loss": data_point["val_loss"],
            "learning_rate": data_point.get("lr", 0)  
        }, step=data_point["iter"])

if __name__ == '__main__':
    log2wandb("stack_v4.227.log")