# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os

import torch
from peft import prepare_model_for_kbit_training
from torch.nn import CrossEntropyLoss
from tqdm import tqdm
from transformers import Adafactor, AdamW, get_linear_schedule_with_warmup, AutoTokenizer, GenerationConfig
from transformers import AutoModelForCausalLM, MistralForCausalLM, LlamaTokenizer

import inspect
import warnings

class XICLModel(object):

    def __init__(self, gpt2="gpt2-large", logger=None, 
        out_dir=None, fp16=False, local_rank=-1, soft_prefix=False, n_tokens=10, 
        prefix_embed_file=None, task_counts=None, max_length=None, data=None, use_presfix=False):
        if logger is None:
            class Logger():
                def info(self, text):
                    print ("Logging from XICLModel:\t", text)
            logger = Logger()

        self.logger = logger
        self.out_dir = out_dir
        self.fp16 = fp16
        self.local_rank = local_rank
        self.max_length = max_length

        if self.local_rank == -1:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            n_gpu = torch.cuda.device_count()
            ws = 1
        else:  # distributed mode
            torch.cuda.set_device(local_rank)
            device = torch.device("cuda", local_rank)
            ws = int(os.environ.get("WORLD_SIZE", os.environ.get("SLURM_NTASKS", 1)))
            torch.distributed.init_process_group(backend="nccl")
            n_gpu = 1

        self.n_gpu = n_gpu
        self.device = device
        if self.local_rank <= 0:
            logger.info("Setting up for local_rank=%d, world_size=%d" % (self.local_rank, ws))
        self.model_name = None
        self.model = None
        self.mode = None
        self.soft_prefix = soft_prefix
        self.prefix_token_ids = None
        self.data = data
        self.use_presfix = use_presfix
        if task_counts is None:
            self.n_tokens = n_tokens
        else:
            self.n_tokens = n_tokens * len(task_counts)
        self.load(gpt2)
        if soft_prefix:
            if "bloom" in self.model_name.lower():
                self.orig_vocab_size = self.model.get_input_embeddings().weight.size(0)-200
            else:
                self.orig_vocab_size = self.model.get_input_embeddings().weight.size(0)
            print("original vocab size: ", self.orig_vocab_size)
            self.model.resize_token_embeddings(self.orig_vocab_size + self.n_tokens)
            self.new_vocab_size = self.model.get_input_embeddings().weight.size(0)
            assert self.new_vocab_size == self.n_tokens + self.orig_vocab_size
            self.model.config.vocab_size = self.new_vocab_size
            if prefix_embed_file is not None:
                self.model.set_input_embeddings(torch.load(prefix_embed_file))
            else:
                self.model.get_input_embeddings().weight.data[-self.n_tokens:] = \
                    self.model.get_input_embeddings().weight.data[:self.n_tokens]
            self.model.tie_weights()

    def __str__(self):
        text = "[Topic_XICL Model]: "
        if self.model_name is None:
            text += "No model loaded yet"
        else:
            text += self.model_name
            if self.mode is None:
                text += " (no mode setted - try .train() or .eval()"
            else:
                text += " (%s mode)" % self.mode
        text += "\nusing device %s, %d gpus, local_rank=%d" % (self.device, self.n_gpu, self.local_rank)
        return ("="*50) + "\n" + text + "\n" + ("="*50)

    def is_none(self):
        return self.model is None

    def train(self):
        self.model.train()
        self.mode = "train"

    def eval(self):
        self.model.eval()
        self.mode = "eval"

    def cuda(self):
        self.model.cuda()

    def to_device(self):
        self.model.to(self.device)

    def prepare_model_for_kbit_training(self, model, use_gradient_checkpointing=True, gradient_checkpointing_kwargs=None):
        r"""
        Note this method only works for `transformers` models.

        This method wraps the entire protocol for preparing a model before running a training. This includes:
            1- Cast the layernorm in fp32 2- making output embedding layer require grads 3- Add the upcasting of the lm
            head to fp32

        Args:
            model (`transformers.PreTrainedModel`):
                The loaded model from `transformers`
            use_gradient_checkpointing (`bool`, *optional*, defaults to `True`):
                If True, use gradient checkpointing to save memory at the expense of slower backward pass.
            gradient_checkpointing_kwargs (`dict`, *optional*, defaults to `None`):
                Keyword arguments to pass to the gradient checkpointing function, please refer to the documentation of
                `torch.utils.checkpoint.checkpoint` for more details about the arguments that you can pass to that method.
                Note this is only available in the latest transformers versions (> 4.34.1).
        """
        loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
        is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
        is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
        if gradient_checkpointing_kwargs is None:
            gradient_checkpointing_kwargs = {}

        for name, param in model.named_parameters():
            # freeze base model's layers
            param.requires_grad = False
        model.get_input_embeddings().weight.requires_grad = True

        if not is_gptq_quantized and not is_aqlm_quantized:
            # cast all non INT8 parameters to fp32
            for param in model.parameters():
                if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
                    param.data = param.data.to(torch.float32)

        if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing:
            # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack
            if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
                # For backward compatibility
                if hasattr(model, "enable_input_require_grads"):
                    model.enable_input_require_grads()
                else:

                    def make_inputs_require_grad(output):
                        output.requires_grad_(True)

                    model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

            # To support older transformers versions, check if the model supports gradient_checkpointing_kwargs
            _supports_gc_kwargs = "gradient_checkpointing_kwargs" in list(
                inspect.signature(model.gradient_checkpointing_enable).parameters
            )

            if not _supports_gc_kwargs and len(gradient_checkpointing_kwargs) > 0:
                warnings.warn(
                    "gradient_checkpointing_kwargs is not supported in this version of transformers. The passed kwargs will be ignored."
                    " if you want to use that feature, please upgrade to the latest version of transformers.",
                    FutureWarning,
                )

            gc_enable_kwargs = (
                {} if not _supports_gc_kwargs else {"gradient_checkpointing_kwargs": gradient_checkpointing_kwargs}
            )

            # enable gradient checkpointing for memory efficiency
            model.gradient_checkpointing_enable(**gc_enable_kwargs)
        return model

    def load(self, model_name="gpt2-large"):
        if not self.soft_prefix:
            model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
            model.cuda()
        else:
            if "Mistral" in model_name:
                model = MistralForCausalLM.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16,
                ).cuda()
            else:
                model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16)
                # model = self.prepare_model_for_kbit_training(model)
                # freeze all parameters but soft prefix
                for name, param in model.named_parameters():
                    param.requires_grad = False
                model.get_input_embeddings().weight.requires_grad = True

                is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
                is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"

                if not is_gptq_quantized and not is_aqlm_quantized:
                    # cast all non INT8 parameters to fp32
                    for param in model.parameters():
                        if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
                            param.data = param.data.to(torch.float32)
        self.model_name = model_name

        model.eval()

        if torch.__version__ == '1.14.0.dev20221208+cu117':
            self.model = torch.compile(model)
        else:
            self.model = model
        #
        if self.data is not None:
            self.tokenizer = self.data.tokenizer
        else:
            if "Mistral" in self.model_name:
                self.tokenizer = LlamaTokenizer.from_pretrained(model_name, trust_remote_code=True, padding_side="left",
                                                           legacy=False)
                if self.tokenizer.pad_token is None:
                    self.tokenizer.pad_token = self.tokenizer.unk_token
            else:
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if 'bloom' in model_name or 'polylm' in model_name:
            self.generation_config = self.model.generation_config
        else:
            self.generation_config = GenerationConfig.from_pretrained(model_name, "generation_config.json")
        #self.generation_config.max_length = self.max_length
        if 't5' not in model_name:
            if not self.soft_prefix or self.use_presfix:
                self.generation_config.max_new_tokens = 32
            else:
                #if "bloom" in self.model_name.lower():
                #    self.generation_config.suppress_tokens = [i for i in range(4, self.model.get_input_embeddings().weight.size(0)-200)]
                #else:
                #    self.generation_config.suppress_tokens = [i for i in range(4, self.model.get_input_embeddings().weight.size(0))]
                self.generation_config.max_new_tokens = 15  # no_repeat_ngram_size=1
                # self.generation_config.no_repeat_ngram_size=1
                self.generation_config.min_new_tokens = 15
            # if 'llama' not in model_name:
            #    self.generation_config.min_new_tokens = 10
        else:
            self.generation_config.max_new_tokens = self.max_length + 32

    def force_word(self):
        force_word = self.prefix_token_ids
        force_words_ids = [[y] for key,x in force_word.items() for y in x]
        self.generation_config.force_words_ids = force_words_ids
        self.generation_config.num_beams=5

    def save(self, step, save_all=False):
        if self.local_rank <= 0:
            if save_all:
                model_state_dict = {key[7:] if key.startswith("module.") else key: value.cpu()
                                    for key, value in self.model.state_dict().items()}
                torch.save(model_state_dict, os.path.join(self.out_dir, "model-{}.pt".format(step)))
                self.logger.info("Saving model parameters at step=%d" % step)
            else:
                torch.save(self.model.get_input_embeddings(),
                    os.path.join(self.out_dir, "soft_embeddings-{}.pt".format(step)))

    def setup_optimizer(self, optimization, num_training_steps, lr, weight_decay, warmup_steps):
        # no_decay = ['bias', 'LayerNorm.weight']
        # optimizer_grouped_parameters = [
        #         {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
        #         {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        # ]

        optimizer_grouped_parameters = [
                {'params': self.model.get_input_embeddings().weight, 'weight_decay': weight_decay}
        ]
        print("fine tune parameters: ", optimizer_grouped_parameters)

        if optimization=="adafactor":
            optimizer = Adafactor(optimizer_grouped_parameters,
                                  lr=lr,
                                  relative_step=False,
                                  warmup_init=False,
                                  weight_decay=weight_decay)
            scheduler = None
        elif optimization.startswith("adamw"):
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=lr,
                              eps=1e-08,
                              weight_decay=weight_decay)
            if self.fp16:
                self.model, optimizer = setup_fp16(self.model, optimizer)
            if optimization=="adamw":
                scheduler = get_linear_schedule_with_warmup(optimizer,
                                                            num_warmup_steps=warmup_steps,
                                                            num_training_steps=num_training_steps)
            else:
                raise NotImplementedError()
        elif optimization=="8bit-adam":
            import bitsandbytes as bnb
            optimizer = bnb.optim.Adam8bit(optimizer_grouped_parameters,
                                           lr=lr, betas=(0.9, 0.995))
            if self.fp16:
                self.model, optimizer = setup_fp16(self.model, optimizer)
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=warmup_steps,
                                                        num_training_steps=num_training_steps)
        else:
            raise NotImplementedError()

        self.optimizer = optimizer
        self.scheduler = scheduler

    def parallel(self):
        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        if self.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.local_rank], output_device=self.local_rank)


    def do_train(self, data, batch_size, num_training_steps, save_period, log_period,
                 gradient_accumulation_steps=1, max_grad_norm=1):
        dataloader = data.get_dataloader(batch_size, is_training=True)

        n_all_params = sum(p.numel() for p in self.model.parameters())
        n_trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        # n_trainable_params = len([param for param in self.model.parameters() if param.requires_grad])
        n_gpus = torch.cuda.device_count()
        self.logger.info("Training {} parameters on {} examples for {} steps using {} GPUs".format(
            n_trainable_params, len(data), num_training_steps, self.n_gpu))
        print("All {} Training {} parameters on {} examples for {} steps using {} GPUs".format(n_all_params,
            n_trainable_params, len(data), num_training_steps, self.n_gpu))
        global_step = 0
        train_losses = 0.0
        best_accuracy = -1
        stop_training=False
        step = 0
        logging_loss = 0.0

        for epoch in range(num_training_steps):
            print("epoch: ", epoch)
            for batch in dataloader:
                global_step += 1

                input_ids=batch[0].to(self.device)
                attention_mask=batch[1].to(self.device)
                output_ids=batch[2].to(self.device)
                token_type_ids = batch[4].to(self.device)

                loss = self.run_model(input_ids, attention_mask, output_ids, token_type_ids)

                loss = loss.mean()
                #loss = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=output_ids).loss
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps


                if torch.isnan(loss).data:
                    print ("Stop training because loss=%s" % (loss.data))
                    stop_training=True
                    break
                # train_losses.append(loss.detach().cpu())
                train_losses += loss.item()
                if self.fp16:
                    from apex import amp
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    # self.optimizer.zero_grad()
                    loss.backward()

                self.model.get_input_embeddings().weight.grad[:self.orig_vocab_size] = 0
                #self.model.lm_head.weight.grad[:self.orig_vocab_size] = 0

                if global_step % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                    self.optimizer.step()    # We have accumulated enought gradients
                    if self.scheduler is not None:
                        self.scheduler.step()
                    self.model.zero_grad()
                    step += 1

                    if step % log_period == 0:
                        self.logger.info("local rank %d\tglobal step %d\ttrain loss %.2f" % (self.local_rank, global_step, (train_losses - logging_loss) / step))
                        logging_loss = train_losses

                    if global_step % save_period == 0:
                        self.save(global_step)

                    if global_step==num_training_steps:
                        break

            if global_step==num_training_steps:
                break

        self.logger.info("Finish training")

    def do_inference(self, data, batch_size=1, verbose=False):
        dataloader = data.get_dataloader(batch_size, is_training=False)
        if verbose:
            dataloader = tqdm(dataloader)
        losses = []
        for batch in dataloader:
            input_ids=batch[0].cuda()
            attention_mask=batch[1].cuda()
            output_ids=batch[2].cuda()
            token_type_ids = batch[4].cuda()
            with torch.no_grad():
                loss = self.run_model(input_ids, attention_mask, output_ids, token_type_ids)
            losses += [loss.item()]
        return losses

    def do_predict(self, data, batch_size=1, losses=None, verbose=False, return_nll=False, use_prefix=False):
        if losses is None:
            losses = self.do_inference(data, batch_size, verbose=verbose)
        # losses = np.array(losses)
        assert len(losses)==len(data)

        dataloader = data.get_dataloader(batch_size, is_training=False)
        dataloader = tqdm(dataloader, desc='Infering the test set')
        predictions = []
        qas_id = [meta['indices'] for meta in data.metadata]
        for batch in dataloader:
            input_ids = batch[0].cuda()
            if self.soft_prefix and not use_prefix:
                output_mask = batch[3].cuda()
            else:
                output_mask = batch[4].cuda()
            preds = self.run_model_predict(input_ids, output_mask)
            predictions.extend(preds)
        assert len(predictions) == len(qas_id)
        gt_labels = []
        for idx, dp in enumerate(data.metadata):
            gt_labels.append(dp["answer"])
        pred_dict = {}
        for idx, preds in zip(qas_id, predictions):
            if idx not in pred_dict.keys():
                pred_dict[idx] = preds
        if return_nll:
            return pred_dict, losses, gt_labels
        else:
            return pred_dict

    def run_model(self, input_ids, attention_mask, output_ids, token_type_ids):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=output_ids)

        label_mask = token_type_ids[..., 1:].contiguous()
        logits = outputs.logits[..., :-1, :].contiguous()
        labels = output_ids.to(logits.device)
        # Shift so that tokens < n predict n
        shift_logits = logits
        shift_labels = labels[..., 1:].contiguous()
        batch_size, seq_length, vocab_size = shift_logits.shape
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction="none")
        loss = loss_fct(
            shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
        )

        losses = loss.view(logits.size(0), logits.size(1)) * label_mask
        return torch.sum(losses, axis=1) / torch.sum(label_mask, axis=1)

    def run_model_predict(self, input_ids, output_mask):
        results = []
        # if self.generation_config is not None:
        for i in range(input_ids.size()[0]):
            input_id = input_ids[i]
            try:
                output_m_id = output_mask[i].cpu().numpy().tolist().index(0)
                if output_m_id == 0:
                    output_m_id = output_mask[i].cpu().numpy().tolist().index(1)
            except:
                output_m_id = input_id.size()[0]
            input_id = input_id[:output_m_id].unsqueeze(0)

            with torch.no_grad():
                outputs = self.model.generate(input_ids=input_id, generation_config=self.generation_config)
                inputs_text = self.tokenizer.batch_decode(input_id, skip_special_tokens=False, clean_up_tokenization_spaces=True)
                rets = self.tokenizer.batch_decode(outputs, skip_special_tokens=False, clean_up_tokenization_spaces=True)
                r = rets[0].replace(inputs_text[0], "").strip()
                out = outputs[:, output_m_id:]
                rets_ = self.tokenizer.batch_decode(out, skip_special_tokens=False,
                                                   clean_up_tokenization_spaces=True)
                print(rets_)
                results.append(r)

        return results

    def get_label(self, input_ids, attention_mask, output_ids, token_type_ids):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=output_ids)

        label_mask = token_type_ids[..., 1:].contiguous()
        logits = outputs.logits[..., :-1, :].contiguous()
        labels = output_ids.to(logits.device)
        # Shift so that tokens < n predict n
        shift_logits = logits
        shift_labels = labels[..., 1:].contiguous()
        batch_size, seq_length, vocab_size = shift_logits.shape
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction="none")
        loss = loss_fct(
            shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
        )

        losses = loss.view(logits.size(0), logits.size(1)) * label_mask
        for i in range(input_ids.size()[0]):
            output_start = output_ids[i].cpu().numpy().tolist().index(0)
            output_end = output_start + len(self.prefix_token_ids["t0"])
            pre_label = outputs.logits[i,output_start:output_end,self.orig_vocab_size:]  # .argmax(dim=-1) + self.orig_vocab_size
            a, idx1 = torch.sort(pre_label, descending=True)  # descending为alse，升序，为True，降序
            idx = idx1[0,:len(self.prefix_token_ids["t0"])] + self.orig_vocab_size
            label_pre = []
            for j in idx.cpu().numpy().tolist():
                for key,value in self.prefix_token_ids.items():
                    if j in value:
                        label_pre.append(key+"-"+str(value.index(j)))
        return torch.sum(losses, axis=1) / torch.sum(label_mask, axis=1)


class XICLModel_xnli(object):

    def __init__(self, gpt2="gpt2-large", logger=None,
        out_dir=None, fp16=False, local_rank=-1, soft_prefix=False, n_tokens=10,
        prefix_embed_file=None, task_counts=None, max_length=None, data=None, use_presfix=False):
        if logger is None:
            class Logger():
                def info(self, text):
                    print ("Logging from XICLModel:\t", text)
            logger = Logger()

        self.use_presfix = use_presfix
        self.logger = logger
        self.out_dir = out_dir
        self.fp16 = fp16
        self.local_rank = local_rank
        self.max_length = max_length

        if self.local_rank == -1:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            n_gpu = torch.cuda.device_count()
            ws = 1
        else:  # distributed mode
            torch.cuda.set_device(local_rank)
            device = torch.device("cuda", local_rank)
            ws = int(os.environ.get("WORLD_SIZE", os.environ.get("SLURM_NTASKS", 1)))
            torch.distributed.init_process_group(backend="nccl")
            n_gpu = 1

        self.n_gpu = n_gpu
        self.device = device
        if self.local_rank <= 0:
            logger.info("Setting up for local_rank=%d, world_size=%d" % (self.local_rank, ws))
        self.model_name = None
        self.model = None
        self.mode = None
        self.soft_prefix = soft_prefix
        self.prefix_token_ids = None
        self.data = data
        if task_counts is None:
            self.n_tokens = n_tokens
        else:
            self.n_tokens = n_tokens * len(task_counts)
        self.load(gpt2)
        if soft_prefix:
            if "bloom" in self.model_name.lower():
                self.orig_vocab_size = self.model.get_input_embeddings().weight.size(0)-200
            else:
                self.orig_vocab_size = self.model.get_input_embeddings().weight.size(0)
            print("original vocab size: ", self.orig_vocab_size)
            self.model.resize_token_embeddings(self.orig_vocab_size + self.n_tokens)
            self.new_vocab_size = self.model.get_input_embeddings().weight.size(0)
            assert self.new_vocab_size == self.n_tokens + self.orig_vocab_size
            self.model.config.vocab_size = self.new_vocab_size
            if prefix_embed_file is not None:
                #self.model.set_input_embeddings(torch.load(prefix_embed_file))
                self.model.get_input_embeddings().load_state_dict(torch.load(prefix_embed_file))
            else:
                self.model.get_input_embeddings().weight.data[-self.n_tokens:] = \
                    self.model.get_input_embeddings().weight.data[:self.n_tokens]
            self.model.tie_weights()

    def __str__(self):
        text = "[Topic_XICL Model]: "
        if self.model_name is None:
            text += "No model loaded yet"
        else:
            text += self.model_name
            if self.mode is None:
                text += " (no mode setted - try .train() or .eval()"
            else:
                text += " (%s mode)" % self.mode
        text += "\nusing device %s, %d gpus, local_rank=%d" % (self.device, self.n_gpu, self.local_rank)
        return ("="*50) + "\n" + text + "\n" + ("="*50)

    def is_none(self):
        return self.model is None

    def train(self):
        self.model.train()
        self.mode = "train"

    def eval(self):
        self.model.eval()
        self.mode = "eval"

    def cuda(self):
        self.model.cuda()

    def to_device(self):
        self.model.to(self.device)

    def load(self, model_name="gpt2-large"):
        if not self.soft_prefix or self.use_presfix:
            model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
            model.cuda()
        else:
            model = AutoModelForCausalLM.from_pretrained(model_name, device_map='auto', torch_dtype=torch.float16)
            # model = self.prepare_model_for_kbit_training(model)
            # freeze all parameters but soft prefix
            for name, param in model.named_parameters():
                param.requires_grad = False
            model.get_input_embeddings().weight.requires_grad = True

            is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
            is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"

            if not is_gptq_quantized and not is_aqlm_quantized:
                # cast all non INT8 parameters to fp32
                for param in model.parameters():
                    if (param.dtype == torch.float16) or (param.dtype == torch.bfloat16):
                        param.data = param.data.to(torch.float32)
        self.model_name = model_name

        if torch.__version__ == '1.14.0.dev20221208+cu117':
            self.model = torch.compile(model)
        else:
            self.model = model
        #
        if self.data is not None:
            self.tokenizer = self.data.tokenizer
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        if 'bloom' in model_name:
            self.generation_config = self.model.generation_config
        else:
            self.generation_config = GenerationConfig.from_pretrained(model_name, "generation_config.json")

        #self.generation_config.max_length = self.max_length
        if 't5' not in model_name:
            if not self.soft_prefix or self.use_presfix:
                self.generation_config.max_new_tokens = 2
            else:
                self.generation_config.suppress_tokens = [i for i in range(4,
                                                                           self.model.get_input_embeddings().weight.size(
                                                                               0) - 200)]
                self.generation_config.max_new_tokens = 10  # no_repeat_ngram_size=1
                self.generation_config.no_repeat_ngram_size = 2
                self.generation_config.min_new_tokens = 10
            # if 'llama' not in model_name:
        else:
            self.generation_config.max_new_tokens = self.max_length + 32

    def force_word(self):
        force_word = self.prefix_token_ids
        force_words_ids = [[y] for key,x in force_word.items() for y in x]
        self.generation_config.force_words_ids = force_words_ids
        self.generation_config.num_beams=5

    def save(self, step, save_all=False):
        if self.local_rank <= 0:
            if save_all:
                model_state_dict = {key[7:] if key.startswith("module.") else key: value.cpu()
                                    for key, value in self.model.state_dict().items()}
                torch.save(model_state_dict, os.path.join(self.out_dir, "model-{}.pt".format(step)))
                self.logger.info("Saving model parameters at step=%d" % step)
            else:
                torch.save(self.model.get_input_embeddings().state_dict(),
                    os.path.join(self.out_dir, "soft_embeddings-{}.pt".format(step)))

    def setup_optimizer(self, optimization, num_training_steps, lr, weight_decay, warmup_steps):
        # no_decay = ['bias', 'LayerNorm.weight']
        # optimizer_grouped_parameters = [
        #         {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay},
        #         {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        # ]

        optimizer_grouped_parameters = [
                {'params': self.model.get_input_embeddings().weight, 'weight_decay': weight_decay}
        ]
        print("fine tune parameters: ", optimizer_grouped_parameters)

        if optimization=="adafactor":
            optimizer = Adafactor(optimizer_grouped_parameters,
                                  lr=lr,
                                  relative_step=False,
                                  warmup_init=False,
                                  weight_decay=weight_decay)
            scheduler = None
        elif optimization.startswith("adamw"):
            optimizer = AdamW(optimizer_grouped_parameters,
                              lr=lr,
                              eps=1e-08,
                              weight_decay=weight_decay)
            if self.fp16:
                self.model, optimizer = setup_fp16(self.model, optimizer)
            if optimization=="adamw":
                scheduler = get_linear_schedule_with_warmup(optimizer,
                                                            num_warmup_steps=warmup_steps,
                                                            num_training_steps=num_training_steps)
            else:
                raise NotImplementedError()
        elif optimization=="8bit-adam":
            import bitsandbytes as bnb
            optimizer = bnb.optim.Adam8bit(optimizer_grouped_parameters,
                                           lr=lr, betas=(0.9, 0.995))
            if self.fp16:
                self.model, optimizer = setup_fp16(self.model, optimizer)
            scheduler = get_linear_schedule_with_warmup(optimizer,
                                                        num_warmup_steps=warmup_steps,
                                                        num_training_steps=num_training_steps)
        else:
            raise NotImplementedError()

        self.optimizer = optimizer
        self.scheduler = scheduler

    def parallel(self):
        if self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        if self.local_rank != -1:
            self.model = torch.nn.parallel.DistributedDataParallel(
                self.model, device_ids=[self.local_rank], output_device=self.local_rank)


    def do_train(self, data, batch_size, num_training_steps, save_period, log_period,
                 gradient_accumulation_steps=1, max_grad_norm=1.0):
        dataloader = data.get_dataloader(batch_size, is_training=True)
        n_all_params = sum(p.numel() for p in self.model.parameters())
        n_trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        # n_trainable_params = len([param for param in self.model.parameters() if param.requires_grad])
        n_gpus = torch.cuda.device_count()
        self.logger.info("Training {} parameters on {} examples for {} steps using {} GPUs".format(
            n_trainable_params, len(data), num_training_steps, self.n_gpu))
        print("All {} Training {} parameters on {} examples for {} steps using {} GPUs".format(n_all_params,
            n_trainable_params, len(data), num_training_steps, self.n_gpu))

        global_step = 0
        train_losses = 0.0
        best_accuracy = -1
        stop_training=False
        step = 0
        logging_loss = 0.0
        for epoch in range(num_training_steps):
            print("epoch: ", epoch)
            for batch in dataloader:
                global_step += 1

                input_ids=batch[0].to(self.device)
                attention_mask=batch[1].to(self.device)
                output_ids=batch[2].to(self.device)

                outputs = self.run_model(input_ids, attention_mask, output_ids)
                loss = outputs.loss.mean()
                if gradient_accumulation_steps > 1:
                    loss = loss / gradient_accumulation_steps

                if torch.isnan(loss).data:
                    print ("Stop training because loss=%s" % (loss.data))
                    stop_training=True
                    break
                # train_losses.append(loss.detach().cpu())
                train_losses += loss.item()
                if self.fp16:
                    from apex import amp
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # self.model.get_input_embeddings().weight.grad[:self.orig_vocab_size] = 0
                if step == 0:
                    n_all_params = sum(p.numel() for p in self.model.parameters())
                    n_trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
                    # n_trainable_params = len([param for param in self.model.parameters() if param.requires_grad])
                    n_gpus = torch.cuda.device_count()
                    self.logger.info("Training {} parameters on {} examples for {} steps using {} GPUs".format(
                        n_trainable_params, len(data), num_training_steps, self.n_gpu))
                    print("All {} Training {} parameters on {} examples for {} steps using {} GPUs".format(n_all_params,
                                                                                                           n_trainable_params,
                                                                                                           len(data),
                                                                                                           num_training_steps,
                                                                                                           self.n_gpu))

                if global_step % save_period == 0:
                    self.save(global_step)
                if global_step % gradient_accumulation_steps == 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

                    self.optimizer.step()    # We have accumulated enought gradients
                    if self.scheduler is not None:
                        self.scheduler.step()
                    self.model.zero_grad()
                    step += 1

                    if step % log_period == 0:
                        self.logger.info("local rank %d\tglobal step %d\ttrain loss %.2f" % (self.local_rank, global_step, (train_losses - logging_loss) / step))
                        logging_loss = train_losses

                    if global_step==num_training_steps:
                        break

            if global_step==num_training_steps:
                break

        self.logger.info("Finish training")

    def do_inference(self, data, batch_size=1, verbose=False):
        dataloader = data.get_dataloader(batch_size, is_training=False)
        if verbose:
            dataloader = tqdm(dataloader)
        losses = []
        for batch in dataloader:
            input_ids=batch[0].cuda()
            attention_mask=batch[1].cuda()
            output_ids=batch[2].cuda()
            with torch.no_grad():
                loss = self.run_model(input_ids, attention_mask, output_ids)
                # self.run_model_predict(input_ids, attention_mask)
            losses += [loss.loss.item()]
        return losses

    def do_predict(self, data, batch_size=1, losses=None, verbose=False, return_nll=False):
        if not self.use_presfix:
            if losses is None:
                losses = self.do_inference(data, batch_size, verbose=verbose)
            # losses = np.array(losses)
            assert len(losses)==len(data)

        dataloader = data.get_dataloader(batch_size, is_training=False)
        dataloader = tqdm(dataloader, desc='Infering the test set')
        predictions = []
        qas_id = [meta['indices'] for meta in data.metadata]
        for batch in dataloader:
            input_ids = batch[0].cuda()
            output_mask = batch[4].cuda()
            if "polylm" in self.model_name:
                input_mask = batch[1].cuda()
                preds = self.run_model_polylm(input_ids, output_mask, input_mask)
            else:
                preds = self.run_model1(input_ids, output_mask)
            predictions.extend(preds)
        assert len(predictions) == len(qas_id)
        gt_labels = []
        for idx, dp in enumerate(data.metadata):
            gt_labels.append(dp["answer"])
        preds_dict = {}
        for idx, pred in zip(qas_id, predictions):
            if idx not in preds_dict.keys():
                preds_dict[idx] = pred
        if return_nll:
            return preds_dict, losses, gt_labels
        else:
            return preds_dict

    def run_model(self, input_ids, attention_mask, output_ids):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=output_ids)
        # losses = outputs.loss
        return outputs

    def run_model1(self, input_ids, output_mask):
        results = []
        # if self.generation_config is not None:
        for i in range(input_ids.size()[0]):
            input_id = input_ids[i]
            try:
                output_m_id = output_mask[i].cpu().numpy().tolist().index(1)
            except:
                output_m_id = input_id.size()[0]
            input_id = input_id[:output_m_id].unsqueeze(0)

            with torch.no_grad():
                outputs = self.model.generate(input_ids=input_id, generation_config=self.generation_config)
                inputs_text = self.tokenizer.batch_decode(input_id, skip_special_tokens=True,
                                                          clean_up_tokenization_spaces=True)
                rets = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                r = rets[0].strip().replace(inputs_text[0], "")
                if r == '' or r.strip() == '' or r == '.':
                    r = 'no answer'
                results.append(r)

        return results

    def run_model_polylm(self, input_ids, output_mask, input_mask):
        results = []
        # if self.generation_config is not None:
        pad_token_id = self.tokenizer.eos_token_id
        for i in range(input_ids.size()[0]):
            input_id = input_ids[i]
            try:
                output_m_id = output_mask[i].cpu().numpy().tolist().index(1)
            except:
                output_m_id = input_id.size()[0]
            input_id = input_id[:output_m_id].unsqueeze(0)

            with torch.no_grad():
                outputs = self.model.generate(input_ids=input_id, attention_mask=input_mask,
                                              generation_config=self.generation_config, pad_token_id=pad_token_id,
                                              do_sample=False,
                                              num_beams=4,
                                              max_length=128,
                                              early_stopping=True)
                inputs_text = self.tokenizer.batch_decode(input_id, skip_special_tokens=True,
                                                          clean_up_tokenization_spaces=True)
                rets = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                r = rets[0].strip().replace(inputs_text[0], "")
                if r == '' or r.strip() == '' or r == '.':
                    r = 'no answer'
                results.append(r)
                print(r)

        return results

    def run_model_predict(self, input_ids, output_mask):
        results = []
        # if self.generation_config is not None:
        for i in range(input_ids.size()[0]):
            input_id = input_ids[i]
            try:
                output_m_id = output_mask[i].cpu().numpy().tolist().index(0)
                if output_m_id == 0:
                    output_m_id = output_mask[i].cpu().numpy().tolist().index(1)
            except:
                output_m_id = input_id.size()[0]
            input_id = input_id[:output_m_id].unsqueeze(0)

            with torch.no_grad():

                outputs = self.model.generate(input_ids=input_id, generation_config=self.generation_config)
                inputs_text = self.tokenizer.batch_decode(input_id, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                rets = self.tokenizer.batch_decode(outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True)
                r = rets[0].replace(inputs_text[0], "").strip()
                out = outputs[:, output_m_id:]
                rets_ = self.tokenizer.batch_decode(out, skip_special_tokens=False,
                                                   clean_up_tokenization_spaces=True)
                print(rets_)
                #results.append(r)

        #return results


def setup_fp16(model, optimizer):
    import apex
    from apex import amp
    apex.amp.register_half_function(torch, "einsum")

    fp16_opt_level = "O1"
    model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level)
    return model, optimizer



