import argparse
import datetime
import json
import numpy as np
import os
import time
from pathlib import Path
import functools
import multiprocessing

import torch
import torch.backends.cudnn as cudnn

import warnings
warnings.filterwarnings("ignore")

from torch.utils.tensorboard import SummaryWriter
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    MixedPrecision,
    ShardingStrategy,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
    checkpoint_wrapper,
    CheckpointImpl,
    apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import (
    transformer_auto_wrap_policy,
)

from fairscale.nn.model_parallel import initialize as fs_init

try:
    from apex.optimizers import FusedAdam as AdamW
except ImportError:
    warnings.warn("cannot import FusedAdam from apex, use torch AdamW instead")
    from torch.optim import AdamW

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from model.meta import MetaModel
from model.meta_tokenpacker import MetaTokenpackerModel
from engine_mmfi_cs_full import train_one_epoch
# from data.one_stage_sense_dataset import OneStageSenseDataset, OneStageDistSampler
# from data.one_stage_sense_dataset_ce import OneStageSenseDataset, OneStageDistSampler
from data.mmfi_dataset_cs_full import OneStageSenseDataset, OneStageDistSampler


def get_args_parser():
    parser = argparse.ArgumentParser('OneLLM Finetuning', add_help=False)
    parser.add_argument('--datasets', type=str, default='image', nargs='+')
    parser.add_argument('--epochs', default=1, type=int)
    parser.add_argument('--batch_size', default=64, type=int,
                        help='Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus')
    parser.add_argument('--accum_iter', default=4, type=int,
                        help='Accumulate gradient iterations (for increasing the effective batch size under memory constraints)')

    # Model parameters
    parser.add_argument('--llama_type', default='llama', type=str, metavar='MODEL',
                        help='Name of model to train')
    parser.add_argument("--llama_ckpt_dir", type=str, default="")
    parser.add_argument("--llama_config", type=str, default="config/llama2/7B.json")
    parser.add_argument("--tokenizer_path", type=str, default="config/tokenpacker-7b/tokenizer.model")

    # Optimizer parameters
    parser.add_argument('--weight_decay', type=float, default=0.02,
                        help='weight decay (default: 0.05)')

    parser.add_argument('--lr', type=float, default=0.001, metavar='LR',
                        help='learning rate (absolute lr)')
    parser.add_argument('--min_lr', type=float, default=0.0001, metavar='LR',
                        help='lower lr bound for cyclic schedulers that hit 0')

    parser.add_argument('--warmup_epochs', type=float, default=1.0, metavar='N',
                        help='epoch to warmup LR')

    parser.add_argument('--clip_grad', type=int, default=-1,
                        help='grad clipping norm')

    parser.add_argument('--output_dir', default='./output_dir',
                        help='path where to save, empty for no saving')
    parser.add_argument('--log_dir', default='./output_dir',
                        help='path where to tensorboard log')
    parser.add_argument('--device', default='cuda',
                        help='device to use for training / testing')
    parser.add_argument('--seed', default=0, type=int)
    parser.add_argument('--resume', default='',
                        help='resume from checkpoint')
    parser.add_argument('--auto_resume', action='store_true')
    parser.add_argument('--init_from', default='',
                        help='init from checkpoint')
    parser.add_argument('--init_from_image', action='store_true')

    parser.add_argument('--num_workers', default=5, type=int)
    parser.add_argument('--pin_mem', action='store_true',
                        help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
    parser.add_argument('--no_pin_mem', action='store_false', dest='pin_mem')
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument('--world_size', default=1, type=int,
                        help='number of distributed processes')
    parser.add_argument('--local_rank', default=-1, type=int)
    parser.add_argument('--dist_on_itp', action='store_true')
    parser.add_argument('--dist_url', default='env://',
                        help='url used to set up distributed training')

    parser.add_argument('--model_parallel_size', type=int, default=1)
    parser.add_argument('--data_parallel', type=str, choices=['ddp', 'sdp', 'fsdp'], default='sdp')
    parser.add_argument('--precision', type=str, choices=['fp16', 'bf16', 'tf32'], default='bf16')
    parser.add_argument('--save_interval', type=int, default=5000)
    parser.add_argument('--save_consolidated', action="store_true",
                        help="save consolidated model weights along with regular checkpoints "
                             "used to resume training. useful for convenient deployment but "
                             "will occupy some additional disk space.")
    parser.add_argument("--checkpointing", action="store_true")

    parser.add_argument('--max_words', type=int, default=2048)
    parser.add_argument('--image_words', type=int, default=0)  # text padding to 1024

    return parser


def main(args):
    multiprocessing.set_start_method("spawn")
    misc.init_distributed_mode(args)
    # must have 2 gpus at least, otherwise raise assertion error.
    # no need to have 2 gpus, but should run the code by torchrun (new launch file)
    fs_init.initialize_model_parallel(args.model_parallel_size)  
    if args.precision == "tf32":
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True

    print('job dir: {}'.format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(', ', ',\n'))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)

    # cudnn.benchmark = True

    global_rank = misc.get_rank()
    mp_rank = fs_init.get_model_parallel_rank()
    mp_world_size = fs_init.get_model_parallel_world_size()
    dp_rank = fs_init.get_data_parallel_rank()
    dp_world_size = fs_init.get_data_parallel_world_size()
    dp_group = fs_init.get_data_parallel_group()

    dataset_train = OneStageSenseDataset(args.datasets, max_words=args.max_words, image_words=args.image_words, tokenizer_path=args.tokenizer_path)
    
    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    # define the model
    model = MetaTokenpackerModel(args.llama_type, args.llama_config, args.llama_ckpt_dir, args.tokenizer_path)
    model.to(device)
    print("Model = %s" % str(model))
    # Since the init_from is after model = MetaSenseModel()
    # The onellm weight will overwrite the llama2 7B and clip pretrained weights.
    if args.init_from:
        print("Init checkpoint from %s" % args.init_from)
        checkpoint = torch.load(os.path.join(args.init_from, f"consolidated.{mp_rank:02d}-of-{mp_world_size:02d}.pth"), map_location='cpu')
        msg = model.load_state_dict(checkpoint, strict=False)
        print(msg)

    mixed_precision_dtype = {
        "fp16": torch.float16,
        "bf16": torch.bfloat16,
        "tf32": torch.float32,
    }[args.precision]
    # TransformerBlock = type(model.llma.layers[0])
    TransformerBlock = type(model.llma.language_model.model.layers[0])
    
    # check cuda version
    cuda_version = torch.version.cuda
    cuda_version = int(cuda_version.split(".")[0])
    if cuda_version == 11:
        model = FSDP(
            model,
            process_group=fs_init.get_data_parallel_group(),
            auto_wrap_policy=functools.partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls=[TransformerBlock],
            ),
            limit_all_gathers=True,
            use_orig_params=True,
            sync_module_states=True,
            mixed_precision=MixedPrecision(
                param_dtype=mixed_precision_dtype,
                reduce_dtype=mixed_precision_dtype,
                buffer_dtype=mixed_precision_dtype,
            ),
            sharding_strategy={
                "sdp": ShardingStrategy.SHARD_GRAD_OP,
                "ddp": ShardingStrategy.NO_SHARD,
                "fsdp": ShardingStrategy.FULL_SHARD,
            }[args.data_parallel],
            ignored_parameters=[param for param in model.parameters() if not param.requires_grad],
            # ignored_states=[param for param in model.parameters() if not param.requires_grad],
        )
        if args.checkpointing:
            print("apply gradient checkpointing")
            non_reentrant_wrapper = functools.partial(
                checkpoint_wrapper,
                offload_to_cpu=False,
                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
            )
            check_fn = lambda submodule: isinstance(submodule, TransformerBlock)
            apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)

        eff_batch_size = args.batch_size * args.accum_iter * fs_init.get_data_parallel_world_size()
        print("effective batch size: %d" % eff_batch_size)
    elif cuda_version == 12:
        model = FSDP(
            model,
            process_group=fs_init.get_data_parallel_group(),
            auto_wrap_policy=functools.partial(
                transformer_auto_wrap_policy,
                transformer_layer_cls=[TransformerBlock],
            ),
            limit_all_gathers=True,
            use_orig_params=True,
            sync_module_states=True,
            mixed_precision=MixedPrecision(
                param_dtype=mixed_precision_dtype,
                reduce_dtype=mixed_precision_dtype,
                buffer_dtype=mixed_precision_dtype,
            ),
            sharding_strategy={
                "sdp": ShardingStrategy.SHARD_GRAD_OP,
                "ddp": ShardingStrategy.NO_SHARD,
                "fsdp": ShardingStrategy.FULL_SHARD,
            }[args.data_parallel],
            ignored_states=[param for param in model.parameters() if not param.requires_grad],
        )
        if args.checkpointing:
            print("apply gradient checkpointing")
            non_reentrant_wrapper = functools.partial(
                checkpoint_wrapper,
                # offload_to_cpu=False,
                checkpoint_impl=CheckpointImpl.NO_REENTRANT,
            )
            check_fn = lambda submodule: isinstance(submodule, TransformerBlock)
            apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)

    # following timm: set wd as 0 for bias and norm layers
    #param_groups = misc.add_weight_decay(model, args.weight_decay)
    param_groups = {
        "decay": {"params": [], "weight_decay": args.weight_decay, "lr": args.lr},
        "no_decay": {"params": [], "weight_decay": 0., "lr": args.lr},
        "scratch_decay": {"params": [], "weight_decay": args.weight_decay, "lr": args.lr},
        "scratch_no_decay": {"params": [], "weight_decay": 0., "lr": args.lr},
    }
    print("Making parameter groups ...")
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        no_decay = name.endswith(".bias") or name.endswith("norm.weight")
        scratch = "llma.resample_layers" in name or "llma.resample_tokens" in name
        group_name = ("scratch_" if scratch else "") + ("no_decay" if no_decay else "decay")
        print(f"{name}: in group {group_name}")
        param_groups[group_name]["params"].append(param)
    optimizer = AdamW(
        [param_groups[key] for key in ["decay", "no_decay", "scratch_decay", "scratch_no_decay"]],
        betas=(0.9, 0.95),
    )
    print(optimizer)
    loss_scaler = NativeScaler(args)

    start_epoch = 0
    start_iter = 0
    if args.resume or args.auto_resume:
        # start_epoch, start_iter = misc.load_model(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler)
        start_epoch, start_iter = misc.load_projector(args=args, model=model, optimizer=optimizer, loss_scaler=loss_scaler)
    
    sampler_train = OneStageDistSampler(
        dataset_train, num_replicas=dp_world_size, rank=dp_rank, shuffle=True, batch_size=args.batch_size,
        acc_grad=args.accum_iter
    )    
    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        sampler=sampler_train,
        drop_last=True
    )

    print(f"Start training for {args.epochs} epochs")
    start_time = time.time()
    for epoch in range(start_epoch, args.epochs):
        if args.distributed:
            data_loader_train.sampler.set_epoch(epoch, start_iter)

        train_stats = train_one_epoch(
            model, data_loader_train,
            optimizer, epoch, start_iter, loss_scaler,
            log_writer=log_writer,
            args=args
        )

        # if args.output_dir and (epoch % args.save_interval == 0 or epoch + 1 == args.epochs):
        # if args.output_dir and epoch == 4:
        #     misc.save_model(
        #         output_dir=args.output_dir,
        #         args=args, epoch=epoch, iteration=0, model=model, optimizer=optimizer,
        #         loss_scaler=loss_scaler, dataset_state=None,
        #     )
        # if args.output_dir and epoch == 19:
        #     # iteration 9999 to make sure the folder will not be delete.
        #     misc.save_projector(
        #         output_dir=args.output_dir,
        #         args=args, epoch=epoch, iteration=9999, model=model, optimizer=optimizer,
        #         loss_scaler=loss_scaler, dataset_state=None,
        #     )

        log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
                     'epoch': epoch,
                     **{f'val_{k}': v for k, v in train_stats.items()}}

        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(log_stats) + "\n")

        start_iter = 0

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))


if __name__ == '__main__':
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)
