import math
import os
import sys
import time
import logging
from itertools import chain

import numpy as np
import random
import json
import torch
import torch.distributed as dist
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm

from model_config import ModelConfig
from tasks import TaskRegistry
from train_arguments import get_argument_parser
from turing.logger import Logger
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear, warmup_linear_decay_exp, warmup_exp_decay_exp, cosine_poly_warmup_decay
from utils import is_time_to_exit, master_process

from dataset_utils import ShardedDatasetWrapper, create_dataloader

from clearml import Task
import deepspeed

global_step = 0
global_data_samples = 0
last_global_step_from_restore = 0
all_step_time = 0.0

SUMMARY_WRITER_DIR_NAME = 'runs'


def get_sample_writer(name, base=".."):
    """Returns a tensorboard summary writer
    """
    return SummaryWriter(
        log_dir=os.path.join(base, SUMMARY_WRITER_DIR_NAME, name))


def checkpoint_model(PATH, ckpt_id, model, epoch, last_global_step,
                     last_global_data_samples, **kwargs):
    """Utility function for checkpointing model + optimizer dictionaries
       The main purpose for this is to be able to resume training from that instant again
    """
    checkpoint_state_dict = {
        'epoch': epoch,
        'last_global_step': last_global_step,
        'last_global_data_samples': last_global_data_samples
    }
    # Add extra kwargs too
    checkpoint_state_dict.update(kwargs)

    success = model.save_checkpoint(PATH, ckpt_id,
                                            checkpoint_state_dict)
    status_msg = 'checkpointing: PATH={}, ckpt_id={}'.format(PATH, ckpt_id)
    if success:
        logging.info(f"Success {status_msg}")
    else:
        logging.warning(f"Failure {status_msg}")
    return


def load_training_checkpoint(args, model, PATH, ckpt_id):
    """Utility function for checkpointing model + optimizer dictionaries
       The main purpose for this is to be able to resume training from that instant again
    """
    logger = args.logger
    # Workaround for learned positional embeddings if checkpoint and current
    # model's max_seq_lengths mismatch.
    if args.resize_posit_embeds:
        if not args.load_only_weights:
            raise ValueError("If you want to resize positional embeddings "
                             "when loading from checkpoint, you have to set "
                             "args.load_only_weights=True.")
        custom_load_fn = model.resize_learned_pos_embeddings
    else:
        custom_load_fn = None

    _, checkpoint_state_dict = model.load_checkpoint(
        PATH, ckpt_id,
        load_module_strict=False,
        load_module_only=args.load_only_weights,
        custom_load_fn=custom_load_fn
    )

    epoch = checkpoint_state_dict['epoch']
    last_global_step = checkpoint_state_dict['last_global_step']
    last_global_data_samples = checkpoint_state_dict[
        'last_global_data_samples']
    del checkpoint_state_dict
    if args.load_only_weights:
        epoch = 0
        last_global_step = 0
        last_global_data_samples = 0
    return (epoch, last_global_step, last_global_data_samples)


def get_dataloader(args, dataset: Dataset):
    if args.local_rank == -1:
        train_sampler = RandomSampler(dataset)
    else:
        train_sampler = DistributedSampler(dataset)
    return (x for x in
            DataLoader(dataset,
                       batch_size=args.eval_bs,
                       sampler=train_sampler,
                       num_workers=args.config['training']['num_workers']))


def pretrain_validation(args, dataset, series_name, index, model):

    config = args.config
    num_layers = config["model_config"]["num_hidden_layers"]
    logger = args.logger
    eval_bs = args.train_micro_batch_size_per_gpu * args.eval_bs_ratio
    max_validation_samples = args.max_validation_samples
    if max_validation_samples == -1:
        max_validation_samples= len(dataset)
    args.eval_bs = eval_bs
    logger.info(
        f"Validation micro batch size: {eval_bs}")
    if args.dense_attention:
        update_weights_scalers(model, num_layers)
    model.eval()

    data_batches = get_dataloader(args, dataset)
    eval_func = args.task.eval_func
    eval_func(data_batches, model, max_validation_samples,
              series_name, index, args)

def train(args,
          index,
          model,
          optimizer,
          pretrain_dataset_provider,
          finetune=False):
    global global_step
    global global_data_samples
    global last_global_step_from_restore
    global all_step_time

    if args.use_sharded_dataset:
        # print dataset files according to their order.
        pretrain_dataset_provider.dataset_order_info()
        dataset = pretrain_dataset_provider.get_shard(index)
        train_sampler = RandomSampler(dataset)
        worker_init = pretrain_dataset_provider.worker_init
    else:
        dataset = pretrain_dataset_provider
        worker_init = None
        if args.local_rank == -1:
            train_sampler = RandomSampler(pretrain_dataset_provider)
        else:
            train_sampler = DistributedSampler(pretrain_dataset_provider)
            train_sampler.set_epoch(index + 1)

    dataset_iterator, total_length = create_dataloader(
        train_data=dataset,
        num_workers=args.config['training']['num_workers'],
        train_batch_size=args.train_micro_batch_size_per_gpu,
        data_sampler=train_sampler, worker_init=worker_init)

    current_data_sample_count = global_data_samples
    rank = dist.get_rank()
    num_layers = args.config["model_config"]["num_hidden_layers"]

    config = args.config
    logger = args.logger
    logger.info(
        f'worker-{dist.get_rank()}: begin epoch {index+1} '
        f'current_sample_count {current_data_sample_count} '
        f'shard_length {total_length} global_data_samples {global_data_samples}'
    )

    model.train()

    epoch_step = 0
    rounds = args.throughput_logging_freq
    step_counts = 0
    lr_this_step = config["training"]["learning_rate"]
    inner_optimizer = optimizer if not args.bf16 else optimizer.optimizer
    if args.dense_attention:
        update_weights_scalers(model, num_layers)

    """
    current_activation = model.current_activation.item()
    if index != 0 and index % 1 == 0:
        current_activation += 1
        for tensor, state in optimizer.optimizer.state.items():
            optimizer.optimizer.state[tensor]['step'] = 0
            optimizer.optimizer.state[tensor]['exp_avg'] = torch.zeros_like(tensor.data)
            optimizer.optimizer.state[tensor]['exp_avg_sq'] = torch.zeros_like(tensor.data)
    print(f"rank {rank}, index {index}, current_activation {current_activation}")

    for group in optimizer.param_groups:
        if group['name'] == f'layer_{current_activation}':
            group['lr'] = lr_this_step
    if current_activation >= args.config["model_config"]["num_hidden_layers"]:
    """

    for group in optimizer.param_groups:
        group['lr'] = lr_this_step
        #if group['name'] != 'others_with_no_wd': group['weight_decay'] = args.config["training"]["weight_decay"]
    """
    default_norm = 0.03125
    for i in range(num_layers):
        if index % 1 != 0: break
        layer = model.bert.encoder.layer[i].attention
        with torch.no_grad():
            layer.queries.data.mul_(default_norm / layer.queries.abs().max().item())
            layer.norm_ratio_queries = layer.default_norm_queries / default_norm
            layer.combiners.data.mul_(default_norm / layer.combiners.abs().max().item())
            layer.norm_ratio_combiners = layer.default_norm_combiners / default_norm
    if args.fp16:
        optimizer.refresh_fp32_params()
    """
    #for i in range(num_layers):
        # model.bert.encoder.layer[i].attention.scaler.requires_grad = False
        #model.bert.encoder.layer[i].activation.scaler.requires_grad = False

    for _, batch in enumerate(tqdm(dataset_iterator, smoothing=1)):
        try:
            step_start = time.time()
            #batch = pretrain_dataset_provider.get_batch(batch_index)
            batch = {name: t.to(args.device) for name, t in batch.items()}  # Move to GPU
            # Calculate forward pass
            loss = model(**batch)

            unscaled_loss = loss.item()
            #print(f"loss {loss}, rank {rank}")

            current_data_sample_count += (args.train_micro_batch_size_per_gpu *
                                          dist.get_world_size())

            # Prefetch training data
            #pretrain_dataset_provider.prefetch_batch()

            #if not np.isfinite(unscaled_loss): continue
            model.backward(loss)
            '''
            if not np.isfinite(unscaled_loss):
                report_model_activations(args, model,
                                         batch, global_step,
                                         reg_losses=args.regularized_activations,
                                         reg_coeff=args.reg_coeff)
                if args.log_problematic_weights:
                    report_model_weights(args, model, global_step)
            '''

            loss = None
            del loss
            #for name, param in model.named_parameters():
            #    if param.grad is not None: print(f"Grad Extremums", name, param.grad.min(), param.grad.max())

            if model.is_gradient_accumulation_boundary():

                #torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=2.**15)
                #for name, param in model.named_parameters():
                #    if param.grad is not None: print(f"Grad Extremums", name, param.grad.min(), param.grad.max())
                lr_this_step = update_learning_rate(
                    args, config, global_step, inner_optimizer)
                #"""
                if (global_step + 1) % args.throughput_logging_freq == 0 and master_process(args):
                    report_step_metrics(args, lr_this_step, unscaled_loss,
                                        global_step, current_data_sample_count)
                # if epoch_step % 16 == 0:
                    # For multi-node
                    # if dist.get_world_size() > 1:
                    #    dist.broadcast_object_list(list(model.parameters()), src=0)
                    # For fp16 training
                    # refresh_fp32_params(optimizer)
                    # For bf16 training
                    # optimizer._restore_from_bit16_weights()
                if epoch_step == 0 and index % args.print_steps == 0:
                    for name, param in model.named_parameters():
                        if param.grad is not None: print(f"Grad Extremums", name, param.grad.min(), param.grad.max())
                    for tensor, state in inner_optimizer.state.items():
                        print("step", inner_optimizer.state[tensor]['step'])
                        print("exp_avg", inner_optimizer.state[tensor]['exp_avg'])
                        print("exp_avg_sq", inner_optimizer.state[tensor]['exp_avg_sq'])
                    logger.info(
                        f"Logging model weights and activations distribution "
                        f"at the start of of epoch: {index}, step: {epoch_step}")
                    report_model_activations(args, model,
                                             batch, global_step)
                    report_model_weights(args, model, global_step)
                """
                # Check that optimizer moving averages are finite and reset them all otherwise
                print(f"Started checking optimizer states on {rank}")
                for tensor, state in optimizer.optimizer.state.items():
                    if not torch.all(torch.isfinite(optimizer.optimizer.state[tensor]['exp_avg_sq'])):
                        model.reset_adam[0] = 1
                        break
                dist.all_reduce(model.reset_adam, op=dist.ReduceOp.MAX)
                if model.reset_adam.item():
                    for tensor, state in optimizer.optimizer.state.items():
                        optimizer.optimizer.state[tensor]['step'] = 0
                        optimizer.optimizer.state[tensor]['exp_avg'].zero_()
                        optimizer.optimizer.state[tensor]['exp_avg_sq'].zero_()
                print(f"checked optimizer states on {rank}")
                """
                # for name, param in model.named_parameters():
                #     if param.grad is not None: print(f"Parameter Extremums", name, param.grad.min(), param.grad.max())
                model.step()
                if args.dense_attention:
                    update_weights_scalers(model, num_layers)
                #print(f"Finished optimization step on {rank}")
                """
                try:
                    print(f"Started optimization step on {rank}")
                    for name, param in model.named_parameters():
                        if param.grad is not None: print(f"Parameter Extremums", name, param.grad.min(), param.grad.max())
                        #if param.grad is not None: print(f"Parameter Grad", name, param.grad)
                    model.step()
                    print(f"Finished optimization step on {rank}")
                except AssertionError as ex:
                    print("Error", ex)
                    for name, param in model.named_parameters():
                        if param.grad is not None: print(f"Parameter Extremums", name, param.grad.min(), param.grad.max())
                        print(f"Parameter Grad", name, param.grad)
                    report_model_activations(args, model,
                                             batch, global_step,

                                             reg_losses=args.regularized_activations,
                                             reg_coeff=args.reg_coeff)
                    report_model_weights(args, model, global_step)
                    import traceback
                    traceback.print_exc()
                """

                report_lamb_coefficients(args, optimizer)
                global_step += 1
                epoch_step += 1
            else:
                # Call DeepSpeed engine step on micro steps
                model.step()

        except StopIteration:
            continue

        current_global_step = global_step - last_global_step_from_restore
        if is_time_to_exit(args=args,
                           epoch_steps=epoch_step,
                           global_steps=current_global_step):
            print(
                f'Warning: Early epoch termination due to max steps limit, epoch step ={epoch_step}, global step = {current_global_step}, epoch = {index+1}'
            )
            break
        step_time = time.time() - step_start
        all_step_time += step_time
        if global_step % rounds == 0 and global_step != 0 and model.is_gradient_accumulation_boundary(
        ) and dist.get_rank() == 0:
            one_step_bs = args.train_micro_batch_size_per_gpu * args.gradient_accumulation_steps * dist.get_world_size(
            ) * rounds
            print(' At step {}, the throughput is {:2f} Samples/s'.format(
                global_step * args.gradient_accumulation_steps,
                one_step_bs / all_step_time),
                  flush=True)
            all_step_time = 0.0

    #pretrain_dataset_provider.release_shard(index)

    global_data_samples = current_data_sample_count

    if  args.eval_train_data and args.train_dataset is not None:
            pretrain_validation(args, args.train_dataset, "Train", index, model)
    if not args.no_eval_val_data:
        pretrain_validation(args, args.eval_dataset, "Validation", index, model)
    if args.eval_test_data:
        pretrain_validation(args, args.test_dataset, "Test", index, model)

def update_weights_scalers(model, num_layers):
    """Update weights scalers of DenseAttention Model"""
    for i in range(num_layers):
        ffn = model.bert.encoder.layer[i].ffn
        ffn.adjust_norm_ratios()
        # layer = model.bert.encoder.layer[i].attention
        # with torch.no_grad():
            # curr_norm_queries = layer.queries.abs().max().item()
            # curr_norm_combiners = layer.combiners.abs().max().item()

        # layer.norm_ratio_queries = min(
        #     layer.default_norm_queries / curr_norm_queries, 1.
        # )
        # layer.norm_ratio_combiners = min(
        #     layer.default_norm_combiners / curr_norm_combiners, 1.
        # )

        """
        print(i, "queries default norm", layer.default_norm_queries,
              "curr norm", curr_norm_queries,
              "norm ratio", layer.norm_ratio_queries)
        # print(i, "combiners default norm", layer.default_norm_combiners,
        #       "curr norm", curr_norm_ffn,
        #       "norm ratio", layer.norm_ratio_combiners)
        print(i, "ffn expand default norm", ffn.default_norm_expand,
              "curr norm", curr_norm_expand,
              "norm ratio", ffn.norm_ratio_expand)
        print(i, "ffn contract default norm", ffn.default_norm_contract,
              "curr norm", curr_norm_contract,
              "norm ratio", ffn.norm_ratio_contract)
        """


def update_learning_rate(args, config, current_global_step, optimizer):
    global last_global_step_from_restore

    global_step_for_lr = current_global_step - last_global_step_from_restore
    lr_schedule = config["training"]["lr_schedule"]
    if lr_schedule == "EE":
        #print(f'LR Schedule is {args.lr_schedule} EE')
        lr_this_step = config["training"][
            "learning_rate"] * warmup_exp_decay_exp(
                global_step_for_lr, config["training"]["decay_rate"],
                config["training"]["decay_step"],
                config["training"]["one_cycle_steps"],
                config["training"]["warmup_proportion"])
    elif lr_schedule == "cosine":
        #print(f'LR Schedule is {args.lr_schedule} EP')
        lr_this_step = config["training"][
            "learning_rate"] * cosine_poly_warmup_decay(
                global_step_for_lr, **config["training"]["lr_scheduler_params"]
        )
    elif lr_schedule == 'constant':
        lr_this_step = config["training"]["learning_rate"]
    else:
        lr_this_step = config["training"]["learning_rate"]

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr_this_step

    return lr_this_step


def report_step_metrics(args, lr, loss, step, data_sample_count):
    ##### Record the LR against global_step on tensorboard #####

    args.tracker_logger.report_scalar(title='Train Steps: lr', series='lr',
                                      value=lr, iteration=step)
    args.tracker_logger.report_scalar(title='Train Steps: loss',
                                      series='loss',
                                      value=loss, iteration=step)
    args.tracker_logger.report_scalar(title='Train Samples: lr',
                                      series='lr',
                                      value=lr, iteration=data_sample_count)
    args.tracker_logger.report_scalar(title='Train Samples: loss',
                                      series='loss',
                                      value=loss, iteration=data_sample_count)
    """
    args.summary_writer.add_scalar(f'Train/lr', lr, step)
    args.summary_writer.add_scalar(f'Train/Samples/train_loss', loss,
                                   data_sample_count)

    args.summary_writer.add_scalar(f'Train/Samples/lr', lr,
                                   data_sample_count)
    """
    ##### Recording  done. #####

    print('bing_bert_progress: step={}, loss={}, lr={}, sample_count={}'.
          format(step + 1, loss, lr, data_sample_count))


def add_normalizer_preforward_hooks(args, model):
    def getActivation(module_name):
        # the hook signature
        def hook(module: torch.nn.Module(), input: torch.Tensor):
            max_val = input[0].detach().absolute().max()
            scaler_val = module.scaler[0].detach()
            max_scaled_val = max_val * scaler_val
            bound = 2
            l_bound_pow = -3
            if max_scaled_val > bound:
                multiplier = bound / 2 * bound ** -math.floor(math.log(max_val, bound))
                print(f"max abs input {max_val}, old scaler {scaler_val}")
                module.scaler.data.fill_(multiplier)
                print(f"Changed scaler of layer {module_name} to {multiplier}")

            return input
        return hook

    hooks = []
    for name, module in model.named_modules():
        if "layer" in name and "activation" in name \
                or "LayerNorm" in name and hasattr(module, "scaler"):
            module.scaler.requires_grad = False
            hooks.append(
                module.register_forward_pre_hook(getActivation(name))
            )
    return hooks


def find_layer_with_nans(args, model, data):
    none_layers = set([args.config["model_config"]["num_hidden_layers"]])
    def getActivation(layer_number):
        # the hook signature
        def hook(model, input, output):
            if not torch.all(torch.isfinite(output)):
                none_layers.add(layer_number)
                print(f"Nans discovered in layer {layer_number}")
        return hook

    hooks = []
    for name, module in model.named_modules():
        if "encoder.layer." in name:
            number = int(name.split("encoder.layer.")[1].split(".")[0])
            hooks.append(
                module.register_forward_hook(getActivation(number))
            )
    # Calculate all activations
    with torch.no_grad():
        dummy_loss = model(**data)
    for hook in hooks:
        hook.remove()

    return min(none_layers)


def report_model_gradients(args, model):
    if master_process(args):
        def getActivation(name):
            # the hook signature
            def hook(module, grad_input, grad_output):
                grad_input = grad_input[0]
                grad_output = grad_output[0]
                if grad_output is not None:
                    print(name, "grad output:", grad_output.min(), grad_output.max())
                if grad_input is not None:
                    print(name, "grad input:", grad_input.min(), grad_input.max())
            return hook
        hooks = []
        for name, module in model.named_modules():
            hooks.append(
                module.register_full_backward_hook(getActivation(name))
            )
        return hooks

def report_activations_fast(args, model):
    if master_process(args):
        def getActivation(name):
            # the hook signature
            def hook(module, input, output):
                if output is not None and not isinstance(output, (list, tuple)):
                    print(name, "layer output:", output.min(), output.max())
                if input is not None and not isinstance(input, (list, tuple)):
                    print(name, "layer input:", input.min(), input.max())
            return hook

        hooks = []
        for name, module in model.named_modules():
            hooks.append(
                module.register_forward_hook(getActivation(name))
            )
        return hooks

# TODO: handle cases when there are np.inf in reporting functions

def report_model_activations(args, model, data, step, bins=20, **kwargs):
    if master_process(args):
        args.logger.info(f"Starting to report activation for step {step}")
        activations = {}
        def getActivation(name):
            # the hook signature
            def hook(model, input, output):
                if not isinstance(output, (list, tuple)):
                    activations[name] = output.cpu().float().numpy()
            return hook
        hooks = []
        for name, module in model.named_modules():
            hooks.append(
                module.register_forward_hook(getActivation(name))
            )
        # Calculate all activations
        last_sample = int(data["input_ids"].shape[0] * args.inputs_logging_ratio)
        data = {name: t[:last_sample] for name, t in data.items()}
        with torch.no_grad():
            if kwargs.get("reg_losses", None) is not None:
                dummy_loss = model(**data, **kwargs)
            else:
                dummy_loss = model(**data)
        for hook in hooks:
            hook.remove()
        for name, values in activations.items():
            finite_values = values[np.isfinite(values)]
            # finite_values is already 1D. 
            if values.ravel().shape != finite_values.shape:
                args.tracker_logger.report_scalar(
                    title=f'nans and infs: {name}', series='n of nans',
                    value=len(values[np.isnan(values)].ravel()), iteration=step
                )
                args.tracker_logger.report_scalar(
                    title=f'nans and infs: {name}', series='n of +infs',
                    value=len(values[np.isposinf(values)].ravel()), iteration=step
                )
                args.tracker_logger.report_scalar(
                    title=f'nans and infs: {name}', series='n of -infs',
                    value=len(values[np.isneginf(values)].ravel()), iteration=step
                )

            if finite_values.size == 0: return
            try:
                vals = values#.mean(axis=-1)
                val_max, val_min = vals.max(), vals.min()
                hist, bounds = np.histogram(vals, bins=bins,
                                            range=(val_min, val_max))
                args.tracker_logger.report_histogram(
                    title=name, series=name, values=hist, iteration=step,
                    xlabels=list(bounds)
                )
            except Exception as ex:
                print(ex)


def report_model_weights(args, model, step, bins=20):
    if master_process(args):
        for name, param in model.named_parameters():
            values = param.detach().cpu().float().numpy()
            hist, bounds = np.histogram(values, bins=bins, range=(np.nanmin(values), np.nanmax(values)))
            args.tracker_logger.report_histogram(
                title=name, series=name, values=hist, iteration=step, 
                xlabels=list(bounds)
            )


def report_lamb_coefficients(args, optimizer):
    if master_process(args):
        if (args.fp16 and args.use_lamb):
            #print("Lamb Coeffs", optimizer.optimizer.get_lamb_coeffs())
            lamb_coeffs = optimizer.optimizer.get_lamb_coeffs()
            lamb_coeffs = np.array(lamb_coeffs)
            if lamb_coeffs.size > 0:
                args.summary_writer.add_histogram(f'Train/lamb_coeffs',
                                                  lamb_coeffs, global_step)


# Refresh fp32 master params from fp16 copies
def refresh_fp32_params(optimizer):
    # Flat fp16 track originally shaped fp16 groups which in turn track
    # real model weights in DeepSpeed fp16 Optimizer.
    for fp32, saved_fp16 in zip(optimizer.fp32_groups_flat, optimizer.fp16_groups_flat):
        fp32.data.copy_(saved_fp16.data)


def get_arguments():
    parser = get_argument_parser()
    # Include DeepSpeed configuration arguments
    parser = deepspeed.add_config_arguments(parser)

    args = parser.parse_args()

    # no cuda mode is not supported
    args.no_cuda = False

    return args


def construct_arguments():
    args = get_arguments()

    # Prepare Logger
    logger = Logger(cuda=torch.cuda.is_available() and not args.no_cuda)
    args.logger = logger
    config = json.load(open(args.config_file, 'r', encoding='utf-8'))
    args.config = config
    args.task = TaskRegistry.get_task(args.task_type)

    args.job_name = config['name'] if args.job_name is None else args.job_name
    print("Running Config File: ", args.job_name)
    # Setting the distributed variables
    print("Args = {}".format(args))

    # Setting all the seeds so that the task is random but same accross processes
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    os.makedirs(args.output_dir, exist_ok=True)
    args.saved_model_path = os.path.join(args.output_dir, "saved_models/",
                                         args.job_name)

    # Issue warning if early exit from epoch is configured
    if args.max_steps < sys.maxsize:
        logging.warning(
            'Early training exit is set after {} global steps'.format(
                args.max_steps))

    if args.max_steps_per_epoch < sys.maxsize:
        logging.warning('Early epoch exit is set after {} global steps'.format(
            args.max_steps_per_epoch))

    return args


def prepare_optimizer_parameters(args, model):
    config = args.config
    deepspeed_config = json.load(
        open(args.deepspeed_config, 'r', encoding='utf-8'))
    param_optimizer = list(model.named_parameters())
    param_optimizer = [n for n in param_optimizer if #'pooler' not in n[0] and
                       'embeddings' not in n[0] and 'layer' not in n[0]]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    if "weight_decay" in config["training"].keys():
        weight_decay = config["training"]["weight_decay"]
    else:
        weight_decay = 0.01


    groups = [{'params': list(model.bert.embeddings.parameters()),
               'lr': 0.0,
               'weight_decay': weight_decay,
               'name': 'embeddings'}]
    for i in range(args.config["model_config"]["num_hidden_layers"]):
        if args.dense_attention:
            groups.append({
                'params': list(model.bert.encoder.layer[i].attention.parameters()),
                'lr': 0.0,
                'weight_decay': weight_decay,
                'name': f'layer_{i}_attention'
            })
            if hasattr(model.bert.encoder.layer[i], 'ffn'):
                groups.append({
                    'params': list(model.bert.encoder.layer[i].ffn.parameters()),
                    'lr': 0.0,
                    'weight_decay': weight_decay,
                    'name': f'layer_{i}_ffn'
                })
        else:
            groups.append({
                'params': list(model.bert.encoder.layer[i].parameters()),
                'lr': 0.0,
                'weight_decay': weight_decay,
                'name': f'layer_{i}_attention'
            })


    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in param_optimizer
            if not any(nd in n for nd in no_decay)
        ],
        'lr': 0.0,
        'weight_decay': weight_decay,
        'name': 'others_with_wd'
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'lr': 0.0,
        'weight_decay': 0.0,
        'name': 'others_with_no_wd'
    }]
    optimizer_grouped_parameters.extend(groups)

    return optimizer_grouped_parameters


def prepare_model_optimizer(args):
    # Initialize torch distributed
    deepspeed.init_distributed(dist_backend='nccl')
    args.local_rank = int(os.environ['LOCAL_RANK'])
    model_class = args.task.model_type
    config_class = ModelConfig


    bert_config = config_class(**args.config["model_config"])
    # Padding for divisibility by 8
    if bert_config.vocab_size % 8 != 0:
        bert_config.vocab_size += 8 - (bert_config.vocab_size % 8)
    print("VOCAB SIZE:", bert_config.vocab_size)

    model = model_class(bert_config, args)

    # Optimizer parameters
    optimizer_grouped_parameters = prepare_optimizer_parameters(args, model)

    # DeepSpeed initializer handles FP16, distributed, optimizer automatically.
    model, optimizer, _, _ = deepspeed.initialize(
        args=args,
        model=model,
        model_parameters=optimizer_grouped_parameters)

    # Overwrite application configs with DeepSpeed config
    args.train_micro_batch_size_per_gpu = model.train_micro_batch_size_per_gpu(
    )
    args.gradient_accumulation_steps = model.gradient_accumulation_steps(
    )
    args.batch_size = model.train_batch_size()
    # Number of global steps between logging sessions
    args.throughput_logging_freq = max(args.throughput_logging_samples // args.batch_size, 1)


    # Set DeepSpeed info
    args.local_rank = model.local_rank
    args.device = model.device
    args.fp16 = model.fp16_enabled()
    args.bf16 = model.bfloat16_enabled()
    args.use_lamb = (model.optimizer_name() ==
                     deepspeed.runtime.config.LAMB_OPTIMIZER
                     or model.optimizer_name() ==
                     deepspeed.runtime.config.ONEBIT_LAMB_OPTIMIZER)

    # Prepare Summary Writer and saved_models path
    if dist.get_rank() == 0:
        summary_writer = get_sample_writer(name=args.job_name,
                                           base=args.output_dir)
        args.summary_writer = summary_writer
        os.makedirs(args.saved_model_path, exist_ok=True)
    if args.use_torch_compile:
        model = torch.compile(model)
    print(model)
    return model, optimizer


def load_checkpoint(args, model):
    global global_step
    global global_data_samples
    global last_global_step_from_restore

    config = args.config
    logger = args.logger

    logger.info(
        f"Restoring previous training checkpoint from PATH={args.load_training_checkpoint}, CKPT_ID={args.load_checkpoint_id}"
    )
    start_epoch, global_step, global_data_samples = load_training_checkpoint(
        args=args,
        model=model,
        PATH=args.load_training_checkpoint,
        ckpt_id=args.load_checkpoint_id)
    logger.info(
        f"The model is loaded from last checkpoint at epoch {start_epoch} when the global steps were at {global_step} and global data samples at {global_data_samples}"
    )
    if args.rewarmup:
        logger.info(
            f"Rewarmup learning rate with last_global_step_from_restore = {global_step}"
        )
        last_global_step_from_restore = global_step

    return start_epoch

def run(args, model, optimizer, start_epoch):
    global global_step
    global global_data_samples
    global last_global_step_from_restore

    config = args.config
    logger = args.logger
    task = args.task

    # if args.use_nvidia_dataset:
    #     pretrain_dataset_provider = NvidiaBertDatasetProvider(args)
    # else:
    #     pretrain_dataset_provider = BingBertDatasetProvider(args)
    print(model)
    print(f"Total parameters in the model: {model.get_num_params(non_embedding=False)}")
    print("Loading train dataset")

    if args.use_sharded_dataset:
        dataset = ShardedDatasetWrapper(args.data_path_prefix,
                                        config["data"]["training"], args)
    else:
        dataset = task.dataset_type(args.data_path_prefix,
                                    config["data"]["training"], args)

    dataset_val = None
    if not args.no_eval_val_data:
        print("Loading eval dataset")
        dataset_val = task.dataset_type(args.data_path_prefix,
                                        config["data"]["validation"], args)

    dataset_test = None
    if args.eval_test_data:
        print("Loading test dataset")
        dataset_test = task.dataset_type(args.data_path_prefix,
                                         config["data"]["test"], args)

    args.train_dataset = dataset if not args.use_sharded_dataset else None
    args.eval_dataset = dataset_val
    args.test_dataset = dataset_test
    #add_normalizer_preforward_hooks(args, model)
    #backward_hooks = report_model_gradients(args, model)
    #forward_hooks = report_activations_fast(args, model)

    for index in range(start_epoch, config["training"]["num_epochs"]):
        logger.info(f"Training Epoch: {index + 1}")
        pre = time.time()
        train(args, index, model, optimizer, dataset)
        #report_model_weights(args, model, global_step)
        # Save ckpts according to "--ckpt_to_save" option,
        # e.g. "--ckpt_to_save 160 161" to save epoch 160 and 161.
        if index % args.ckpt_to_save == 0:
            logger.info(
                f"Saving a checkpointing of the model for epoch: {index+1}")

            checkpoint_model(PATH=args.saved_model_path,
                             ckpt_id='epoch{}_step{}'.format(
                                 index + 1, global_step),
                             model=model,
                             epoch=index + 1,
                             last_global_step=global_step,
                             last_global_data_samples=global_data_samples)

        post = time.time()
        logger.info(f"Time for shard {index + 1}: {post-pre} seconds")

        current_global_step = global_step - last_global_step_from_restore
        if is_time_to_exit(args=args, global_steps=current_global_step):
            print(
                f'Warning: Early training termination due to max steps limit, epoch={index+1}, global_step={current_global_step}'
            )
            break


def main():
    start = time.time()
    args = construct_arguments()
    model, optimizer = prepare_model_optimizer(args)
    if master_process(args):
        task = Task.init(project_name=args.project_name,
                         task_name="research", reuse_last_task_id=False)
        Task.set_random_seed(args.seed)
        task.connect(args)
        task.connect(args.config, 'bert_config')
        task.connect_configuration(args.deepspeed_config, name='deepspeed_config')
        args.tracker_logger = task.get_logger()
    start_epoch = 0
    if args.load_training_checkpoint and args.load_checkpoint_id:
        start_epoch = load_checkpoint(args, model)

    run(args, model, optimizer, start_epoch)
    elapsed = time.time() - start
    logger = args.logger
    logger.info(f"Elapsed time: {elapsed} seconds")


if __name__ == "__main__":
    main()


