"""
This training script can be run both on a single gpu in debug mode,
and also in a larger training run with distributed data parallel (ddp).

To run on a single GPU, example:
$ python train.py --batch_size=32 --compile=False

To run with DDP on 4 gpus on 1 node, example:
$ torchrun --standalone --nproc_per_node=4 train.py

To run with DDP on 4 gpus across 2 nodes, example:
- Run on the first (master) node with example IP 123.456.123.456:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py
- Run on the worker node:
$ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py
(If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1)
"""

import os
import time
from contextlib import nullcontext
from datetime import datetime
import numpy as np
import torch
from functools import partial
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import destroy_process_group
from nn.gpt2 import StructGPT
from nn.dense_gpt2 import DenseGPT
from model.gpt_fns import get_lr_mult
from model.gpt_fns import update_lrs
from model.gpt_fns import reset_lrs
from model.gpt_fns import get_batch_all
from model.gpt_fns import estimate_loss
from model.gpt_fns import get_epoch
from model.gpt_fns import init_dist_process
from model.gpt_fns import construct_configs
from nn.cola_nn import cola_parameterize
from nn.cola_nn import get_model_summary_and_flops
from tqdm import tqdm

# -----------------------------------------------------------------------------
eval_interval, log_interval, eval_iters, eval_only = 2000, 1, 200, False
out_dir, always_save_checkpoint = "out", False
ckpt_path = ""
dataset = 'open_small'
data_dir = './'
data_dir = os.path.join(data_dir, dataset)
block_size, batch_size, gradient_accumulation_steps = 1024, 12, 5 * 8
vocab_size = 50_304
data_dtype = np.uint8
base_n_head, base_d_head, base_d_model, base_d_embd, base_ffn_expansion = -1, 64, 768, 768, 1
n_head, d_head, d_model, d_embd, ffn_expansion = -1, 64, 768, -1, 4
n_layer = 12
num_ffn_experts = 1
split_qkv = False
axial = False
dropout, bias, do_qk_ln = 0.0, False, False
opt_name, init_lr, weight_decay, beta1, beta2, grad_clip = "AdamW", 6e-4, 1e-1, 0.9, 0.95, 1.0
decay_lr, warmup_iters = True, 2_000
spec_penalty_weight = 0.
aux_loss_weight = 0.01
max_iters = 100_000
backend = 'nccl'  # 'nccl', 'gloo', etc.
device = 'cuda' if torch.cuda.is_available() else 'cpu'
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
compile = True
struct = "none"
tt_dim, tt_rank, num_blocks, rank_frac, every_n_fwds = 2, 1, 4, 0.2, 100
expr = ""
num_experts = 0
lm_head_struct = ''
lm_head_tt_rank = -1
lm_head_rank_frac = -1.
layers, input_lr_mult = "all_but_last", 1.
wandb_log, wandb_project = False, "attention"
# -----------------------------------------------------------------------------
exec(open('./model/configurator.py').read())  # overrides from command line or config file
if d_model != d_embd and d_embd != -1:
    base_d_embd = base_d_model - 1  # a hack to ensure base model has emb up/down sampler params
now = datetime.now()
timestamp = now.strftime("%Y-%m-%d_%H%M%S")
lr_decay_iters, min_lr = max_iters, init_lr / 10.
# wandb_run_name = f"l{n_layer}-h{n_head}-d{d_model}-e{d_embd}-{struct}_{now.strftime('%H%M%S')}"
wandb_run_name = f"{struct}_e{num_experts}_{layers}_l{n_layer}-dm{d_model}-de{d_embd}-h{n_head}-dh{d_head}-ttr{tt_rank}-{now.strftime('%H%M%S')}"
out_dir = f'{out_dir}/{wandb_run_name}'
device_type = 'cuda' if 'cuda' in device else 'cpu'
config_keys = [k for k, v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))]
config = {k: globals()[k] for k in config_keys}  # will be useful for logging
print("*=" * 50)
for key in sorted(config.keys()):
    print(f"{key}: {config[key]}")
print("*=" * 50)
# -----------------------------------------------------------------------------

if struct == "none":
    GPT = DenseGPT
else:
    GPT = StructGPT
print(f"Eval interval: {eval_interval:,d}")

ddp = int(os.environ.get('RANK', -1)) != -1
if ddp:
    aux = init_dist_process(backend, gradient_accumulation_steps)
    master_process, seed_offset, ddp_world_size, ddp_local_rank, gradient_accumulation_steps = aux
else:
    master_process, seed_offset, ddp_world_size, ddp_local_rank = True, 0, 1, 0
tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size
print(f"tokens per iteration will be: {tokens_per_iter:,}")



if master_process:
    os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
# note: float16 data type will automatically use a GradScaler
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)

train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=data_dtype, mode='r')
dataset_size = len(train_data)
print(f"Total tokens: {dataset_size:,d}")
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=data_dtype, mode='r')


get_batch = partial(get_batch_all, train_data=train_data, val_data=val_data, batch_size=batch_size, block_size=block_size,
                    device=device, device_type=device_type)

for _ in (pbar := tqdm(range(max_iters))):
    X, Y = get_batch('train')