import os
import sys

with open(sys.argv[0]) as f:
    code = f.read()  # read the code of this file ASAP, for logging
import uuid
import glob
import time
from dataclasses import dataclass

import json
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from model import GPT, GPTConfig

import argparse
WANDB = True
if WANDB:
    import wandb

from muon import Muon
from dataloader import DistributedDataLoader


# -----------------------------------------------------------------------------
# int main


parser = argparse.ArgumentParser(description="GPT-2 Training Script")

# data hyperparams
parser.add_argument("--input_folder", type=str, help="input folder to train on")

# optimization hyperparams
parser.add_argument(
    "--batch_size",
    type=int,
    default=8 * 64,
    help="batch size, in sequences, across all devices",
)
parser.add_argument(
    "--device_batch_size",
    type=int,
    default=64,
    help="batch size, in sequences, per device",
)
parser.add_argument(
    "--sequence_length", type=int, default=512, help="sequence length, in tokens"
)
parser.add_argument("--learning_rate", type=float, default=0.0036)
parser.add_argument("--warmup_ratio", type=float, default=0)
parser.add_argument(
    "--warmdown_ratio",
    type=float,
    default=1,
    help="ratio of total iterations for linear warmup/warmdown for triangular or trapezoidal schedule",
)
parser.add_argument("--num_epochs", type=int, default=1)
parser.add_argument("--weight_decay", type=float, default=0)

# evaluation and logging hyperparams
parser.add_argument(
    "--val_loss_every",
    type=int,
    default=125,
    help="every how many steps to evaluate val loss? 0 for only at the end",
)
parser.add_argument(
    "--val_tokens",
    type=int,
    default=10485760, #assemble 10M tokens
    help="how many tokens of validation data? it's important to keep this fixed for consistent comparisons",
)
parser.add_argument(
    "--save_every",
    type=int,
    default=0,
    help="every how many steps to save the checkpoint? 0 for only at the end",
)
parser.add_argument("--load_checkpoint", type=str, default=None)

parser.add_argument("--wandb_project", type=str, default="gpt2-finetune")
parser.add_argument("--run_name", type=str, default=None)
parser.add_argument("--output_dir", type=str)
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--model_size", type=str, default="base")
parser.add_argument("--grad_max_norm", type=float, default=0)

args = parser.parse_args()

# do_val = args.val_tokens > 0
do_val = False #no nedd to val
print("Do val: ", do_val)
# parse input folder
input_folder = args.input_folder
input_bin = os.path.join(input_folder, "*_train_*.bin")
if do_val:
    input_val_bin = os.path.join(input_folder, "*_val_*.bin")
metadata_file = os.path.join(input_folder, "metadata.json")
with open(metadata_file, "r") as f:
    metadata = json.load(f)

# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
device = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = ddp_rank == 0  # this process will do logging, checkpointing etc.

# convenience variables
B, T = args.device_batch_size, args.sequence_length 
# calculate the number of steps to take in the val loop.
assert args.val_tokens % (B * T * ddp_world_size) == 0
val_steps = args.val_tokens // (B * T * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (B * ddp_world_size) == 0
train_accumulation_steps = args.batch_size // (B * ddp_world_size)
epochs = args.num_epochs
num_iterations = epochs * metadata["train_tokens"] // (T * args.batch_size)
warmup_iters = int(args.warmup_ratio * num_iterations)
warmdown_iters = int(args.warmdown_ratio * num_iterations)
print(metadata['train_tokens'])
print(f"num_iterations: {epochs} * {metadata['train_tokens']} // {T} * {args.batch_size}= {num_iterations}")
print(f"warmup_iters: 0-{warmup_iters}, warmdown_iters: {num_iterations-warmdown_iters}-{num_iterations}")
print(f"Token each iteration: {T * args.batch_size}")

print(f"Save {int(num_iterations/args.save_every)+1} checkpoints")
print(f"Evaluate {int(num_iterations/args.val_loss_every)+1} times")

# load tokens
train_loader = DistributedDataLoader(input_bin, B, T, ddp_rank, ddp_world_size)
if do_val:
    val_loader = DistributedDataLoader(input_val_bin, B, T, ddp_rank, ddp_world_size)
if master_process:
    print(
        f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files"
    )
    if do_val:
        print(
            f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files"
        )
x, y = train_loader.next_batch() #a B * T tensor

# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
from model_size import model_configs

num_vocab = 50304
if args.load_checkpoint:
    model = GPT.from_pretrained(args.load_checkpoint, device=device)
else:
    config = model_configs[args.model_size]
    model = GPT(GPTConfig(vocab_size=num_vocab, **config))
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True  # suggested by @Chillee
# model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module  # always contains the "raw" unwrapped model
if args.bf16:
    ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
else:
    ctx = torch.amp.autocast(device_type="cuda", dtype=torch.float32)

# init the optimizer(s)
optimizer1 = torch.optim.AdamW(
    raw_model.parameters(),
    lr=args.learning_rate,
    betas=(0.9, 0.95),
    weight_decay=args.weight_decay,
    fused=True,
    eps=1e-6,
)
# optimizer2 = Muon(
#     raw_model.transformer.h.parameters(),
#     lr=0.1 * args.learning_rate,
#     momentum=0.95,
#     rank=ddp_rank,
#     world_size=ddp_world_size,
# )
optimizers = [optimizer1]


# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
    assert it <= num_iterations
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return (it + 1) / warmup_iters
    # 2) constant lr for a while
    elif warmdown_iters==0 or it < num_iterations - warmdown_iters:
        return 1.0
    # 3) linear warmdown
    else:
        decay_ratio = 0.1 + 0.9 * (num_iterations - it) / warmdown_iters
        return decay_ratio


schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# begin logging
if master_process:
    os.makedirs(args.output_dir, exist_ok=True)
    logdir = os.path.join(args.output_dir, args.run_name)
    os.makedirs(logdir, exist_ok=True)
    logfile = f"{logdir}.txt"
    logjsonl_file = f"{logdir}.jsonl"
    # create the log file
    with open(logfile, "w") as f:
        # begin the log by printing this file (the Python code)
        f.write("=" * 100 + "\n")
        f.write(code)
        f.write("=" * 100 + "\n")
        # log information about the hardware/software environment this is running on
        # and print the full `nvidia-smi` to file
        f.write(
            f"Running pytorch {torch.version.__version__} compiled for CUDA {torch.version.cuda}\nnvidia-smi:\n"
        )
        import subprocess

        result = subprocess.run(
            ["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
        )
        f.write(f"{result.stdout}\n")
        f.write("=" * 100 + "\n")
    if WANDB:
        wandb.init(project=args.wandb_project, name=args.run_name)
        wandb.config.update(args)
        wandb.run.name = args.run_name

training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
train_loader.reset()
for step in range(num_iterations + 1):
    last_step = step == num_iterations
    # This effectively ignores timing first 10 steps, which are slower for weird reasons.
    # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
    # steps with dummy data first, and then re-initialize the model and reset the loader.
    if step == 10:
        training_time_ms = 0
        t0 = time.time()
    timed_steps = (
        float("nan") if step <= 11 else (step - 10) + 1
    )  # <= 11 to avoid bug in val

    # once in a while evaluate the validation dataset
    if do_val and (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # run validation batches
        model.eval()
        val_loader.reset() # move to the head of the file
        val_loss = 0.0
        for _ in range(val_steps):
            x_val, y_val = val_loader.next_batch()
            with (
                ctx
            ):  # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
                _, loss = model(x_val, y_val, return_logits=False)
                val_loss += loss.detach()
                del loss
        dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
        val_loss /= val_steps
        # log val loss to console and to logfile
        if master_process:
            print(
                f"step:{step}/{num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms estimated_time_left:{training_time_ms/1000/60/(timed_steps-1)*(num_iterations-step+1):.2f}min"
            )
            with open(logfile, "a") as f:
                f.write(
                    f"step:{step}/{num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms\n"
                )
        # log metrics to wandb
        if master_process:
            if WANDB:
                wandb.log(
                    {
                        "val/loss": val_loss.item(),
                        "train/time_ms": training_time_ms,
                        "train/step_avg_ms": training_time_ms / (timed_steps - 1),
                        "train/step": step,
                    }
                )
            with open(logjsonl_file, 'a') as f:
                f.write(json.dumps({
                    "val/loss": val_loss.item(),
                    "train/time_ms": training_time_ms,
                    "train/step_avg_ms": training_time_ms / (timed_steps - 1),
                    "train/step": step,
                }) + '\n')
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    if master_process and (
        last_step or (args.save_every > 0 and step % args.save_every == 0)
    ):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # save the state of the training process
        log = dict(
            step=step,
            code=code,
            model=raw_model.state_dict(),
            config=raw_model.config,
            # optimizers=[opt.state_dict() for opt in optimizers],
        )
        torch.save(log, f"{logdir}/state_step{step:06d}.pt")
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time() 

    # bit confusing: we want to make sure to eval on 0th iteration
    # but also after the very last iteration. so we loop for step <= num_iterations
    # instead of just < num_iterations (one extra due to <=), only to do
    # the validation/sampling one last time, and then we break right here as we're done.
    if last_step:
        break

    # --------------- TRAINING SECTION BEGIN -----------------
    model.train()
    for i in range(1, train_accumulation_steps + 1):
        # forward pass
        with ctx:
            _, loss = model(x, y, return_logits=False)
            train_loss = loss.detach()
        # advance the dataset for the next batch
        x, y = train_loader.next_batch()
        # backward pass
        if i < train_accumulation_steps:
            with model.no_sync():  # there's no need to sync gradients every accumulation step
                loss.backward()
        else:
            loss.backward()  # just sync on the last step
    for p in model.parameters():
        p.grad /= train_accumulation_steps
    if args.grad_max_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_max_norm)
    # step the optimizers and schedulers
    for opt, sched in zip(optimizers, schedulers):
        opt.step()
        sched.step()
    # null the gradients
    model.zero_grad(set_to_none=True)
    # --------------- TRAINING SECTION END -------------------
    # everything that follows now is just diagnostics, prints, logging, etc.

    # dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
    if master_process:
        approx_time = training_time_ms + 1000 * (time.time() - t0)
        print(
            f"step:{step+1}/{num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms estimated_time_left:{approx_time/1000/60/(timed_steps-1)*(num_iterations-step+1):.2f}min"
        )
        with open(logfile, "a") as f:
            f.write(
                f"step:{step+1}/{num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms\n"
            )
        # log metrics to wandb
        if WANDB:
            wandb.log(
                {
                    "train/loss": train_loss.item(),
                    "train/time_ms": approx_time,
                    "train/step_avg_ms": approx_time / timed_steps,
                    "train/step": step + 1,
                    "lr": get_lr(step) * args.learning_rate,
                }
            )
        with open(logjsonl_file, 'a') as f:
            f.write(json.dumps(
                {
                    "train/loss": train_loss.item(),
                    "train/time_ms": approx_time,
                    "train/step_avg_ms": approx_time / timed_steps,
                    "train/step": step + 1,
                    "lr": get_lr(step) * args.learning_rate,
                }
            ) + '\n')
if master_process:
    print(
        f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB"
    )

# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
