import os
import sys

with open(sys.argv[0]) as f:
    code = f.read()  # read the code of this file ASAP, for logging
import uuid
import glob
import time
from dataclasses import dataclass

import json
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from model import GPT, GPTConfig

import argparse
WANDB = False
if WANDB:
    import wandb

from muon import Muon
from dataloader import DistributedDataLoader


# -----------------------------------------------------------------------------
# int main


parser = argparse.ArgumentParser(description="GPT-2 Training Script")

# data hyperparams
parser.add_argument("--input_folder", type=str, help="input folder to train on")

# optimization hyperparams
parser.add_argument(
    "--batch_size",
    type=int,
    default=8 * 64,
    help="batch size, in sequences, across all devices",
)
parser.add_argument(
    "--device_batch_size",
    type=int,
    default=64,
    help="batch size, in sequences, per device",
)
parser.add_argument(
    "--sequence_length", type=int, default=512, help="sequence length, in tokens"
)
parser.add_argument("--learning_rate", type=float, default=0.0036)
parser.add_argument("--warmup_ratio", type=float, default=0)
parser.add_argument(
    "--warmdown_ratio",
    type=float,
    default=1,
    help="ratio of total iterations for linear warmup/warmdown for triangular or trapezoidal schedule",
)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--weight_decay", type=float, default=0)

# evaluation and logging hyperparams
parser.add_argument(
    "--val_loss_every",
    type=int,
    default=125,
    help="every how many steps to evaluate val loss? 0 for only at the end",
)
parser.add_argument(
    "--val_tokens",
    type=int,
    default=10485760,
    help="how many tokens of validation data? it's important to keep this fixed for consistent comparisons",
)
parser.add_argument(
    "--save_every",
    type=int,
    default=0,
    help="every how many steps to save the checkpoint? 0 for only at the end",
)
parser.add_argument("--load_checkpoint", type=str, default=None)

parser.add_argument("--wandb_project", type=str, default="gpt2-finetune")
parser.add_argument("--run_name", type=str, default=None)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--model_size", type=str, default="base")
args = parser.parse_args()

do_val = args.val_tokens > 0
print("Do val: ", do_val)
# parse input folder
input_folder = args.input_folder
input_bin = os.path.join(input_folder, "*_train_*.bin")
if do_val:
    input_val_bin = os.path.join(input_folder, "*_val_*.bin")
# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
print(ddp_world_size)