import argparse
import copy
import datetime
import json
import os
import time
from pathlib import Path
import random
import numpy as np
import timm.optim.optim_factory as optim_factory
import torch
import torch.backends.cudnn as cudnn
import util.misc as misc
from engine_pretrain import train_one_epoch, val_one_epoch
from torch.utils.data import Dataset
from torch.utils.tensorboard import SummaryWriter
from util.misc import NativeScalerWithGradNormCount as NativeScaler
from util.datasets import WikiTextDataset, MixDataset
from lavin.mm_adaptation import LLama
def get_args_parser():
    parser = argparse.ArgumentParser("MAE pre-training", add_help=False)
    parser.add_argument(
        "--batch_size",
        default=64,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument("--epochs", default=400, type=int)
    parser.add_argument('--bits', default='4bit', type=str,choices=['4bit','8bit','16bit'],
                        help='Quantization bits for training, fp16 by default')
    parser.add_argument( "--accum_iter", default=1, type=int,
                        help="Accumulate gradient iterations (for increasing the effective batch size under memory constraints)",
    )
    # Model parameters
    parser.add_argument('--llama_model_path', default='./llama', type=str,
                        help='path of llama model')
    parser.add_argument('--llm_model', default='7B', type=str, metavar='MODEL',
                        help='Name of llm model to train')
    parser.add_argument('--use_vicuna',  action='store_true',   help='use vicuna weights')
    parser.add_argument('--cpu_load',  action='store_true',   help='load the model on cpu and avoid OOM on gpu')
    parser.add_argument('--drop_path', type=float, default=0., metavar='LENGTH', help='drop path')
    parser.add_argument('--max_seq_len', type=int, default=1024, metavar='LENGTH',
                        help='the maximum sequence length')
    parser.add_argument('--temperature', type=float, default=10., metavar='LENGTH',
                        help='the temperature of router')
    parser.add_argument('--hidden_proj', type=int, default=128, metavar='LENGTH',
                    help='the visual adapter dim')
    # Optimizer parameters
    parser.add_argument("--weight_decay", type=float, default=0.05, 
                        help="weight decay (default: 0.05)")
    parser.add_argument("--lr", type=float, default=None, metavar="LR", 
                        help="learning rate (absolute lr)")
    parser.add_argument("--blr", type=float, default=1e-3, metavar="LR",
                        help="base learning rate: absolute_lr = base_lr * total_batch_size / 256",)
    parser.add_argument("--min_lr", type=float, default=0.0, metavar="LR", 
                        help="lower lr bound for cyclic schedulers that hit 0")
    parser.add_argument('--gradient_checkpointing', action='store_true',
                        help='saving memory costs via gradient_checkpointing')
    parser.add_argument("--warmup_epochs", type=int, default=40, metavar="N", 
                        help="epochs to warmup LR")
    parser.add_argument('--clip_grad', type=float, default=None, metavar='clip gradient',
                        help='clips gradient norm of an iterable of parameters')
    # Dataset parameters
    parser.add_argument("--data_path", default="../data", type=str, help="scienceqa dataset path")
    parser.add_argument("--data_root", default="../data", type=str, help="scienceqa dataset path")
    parser.add_argument("--data_name", default="wikitext", type=str, help="dataset path",choices=["wikitext","mix"])
    parser.add_argument('--prompt_format', type=str, default='CQM-A',
                        choices=[
                            'CQM-A', 'CQM-LA', 'CQM-EA', 'CQM-LEA', 'CQM-ELA', 'CQM-AL', 'CQM-AE', 'CQM-ALE', 'QCM-A',
                            'QCM-LA', 'QCM-EA', 'QCM-LEA', 'QCM-ELA', 'QCM-AL', 'QCM-AE', 'QCM-ALE', 'QCML-A', 'QCME-A',
                            'QCMLE-A', 'QCLM-A', 'QCEM-A', 'QCLEM-A', 'QCML-AE'
                        ],
                        help='prompt format template')
    parser.add_argument('--options', type=list, default=["A", "B", "C", "D", "E"])
    parser.add_argument('--caption_file', type=str, default='../data/captions.json')
    parser.add_argument('--use_caption', action='store_true', help='use image captions or not')
    parser.add_argument('--split', type=str, default='train')
    parser.add_argument('--n_prompt', type=int, default=10, metavar='LENGTH',
                        help='the length of visual features')
    parser.add_argument("--output_dir", default="./output_dir", help="path where to save, empty for no saving")
    parser.add_argument("--log_dir", default="./output_dir", help="path where to tensorboard log")
    parser.add_argument("--device", default="cuda", help="device to use for training / testing")
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--resume", default="", help="resume from checkpoint")

    parser.add_argument("--start_epoch", default=0, type=int, metavar="N", help="start epoch")
    parser.add_argument("--num_workers", default=10, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
    parser.set_defaults(pin_mem=True)

    # distributed training parameters
    parser.add_argument("--world_size", default=1, type=int, help="number of distributed processes")
    parser.add_argument("--local_rank", default=-1, type=int)
    parser.add_argument("--dist_on_itp", action="store_true")
    parser.add_argument("--dist_url", default="env://", help="url used to set up distributed training")

    #eval config
    parser.add_argument('--ckpt_dir', type=str, default='../data')
    parser.add_argument('--tokenizer_path', type=str, default='../data')
    parser.add_argument('--adapter_path', type=str, default=None)
    parser.add_argument('--max_batch_size', type=int, default=16)
    

    # omniquant method config
    parser.add_argument("--wbits", type=int, default=4)
    parser.add_argument("--abits", type=int, default=16)
    parser.add_argument("--group_size", type=int, default=128, help="quantization group size for per channel")
    parser.add_argument("--let",default=False, action="store_true", help="activate learnable equivalent transformation")
    parser.add_argument("--lwc",default=False, action="store_true", help="activate learnable weight clipping")
    parser.add_argument("--symmetric",default=False, action="store_true", help="symmetric quantization")
    parser.add_argument("--a_dynamic_method", type=str, default="per_token", choices=["per_token"])
    parser.add_argument("--w_dynamic_method", type=str, default="per_channel", choices=["per_channel"])
    parser.add_argument("--let_lr", type=float, default=5e-3, help="learning rate for learnable equivalent transformation")
    parser.add_argument("--lwc_lr", type=float, default=1e-2, help="learning rate for learnable weight clipping")
    parser.add_argument("--wd", type=float, default=0, help="weight decay for quantization")
    
    # calib config 
    parser.add_argument('--quant_resume',type=str ,default=None, # llama-7b-w4a16_modify.pth
                        help='resume from quant checkpoint')
    parser.add_argument("--nsamples", type=int, default=512)
    parser.add_argument("--calib_epochs", type=int, default=0) # 0
    parser.add_argument("--cache_dir", default="./cache", type=str, help="cache dir of dataset, leading to faster debug")
    parser.add_argument("--model_family", default="llama", type=str, help="llm family")
    parser.add_argument("--calib_dataset",type=str,default="wikitext2",
        choices=["wikitext2", "ptb", "c4", "mix","pile",'scienceqa','scienceqawithimage'],
        help="Where to extract calibration data from.",
    )

    # pretrain addtional config
    parser.add_argument("--start_layer", type=int, default=5)
    parser.add_argument('--exp', type=str,default=None, help='experimental name')
    parser.add_argument('--scaling_resume',type=str ,default=None, # llama-7b-w4a16_modify.pth
                        help='resume from scaling checkpoint')
    parser.add_argument('--do_train',default=False, action="store_true",help="whether to do train")
    parser.add_argument('--do_test',default=False, action="store_true",help="whether to do evals")
    parser.add_argument('--mix_ratio',type=float,default=None,help="mix ratio between wikitext and scienceqa")
    parser.add_argument('--need_img', default=False, action="store_true",
                        help='whether to must have image')
    parser.add_argument('--idx_path',type=str,default=None,help="the path for scienceqa data")
    return parser


def main(args):

    misc.init_distributed_mode(args)

    print("job dir: {}".format(os.path.dirname(os.path.realpath(__file__))))
    print("{}".format(args).replace(", ", ",\n"))

    device = torch.device(args.device)

    # fix the seed for reproducibility
    seed = args.seed + misc.get_rank()
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    cudnn.benchmark = True
    if args.data_name == 'wikitext':
        dataset_train = WikiTextDataset(args, 'train', args.llama_model_path, args.max_seq_len)
        dataset_val = WikiTextDataset(args, 'test', args.llama_model_path, args.max_seq_len)
    else:
        dataset_train = MixDataset(args, 'train', args.llama_model_path, args.max_seq_len)
        dataset_val = WikiTextDataset(args, 'test', args.llama_model_path, args.max_seq_len)
    print(dataset_train)
    print(dataset_val)

    if True:
        num_tasks = misc.get_world_size()
        global_rank = misc.get_rank()
        sampler_train = torch.utils.data.DistributedSampler(
            dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )

        sampler_val = torch.utils.data.DistributedSampler(
            dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=True
        )

        print("Sampler_train = %s" % str(sampler_train))
    else:
        sampler_train = torch.utils.data.RandomSampler(dataset_train)

    if global_rank == 0 and args.log_dir is not None:
        os.makedirs(args.log_dir, exist_ok=True)
        log_writer = SummaryWriter(log_dir=args.log_dir)
    else:
        log_writer = None

    data_loader_train = torch.utils.data.DataLoader(
        dataset_train,
        sampler=sampler_train,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    data_loader_val = torch.utils.data.DataLoader(
        dataset_val,
        sampler=sampler_val,
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        pin_memory=args.pin_mem,
        drop_last=True,
    )

    # define the model
    if args.bits in ['4bit','8bit']:
        args.weight_quant_params = {
            "n_bits": args.wbits,
            "per_channel_axes": [0],
            "symmetric": args.symmetric,
            "dynamic_method": args.w_dynamic_method,
            "group_size": args.group_size,
            "lwc":args.lwc
        }
        args.act_quant_params = {
            "n_bits":  args.abits,
            "per_channel_axes": [],
            "symmetric": False,
            "dynamic_method": args.a_dynamic_method,
        }
    model = LLama(args)

    model.to(device)

    model_without_ddp = model
    print("Model = %s" % str(model_without_ddp))

    eff_batch_size = args.batch_size * args.accum_iter * misc.get_world_size()

    if args.lr is None:  # only base_lr is specified
        args.lr = args.blr * eff_batch_size / 256

    print("base lr: %.2e" % (args.lr * 256 / eff_batch_size))
    print("actual lr: %.2e" % args.lr)

    print("accumulate grad iterations: %d" % args.accum_iter)
    print("effective batch size: %d" % eff_batch_size)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
        model_without_ddp = model.module

    # following timm: set wd as 0 for bias and norm layers
    param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, args.weight_decay)
    optimizer = torch.optim.AdamW(param_groups, lr=args.lr, betas=(0.9, 0.999))
    print(optimizer)
    loss_scaler = NativeScaler()

    misc.load_model(args=args, model_without_ddp=model_without_ddp, optimizer=optimizer, loss_scaler=loss_scaler)

    if args.do_train:
        print(f"Start training for {args.epochs} epochs")
        start_time = time.time()
        for epoch in range(args.start_epoch, args.epochs):

            if args.distributed:
                data_loader_train.sampler.set_epoch(epoch)
                data_loader_val.sampler.set_epoch(epoch)

            train_stats = train_one_epoch(
                model, data_loader_train, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args
            )
            if args.output_dir:
                misc.save_model(
                    args=args,
                    model=model,
                    model_without_ddp=model_without_ddp,
                    optimizer=optimizer,
                    loss_scaler=loss_scaler,
                    epoch=epoch,
                )

            val_stats = val_one_epoch(
                model, data_loader_val, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args
            )

            log_stats = {
                **{f"train_{k}": v for k, v in train_stats.items()},
                "epoch": epoch,
                **{f"val_{k}": v for k, v in val_stats.items()},
            }

            if args.output_dir and misc.is_main_process():
                if log_writer is not None:
                    log_writer.flush()
                with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                    f.write(json.dumps(log_stats) + "\n")
        total_time = time.time() - start_time
        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
        print("Training time {}".format(total_time_str))
        if args.output_dir and misc.is_main_process():
            if log_writer is not None:
                log_writer.flush()
            with open(os.path.join(args.output_dir, "log.txt"), mode="a", encoding="utf-8") as f:
                f.write(json.dumps(f"Training time {total_time_str}") + "\n")
    if args.do_test:
        epoch = 0
        # for name, param in model.named_parameters():
        #     print(name, param.dtype)
        print(f"test model... quant params:{args.quant_resume} | scaling params:{args.scaling_resume}")
        val_stats = val_one_epoch(
            model, data_loader_val, optimizer, device, epoch, loss_scaler, log_writer=log_writer, args=args
        )
        log_stats = {
            **{f"val_{k}": v for k, v in val_stats.items()},
        }
        print(log_stats)

if __name__ == "__main__":
    args = get_args_parser()
    args = args.parse_args()
    if args.output_dir:
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
    main(args)