import os
import sys
import uuid
import glob
import time
from dataclasses import dataclass

import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torch.distributed as dist
import torch._inductor.config as config
from torch.nn.parallel import DistributedDataParallel as DDP
from models.model_time_token import GPT, GPTConfig
from optimizer import Muon
import wandb
import numpy as np
from omegaconf import OmegaConf
config = OmegaConf.load(sys.argv[1])

import importlib

def load_class(module_name, class_name):
    module = importlib.import_module(module_name)
    return getattr(module, class_name)
# FM utils
fm_type = load_class("FM_utils",config.fm.type)
fm = fm_type(**dict(config.fm_config))


# -----------------------------------------------------------------------------
# Our own simple Distributed Data Loader

import boto3
from botocore.config import Config as boto_config

config_boto = boto_config(
    read_timeout=2000,
    connect_timeout=2000,
    retries={"max_attempts": 13}
)
s3 = None


def _peek_data_shard(filename,iters_try=1000):
    # only reads the header, returns header data
    try:
        with open(filename, "rb") as f:
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
        assert header[0] == 20240520
        assert header[1] == 1
        ntok = header[2] # number of tokens (claimed)
        return ntok # for 
    except:
        pass
    for i in range(iters_try):
        try:
            f = s3.get_object(Bucket='fineweb',Key=filename.split("/")[-1])['Body']
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520
            assert header[1] == 1
            ntok = header[2] # number of tokens (claimed)
            return ntok # for now just return the number of tokens
        except:
            pass
    assert False, "Iters to read file is out"

def _load_data_shard(filename,iters_try=1000):
    try:
        with open(filename, "rb") as f:
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520, "magic number mismatch in the data .bin file"
            assert header[1] == 1, "unsupported version"
            ntok = header[2] # number of tokens (claimed)
            # the rest of it are tokens, stored as uint16
            if num_vocab < 2**16:
                dtype = np.uint16
            else:
                dtype = np.uint32
            tokens = np.frombuffer(f.read(), dtype=dtype)
        assert len(tokens) == ntok, "number of tokens read does not match header?"
        return tokens
    except:
        pass
    for i in range(iters_try):
        try:
            f = s3.get_object(Bucket='fineweb',Key=filename.split("/")[-1])['Body']
            # first read the header, which is 256 int32 integers (4 bytes each)
            header = np.frombuffer(f.read(256*4), dtype=np.int32)
            assert header[0] == 20240520, "magic number mismatch in the data .bin file"
            assert header[1] == 1, "unsupported version"
            ntok = header[2] # number of tokens (claimed)
            # the rest of it are tokens, stored as uint16
            if num_vocab < 2**16:
                dtype = np.uint16
            else:
                dtype = np.uint32
            tokens = np.frombuffer(f.read(), dtype=dtype)
            assert len(tokens) == ntok, "number of tokens read does not match header?"
            return tokens
        except:
            pass
    assert False, "Iters to read file is out"

class DistributedDataLoader:
    def __init__(self, vocab_size, filename_pattern, B, T, process_rank, num_processes,num_tokens_done=0, condition=False):
        self.process_rank = process_rank
        self.num_processes = num_processes
        self.B = B
        self.T = T
        if condition:
            self.T *= 2
        self.vocab_size = vocab_size

        # glob files that match the pattern
        self.files = sorted(glob.glob(filename_pattern))
        assert len(self.files) > 0, f"did not find any files that match the pattern {filename_pattern}"

        # load and validate all data shards, count number of tokens in total
        ntok_total = 0
        for fname in self.files:
            shard_ntok = _peek_data_shard(fname)
            assert shard_ntok >= num_processes * B * T + 1
            ntok_total += int(shard_ntok)
        self.ntok_total = ntok_total
        self.num_tokens_done = num_tokens_done
        self.condition = condition
        # kick things off
        self.reset()

    def reset(self):
        self.current_shard = 0 + (self.num_tokens_done // (10**8)) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T + self.num_tokens_done % (10**8)
        self.tokens = _load_data_shard(self.files[self.current_shard])
        if self.current_position + (self.B * self.T * self.num_processes+1) > len(self.tokens):
            self.advance()

    def advance(self): # advance to next data shard
        self.current_shard = (self.current_shard + 1) % len(self.files)
        self.current_position = self.process_rank * self.B * self.T
        self.tokens = _load_data_shard(self.files[self.current_shard])

    def next_batch(self):
        B = self.B
        T = self.T
        buf = self.tokens[self.current_position : self.current_position+B*T]
        buf = torch.tensor(buf.astype(np.int32), dtype=torch.long)
        x1 = buf.view(B, T) # targets
        mask = None
        if self.condition:
            x1, mask = x1.chunk(2,dim=1)
            mask = mask.bool().cuda()
        x1 = x1.cuda()
        x0 = fm.sampler_0(x1,self.vocab_size)
        t,xt = fm.interpolate(x0,x1,mask=mask)
        # advance current position and load next shard if necessary
        self.current_position += B * T * self.num_processes
        if self.current_position + (B * T * self.num_processes+1) > len(self.tokens):
            self.advance()
        return t.cuda(), xt, x1


# -----------------------------------------------------------------------------
# int main

# @dataclass
class Hyperparameters:
    input_bin = config.data.input_bin
    input_val_bin = config.data.input_val_bin
    data_condition = config.data.condition
    input_bin_pretrain = config.data.input_bin_pretrain
    input_val_bin_pretrain = config.data.input_val_bin_pretrain
    num_tokens_to_train = config.training_config.num_tokens_to_train * 10**9
    batch_size = config.training_config.batch_size 
    device_batch_size = config.training_config.device_batch_size
    sequence_length = config.data.sequence_length
    embed_learning_rate = config.optimizer.embed_learning_rate
    muon_learning_rate = config.optimizer.muon_learning_rate
    warmup_iters = config.optimizer.warmup_iters
    warmdown_iters = config.optimizer.warmdown_iters 
    weight_decay = config.optimizer.weight_decay
    val_loss_every = config.training_config.val_loss_every
    save_every = config.training_config.save_every
    project_name = config.training_config.project_name
    run_name = config.training_config.run_name
    checkpoint = config.training_config.checkpoint
    pretrain = config.training_config.pretrain
args = Hyperparameters()

if args.checkpoint is not None:
    ckpt_path = args.checkpoint
    start_iter = torch.load(ckpt_path,map_location="cpu")["step"]
elif args.pretrain is not None:
    ckpt_path = args.pretrain
    start_iter = 0
else:
    ckpt_path = None
    start_iter = 0

args.num_iterations = int(args.num_tokens_to_train / args.batch_size / args.sequence_length)

# set up DDP (distributed data parallel). torchrun sets this env variable
assert torch.cuda.is_available()
dist.init_process_group(backend='nccl')
ddp_rank = int(os.environ['RANK'])
ddp_local_rank = int(os.environ['LOCAL_RANK'])
ddp_world_size = int(os.environ['WORLD_SIZE'])
device = f'cuda:{ddp_local_rank}'
torch.cuda.set_device(device)
print(f"using device: {device}")
master_process = (ddp_rank == 0) # this process will do logging, checkpointing etc.
args.val_tokens = args.device_batch_size * ddp_world_size * args.sequence_length * 200
# convenience variables
B, T = args.device_batch_size, args.sequence_length
# calculate the number of steps to take in the val loop.
assert args.val_tokens % (B * T * ddp_world_size) == 0
val_steps = args.val_tokens // (B * T * ddp_world_size)
# calculate the steps of gradient accumulation required to attain the desired global batch size.
assert args.batch_size % (B * ddp_world_size) == 0
train_accumulation_steps = args.batch_size // (B * ddp_world_size)

# load tokens
num_vocab = config.model.vocab_size
num_tokens_done = start_iter * args.batch_size * args.sequence_length
train_loader = DistributedDataLoader(num_vocab, args.input_bin, B, T, ddp_rank, ddp_world_size,num_tokens_done=num_tokens_done,condition=args.data_condition)
val_loader = DistributedDataLoader(num_vocab, args.input_val_bin, B, T, ddp_rank, ddp_world_size,condition=args.data_condition)
if args.input_bin_pretrain is not None and args.input_val_bin_pretrain is not None:
    pretrain_data = True
    train_loader_pretrain = DistributedDataLoader(num_vocab, args.input_bin_pretrain, B, T, ddp_rank, ddp_world_size,num_tokens_done=num_tokens_done)
    val_loader_pretrain = DistributedDataLoader(num_vocab, args.input_val_bin_pretrain, B, T, ddp_rank, ddp_world_size)
else:
    pretrain_data = False
if master_process:
    print(f"Training DataLoader: total number of tokens: {train_loader.ntok_total} across {len(train_loader.files)} files")
    print(f"Validation DataLoader: total number of tokens: {val_loader.ntok_total} across {len(val_loader.files)} files")
t, xt, y = train_loader.next_batch()
if pretrain_data:
    t_p, xt_p, y_p = train_loader_pretrain.next_batch()
# +
# there are only 50257 unique GPT-2 tokens; we extend to nearest multiple of 128 for efficiency. suggested to me by @Grad62304977.
# this originates from Karpathy's experiments.
model_type = load_class(f"models.{config.model_type}","GPT")
model = model_type(config.model)
if ckpt_path is not None:
    ckpt = torch.load(ckpt_path,map_location="cpu")["model"]
    ckpt = {name.replace("_orig_mod.",""):value for name,value in ckpt.items()}
    model.load_state_dict(ckpt,strict=False)
    if master_process:
        print("Model loaded from checkpoint")
    
model = model.cuda()
if hasattr(config, "coordinate_descent_tuning"):
    config.coordinate_descent_tuning = True # suggested by @Chillee
model = torch.compile(model)
# here we wrap model into DDP container
model = DDP(model, device_ids=[ddp_local_rank])
raw_model = model.module # always contains the "raw" unwrapped model
ctx = torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16)
# -

if master_process:
    if config.training_config.wandb_key is not None:
        wandb.login(key=config.training_config.wandb_key)
        wandb.init(
            project=args.project_name,
            name=args.run_name,
        )

# init the optimizer(s)
params = list(raw_model.transformer.h.parameters())
try:
    params += [raw_model.skip_weights]
except:
    pass
optimizer1 = torch.optim.AdamW(list(raw_model.transformer.wte.parameters())+
                               list(raw_model.t_embedder.parameters())+
                               list(raw_model.lm_head.parameters())+
                               [p for p in params if len(p.shape) != 2], 
                               lr=args.embed_learning_rate, betas=(0.9, 0.95),
                               weight_decay=args.weight_decay, 
                               fused=True
                              )
optimizer2 = Muon([p for p in params if len(p.shape) == 2], lr=args.muon_learning_rate, momentum=0.95,
                  rank=ddp_rank, world_size=ddp_world_size)
optimizers = [optimizer1, optimizer2]
if ckpt_path is not None:
    for i in range(len(optimizers)):
        try:
            optimizers[i].load_state_dict(torch.load(ckpt_path,map_location="cpu")["optimizers"][i])
        except:
            print(f"Optimizer {i} not found in checkpoint")
    if master_process:
        print("Optimizers loaded from checkpoint")
# learning rate decay scheduler (linear warmup and warmdown)
def get_lr(it):
    assert it <= args.num_iterations
    # 1) linear warmup for warmup_iters steps
    if it < args.warmup_iters:
        return (it+1) / args.warmup_iters
    # 2) constant lr for a while
    elif it < args.num_iterations - args.warmdown_iters:
        return 1.0
    # 3) linear warmdown
    else:
        decay_ratio = (args.num_iterations - it) / args.warmdown_iters
        return decay_ratio
schedulers = [torch.optim.lr_scheduler.LambdaLR(opt, get_lr) for opt in optimizers]

# begin logging
if master_process:
    run_id = args.run_name
    logdir = 'logs/%s/' % run_id
    os.makedirs(logdir, exist_ok=True)

training_time_ms = 0
# start the clock
torch.cuda.synchronize()
t0 = time.time()
# begin training
train_loader.reset()
for step in range(start_iter, args.num_iterations + 1):
    last_step = (step == args.num_iterations)
    # This effectively ignores timing first 10 steps, which are slower for weird reasons.
    # Alternately, and slightly more correctly in terms of benchmarking, we could do 10
    # steps with dummy data first, and then re-initialize the model and reset the loader.
    if step - start_iter == 10:
        training_time_ms = 0
        t0 = time.time()
    timed_steps = float('nan') if step - start_iter <= 11 else (step - 10 - start_iter) + 1 # <= 11 to avoid bug in val

    # once in a while evaluate the validation dataset
    if (last_step or (args.val_loss_every > 0 and step % args.val_loss_every == 0)):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # run validation batches
        model.eval()
        val_loader.reset()
        val_loss = 0.0
        for _ in range(val_steps):
            t_val, x_val, y_val = val_loader.next_batch()
            with ctx: # of course, we'd like to use no_grad() here too, but that creates a torch.compile error for some reason
                _, loss = model(t_val, x_val, y_val, return_logits=False)
                val_loss += loss.detach()
                del loss
        dist.all_reduce(val_loss, op=dist.ReduceOp.AVG)
        val_loss /= val_steps
        # log val loss to console and to logfile
        if master_process:
            print(f'step:{step}/{args.num_iterations} val_loss:{val_loss:.4f} train_time:{training_time_ms:.0f}ms step_avg:{training_time_ms/(timed_steps-1):.2f}ms')
            
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    if master_process and (last_step or (args.save_every > 0 and step % args.save_every == 0)):
        # stop the clock
        torch.cuda.synchronize()
        training_time_ms += 1000 * (time.time() - t0)
        # save the state of the training process
        log = dict(step=step, model=model.module.state_dict(), optimizers=[opt.state_dict() for opt in optimizers])
        torch.save(log, 'logs/%s/ckpt.pt' % (run_id,))
        # start the clock again
        torch.cuda.synchronize()
        t0 = time.time()

    # bit confusing: we want to make sure to eval on 0th iteration
    # but also after the very last iteration. so we loop for step <= num_iterations
    # instead of just < num_iterations (one extra due to <=), only to do
    # the validation/sampling one last time, and then we break right here as we're done.
    if last_step:
        break

    # --------------- TRAINING SECTION BEGIN -----------------
    model.train()
    for i in range(1, train_accumulation_steps+1):
        # forward pass
        with ctx:
            _, loss = model(t, xt, y, return_logits=False)
            if pretrain_data:
                _, loss_p = model(t_p, xt_p, y_p, return_logits=False)
                loss = 0.5 * loss + 0.5 * loss_p
            train_loss = loss.detach()
        # advance the dataset for the next batch
        t, xt, y = train_loader.next_batch()
        if pretrain_data:
            t_p, xt_p, y_p = train_loader_pretrain.next_batch()
        # backward pass
        if i < train_accumulation_steps:
            with model.no_sync(): # there's no need to sync gradients every accumulation step
                loss.backward()
        else:
            loss.backward() # just sync on the last step
    for p in model.parameters():
        p.grad /= train_accumulation_steps
    # step the optimizers and schedulers
    for opt, sched in zip(optimizers, schedulers):
        opt.step()
        sched.step()
    # null the gradients
    model.zero_grad(set_to_none=True)
    # --------------- TRAINING SECTION END -------------------
    # everything that follows now is just diagnostics, prints, logging, etc.

    #dist.all_reduce(train_loss, op=dist.ReduceOp.AVG) # all-reducing the training loss would be more correct in terms of logging, but slower
    if master_process:
        approx_time = training_time_ms + 1000 * (time.time() - t0)
        print(f"step:{step+1}/{args.num_iterations} train_loss:{train_loss.item():.4f} train_time:{approx_time:.0f}ms step_avg:{approx_time/timed_steps:.2f}ms")
        if step % 10 == 0 and config.training_config.wandb_key is not None:
            wandb.log({"loss":train_loss.item(),"log_loss":np.log(train_loss.item())})

if master_process and config.training_config.wandb_key is not None:
    wandb.finish()
    print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

# -------------------------------------------------------------------------
# clean up nice
dist.destroy_process_group()
