

####################################################################################################
# Imports.
####################################################################################################


import time

global_start_time = time.time()
import os
import socket
import json

from typing import TYPE_CHECKING, Any, Optional
import sys
import datetime
import shutil
from contextlib import nullcontext, contextmanager
import gc
import torch

# Check device health immediately after loading torch and standard libraries without loading cuda/hip:
nvml_count = torch.cuda._device_count_amdsmi() if torch.version.hip else torch.cuda._device_count_nvml()
if nvml_count < 1:
    raise ValueError(f"Node failure! Device manager init failed on {socket.gethostname()}")

import math
from transformers import AutoModelForCausalLM, AutoTokenizer, get_scheduler
from datasets import load_dataset, Dataset, load_from_disk

os.environ["WANDB_INIT_TIMEOUT"] = "1200"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import wandb

from lm_eval import evaluator
from lm_eval.models.huggingface import HFLM
from lm_eval.utils import make_table

USE_LOCAL_CODE = True
if USE_LOCAL_CODE:
    import modeling  # noqa

if TYPE_CHECKING:
    import torch.distributed
    import torch.version
    import torch._dynamo.config


from dataclasses import dataclass, field
from jsonargparse import CLI

import logging

logging.getLogger("lm_eval").setLevel(logging.ERROR)

end_time = time.time()
if int(os.getenv("SLURM_PROCID", "0")) == 0:
    print(f"{time.ctime()[:-5]}: Time to load libraries: {end_time - global_start_time:.02f} seconds.")


@dataclass
class CLISettings:
    run_name: str = "default-run"
    out_path: str = "outputs"
    # data
    dataset_location: str = "openai/gsm8k"
    model_name: str = "tomg-group-umd/huginn-0125"
    dataset_args: dict[str, Any] = field(default_factory=lambda: dict(q_col="question", a_col="answer"))
    dataset_config: str = "main"
    dataset_split: str = "train"
    max_seq_length: int = 512
    max_samples: Optional[int] = None

    # impl
    micro_batch_size: int = 8
    compile: bool = True
    optimizer: str = "PagedAdamW8bit"
    seed: int = 0
    precision: str = "bf16-mixed"
    gradient_checkpointing: bool = True
    local_only: bool = False  # do not attempt to look for model changes online
    low_mem_ddp: bool = True  # side-step pytorch ddp for the dumber version

    # training
    max_steps: int = 0
    epochs: int = 1
    global_batch_size: int = 64
    optim_config: dict[str, Any] = field(
        default_factory=lambda: dict(lr=5e-7, weight_decay=0.0, betas=(0.9, 0.95), eps=1e-8)
    )
    scheduler_args: dict[float, Any] = field(default_factory=lambda: dict(warmup=0.1, cooldown=0.1))  # type: ignore
    take_loss_over_all_tokens: bool = False  # for chat templated datasets default is to only supervise assistant tokens
    max_grad_norm: float = 1.0
    freeze_components: str = ""  # comma-separated list in string format
    # Optional: mess with sampling
    sampling_scheme: str = "poisson-lognormal"
    mean_recurrence: int = 32
    max_backprop_depth: int = 8

    # logging & eval
    use_wandb: bool = True
    log_interval: int = 4  # log every n-th microbatch step
    eval_task: str = "gsm8k"
    eval_interval: int = 100000000000000  # eval every n-th microbatch step using HF
    num_eval_examples: int = 400
    run_vllm_eval_after_training: bool = (
        False  # whether to run full vllm eval after training, requires reload from disk
    )
    save_final_checkpoint: bool = True

    def __post_init__(self):
        pass


@dataclass
class Message:
    role: str
    content: str


def is_main_process():
    if torch.distributed.is_initialized():
        return torch.distributed.get_rank() == 0
    else:
        return True


def seed_everything(seed):
    import random  # noqa: PLC0415
    import numpy as np  # noqa: PLC0415

    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)


def get_unwrapped_model(state):
    return state["model"].module if state["distributed"] else state["model"]


@torch.no_grad()
def _allreduce_chunk_stream(model, world_size=1, device=torch.device("cpu")):
    """Simple implementation with fixed MB chunks that can span gradients"""
    chunk_size = 1024 * 1024 * 64 // 4  # 64MB fp32 as in warmup

    chunk = torch.empty(chunk_size, dtype=torch.float32, device=device)
    chunk_index = 0
    param_refs = []

    for p in model.parameters():
        if p.grad is not None:
            grad_index = 0
            while grad_index < p.grad.numel():
                # Fold
                n = min(chunk_size - chunk_index, p.grad.numel() - grad_index)
                chunk[chunk_index : chunk_index + n] = p.grad.view(-1)[grad_index : grad_index + n]
                param_refs.append((p, grad_index, chunk_index, n))
                chunk_index += n
                grad_index += n
                if chunk_index == chunk_size:
                    # Average over ranks
                    torch.distributed.all_reduce(chunk)
                    chunk.div_(world_size)
                    # Unfold
                    for param, start_p, start_c, numel in param_refs:
                        param.grad.view(-1)[start_p : start_p + numel] = chunk[start_c : start_c + numel]
                    # Reset
                    chunk = torch.empty(chunk_size, dtype=torch.float32, device=device)
                    chunk_index = 0
                    param_refs = []
    # Handle final chunk:
    if chunk_index > 0:
        torch.distributed.all_reduce(chunk)  # keep consistent MB size
        chunk.div_(world_size)
        for param, start_p, start_c, numel in param_refs:
            param.grad.view(-1)[start_p : start_p + numel].copy_(chunk[start_c : start_c + numel])


class LowMemDDP(torch.nn.Module):
    def __init__(self, module, device_ids: list = [torch.device("cpu")], **kwargs):
        super().__init__()
        self.module = module
        self.backward_sync = True
        self.world_size = torch.distributed.get_world_size()
        self.local_device = device_ids[0]

    @contextmanager
    def no_sync(self):
        try:
            self.backward_sync = False
            yield
        finally:
            self.backward_sync = True

    def forward(self, *inputs, **kwargs):
        return self.module(*inputs, **kwargs)

    def explicitly_sync_gradients(self):
        # I originally wanted this to be a hook to mirror the DDP interface exactly, but all working solutions were terribly brittle
        if self.backward_sync:
            _allreduce_chunk_stream(self.module, world_size=self.world_size, device=self.local_device)


def maybe_freeze_model_components(model, freeze_components):
    """Allowed for now: wte,prelude,adapter,core_block,coda,ln_f,lm_head. Hopefully freezing lm_head == freezing wte"""
    for component in freeze_components.split(","):
        if len(component) > 0:
            if component in model.transformer or component == "lm_head":
                module = model.transformer[component] if component in model.transformer else model.lm_head  # noqa: SIM401
                for name, param in module.named_parameters():
                    param.requires_grad = False
                    print(f"Parameter {name} frozen.")
            else:
                print(f"Skipping unknown component {component}")
    return model


####################################################################################################
# Main driver functions.
####################################################################################################
DEFAULT_SYS_PROMPT = "You are a helpful assistant that can assist users with mathematical reasoning."


def startup(cfg: CLISettings):
    """The main setup function for the training script."""
    if cfg.seed > 0:
        seed_everything(cfg.seed)
    ##########    Comms              ##############
    rank = int(os.getenv("SLURM_PROCID", os.getenv("RANK", "0")))
    local_device = torch.device(f"cuda:{rank % torch.cuda.device_count()}")
    if torch.cuda.device_count() > 1:
        distributed = True
        torch.distributed.init_process_group(
            backend="nccl",
            rank=rank,
            world_size=int(os.getenv("SLURM_NTASKS", os.getenv("WORLD_SIZE", -1))),
            device_id=local_device,
            timeout=datetime.timedelta(hours=2),
        )
        world_size = torch.distributed.get_world_size()
        print(f"Comms formed on rank {rank} with device {local_device} out of world size {world_size}.")
    else:
        world_size = 1
        distributed = False
    torch.cuda.set_device(local_device)

    if cfg.precision == "bf16-true":
        torch.set_default_dtype(torch.bfloat16)
        weight_dtype = torch.bfloat16
        autocast_args = {"device_type": "cuda", "enabled": False, "dtype": torch.bfloat16}
    elif cfg.precision == "bf16-mixed":
        torch.set_default_dtype(torch.float32)
        weight_dtype = torch.float32
        autocast_args = {"device_type": "cuda", "enabled": True, "dtype": torch.bfloat16, "cache_enabled": False}

    ########## Model and tokenizer ##############
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        trust_remote_code=not USE_LOCAL_CODE,
        torch_dtype=weight_dtype,
        low_cpu_mem_usage=True,
        device_map="cuda",
        local_files_only=cfg.local_only,
    )
    model = maybe_freeze_model_components(model, cfg.freeze_components)
    if cfg.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, local_files_only=cfg.local_only)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    ##########  Distribute model   ##############
    if distributed:
        if cfg.low_mem_ddp:
            model = LowMemDDP(model, device_ids=[local_device])
        else:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[local_device],
                find_unused_parameters=True,
                gradient_as_bucket_view=True,  # static_graph=True?
            )
    if cfg.compile:
        model = torch.compile(model, fullgraph=False, dynamic=False, mode="max-autotune-no-cudagraphs")

    ##########     Data            ##############
    def format_and_tokenize_examples(examples):
        conversations = []
        for idx in range(len(examples[cfg.dataset_args["q_col"]])):
            if cfg.dataset_args["q_col"] == "messages":
                messages = examples["messages"][idx]
            elif cfg.dataset_args["q_col"] != "text":
                messages = [
                    Message(role="system", content=DEFAULT_SYS_PROMPT),
                    Message(role="user", content=examples[cfg.dataset_args["q_col"]][idx].strip()),
                    Message(role="Huginn", content=examples[cfg.dataset_args["a_col"]][idx].strip()),
                ]
            else:
                messages = tokenizer.bos_token + examples[cfg.dataset_args["q_col"]][idx].strip()
            conversations.append(messages)

        if cfg.dataset_args["q_col"] != "text":
            chat_encoding = tokenizer.apply_chat_template(
                conversations,
                tokenize=True,
                add_generation_prompt=False,
                return_assistant_tokens_mask=True,
                padding="max_length",
                max_length=cfg.max_seq_length + 1,
                return_tensors="pt",
                return_dict=True,
                truncation=True,
            )
            if cfg.take_loss_over_all_tokens:
                chat_encoding["assistant_masks"] = chat_encoding["attention_mask"]
        else:
            chat_encoding = tokenizer(
                conversations,
                padding="max_length",
                max_length=cfg.max_seq_length + 1,
                return_tensors="pt",
                truncation=True,
            )
            chat_encoding["assistant_masks"] = chat_encoding["attention_mask"].clone()

        return {
            "input_ids": chat_encoding["input_ids"],
            "mask": chat_encoding["assistant_masks"],
            "attention_mask": chat_encoding["attention_mask"],
        }

    cfg.token_id_col_name = "input_ids"  # type: ignore
    dataset_save_dir = f"{cfg.out_path}/{cfg.run_name}/dataset"
    if is_main_process():  # only do mapping on rank 0
        try:
            dataset: Dataset = load_dataset(cfg.dataset_location, cfg.dataset_config)[cfg.dataset_split]  # type: ignore
        except BaseException as e:
            print(e)
            dataset: Dataset = load_from_disk(cfg.dataset_location, cfg.dataset_config)  # type: ignore

        if cfg.max_samples is not None:
            dataset = dataset.select(range(cfg.max_samples))

        if os.path.exists(dataset_save_dir):  # delete any old dataset
            shutil.rmtree(dataset_save_dir)

        tokenized_dataset = dataset.map(
            format_and_tokenize_examples,
            num_proc=16,
            remove_columns=dataset.column_names,
            batched=True,
            batch_size=1024,
        )

    if distributed:  # load the dataset to other ranks
        if is_main_process():
            tokenized_dataset.save_to_disk(dataset_save_dir)
        torch.distributed.barrier()
        tokenized_dataset = load_from_disk(dataset_save_dir)
        torch.distributed.barrier()

    if rank == 0:
        idx = int(torch.randint(len(tokenized_dataset), (1,)))
        print(f"-----------------------------------Processed Data example idx {idx}:----------------------------")
        print(tokenized_dataset[idx])
        print(tokenizer.decode(tokenized_dataset[idx]["input_ids"], skip_special_tokens=False))
        print("--------------------------------------------------------------------------------------------")
    tokenized_dataset.set_format("pt")
    if distributed:
        sampler = torch.utils.data.DistributedSampler(
            tokenized_dataset,  # type: ignore
            shuffle=True,
            num_replicas=world_size,
            rank=rank,
            seed=cfg.seed,
        )
        dataloader = torch.utils.data.DataLoader(
            tokenized_dataset,  # type: ignore
            batch_size=cfg.micro_batch_size,
            sampler=sampler,
            pin_memory=True,
        )
    else:
        dataloader = torch.utils.data.DataLoader(
            tokenized_dataset,  # type: ignore
            batch_size=cfg.micro_batch_size,
            shuffle=True,
            pin_memory=True,
        )
    ##########     Optimizer       ##############
    if cfg.optimizer == "OffloadedAdamW":
        from torchao.optim import CPUOffloadOptimizer  # noqa: PLC0415 # optional

        optimizer = CPUOffloadOptimizer(model.parameters(), torch.optim.AdamW, **cfg.optim_config)

    elif cfg.optimizer == "AdamW8bit":
        from torchao.optim import Adam8bit  # noqa: PLC0415 # optional

        optimizer = Adam8bit(model.parameters(), **cfg.optim_config)
    elif cfg.optimizer == "PagedAdamW8bit":
        from bitsandbytes.optim import PagedAdamW8bit  # noqa: PLC0415 # optional

        optimizer = PagedAdamW8bit(model.parameters(), **cfg.optim_config)
    elif cfg.optimizer == "SGD":
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=cfg.optim_config["lr"],
            weight_decay=cfg.optim_config["weight_decay"],
            nesterov=True,
            momentum=0.95,
        )
    else:
        optimizer = torch.optim.AdamW(model.parameters(), **cfg.optim_config)
    ##########     Scheduler       ##############
    # refactored into step = micro_batch_step because math is hard
    assert cfg.micro_batch_size <= cfg.global_batch_size
    accumulation_steps = max(1, int(cfg.global_batch_size / cfg.micro_batch_size / world_size))
    max_training_steps = cfg.epochs * len(dataloader)  # mbs steps
    max_training_steps = min(cfg.max_steps, max_training_steps) if cfg.max_steps else max_training_steps
    num_warmup_steps = math.ceil(cfg.scheduler_args["warmup"] * max_training_steps)  # type: ignore
    num_decay_steps = math.ceil(cfg.scheduler_args["cooldown"] * max_training_steps)  # type: ignore
    if rank == 0:
        print(
            f"Running with {accumulation_steps} accum steps from a total {max_training_steps} of data steps:"
            f"of which {num_warmup_steps} are warm up and {num_decay_steps} cool down. Peak LR: {cfg.optim_config['lr']}"
        )
    scheduler = get_scheduler(
        name="warmup_stable_decay",
        optimizer=optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=max_training_steps,
        scheduler_specific_kwargs={"num_decay_steps": num_decay_steps},
    )
    ### Step Scheduler:

    state = {
        "model": model,
        "optimizer": optimizer,
        "tokenizer": tokenizer,
        "dataloader": dataloader,
        "distributed": distributed,
        "rank": rank,
        "scheduler": scheduler,
        "autocast_args": autocast_args,
        "step": 0,
    }

    cfg.world_size = world_size  # type: ignore
    if rank == 0:
        wandb.init(
            entity="jonasgeiping",
            project="huginn-post",
            name=cfg.run_name,
            config=cfg,  # type: ignore
            dir=cfg.out_path,
            mode="online" if cfg.use_wandb else "disabled",
        )
        num_trainable_parameters = sum((p.numel() for p in model.parameters() if p.requires_grad))
        wandb.summary["num_trainable_parameters"] = num_trainable_parameters
        wandb.summary["device"] = torch.cuda.get_device_name()
    return state, local_device


def recurrent_step_sampler(step, sampling_scheme, mean_recurrence=32, max_backprop_depth=8, lockstep=True):
    # Returns tensors so torch has an easier time compiling
    seed_n = 514229 + step
    if not lockstep and torch.distributed.is_initialized():
        seed_n = seed_n * (torch.distributed.get_rank() + 1)
    seed_k = 317811 + step
    if not lockstep and torch.distributed.is_initialized():
        seed_k = seed_k * (torch.distributed.get_rank() + 1)
    n_generator, k_generator = torch.Generator(device="cpu"), torch.Generator(device="cpu")
    n_generator.manual_seed(seed_n % (2**31 - 1))
    k_generator.manual_seed(seed_k % (2**31 - 1))

    t, s = max(mean_recurrence - max_backprop_depth, 0), max_backprop_depth
    if sampling_scheme == "poisson-lognormal":
        sigma = 0.5
        mu = math.log(t + s) - (sigma**2 / 2)
        rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma, generator=n_generator)
        p = torch.poisson(torch.tensor([rate], dtype=torch.float), generator=n_generator) + 1
        n = torch.clamp(p - s, min=0)
        k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
    elif sampling_scheme == "exponential":
        sample = torch.zeros((1,)).exponential_(lambd=1.0 / (t + s), generator=n_generator)
        p = sample.floor().to(torch.long) + 1
        n = torch.clamp(p - s, min=0)
        k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
    elif sampling_scheme == "simple":
        p = torch.randint(low=0, high=2 * t, size=(1,), generator=n_generator)
        n = torch.clamp(p - s, min=0)
        k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p))
    elif sampling_scheme == "simple-rand-k":
        n = torch.randint(low=0, high=2 * t, size=(1,), generator=n_generator)
        k = torch.randint(low=1, high=2 * s + 1, size=(1,), generator=k_generator)
    elif sampling_scheme == "avi":
        n = torch.randint(low=0, high=t * 2, size=(1,), generator=n_generator)
        k = torch.randint(low=1, high=1 + min(t * 2 - int(n.item()), s * 2), size=(1,), generator=k_generator)
    elif sampling_scheme == "fixed":
        n, k = torch.tensor(t), torch.tensor(s)
    return n.to(dtype=torch.long), k.to(dtype=torch.long)


def train(state, device, cfg):
    model, optimizer = state["model"], state["optimizer"]
    model.train()
    accumulation_steps = cfg.global_batch_size // cfg.micro_batch_size // cfg.world_size
    optimizer_step = 0
    step_time = time.time()
    tokens_in_step = 0
    running_loss = torch.as_tensor(0.0, dtype=torch.float32, device=device)

    for epoch in range(cfg.epochs):
        for inputs in state["dataloader"]:
            input_ids = inputs[cfg.token_id_col_name][:, :-1].to(dtype=torch.long, device=device, non_blocking=True)
            # Need to take into account the assistant and attention if sequences are being padded
            mask = ~(inputs["mask"].bool() & inputs["attention_mask"].bool())

            labels = torch.where(mask[:, 1:], -100, inputs[cfg.token_id_col_name][:, 1:]).to(
                dtype=torch.long, device=device, non_blocking=True
            )
            tokens_in_step += input_ids.numel()
            state["step"] += 1  # microbatch_step
            is_accumulating = state["step"] % accumulation_steps != 0

            # The actual compute step of  Forward, loss, and backward computation:
            def tightly_scoped_fwd_bwd(model, input_ids, labels):
                num_steps_pair = recurrent_step_sampler(
                    state["step"], cfg.sampling_scheme, cfg.mean_recurrence, cfg.max_backprop_depth
                )
                with model.no_sync() if is_accumulating and state["distributed"] else nullcontext():
                    with torch.autocast(**state["autocast_args"]):
                        outputs = model(input_ids, labels=labels, num_steps=num_steps_pair)
                    (outputs["loss"] / accumulation_steps).backward()
                    if state["distributed"] and cfg.low_mem_ddp:
                        model.explicitly_sync_gradients()
                    return outputs["loss"].detach(), num_steps_pair

            loss, (n, k) = tightly_scoped_fwd_bwd(model, input_ids, labels)
            running_loss += loss / accumulation_steps
            if not is_accumulating:
                total_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=cfg.max_grad_norm, norm_type=2.0
                )
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)  # making room for validation seqs
                optimizer_step += 1
                if state["distributed"]:
                    torch.distributed.all_reduce(running_loss, op=torch.distributed.ReduceOp.AVG)
                if state["rank"] == 0 and optimizer_step % cfg.log_interval == 0:
                    time_interval = (time.time() - step_time) / accumulation_steps
                    tok_sec = tokens_in_step * cfg.world_size / (time.time() - step_time)
                    lr = optimizer.param_groups[0]["lr"]
                    max_alloc = torch.cuda.max_memory_allocated() / 1024**3
                    max_res = torch.cuda.max_memory_reserved() / 1024**3
                    print(
                        f"Epoch {epoch:2d} | Step: {state['step']:4d} | Updates: {optimizer_step:4d} | LR: {lr:2.2e} | "
                        f"Loss: {running_loss:2.4f} / log-ppl: {running_loss.exp():2.4f} | \n"
                        f"         | Time/step: {time_interval:2.4f} | Tok/sec: {tok_sec:9.2f}    | "
                        f"Grad-Norm: {total_norm.item():2.4f} \n"
                        f"         | Max mem allocated: {max_alloc:4.2f} GB | Max mem reserved: {max_res:4.2f} GB"
                    )
                    wandb.log(
                        {
                            "train/step": state["step"],
                            "train/update": optimizer_step,
                            "train/epoch": epoch,
                            "train/lr": lr,
                            "train/total_samples": state["step"] * cfg.micro_batch_size * cfg.world_size,
                            "train/tok_sec": tok_sec,
                            "train/time_per_step": time_interval,
                            "train/num_steps_no_grad": n,
                            "train/num_steps_with_grad": k,
                            "train/grad_norm": total_norm.item(),
                            "train/loss": running_loss.item(),
                            "train/log_ppl": running_loss.exp().item(),
                        },
                        step=state["step"],
                    )
                tokens_in_step = 0
                running_loss.zero_()

            state["scheduler"].step()  # always step to simplify tracking
            if state["step"] % cfg.eval_interval == 0:
                validate(state, state["step"], cfg, task=cfg.eval_task)

            if not is_accumulating:
                step_time = time.time()  # reset time here

            if cfg.max_steps and state["step"] >= cfg.max_steps:
                model.eval()
                return state

    model.eval()
    return state


####################################################################################################
# Distributed Live eval
####################################################################################################


class TorchDistributedDefaultAccelerator:
    """Using this to spoof HuggingFace accelerate which is unfortunately woven into lm-eval"""

    def __init__(self, rank, size):
        self.process_index, self.num_processes = rank, size
        if self.num_processes > 1:
            assert torch.distributed.is_initialized()
            self.distributed = True
        else:
            self.distributed = False
        self.is_local_main_process = True

    def gather(self, tensor):
        if self.distributed:
            output_tensors = torch.empty(self.num_processes * tensor.numel(), dtype=tensor.dtype, device="cuda")
            torch.distributed.all_gather_into_tensor(output_tensors, tensor)
            return output_tensors.view(-1, *tensor.size()[1:])
        else:
            return tensor

    def wait_for_everyone(self):
        return

    def unwrap_model(self, model):
        return model

    @property
    def device(self):
        return torch.device("cuda")  # picked up from default device


def validate(state, step: int, cfg, task="gsm8k"):
    # eval on-the-fly
    unwrapped_model = get_unwrapped_model(state)
    unwrapped_model.eval()
    hflm_wrap = HFLM(
        pretrained=unwrapped_model,
        tokenizer=state["tokenizer"],
        batch_size=4,
    )
    if state["distributed"]:
        hflm_wrap._rank = state["rank"]
        hflm_wrap._world_size = cfg.world_size
        hflm_wrap.accelerator = TorchDistributedDefaultAccelerator(state["rank"], cfg.world_size)  # type: ignore # crime

    with torch.autocast(**state["autocast_args"]):  # otherwise eval runs in fp32 for bf16-mixed which takes years
        results = evaluator.simple_evaluate(
            model=hflm_wrap,
            tasks=[task],
            apply_chat_template=True,
            fewshot_as_multiturn=True,
            system_instruction=DEFAULT_SYS_PROMPT,
            limit=cfg.num_eval_examples,
            num_fewshot=0,
            gen_kwargs={"num_steps": cfg.mean_recurrence},
        )

    results_by_step = {}
    if state["rank"] == 0 and results is not None:
        print(make_table(results))
        if "groups" in results:
            print(make_table(results, "groups"))
        results_by_step[str(step)] = results["results"][task]

        os.makedirs(f"{cfg.out_path}/{cfg.run_name}", exist_ok=True)
        with open(f"{cfg.out_path}/{cfg.run_name}/eval.json", "a") as f:
            json.dump(results_by_step, f)

        metrics = {f"eval/{k.replace(',', '_')}": v for k, v in results["results"][task].items()}
        metrics["eval/label"] = step
        wandb.log(metrics, step=step)

    unwrapped_model.train()


def validate_vllm(model_path, step: int, cfg, task="gsm8k", single_worker_workaround=False):
    """For now only as offline variant that evals a saved model checkpoint. Requires vllm==0.9.2"""
    if torch.distributed.is_initialized():
        if single_worker_workaround:
            rank_before_fall = torch.distributed.get_rank()
            torch.distributed.destroy_process_group()
            if rank_before_fall > 0:
                sys.exit(0)
        else:
            torch.distributed.barrier()
    gc.collect()
    torch.cuda.empty_cache()
    max_alloc = torch.cuda.memory_allocated() / 1024**3
    max_res = torch.cuda.memory_reserved() / 1024**3
    fraction_vis = torch.cuda.get_per_process_memory_fraction()
    print(f"Mem allocated before vllm: {max_alloc:4.2f} GB | Mem reserved: {max_res:4.2f} GB | frac: {fraction_vis}")
    print(cfg.world_size)
    results_by_step = {}
    metrics = {}
    start_time = time.time()
    eval_range = [1, 4, 16, 32, 48, 64]
    eval_table_data = []  # for wandb

    for r in eval_range:
        results = evaluator.simple_evaluate(
            model="vllm",
            model_args=dict(
                pretrained=os.path.realpath(model_path),
                tokenizer=str(cfg.model_name),
                trust_remote_code=True,
                dtype="bfloat16",
                tensor_parallel_size=1
                if not single_worker_workaround
                else 1,  # cant get this to work easily, exiting early on all others,
                gpu_memory_utilization=0.5,  # :( # cannot get HF to free all remaining memory on h100????
                data_parallel_size=1,
                max_model_len=1024,
                hf_overrides={"mean_recurrence": r},
            ),
            batch_size="auto",
            tasks=[task],
            apply_chat_template=True,
            fewshot_as_multiturn=True,
            system_instruction=DEFAULT_SYS_PROMPT,
            num_fewshot=0,
        )

        if is_main_process() and results is not None:
            print(make_table(results))
            if "groups" in results:
                print(make_table(results, "groups"))
            results_by_step[f"{step}_{r}"] = results["results"][task]
            row_data = {"step": step, "recurrence_r": r}
            # Add all metrics to the row
            for metric_name, value in results["results"][task].items():
                clean_name = metric_name.replace(",", "_")
                row_data[clean_name] = value
                # backward compatibility
                if r == 32:
                    metrics[f"eval/{clean_name}"] = value
                metrics[f"eval/scalar_{clean_name}_{r}"] = value
            eval_table_data.append(row_data)

    if is_main_process():
        # Save results to json as before
        os.makedirs(f"{cfg.out_path}/{cfg.run_name}", exist_ok=True)
        with open(f"{cfg.out_path}/{cfg.run_name}/eval.json", "a") as f:
            json.dump(results_by_step, f)

        # Create and log the WandB table with all evaluation data
        columns = list(eval_table_data[0].keys())
        data_rows = [[row[col] for col in columns] for row in eval_table_data]
        # Create table with explicit columns
        eval_table = wandb.Table(columns=columns, data=data_rows)

        # Log the table and other metrics
        metrics[f"eval/tabled_acc_{step}"] = eval_table
        metrics["eval/time"] = time.time() - start_time
        metrics["eval/label"] = step

        wandb.log(metrics, step=step)


def save_checkpoint(state, path):
    get_unwrapped_model(state).save_pretrained(path)
    state["tokenizer"].save_pretrained(path)


####################################################################################################
# Main control loop
####################################################################################################


def main():
    """Encapsulates main scope away from import calls."""

    # Configuration loader
    cfg: CLISettings = CLI(CLISettings)

    # Print system setup hello on all devices
    print("--------------------------------------------------------------------")
    print(f"------------------ Launching run {cfg.run_name}--------------------")
    print("--------------------------------------------------------------------")
    print("--------------------------------------------------------------------")
    print(f"Platform: {sys.platform}, Python: {sys.version.split(' (')[0]}, PyTorch: {torch.__version__}")
    print(f"CPU threads: {torch.get_num_threads()}, GPUs: {torch.cuda.device_count()} on {socket.gethostname()}.")
    driver = f"HIP/ROCM {torch.version.hip}" if torch.version.hip else f"CUDA: {torch.version.cuda}"
    print(f"GPU : {torch.cuda.get_device_name()}. {driver}.")

    # set flags
    torch.set_float32_matmul_precision("high")
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True  # Should be true anyway
    torch._dynamo.config.optimize_ddp = "python_reducer"
    torch._dynamo.config.compiled_autograd = False

    train_time = time.time()

    state, device = startup(cfg)
    main_process_flag = is_main_process()
    print("Startup completed")
    state = train(state, device, cfg)
    print("Training process complete. Offloading model.")

    # VLLM eval, a bit messy to orchestrate
    if cfg.save_final_checkpoint or cfg.run_vllm_eval_after_training:
        if main_process_flag:
            save_checkpoint(state, f"{cfg.out_path}/{cfg.run_name}/final_checkpoint")
        try:
            if cfg.run_vllm_eval_after_training:
                last_step = state["step"]
                del state
                validate_vllm(f"{cfg.out_path}/{cfg.run_name}/final_checkpoint", last_step, cfg, task=cfg.eval_task)
        finally:
            if not cfg.save_final_checkpoint and main_process_flag:
                shutil.rmtree(f"{cfg.out_path}/{cfg.run_name}/final_checkpoint")

    # Now exit
    if main_process_flag:
        print("--------------------------------------------------------------------")
        print(f"Training time: {str(datetime.timedelta(seconds=time.time() - train_time))} ")
        max_alloc = f"{torch.cuda.max_memory_allocated(device) / float(1024**3):,.3f} GB"
        max_reserved = f"{torch.cuda.max_memory_reserved(device) / float(1024**3):,.3f} GB"
        print(f"Max. Mem allocated: {max_alloc}. Max. Mem reserved: {max_reserved}.")
        print("--------------------------------------------------------------------")
        wandb.finish()
        dataset_save_dir = f"{cfg.out_path}/{cfg.run_name}/dataset"
        if os.path.exists(dataset_save_dir):
            shutil.rmtree(dataset_save_dir)


def shutdown():
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
    print(f"---------Total time: {str(datetime.timedelta(seconds=time.time() - global_start_time))} ---------")
    print("-----------------Shutdown complete.--------------------------")


def guarded_main():
    try:
        run_name = main()
        print("--------------------------------------------------------------------")
        print(f"Run {run_name} finished without error.")
    except BaseException:
        print("--------------------------------------------------------------------")
        print("Run finished with errors.")
        raise
    finally:
        shutdown()  # guarantee NCCL deconstruction


if __name__ == "__main__":
    guarded_main()
