import math
import os
import time
from contextlib import nullcontext
from datetime import datetime
from functools import partial

import torch
from model import Transformer, ModelArgs
from torch.distributed import destroy_process_group, init_process_group

from tinystories import Task
from export import model_export

import torch._dynamo
torch._dynamo.config.suppress_errors = True
from tqdm import tqdm

import time 
import numpy as np

# -----------------------------------------------------------------------------
# I/O
init_from = "scratch"  # 'scratch' or 'resume'
# wandb logging
wandb_log = True  # disabled by default
# data
batch_size = 1024  # if gradient_accumulation_steps > 1, this is the micro-batch size
max_seq_len = 6
vocab_source = "custom" # llama2|custom; use Lllama 2 vocab from Meta, or custom trained
vocab_size = 64 # the Llama 2 tokenizer has 32K tokens
# model
dim = 256
n_layers = 12
n_heads = 6
n_kv_heads = 6
multiple_of = 32
dropout = 0.0
pos_enc = "off"
# adamw optimizer
# gradient_accumulation_steps = 1  # used to simulate larger batch sizes
learning_rate = 5e-4  # max learning rate
weight_decay = 1e-5
beta1 = 0.9
beta2 = 0.95
# grad_clip = 1.0  # clip gradients at this value, or disable if == 0.0
# learning rate decay settings
decay_lr = True  # whether to decay the learning rate
warmup_iters = 5  # how many steps to warm up for
# system
device = "cuda"  # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
dtype = "float16"  # float32|bfloat16|float16
compile = True  # use PyTorch 2.0 to compile the model to be faster
# -----------------------------------------------------------------------------
config_keys = [
    k
    for k, v in globals().items()
    if not k.startswith("_") and isinstance(v, (int, float, bool, str))
]
exec(open("./configurator.py").read())  # overrides from command line or config file
config = {k: globals()[k] for k in config_keys}  # will be useful for logging

num_epochs = 10000


import json 
from tokenizer import Tokenizer

DATA_PROCESS_DIR = "./TinyStories_processing_files"
SAVE_DIR = "./"


# validating checks
assert vocab_source in ["llama2", "custom"]
assert vocab_source == "custom" or vocab_size == 32000, "The vocab from Meta has 32K tokens"


master_process = True
seed_offset = 0

tokens_per_iter = batch_size * max_seq_len


torch.manual_seed(1337 + seed_offset)
torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
device_type = "cuda" if "cuda" in device else "cpu"  # for later use in torch.autocast
# note: float16 data type will automatically use a GradScaler
ptdtype = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}[dtype]

# model init
model_args = dict(
    dim=dim,
    n_layers=n_layers,
    n_heads=n_heads,
    n_kv_heads=n_kv_heads,
    vocab_size=vocab_size,
    multiple_of=multiple_of,
    max_seq_len=max_seq_len,
    dropout=dropout,
    pos_enc=pos_enc,
)  # start with model_args from command line

print("Initializing a new model from scratch")
gptconf = ModelArgs(**model_args)
model = Transformer(gptconf)


model.to(device)

# initialize a GradScaler. If enabled=False scaler is a no-op
scaler = torch.cuda.amp.GradScaler(enabled=(dtype == "float16"))

# optimizer
optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type)

# compile the model
if compile:
    print("compiling the model... (takes a ~minute)")
    unoptimized_model = model
    model = torch.compile(model)  # requires PyTorch 2.0



# logging
if wandb_log and master_process:
    import wandb
    wandb.init(
        # set the wandb project where this run will be logged
        project="NTP_Deepnet", name="correct_V23_12L_v64_s100_d256_5e-4_total_10000_step_lrdecay_exp"
    )

# training loop
t0 = time.time()
running_mfu = -1.0



from loggers_full import * 
import psutil

iteration_counter = 0
log_iteration = 0

# fixing some hyperparams to sensible defaults
lr_decay_iters = num_epochs 
min_lr = 0.0  # minimum learning rate, should be ~= learning_rate/10 per Chinchilla

# learning rate decay scheduler (cosine with warmup)
def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return min_lr
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return min_lr + coeff * (learning_rate - min_lr)

# def get_lr(it, lr):
#     # 1) linear warmup for warmup_iters steps
#     if it < warmup_iters:
#         return learning_rate * it / warmup_iters
#     # 2) if it > lr_decay_iters, return min learning rate
#     if it > lr_decay_iters:
#         return min_lr
    
#     if it in [50,100,150,200,250]:
#         lr = lr / 10
    
#     return lr
    
from tinystories import PretokDataset
ds = PretokDataset("train", max_seq_len, vocab_size, vocab_source)
dl = torch.utils.data.DataLoader(
    ds, batch_size=batch_size, pin_memory=True, num_workers=0
)


total_loss_list = []
total_loss_per_batch_ave_list = []

WB_epochs_list = []
WB_loss_list = []
WB_lr_list = []

lr = learning_rate
for epoch in range(0, num_epochs):

    print("_" * 100)
    print("Epoch " + str(epoch))
    print("_" * 100)

    epoch_time = time.time()

    lr = get_lr(epoch) if decay_lr else learning_rate
    # lr = get_lr(epoch, lr) if decay_lr else learning_rate
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    
    num_loss_calc = 0
    total_loss = 0
    total_loss_per_batch_ave = 0
    num_batches = 0


    num_datapoints = 0
    for batch_idx, (X, Y) in tqdm(enumerate(dl)):


        X = X.to(device, non_blocking=True)
        Y = Y.to(device, non_blocking=True)

        logits, h = model(X, Y)

        loss = model.last_loss
        iteration_counter += 1

        total_loss += loss.item()
        total_loss_per_batch_ave += loss.item()

        # backward pass, with gradient scaling if training in fp16
        scaler.scale(loss).backward()

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        # timing and logging
        t1 = time.time()
        dt = t1 - t0
        t0 = t1

        num_loss_calc += 1

        num_datapoints += Y.shape[0]

        num_batches += 1

    total_loss_list.append(total_loss / num_datapoints)
    total_loss_per_batch_ave_list.append(total_loss_per_batch_ave / num_batches)

    WB_epochs_list.append(epoch)
    WB_loss_list.append(total_loss / num_datapoints)
    WB_lr_list.append(lr)
    weight_norm = torch.norm(model.output.weight, p = 2).detach().item()


    if wandb_log:
        try:
            wandb.log(
                {
                    "epoch": epoch,
                    "loss/train": total_loss / num_datapoints,
                    "lr": lr,
                    "W norm": weight_norm,
                }, step = epoch
            )
        except Exception as e:
            print(f"logging to wandb failed: {e}")

    
    epoch_time = time.time() - epoch_time
    print("*" * 100)
    print("This epoch took " + str(epoch_time / 60) + " minutes.")
    print("*" * 100)
    print("num_datapoints: " + str(num_datapoints))
    print("Logging the info ...")
    model.eval()   

    if log_iteration % 10 == 0 or log_iteration in [0, 1, 3, 5, num_epochs - 5, num_epochs - 2, num_epochs]:
        torch.save(model.state_dict(), SAVE_DIR + "model_it" + str(log_iteration) + ".pth")
        torch.save(optimizer.state_dict(), SAVE_DIR + "optimizer_it" + str(log_iteration) + ".pth")
        with open(SAVE_DIR + "loss_list.npy", 'wb') as f:
            np.save(f, total_loss_list)
        with open(SAVE_DIR + "loss_per_batch_ave_list.npy", 'wb') as f:
            np.save(f, total_loss_per_batch_ave_list)
    log_iteration += 1
    

    model.train()
    optimizer.zero_grad(set_to_none=True)

    # Getting % usage of virtual_memory ( 3rd field)
    print('RAM memory % used:', psutil.virtual_memory()[2])
    print('RAM Used (GB):', psutil.virtual_memory()[3]/1000000000)
    print("GPU Memory Allocated:", torch.cuda.memory_allocated(device) / (1024 ** 3), "GB")
    print("Total Loss : " + str(total_loss))
    print("Total Loss (per batch average): " + str(total_loss_per_batch_ave / num_batches))
    print("Average Loss : " + str(total_loss / 521404))
    print("Learning rate: " + str(lr))
    print("Done. ")

