import warnings
from dataclasses import dataclass, field
from typing import Optional

import torch
from peft.utils import _prepare_prompt_learning_config
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, AutoModelForCausalLM, BloomPreTrainedModel, PreTrainedModel

from functools import partial
from peft.peft_model import _get_batch_size
import numpy as np
from enhanced_models import BloomForCausalLMEnhanced, OPTForCausalLMEnhanced, GPTJForCausalLMEnhanced
from peft import (
    LoraConfig,
    PeftConfig,
    PeftModel,
    get_peft_model,
    get_peft_config,
    get_peft_model,
    PromptTuningInit,
    PromptTuningConfig,
    TaskType,
    PeftType,
    PrefixTuningConfig, PeftModelForCausalLM, PeftMixedModel, MODEL_TYPE_TO_PEFT_MODEL_MAPPING
)

NAMES_TO_CHECKPOINTS = {
    'gptj': 'EleutherAI/gpt-j-6B',
    'gpt-neox': 'EleutherAI/gpt-neox-20b',
    'opt-1.3b': "facebook/opt-1.3b",
    'opt-6.7b': "facebook/opt-6.7b",
    'opt-30b': "facebook/opt-30b",
    'opt-66b': "facebook/opt-66b",
    'bloom-1.7b': 'bigscience/bloom-1b7',
    'bloom-3b': 'bigscience/bloom-3b',
    'bloom-7.1b': 'bigscience/bloom-7b1',
    'pythia-6.9b': 'EleutherAI/pythia-6.9b',
    'pythia-12b': 'EleutherAI/pythia-12b',
    'llama-7b': 'Neko-Institute-of-Science/LLaMA-7B-HF',
    'llama3-8b': 'meta-llama/Meta-Llama-3-8B',  # 'Meta-Llama-3-8B', #'Neko-Institute-of-Science/LLaMA-7B-HF',
    'llama-13b': 'Neko-Institute-of-Science/LLaMA-13B-HF',
    'llama-30b': 'Neko-Institute-of-Science/LLaMA-30B-hf',
    'llama-65b': 'Neko-Institute-of-Science/LLaMA-65B-hf',
    'falcon-1b': 'tiiuae/falcon-rw-1b',
    'falcon-7b': 'tiiuae/falcon-7b',
    'falcon-40b': 'tiiuae/falcon-40b',
    'llama-2-13b-hf': 'meta-llama/Llama-2-13b-hf',
    'llama-2-70b-hf': 'meta-llama/Llama-2-70b-hf',
}

INITIAL_PROMPT = {
    'sst2': "classify the sentiment of the following text as positive or negative:",
    'dbpedia': "classify the following text into one of the following categories: Company, Educational Institution, Artist, Athlete, Office Holder,\
        Mean Of Transportation, Building, Natural Place, Village, Animal,Plant, Album, Film, Written Work",
    'agnews': "classify the following text into one of the following categories: World, Sports, Business, Technology",
    'trec': "classify the following text into one of the following categories: Description, Entity, Expression, Human, Location, Number",
    'seq_language': 'Guess the next value in the sequence'
}

INITIAL_PROMPT_TEMPLETE = {
    'sst2': "classify the sentiment of the following {} {} {} positive {} or {} negative {}",
    'dbpedia': "classify the following {} {} {} Company {} or {} Educational Institution {} or {} Artist {} or {} Athlete {}or {}Office Holder {} or\
        {} Mean Of Transportation {} or {} Building {} or {} Natural Place {} or {} Village {} or {} Animal {} or {} Plant {} or {} Album {} or {} Film {} or {} Written Work {}",
    'agnews': "classify the following text into one of the following categories: World, Sports, Business, Technology",
    'trec': "classify the following text into one of the following categories: Description, Entity, Expression, Human, Location, Number",
}

INITIAL_PROMPT_VANILLA = {
    'sst2': "classify the following text",
    'dbpedia': "classify the following text",
    'agnews': "classify the following text",
    'trec': "classify the following text",
}


class Generator:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def __repr__(self):
        return self.model.__repr__()

    def __str__(self):
        return self.model.__str__()


def load_generator(model_name, cache_dir=None, precision='fp16', local_files_only=False, device_map="auto", args=None,
                   ICL_token_ids=None, ICL_mask=None, ICL_text=None):
    torch.backends.cudnn.deterministic = True
    precision = args.fp16
    MAX_NUM_ATTEMPTS = 3
    attempts = 0
    while True:
        try:
            tokenizer = AutoTokenizer.from_pretrained(
                NAMES_TO_CHECKPOINTS[model_name],
                cache_dir=cache_dir,
                padding_side='right',
                trust_remote_code=True,
                local_files_only=local_files_only,
            )
            break
        except:
            attempts += 1
            print('Cant load tokenizer')
            if attempts == MAX_NUM_ATTEMPTS:
                exit(1)

    attempts = 0
    while True:
        try:
            model = AutoModelForCausalLM.from_pretrained(
                NAMES_TO_CHECKPOINTS[model_name],
                cache_dir=cache_dir,
                torch_dtype=precision,
                device_map=device_map,
                trust_remote_code=True,
                local_files_only=local_files_only,
            )
            break
        except:
            attempts += 1
            print('Cant load model')
            if attempts == MAX_NUM_ATTEMPTS:
                exit(1)

    if 'llama' in model_name:
        tokenizer.add_special_tokens({'pad_token': '[PAD]'})
        tokenizer.pad_token_id = tokenizer.unk_token_id
        model.config.pad_token_id = tokenizer.unk_token_id
    else:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = model.config.eos_token_id

    if args.method_boost_type == 'ours':
        model = OurModelWrapper(model, args)
    elif args.method_boost_type in ['LORA']:
        config = LoraConfig(  # 8-8
            r=args.lora_r, lora_alpha=args.lora_alpha, lora_dropout=0.0,
            target_modules=["embed_tokens", "lm_head", "q_proj", "v_proj"]
        )
        model = get_peft_model(model, config)

    elif args.method_boost_type in ['prompt_tuning']:
        tokenizer_name_or_path = NAMES_TO_CHECKPOINTS[model_name]

        initial_prompt = ICL_text
        if initial_prompt is None:
            initial_prompt = INITIAL_PROMPT[args.dataset[0]]

        initial_prompt_len = len(tokenizer.tokenize(initial_prompt))
        if ICL_token_ids is not None:
            initial_prompt_len = len(ICL_token_ids)
        if args.use_random_instruction:
            initial_prompt_len = args.num_of_learnable_format

        context_ids = tokenizer.encode(initial_prompt)
        if ICL_token_ids is not None:
            context_ids = ICL_token_ids

        peft_config = PromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=initial_prompt_len,
            prompt_tuning_init_text=initial_prompt,  # TODO: update this prompt
            tokenizer_name_or_path=tokenizer_name_or_path,
        )
        model = get_peft_model(model, peft_config)
        # print((model.prompt_encoder.default.embedding.weight - model.word_embeddings(
        #     torch.Tensor(context_ids).long().to(model.prompt_encoder.default.embedding.weight.device))).norm(dim=1))
        if args.use_random_instruction == 0:
            init_token_ids = torch.Tensor(context_ids).long().to(model.prompt_encoder.default.embedding.weight.device)
            word_embedding_weights = model.word_embeddings(init_token_ids).detach().clone()
            word_embedding_weights = word_embedding_weights.to(torch.float32)
        else:
            new_weights = torch.zeros_like(model.prompt_encoder.default.embedding.weight).to(
                model.prompt_encoder.default.embedding.weight.device)
            new_weights = torch.nn.init.normal_(new_weights, mean=0, std=0.02)
            word_embedding_weights = new_weights

        model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(word_embedding_weights)

    elif args.method_boost_type in ['prefix_tuning']:
        peft_config = PrefixTuningConfig(task_type=TaskType.CAUSAL_LM, num_virtual_tokens=args.num_virtual_tokens)
        model = get_peft_model(model, peft_config)
    elif args.method_boost_type in ['our_prompt_tuning']:
        tokenizer_name_or_path = NAMES_TO_CHECKPOINTS[model_name]
        initial_prompt = INITIAL_PROMPT[args.dataset[0]]
        initial_prompt_len = len(tokenizer.tokenize(initial_prompt)) if ICL_token_ids is None else len(ICL_token_ids)
        peft_config = OurPromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=initial_prompt_len,
            prompt_tuning_init_text=initial_prompt,  # TODO: update this prompt
            tokenizer_name_or_path=tokenizer_name_or_path,
            ICL_token_ids=ICL_token_ids,
            ICL_mask=ICL_mask,
            update_tokens=args.update_tokens,
            ICL_loss_pattern=args.ICL_loss_pattern,
            random_losses_num=args.random_losses_num,
            ICL_projection_epsilon=args.epsilon,
            ICL_projection_epsilon_type=args.ICL_projection_epsilon_type,
            remove_train_example_from_icl=args.remove_train_example_from_icl,
            peft_weighted_loss_type=args.peft_weighted_loss_type,
            peft_weighted_loss_decay_factor=args.peft_weighted_loss_decay_factor,
            decay_projection=args.decay_projection,
            decay_projection_base=args.decay_projection_base,
            used_loss_tokens=args.used_loss_tokens,
            force_same_masks=args.force_same_masks,
            ICL_projection_format_epsilon=args.format_epsilon,
            ICL_projection_epsilon_multiplier=args.ICL_projection_epsilon_multiplier,
            ICL_projection_format_epsilon_multiplier=args.ICL_projection_format_epsilon_multiplier,
            number_of_learnable_tokens=args.number_of_learnable_tokens,
            peft_lr=args.peft_lr
        )
        model = our_get_peft_model(model, peft_config)
    elif args.method_boost_type in ['IPT']:
        tokenizer_name_or_path = NAMES_TO_CHECKPOINTS[model_name]
        initial_prompt = INITIAL_PROMPT[args.dataset[0]]
        initial_prompt_len = len(tokenizer.tokenize(initial_prompt)) if ICL_token_ids is None else len(ICL_token_ids)
        peft_config = OurPromptTuningConfig(
            task_type=TaskType.CAUSAL_LM,
            prompt_tuning_init=PromptTuningInit.TEXT,
            num_virtual_tokens=initial_prompt_len,
            prompt_tuning_init_text=initial_prompt,  # TODO: update this prompt
            tokenizer_name_or_path=tokenizer_name_or_path,
            ICL_token_ids=ICL_token_ids,
            ICL_mask=ICL_mask,
            update_tokens=args.update_tokens,
            ICL_loss_pattern=args.ICL_loss_pattern,
            random_losses_num=args.random_losses_num,
            ICL_projection_epsilon=args.epsilon,
            ICL_projection_epsilon_type=args.ICL_projection_epsilon_type,
            remove_train_example_from_icl=args.remove_train_example_from_icl,
            peft_weighted_loss_type=args.peft_weighted_loss_type,
            peft_weighted_loss_decay_factor=args.peft_weighted_loss_decay_factor,
            number_of_learnable_tokens=args.number_of_learnable_tokens,
            ICL_projection_epsilon_multiplier=1.0,
            ICL_projection_format_epsilon=1.0,
            ICL_projection_format_epsilon_multiplier=1.0,
            decay_projection=1.0,
            decay_projection_base=1.0,
            force_same_masks=0,
            used_loss_tokens=0,
            peft_lr=1.0
        )
        model = IPT_get_peft_model(model, peft_config)
        if args.use_random_instruction == 1:
            if ICL_mask is not None:
                inst_len = args.num_of_learnable_format
                new_weights = torch.zeros_like((model.prompt_encoder.default.embedding.weight)).to(
                    model.prompt_encoder.default.embedding.weight.device)
                new_weights[inst_len:] = model.prompt_encoder.default.embedding.weight[inst_len:]
                new_weights[:inst_len] = torch.nn.init.normal_(new_weights[:inst_len], mean=0, std=0.02).to(
                    model.prompt_encoder.default.embedding.weight.device)

                model.prompt_encoder.default.embedding.weight = torch.nn.Parameter(new_weights)

    if args.method_boost_type in ['prefix_tuning', 'LORA', 'prompt_tuning', 'our_prompt_tuning']:
        num_of_params = 0
        for parameter in model.parameters():
            if parameter.requires_grad:
                num_of_params += len(parameter.view(-1, ))
        print('{}M params'.format(num_of_params / 1e6))

    generator = Generator(model=model, tokenizer=tokenizer)

    return generator


def pgd_without_projection(args, delta, masked_output, input_mask, masked_input_ids, batch_size,
                           num_of_loss_tokens_edits, vocab_size, seq_len):
    alpha = args.alpha
    loss_fct = CrossEntropyLoss()
    loss = loss_fct(
        masked_output.view(batch_size * num_of_loss_tokens_edits, vocab_size),
        masked_input_ids.view(batch_size * num_of_loss_tokens_edits)
    )
    print('pgd_without_projection loss: {}'.format(loss.item()))
    delta.grad = torch.autograd.grad(loss, delta)[0]
    delta.data[input_mask] += -((alpha * delta.grad.detach()) / (
        delta.grad.detach().view(batch_size * seq_len, -1).norm(dim=1).view(batch_size, seq_len,
                                                                            1)))[
        input_mask]

    return delta


def pgd_with_projection(args, orig_model, input_ids, inputs_embeds, clipped_attention_mask, input_mask,
                        loss_mask,
                        vocab_size, first_delta=None, alpha=None, epsilon=None, display_enhancment_loss=True):
    if alpha is None:
        alpha = args.alpha
    if epsilon is None:
        epsilon = args.epsilon

    normalization_epsilon = args.normalization_epsilon
    loss_fct = CrossEntropyLoss()
    device = inputs_embeds.device
    batch_size, seq_len, hidden_size = inputs_embeds.shape
    num_of_loss_tokens_edits = loss_mask[0].sum()

    delta = torch.zeros((batch_size, seq_len, hidden_size), device=device,
                        requires_grad=True)
    if first_delta is None:
        first_delta = torch.zeros((batch_size, seq_len, hidden_size), device=device)

    boost_inputs_embeds = inputs_embeds + delta + first_delta.expand(batch_size, -1, -1)
    output = orig_model(inputs_embeds=boost_inputs_embeds, attention_mask=clipped_attention_mask)

    shift_logits = output.logits[..., :-1, :].contiguous()
    masked_output = shift_logits[loss_mask[:, 1:]].view(batch_size, -1, vocab_size)
    masked_input_ids = input_ids[loss_mask].view(batch_size, -1)

    loss = loss_fct(
        masked_output.view(batch_size * num_of_loss_tokens_edits, vocab_size),
        masked_input_ids.view(batch_size * num_of_loss_tokens_edits)
    )
    if display_enhancment_loss:
        print('pgd_with_projection loss: {}'.format(loss.item()))
    delta.grad = torch.autograd.grad(loss, delta)[0]
    delta.data[input_mask] += -((alpha * delta.grad.detach()) / (
        delta.grad.detach().view(batch_size * seq_len, -1).norm(dim=1).view(batch_size, seq_len,
                                                                            1)))[
        input_mask]

    delta.grad.zero_()

    delta_non_nan = torch.zeros((batch_size, seq_len, hidden_size), device=device)
    delta_non_nan[delta.isnan() == False] = delta[delta.isnan() == False]

    ret_delta = delta_non_nan + first_delta
    ret_delta[input_mask] *= (epsilon / ((ret_delta.detach().view(batch_size * seq_len, -1).norm(dim=1)).clamp(
        min=epsilon)).view(batch_size, seq_len, 1) + normalization_epsilon)[input_mask]

    return ret_delta


def iterative_pgd_with_projection(args, orig_model, input_ids, inputs_embeds, clipped_attention_mask, input_mask,
                                  loss_mask,
                                  vocab_size, first_delta=None):
    num_of_loss_tokens_edits = loss_mask[0].sum()
    batch_size, seq_len, hidden_size = inputs_embeds.shape
    alpha = args.alpha
    epsilon = args.epsilon
    normalization_epsilon = args.normalization_epsilon
    loss_fct = CrossEntropyLoss()
    device = inputs_embeds.device
    if first_delta is None:
        first_delta = torch.zeros((1, seq_len, hidden_size), device=device)
    last_delta = torch.zeros((1, seq_len, hidden_size), device=device)
    iterative_batch_size = 1
    num_of_updates = torch.zeros((seq_len,)).int().to(device)
    sum_loss = 0

    for i in range(batch_size):
        inputs_embeds_i = inputs_embeds[i].unsqueeze(0)
        delta = torch.zeros((iterative_batch_size, seq_len, hidden_size), device=device,
                            requires_grad=True)

        boost_inputs_embeds_i = inputs_embeds_i + delta + last_delta + first_delta

        clipped_attention_mask_i = clipped_attention_mask[i].unsqueeze(0)

        output = orig_model(inputs_embeds=boost_inputs_embeds_i, attention_mask=clipped_attention_mask_i)

        shift_logits = output.logits[..., :-1, :].contiguous()
        loss_mask_i = loss_mask[i].unsqueeze(0)
        input_mask_i = input_mask[i].unsqueeze(0)
        input_ids_i = input_ids[i].unsqueeze(0)
        masked_output = shift_logits[loss_mask_i[:, 1:]].view(iterative_batch_size, -1, vocab_size)
        masked_input_ids = input_ids_i[loss_mask_i].view(iterative_batch_size, -1)
        loss = loss_fct(
            masked_output.view(iterative_batch_size * num_of_loss_tokens_edits, vocab_size),
            masked_input_ids.view(iterative_batch_size * num_of_loss_tokens_edits)
        )
        sum_loss += loss.item()

        delta.grad = torch.autograd.grad(loss, delta)[0]
        delta.data[input_mask_i] += -((alpha * delta.grad.detach()) / (
            delta.grad.detach().view(iterative_batch_size * seq_len, -1).norm(dim=1).view(iterative_batch_size, seq_len,
                                                                                          1)))[
            input_mask_i]
        delta.data[input_mask_i] *= \
            (epsilon / ((delta.grad.detach().view(iterative_batch_size * seq_len, -1).norm(dim=1)).clamp(
                min=epsilon)).unsqueeze(0).view(iterative_batch_size, seq_len, 1) + normalization_epsilon)[input_mask_i]

        delta.grad.zero_()
        delta_non_nan = torch.zeros((1, seq_len, hidden_size), device=device)
        delta_non_nan[delta.isnan() == False] = delta[delta.isnan() == False]
        last_delta = last_delta * num_of_updates.view(1, -1, 1) + delta_non_nan
        num_of_updates += (delta.abs().sum(-1) > 0).int().sum(dim=0)
        num_of_updates_for_div = num_of_updates.clone()
        num_of_updates_for_div[num_of_updates_for_div == 0] = 1
        last_delta = last_delta / num_of_updates_for_div.view(1, -1, 1)

    print('iterative_pgd_with_projection loss: {}'.format(sum_loss / batch_size))
    ret_delta = last_delta + first_delta
    ret_input_mask = input_mask.sum(dim=0, keepdim=True).bool()
    ret_delta[ret_input_mask] *= \
        (epsilon / ((ret_delta.detach().view(iterative_batch_size * seq_len, -1).norm(dim=1)).clamp(
            min=epsilon)).unsqueeze(0).expand(
            iterative_batch_size, seq_len).view(iterative_batch_size, seq_len, 1) + normalization_epsilon)[
            ret_input_mask]

    return ret_delta


def pgd_with_regularization(args, delta, masked_output, input_mask, masked_input_ids, batch_size,
                            num_of_loss_tokens_edits, vocab_size, seq_len):
    alpha = args.alpha
    normalization_epsilon = args.normalization_epsilon
    loss_fct = CrossEntropyLoss()
    loss1 = loss_fct(
        masked_output.view(batch_size * num_of_loss_tokens_edits, vocab_size),
        masked_input_ids.view(batch_size * num_of_loss_tokens_edits)
    )
    loss2 = args.attack_regularization_value * delta.norm(dim=2).mean()

    loss = loss1 + loss2

    print('pgd_with_regularization loss: {}, loss1: {}, loss2: {}'.format(loss.item(), loss1.item(), loss2.item()))
    delta.grad = torch.autograd.grad(loss, delta)[0]
    delta.data[input_mask] += -((alpha * delta.grad.detach()) / (
            delta.grad.detach().view(batch_size * seq_len, -1).norm(dim=1).view(batch_size, seq_len,
                                                                                1) + normalization_epsilon))[
        input_mask]

    return delta


class EncodeDecodeICL:
    def __init__(self, args):
        self.args = args

    def updateInputMask(self, ICL_mask_examples, input_target_id, output_input_mask=None):
        INPUT_MASK_OFFSET = -1
        OUTPUT_MASK_OFFSET = 1
        OUTPUT_OFFSET = 2

        if self.args.update_tokens in ['input']:
            current_id_mask = (ICL_mask_examples == input_target_id).unsqueeze(0)
        elif self.args.update_tokens in ['input_mask']:
            current_id_mask = (ICL_mask_examples == input_target_id + INPUT_MASK_OFFSET).unsqueeze(0)
        elif self.args.update_tokens in ['output_mask']:
            current_id_mask = (ICL_mask_examples == input_target_id + OUTPUT_MASK_OFFSET).unsqueeze(0)
        elif self.args.update_tokens in ['all_masks']:
            current_id_mask = (ICL_mask_examples == input_target_id + INPUT_MASK_OFFSET).unsqueeze(0) | (
                    ICL_mask_examples == input_target_id + OUTPUT_MASK_OFFSET).unsqueeze(0)
        elif self.args.update_tokens in ['input_and_all_masks']:
            current_id_mask = (ICL_mask_examples == input_target_id).unsqueeze(0) | (
                    ICL_mask_examples == input_target_id + INPUT_MASK_OFFSET).unsqueeze(0) | (
                                      ICL_mask_examples == input_target_id + OUTPUT_MASK_OFFSET).unsqueeze(0)

        else:
            raise Exception('Unknown update tokens value')

        if output_input_mask is not None:
            output_input_mask = output_input_mask | current_id_mask
        else:
            output_input_mask = current_id_mask
        return output_input_mask

    def updateOutputMask(self, ICL_mask_examples, loss_target_id, output_loss_mask=None):
        current_id_mask = (ICL_mask_examples == loss_target_id).unsqueeze(0)
        if output_loss_mask is not None:
            output_loss_mask = output_loss_mask | current_id_mask
        else:
            output_loss_mask = current_id_mask

        return output_loss_mask


class CrossValidation(EncodeDecodeICL):
    # TODO need to change it to cross validation encoder decoder. we should use all till the last one for the loss and the last one for validation.
    def __init__(self, args):
        super(CrossValidation, self).__init__(args)
        self.ICL_initialization = False
        self.ICL_embeding = None

    def encode(
            self,
            to_word_embeding,
            input_ids,
            ICL_mask,
            list_input_target_id,
            list_loss_target_id,
            list_target_id_range,
            attention_mask
    ):
        assert ICL_mask.max() % 4 == 0  # correct format
        device = input_ids.device

        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        input_examples_orig = input_ids[0, :last_example_idx].clone()
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()
        inputs_examples_embeds = to_word_embeding(input_examples_orig)
        clipped_attention_mask = attention_mask[0, :last_example_idx].clone().unsqueeze(0)

        output_input_mask = torch.Tensor().to(device=device).bool()
        output_loss_mask = torch.Tensor().to(device=device).bool()
        output_input_examples = torch.Tensor().to(device=device).long()
        output_label_examples = torch.Tensor().to(device=device).long()
        output_attention_mask = torch.Tensor().to(device=device).long()

        for i in range(len(list_input_target_id)):
            begin_of_range = list_target_id_range[i][0]
            end_of_range = list_target_id_range[i][1]

            _input = input_examples_orig.clone()

            cur_mask = (ICL_mask_examples >= begin_of_range) & (ICL_mask_examples <= end_of_range)
            cur_example = torch.cat((_input[cur_mask == False], _input[cur_mask]), dim=0).unsqueeze(0)
            output_input_examples = torch.cat((output_input_examples, cur_example), dim=0)

            cur_ICL_mask = torch.cat((ICL_mask_examples[cur_mask == False], ICL_mask_examples[cur_mask]),
                                     dim=0)
            cur_input_mask, cur_loss_mask = None, None
            for id_idx in range(len(list_input_target_id)):
                if list_input_target_id[id_idx] < begin_of_range or list_input_target_id[id_idx] > end_of_range:
                    cur_input_mask = self.updateInputMask(cur_ICL_mask, list_input_target_id[id_idx], cur_input_mask)
                    cur_loss_mask = self.updateOutputMask(cur_ICL_mask, list_loss_target_id[id_idx], cur_loss_mask)

            output_input_mask = torch.cat((output_input_mask, cur_input_mask), dim=0)
            output_loss_mask = torch.cat((output_loss_mask, cur_loss_mask))

            # set ICL label mask
            cur_label = self.updateOutputMask(cur_ICL_mask, list_loss_target_id[i])
            output_label_examples = torch.cat((output_label_examples, cur_label))
            output_attention_mask = torch.cat((output_attention_mask, clipped_attention_mask), dim=0)

        embde_dim = (
            ICL_mask_examples.max().int().item() // 4, inputs_examples_embeds.shape[0],
            inputs_examples_embeds.shape[1]
        )

        return inputs_examples_embeds.expand(
            embde_dim), output_input_examples, output_input_mask, output_loss_mask, output_label_examples, output_attention_mask

    def decode(self, inputs_embeds, inputs_embeds_delta, input_mask,
               ICL_mask, list_input_target_id, list_loss_target_id, output_label_examples):
        enhanced_inputs_embeds = inputs_embeds.clone()
        enhanced_inputs_embeds += inputs_embeds_delta

        return enhanced_inputs_embeds


class OneToAll:
    def __init__(self, args):
        super(AllToAll, self).__init__(args)

    def encode(
            self,
            to_word_embeding,
            input_ids,
            ICL_mask,
            list_input_target_id,
            list_loss_target_id,
            list_target_id_range,
            attention_mask
    ):
        assert ICL_mask.max() % 4 == 0  # correct format
        # reduce the batch size to 1, extract the examples only
        device = input_ids.device

        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        input_examples_orig = input_ids[0, :last_example_idx].clone()
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()
        clipped_attention_mask = attention_mask[0:, :last_example_idx].clone()

        inputs_examples_embeds = to_word_embeding(input_examples_orig)
        # masked_indices = torch.nonzero(ICL_mask, as_tuple=True)

        output_input_mask = torch.Tensor().to(device=device).bool()
        output_loss_mask = torch.Tensor().to(device=device).bool()
        output_input_examples = torch.Tensor().to(device=device).long()

        cur_loss_mask = (ICL_mask_examples == list_loss_target_id[0]).unsqueeze(0)
        output_loss_mask = torch.cat((output_loss_mask, cur_loss_mask))

        output_input_examples = torch.cat((output_input_examples, input_examples_orig.clone().unsqueeze(0)), dim=0)
        output_input_mask = (ICL_mask_examples == list_input_target_id[0]).unsqueeze(0)

        for i in range(1, len(list_input_target_id)):
            output_input_mask = self.updateInputMask(ICL_mask_examples, list_input_target_id[i], output_input_mask)
            output_loss_mask = self.updateOutputMask(ICL_mask_examples, list_loss_target_id[i], output_loss_mask)

        embde_dim = (1, inputs_examples_embeds.shape[0], inputs_examples_embeds.shape[1])

        return inputs_examples_embeds.expand(
            embde_dim), output_input_examples, output_input_mask, output_loss_mask, clipped_attention_mask

    def decode(self, inputs_embeds, inputs_embeds_delta, input_mask,
               ICL_mask, list_input_target_id, list_loss_target_id):
        enhanced_inputs_embeds = inputs_embeds[0].clone()
        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()

        for i in range(len(list_input_target_id)):
            mask_example = self.updateInputMask(ICL_mask_examples, list_input_target_id[i])
            mask_example = mask_example.squeeze()
            enhanced_inputs_embeds[mask_example, :] += inputs_embeds_delta[0, mask_example]

        assert torch.any(enhanced_inputs_embeds != inputs_embeds)

        return enhanced_inputs_embeds


class AllToAll(EncodeDecodeICL):
    def __init__(self, args):
        super(AllToAll, self).__init__(args)

    def encode(
            self,
            to_word_embeding,
            input_ids,
            ICL_mask,
            list_input_target_id,
            list_loss_target_id,
            list_target_id_range,
            attention_mask
    ):
        assert ICL_mask.max() % 4 == 0  # correct format
        # reduce the batch size to 1, extract the examples only
        device = input_ids.device

        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        input_examples_orig = input_ids[0, :last_example_idx].clone()
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()
        clipped_attention_mask = attention_mask[:1, :last_example_idx].clone()

        inputs_examples_embeds = to_word_embeding(input_examples_orig)
        # masked_indices = torch.nonzero(ICL_mask, as_tuple=True)

        output_input_mask = torch.Tensor().to(device=device).bool()
        output_loss_mask = torch.Tensor().to(device=device).bool()
        output_input_examples = torch.Tensor().to(device=device).long()

        cur_loss_mask = (ICL_mask_examples == list_loss_target_id[0]).unsqueeze(0)
        output_loss_mask = torch.cat((output_loss_mask, cur_loss_mask))

        output_input_examples = torch.cat((output_input_examples, input_examples_orig.clone().unsqueeze(0)), dim=0)
        output_input_mask = (ICL_mask_examples == list_input_target_id[0]).unsqueeze(0)

        for i in range(1, len(list_input_target_id)):
            output_input_mask = self.updateInputMask(ICL_mask_examples, list_input_target_id[i], output_input_mask)
            output_loss_mask = self.updateOutputMask(ICL_mask_examples, list_loss_target_id[i], output_loss_mask)

        embde_dim = (1, inputs_examples_embeds.shape[0], inputs_examples_embeds.shape[1])

        return inputs_examples_embeds.expand(
            embde_dim), output_input_examples, output_input_mask, output_loss_mask, clipped_attention_mask

    def decode(self, inputs_embeds, inputs_embeds_delta, input_mask,
               ICL_mask, list_input_target_id, list_loss_target_id):
        enhanced_inputs_embeds = inputs_embeds[0].clone()
        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()

        for i in range(len(list_input_target_id)):
            mask_example = self.updateInputMask(ICL_mask_examples, list_input_target_id[i])
            mask_example = mask_example.squeeze()
            enhanced_inputs_embeds[mask_example, :] += inputs_embeds_delta[0, mask_example]

        assert torch.any(enhanced_inputs_embeds != inputs_embeds)

        return enhanced_inputs_embeds


class OneToAllIter(EncodeDecodeICL):
    def __init__(self, args):
        super(OneToAllIter, self).__init__(args)

    def encode(
            self,
            to_word_embeding,
            input_ids,
            ICL_mask,
            list_input_target_id,
            list_loss_target_id,
            list_target_id_range,
            attention_mask
    ):
        assert ICL_mask.max() % 4 == 0  # correct format
        # reduce the batch size to 1, extract the examples only
        device = input_ids.device

        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        input_examples_orig = input_ids[0, :last_example_idx].clone()
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()
        clipped_attention_mask = attention_mask[:1, :last_example_idx].clone()

        inputs_examples_embeds = to_word_embeding(input_examples_orig)
        # masked_indices = torch.nonzero(ICL_mask, as_tuple=True)

        output_input_mask = torch.Tensor().to(device=device).bool()
        output_loss_mask = torch.Tensor().to(device=device).bool()
        output_input_examples = torch.Tensor().to(device=device).long()

        for i in range(len(list_input_target_id)):
            output_input_examples = torch.cat((output_input_examples, input_examples_orig.clone().unsqueeze(0)))

            cur_output_input_mask = None
            for j in range(len(list_input_target_id)):
                if j != i:
                    cur_output_input_mask = self.updateInputMask(ICL_mask_examples, list_input_target_id[j],
                                                                 cur_output_input_mask)
            output_input_mask = torch.cat((output_input_mask, cur_output_input_mask), dim=0)

            cur_output_input_mask = self.updateOutputMask(ICL_mask_examples, list_loss_target_id[i])
            output_loss_mask = torch.cat((output_loss_mask, cur_output_input_mask), dim=0)

        embde_dim = (output_input_examples.shape[0], inputs_examples_embeds.shape[0], inputs_examples_embeds.shape[1])

        return inputs_examples_embeds.expand(
            embde_dim), output_input_examples, output_input_mask, output_loss_mask, clipped_attention_mask.expand((
            output_input_examples.shape[0], -1))

    def decode(self, inputs_embeds, inputs_embeds_delta, input_mask,
               ICL_mask, list_input_target_id, list_loss_target_id):
        enhanced_inputs_embeds = inputs_embeds[0].clone()
        last_example_idx = torch.argmax(ICL_mask[0]) + 1
        ICL_mask_examples = ICL_mask[0, :last_example_idx].clone()

        for i in range(len(list_input_target_id)):
            mask_example = self.updateInputMask(ICL_mask_examples, list_input_target_id[i])
            mask_example = mask_example.squeeze()
            enhanced_inputs_embeds[mask_example, :] += inputs_embeds_delta[0, mask_example]

        assert torch.any(enhanced_inputs_embeds != inputs_embeds)

        return enhanced_inputs_embeds


class OurModelWrapper(BloomPreTrainedModel):
    def __init__(self, orig_model, args):
        super(OurModelWrapper, self).__init__(orig_model.config)
        self.orig_model = orig_model
        self.args = args
        self.ICL_initialization = False
        self.ICL_embeding = None
        if self.args.enhancement_function in ['pgd_without_projection']:
            self.enhancement_function = partial(pgd_without_projection, args)
        elif self.args.enhancement_function in ['pgd_with_regularization']:
            self.enhancement_function = partial(pgd_with_regularization, args)
        elif self.args.enhancement_function in ['pgd_with_projection']:
            self.enhancement_function = partial(pgd_with_projection, args)
        elif self.args.enhancement_function in ['iterative_pgd_with_projection']:
            self.enhancement_function = partial(iterative_pgd_with_projection, args)
        else:
            raise Exception('Unknown enhancement_function')

        if self.args.token_update_strategy in ['all_to_all']:
            self.tokenUpdateStrategy = AllToAll(args)
        elif self.args.token_update_strategy in ['one_to_all_iter']:
            self.tokenUpdateStrategy = OneToAllIter(args)
        else:
            raise Exception('Unknown token update strategy')
        if 'GPT' in orig_model._get_name():
            self.word_embeddings = self.orig_model.transformer.wte
            self.vocab_size = self.orig_model.transformer.config.vocab_size
        elif 'OPT' in orig_model._get_name():
            self.word_embeddings = self.orig_model.base_model.decoder.embed_tokens
            self.vocab_size = self.orig_model.base_model.decoder.vocab_size
        else:
            self.word_embeddings = self.orig_model.transformer.word_embeddings
            self.vocab_size = self.orig_model.transformer.config.vocab_size

    def enhance_input_embeddings(self, inputs_embeds, input_ids, input_mask, loss_mask, clipped_attention_mask):
        num_of_iter = self.args.num_of_iter
        batch_size, seq_len, hidden_size = inputs_embeds.shape

        device = inputs_embeds.device

        if self.args.use_non_casual_mask:
            clipped_attention_mask = clipped_attention_mask.unsqueeze(2).repeat(1, 1, seq_len).unsqueeze(1).clone().to(
                device)
        delta = None
        for i in range(num_of_iter):
            delta = self.enhancement_function(self.orig_model, input_ids, inputs_embeds, clipped_attention_mask,
                                              input_mask, loss_mask,
                                              self.vocab_size, delta, self.args.alpha, self.args.epsilon,
                                              self.args.display_enhancment_loss)

        return delta.detach()

    def cross_validation(self, word_embeddings, input_ids, ICL_mask, list_input_target_id, list_loss_target_id,
                         list_target_id_range, attention_mask):
        batch_method_fn = CrossValidation(self.args)

        best_avg = 0
        best_params = {
            'alpha': 0,
            'epsilon': 0,
            'num_of_iter': 0,
        }
        list_cv_param_options = []
        for alpha_i in [0.01, 0.05, 0.1, 0.5, 1]:
            for epsilon_i in [10, 1, 0.1]:
                for num_of_iter_i in [20, 50, 100]:
                    list_cv_param_options.append((float(alpha_i), float(epsilon_i), int(num_of_iter_i)))
        iter_counter = 1
        for alpha_i, epsilon_i, num_of_iter_i in list_cv_param_options:
            self.args.num_of_iter = num_of_iter_i
            self.args.alpha = alpha_i
            self.args.epsilon = epsilon_i
            self.args.display_enhancment_loss = False

            inputs_examples_embeds, input_examples, input_mask, loss_mask, label_examples, clipped_attention_mask = batch_method_fn.encode(
                word_embeddings,
                input_ids,
                ICL_mask,
                list_input_target_id,
                list_loss_target_id,
                list_target_id_range,
                attention_mask
            )

            idx = 0
            list_batch_idx = []
            while idx < len(inputs_examples_embeds):
                list_batch_idx.append(idx)
                idx += self.args.enhancement_batch_size
            if len(inputs_examples_embeds) not in list_batch_idx:
                list_batch_idx.append(len(inputs_examples_embeds))

            list_start_batch_idx = []
            list_end_batch_idx = []
            for i in range(len(list_batch_idx) - 1):
                list_start_batch_idx.append(list_batch_idx[i])
                list_end_batch_idx.append(list_batch_idx[i + 1])

            inputs_embeds_delta = torch.Tensor().to(inputs_examples_embeds.device)
            for start_idx, end_idx in zip(list_start_batch_idx, list_end_batch_idx):
                inputs_embeds_delta_i = self.enhance_input_embeddings(
                    inputs_examples_embeds[start_idx:end_idx],
                    input_examples[start_idx:end_idx],
                    input_mask[start_idx:end_idx],
                    loss_mask[start_idx:end_idx],
                    clipped_attention_mask[start_idx:end_idx]
                )
                inputs_embeds_delta = torch.cat((inputs_embeds_delta, inputs_embeds_delta_i), dim=0)

            ICL_embeding = batch_method_fn.decode(
                inputs_examples_embeds,
                inputs_embeds_delta,
                input_mask,
                ICL_mask,
                list_input_target_id,
                list_loss_target_id,
                label_examples
            )
            # TODO:EMBED, where ICL_mask is not zero or first token, replace with ICL_embeding...
            output = self.orig_model(inputs_embeds=ICL_embeding,
                                     attention_mask=clipped_attention_mask)

            en_batch_size = ICL_embeding.shape[0]
            gt_label = input_ids[0, :ICL_embeding.shape[1]][label_examples[0].bool()].repeat(en_batch_size, 1)
            model_output_labels = output.logits[label_examples.bool()]
            model_output_pred = torch.nn.Softmax(dim=1)(model_output_labels)
            avg_pred = torch.mean(torch.gather(model_output_pred, 1, gt_label)).item()

            print('Iter {} / {} avg prob: {} - alpha={}, epsilon={}, num_of_iter={}'.format(iter_counter,
                                                                                            len(list_cv_param_options),
                                                                                            avg_pred,
                                                                                            alpha_i, epsilon_i,
                                                                                            num_of_iter_i))
            if avg_pred > best_avg:
                best_avg = avg_pred
                best_params['alpha'] = alpha_i
                best_params['epsilon'] = epsilon_i
                best_params['num_of_iter'] = num_of_iter_i
                print('\tBest params so far!')
            iter_counter += 1

        return best_params

    def forward(self, input_ids=None, attention_mask=None, ICL_mask=None):
        mask_ids_unique = torch.unique(ICL_mask)
        assert ((len(mask_ids_unique) - 1) // 4) == self.args.num_shots[0]

        if not self.ICL_initialization:
            list_input_target_id = []
            list_loss_target_id = []
            list_target_id_range = []
            template_size = (ICL_mask[0].max() // self.args.num_shots[0]).item()
            for i in range(self.args.num_shots[0]):
                list_input_target_id.append(i * template_size + 2)
                list_loss_target_id.append(i * template_size + 4)
                list_target_id_range.append((i * template_size + 1, (i + 1) * template_size))

            if self.args.hp_cross_validation:
                new_args = self.cross_validation(
                    self.word_embeddings,
                    input_ids,
                    ICL_mask,
                    list_input_target_id,
                    list_loss_target_id,
                    list_target_id_range,
                    attention_mask
                )
                self.args.alpha = new_args['alpha']
                self.args.epsilon = new_args['epsilon']
                self.args.num_of_iter = new_args['num_of_iter']
                self.args.display_enhancment_loss = True
                print('*' * 50)
                print('alpha: {}, epsilon: {}, num_of_iter: {}'.format(new_args['alpha'], new_args['epsilon'],
                                                                       new_args['num_of_iter']))
                print('*' * 50)

            inputs_examples_embeds, input_examples, input_mask, loss_mask, clipped_attention_mask = self.tokenUpdateStrategy.encode(
                self.word_embeddings,
                input_ids,
                ICL_mask,
                list_input_target_id,
                list_loss_target_id,
                list_target_id_range,
                attention_mask
            )

            idx = 0
            list_batch_idx = []
            while idx < len(inputs_examples_embeds):
                list_batch_idx.append(idx)
                idx += self.args.enhancement_batch_size
            if len(inputs_examples_embeds) not in list_batch_idx:
                list_batch_idx.append(len(inputs_examples_embeds))

            # break batch into smaller sub batches
            list_start_batch_idx = []
            list_end_batch_idx = []
            for i in range(len(list_batch_idx) - 1):
                list_start_batch_idx.append(list_batch_idx[i])
                list_end_batch_idx.append(list_batch_idx[i + 1])

            inputs_embeds_delta = torch.Tensor().to(inputs_examples_embeds.device)
            for start_idx, end_idx in zip(list_start_batch_idx, list_end_batch_idx):
                inputs_embeds_delta_i = self.enhance_input_embeddings(
                    inputs_examples_embeds[start_idx:end_idx],
                    input_examples[start_idx:end_idx],
                    input_mask[start_idx:end_idx],
                    loss_mask[start_idx:end_idx],
                    clipped_attention_mask[start_idx:end_idx]
                )
                inputs_embeds_delta = torch.cat((inputs_embeds_delta, inputs_embeds_delta_i), dim=0)

            self.ICL_embeding = self.tokenUpdateStrategy.decode(
                inputs_examples_embeds,
                inputs_embeds_delta,
                input_mask,
                ICL_mask,
                list_input_target_id,
                list_loss_target_id
            )

            self.ICL_initialization = True

        input_embedings = self.word_embeddings(input_ids)
        input_embedings[:, :len(self.ICL_embeding), :] = self.ICL_embeding

        # TODO:EMBED, where ICL_mask is not zero or first token, replace with ICL_embeding...
        output = self.orig_model(inputs_embeds=input_embedings, attention_mask=attention_mask)
        return output

    def __device__(self):
        return self.model.device

    def __repr__(self):
        return self.model.__repr__()

    def __str__(self):
        return self.model.__str__()


class OursPeftModelForCausalLM(PeftModelForCausalLM):
    # def __init__(self, config, word_embeddings):
    #     super().__init__(config, word_embeddings)
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        if isinstance(self.peft_config['default'], OurPromptTuningConfig) and self.peft_config[
            'default'].ICL_token_ids is not None:
            device = model.device
            if self.peft_config['default'].number_of_learnable_tokens > 0:
                self.peft_config['default'].ICL_token_ids = [self.peft_config['default'].ICL_token_ids[0]] * \
                                                            self.peft_config['default'].number_of_learnable_tokens + \
                                                            self.peft_config['default'].ICL_token_ids
                self.peft_config['default'].ICL_mask = [-1] * self.peft_config['default'].number_of_learnable_tokens + \
                                                       self.peft_config['default'].ICL_mask
                self.active_peft_config.num_virtual_tokens = len(self.peft_config['default'].ICL_mask)

            embedding = torch.nn.Embedding(
                len(self.peft_config['default'].ICL_token_ids),
                self.peft_config['default'].token_dim)

            extended_ICL_token_ids = self.peft_config['default'].ICL_token_ids
            word_embedding_weights = self.word_embeddings(
                torch.Tensor(extended_ICL_token_ids).long().to(device)
            ).detach().clone()
            word_embedding_weights = word_embedding_weights.to(torch.float32)
            embedding.weight = torch.nn.Parameter(word_embedding_weights)
            self.prompt_encoder.default.embedding = embedding
            self.set_update_embeds()
            self.set_adaptive_projection()

    def set_adaptive_projection(self):
        total_epsilon = self.peft_config['default'].ICL_projection_epsilon * torch.sqrt(
            torch.Tensor([self.peft_config['default'].encoder_hidden_size / 2048])
        ).item()

        total_format_epsilon = self.peft_config['default'].ICL_projection_format_epsilon * torch.sqrt(
            torch.Tensor([self.peft_config['default'].encoder_hidden_size / 2048])
        ).item()

        tensor_ICL_mask = torch.Tensor(self.peft_config['default'].ICL_mask).long().clone()

        sample_total_epsilon = torch.ones_like(tensor_ICL_mask).to(
            torch.float32) * 1e-10

        ## set epsilon in all enteries
        mask_idx = (tensor_ICL_mask > 0) & (
                (torch.remainder(tensor_ICL_mask, 4) == 1) | (
                torch.remainder(tensor_ICL_mask, 4) == 3))
        sample_total_epsilon[mask_idx] = total_format_epsilon

        seperator_idx = tensor_ICL_mask == 0
        sample_total_epsilon[seperator_idx] = total_format_epsilon

        instruction_idx = tensor_ICL_mask == -1
        sample_total_epsilon[instruction_idx] = 999

        input_idx = (tensor_ICL_mask > 0) & (
            (torch.remainder(tensor_ICL_mask, 4) == 2))
        sample_total_epsilon[input_idx] = total_epsilon

        ## set decay
        token_exp = torch.flip(torch.cumsum(torch.ones_like(sample_total_epsilon).long(), dim=0), dims=[0])
        projection_growth = self.peft_config['default'].ICL_projection_epsilon_multiplier ** token_exp
        sample_total_epsilon *= projection_growth
        self.sample_total_epsilon = sample_total_epsilon
        print('#' * 20)
        print('not limited: ', (self.sample_total_epsilon > 1).sum(), '\t limited: ',
              (self.sample_total_epsilon <= 1).sum())
        print('#' * 20)

    def set_update_embeds(self):
        update_tokens = self.peft_config['default'].update_tokens
        ICL_mask = self.peft_config['default'].ICL_mask
        mask = torch.zeros(len(ICL_mask), 1).long()
        tensor_ICL_mask = torch.Tensor(ICL_mask).long()
        if update_tokens in ['input_and_all_masks']:
            mask_idx = tensor_ICL_mask == -1
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 1)
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 2)
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 3)
            mask[mask_idx] = 1
        elif update_tokens in ['all']:
            mask_idx = torch.ones_like(tensor_ICL_mask).bool()
            mask[mask_idx] = 1
        elif update_tokens in ['all_non_zero']:
            mask_idx = tensor_ICL_mask == -1
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 1)
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 2)
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 3)
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 4)
            mask[mask_idx] = 1
        elif update_tokens in ['all_masks']:
            mask_idx = tensor_ICL_mask == -1
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 1)
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 3)
            mask[mask_idx] = 1
        elif update_tokens in ['input']:
            mask_idx = tensor_ICL_mask == -1
            mask_idx = mask_idx | (torch.remainder(tensor_ICL_mask, 4) == 2)
            mask[mask_idx] = 1
        else:
            raise Exception('Unknown update_pattern')

        def backward_hook(grad):
            device = grad.device
            grad = grad * mask.to(device)
            return grad

        self.prompt_encoder.default.embedding.weight.register_hook(backward_hook)

    def forward(
            self,
            input_ids=None,
            attention_mask=None,
            inputs_embeds=None,
            labels=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            task_ids=None,
            sample_mask=None,
            **kwargs,
    ):
        if isinstance(self.peft_config['default'], OurPromptTuningConfig) and self.peft_config[
            'default'].ICL_token_ids is not None and self.training:
            self.ICL_projection()
        peft_config = self.active_peft_config
        if not peft_config.is_prompt_learning:
            if self.base_model.config.model_type == "mpt":
                if inputs_embeds is not None:
                    raise AssertionError("forward in MPTForCausalLM does not support inputs_embeds")
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

            if peft_config.peft_type == PeftType.POLY:
                kwargs["task_ids"] = task_ids

            with self._enable_peft_forward_hooks(**kwargs):
                kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
                return self.base_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    inputs_embeds=inputs_embeds,
                    labels=labels,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    **kwargs,
                )

        batch_size = _get_batch_size(input_ids, inputs_embeds)

        list_begin_of_example_in_ICL, list_sample_order_in_ICL = [], []
        if attention_mask is not None:
            if self.peft_config['default'].remove_train_example_from_icl == 0 or self.train == False:
                # concat prompt attention mask
                prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device)
                attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)
            else:
                prefix_attention_mask, list_begin_of_example_in_ICL, list_sample_order_in_ICL = self.get_icl_attention_mask(
                    input_ids)
                attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

        if kwargs.get("position_ids", None) is not None:
            warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.")
            kwargs["position_ids"] = None
        if kwargs.get("token_type_ids", None) is not None:
            warnings.warn("Token type ids are not supported for parameter efficient tuning. Ignoring token type ids")
            kwargs["token_type_ids"] = None
        kwargs.update(
            {
                "attention_mask": attention_mask,
                "labels": labels,
                "output_attentions": output_attentions,
                "output_hidden_states": output_hidden_states,
                "return_dict": return_dict,
            }
        )

        if peft_config.peft_type == PeftType.PREFIX_TUNING:
            past_key_values = self.get_prompt(batch_size)
            return self.base_model(
                input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
            )
        else:
            if isinstance(self.peft_config['default'], OurPromptTuningConfig) == False and self.peft_config[
                'default'].ICL_token_ids is not None:
                if inputs_embeds is None:
                    inputs_embeds = self.word_embeddings(input_ids)
                # concat prompt labels
                if labels is not None:
                    prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
                    kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
                prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
                prompts = prompts.to(inputs_embeds.dtype)
                inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
                return self.base_model(inputs_embeds=inputs_embeds, **kwargs)
            else:
                device = input_ids.device
                ICL_prefix = self.prompt_encoder.default.embedding(
                    torch.range(0, self.peft_config['default'].num_virtual_tokens - 1).expand(batch_size, -1).long().to(
                        device)
                )
                train_embeds = self.word_embeddings(input_ids)
                if self.peft_config['default'].force_same_masks == 1:
                    assert self.peft_config['default'].ICL_mask is not None
                    assert train_embeds.shape[0] == 1
                    mean_input_mask = ICL_prefix[0][torch.tensor(self.peft_config['default'].ICL_mask).long() == 1]
                    mean_output_mask = ICL_prefix[0][torch.tensor(self.peft_config['default'].ICL_mask).long() == 3]
                    ## update mean
                    train_embeds[sample_mask == 1] = mean_input_mask.to(train_embeds.dtype)
                    train_embeds[sample_mask == 3] = mean_output_mask.to(train_embeds.dtype)

                inputs_embeds = torch.cat((ICL_prefix, train_embeds), dim=1)
                extended_input_mask_labels, extended_input_mask = self.get_extended_labels(input_ids, labels,
                                                                                           sample_mask,
                                                                                           list_begin_of_example_in_ICL,
                                                                                           list_sample_order_in_ICL)

                kwargs['labels'] = extended_input_mask_labels
                if (labels is not None) and (self.peft_config['default'].peft_weighted_loss_type not in ['none']):
                    kwargs['labels'] = None
                    model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs)
                    model_output['loss'] = self.get_costum_loss(model_output, extended_input_mask_labels,
                                                                extended_input_mask)
                else:
                    model_output = self.base_model(inputs_embeds=inputs_embeds, **kwargs)

                model_output['gt_labels'] = extended_input_mask_labels.clone()
                model_output['extended_input_mask'] = extended_input_mask.clone()

                return model_output

    def get_costum_loss(self, model_output, labels, extended_input_mask):

        lm_logits = model_output.logits
        device = lm_logits.device
        labels = labels.to(device)
        # Shift so that tokens < n predict n
        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        shift_extended_input_mask = extended_input_mask[..., 1:].contiguous()

        shift_labels_bool = (shift_labels.clone().detach() != -100).bool()
        batch_size, seq_length, vocab_size = shift_logits.shape
        # Flatten the tokens
        loss_fct = CrossEntropyLoss(reduction='none', ignore_index=-100)
        loss = loss_fct(
            shift_logits.view(batch_size * seq_length, vocab_size), shift_labels.view(batch_size * seq_length)
        )
        loss = loss.view(batch_size, seq_length)
        shift_labels_weights = shift_labels_bool.clone().detach().float()

        if 'equal' in self.peft_config['default'].peft_weighted_loss_type:
            for i in range(batch_size):
                # we should have N - 1 loss terms, we removed the test one
                test_idx = shift_extended_input_mask[i] == shift_extended_input_mask[i].max()
                all_labels_idx = (shift_extended_input_mask[i] > 0) & (shift_extended_input_mask[i] % 4 == 0)
                new_weight = (test_idx.sum()) / ((all_labels_idx) & (test_idx == False)).sum()
                shift_labels_weights[i][(shift_labels_bool[i]) & (all_labels_idx) & (test_idx == False)] = new_weight
                shift_labels_weights[i][test_idx] *= self.peft_config['default'].peft_weighted_loss_decay_factor

        elif 'decay' in self.peft_config['default'].peft_weighted_loss_type:
            for i in range(batch_size):
                # we should have N - 1 loss terms, we removed the test one
                labels_mask_ids = shift_extended_input_mask[i][
                    (shift_extended_input_mask[i] > 0) & (shift_extended_input_mask[i] % 4 == 0)].unique()
                exponential_decay = torch.ones_like(shift_extended_input_mask[i]).to(device=device).float()
                decay_value = 1
                for mask_num in torch.flip(labels_mask_ids, [0]):
                    exponential_decay[shift_extended_input_mask[i] == mask_num] = decay_value
                    decay_value *= self.peft_config['default'].peft_weighted_loss_decay_factor
                shift_labels_weights[i] *= exponential_decay
        loss = (loss[shift_labels_bool] * shift_labels_weights[shift_labels_bool]).mean()

        return loss

    def get_icl_attention_mask(self, input_ids):
        ICL_ids = self.peft_config['default'].ICL_token_ids
        ICL_mask = self.peft_config['default'].ICL_mask
        array_ICL_mask = np.array(ICL_mask)
        array_ICL_ids = np.array(ICL_ids)
        batch_size, seq_len = input_ids.size()

        N_examples = max(ICL_mask) // 4
        list_examples = []
        list_begin_end_examples = []
        for i in range(N_examples):
            ICL_start_idx = torch.where(torch.Tensor(array_ICL_mask >= i * 4 + 1))[0][0]
            ICL_end_idx = torch.where(torch.Tensor((array_ICL_mask <= i * 4 + 4) & (array_ICL_mask > 0)))[0][-1] + 1
            list_begin_end_examples.append((ICL_start_idx, ICL_end_idx))
            array_example = array_ICL_ids[ICL_start_idx:ICL_end_idx]
            list_examples.append(array_example.tolist())

        list_example_id = []
        for i in range(len(input_ids)):
            list_ids = input_ids[i].tolist()
            example_id = -1
            best_fit = -1
            for ICL_i in range(len(list_examples)):
                cur_ICL_exmaple = list_examples[ICL_i]
                num_of_idx = min([len(list_ids), len(cur_ICL_exmaple)])
                cur_score = np.sum([list_ids[i] == cur_ICL_exmaple[i] for i in range(num_of_idx)])
                if cur_score > best_fit:
                    example_id = ICL_i
                    best_fit = cur_score
            assert example_id > -1
            list_example_id.append(example_id)
        assert len(list_example_id) == len(input_ids)  # the same sample can occure twice in the same batch

        prefix_attention_mask = torch.ones((len(input_ids), len(array_ICL_mask))).long().to(input_ids.device)
        # if we are in this part of the code, we know that self.peft_config['default'].remove_train_example_from_icl == 1
        for i in range(batch_size):
            samples_num = list_example_id[i]
            begin_idx, end_idx = list_begin_end_examples[samples_num]
            prefix_attention_mask[i, begin_idx:end_idx] = 0

        return prefix_attention_mask, list_begin_end_examples, list_example_id

    def ICL_projection(self):
        with torch.no_grad():
            device = self.prompt_encoder.default.embedding.weight.device
            embeddings_weights = self.prompt_encoder.default.embedding.weight.clone()
            original_embedding_weights = self.word_embeddings(
                torch.Tensor(self.peft_config['default'].ICL_token_ids).long().to(device)
            ).detach().to(torch.float32).clone()
            new_embeddings_weights = self.prompt_encoder.default.embedding.weight.clone()

            if self.peft_config['default'].ICL_projection_epsilon_type in ['token_wise']:
                embedding_delta = embeddings_weights - original_embedding_weights
                token_norm = torch.norm(embedding_delta, p=2, dim=1)
                projection_mask = token_norm > 0
                if torch.any(projection_mask):
                    sample_total_epsilon = self.sample_total_epsilon.clone().detach().to(device=device)
                    embedding_delta[projection_mask] *= (
                            sample_total_epsilon[projection_mask] / (token_norm[projection_mask].clamp(
                        min=sample_total_epsilon[projection_mask]))).view(-1, 1)
                    new_embeddings_weights.data = original_embedding_weights.data + embedding_delta.data

            elif self.peft_config['default'].ICL_projection_epsilon_type in ['all_tokens']:
                embedding_delta = embeddings_weights - original_embedding_weights
                active_tokens = torch.norm(embedding_delta, p=2, dim=1)
                projection_mask = active_tokens > 0
                if torch.any(projection_mask):
                    token_norm = torch.norm(embedding_delta, p=2)
                    total_epsilon = self.peft_config['default'].ICL_projection_epsilon * torch.sqrt(
                        torch.sum(projection_mask) * self.peft_config['default'].encoder_hidden_size / 2048).item()

                    embedding_delta[projection_mask] *= (total_epsilon / (token_norm.clamp(
                        min=total_epsilon))).view(-1, 1)
                    new_embeddings_weights.data = original_embedding_weights.data + embedding_delta.data
            else:
                raise NotImplementedError

            if self.peft_config['default'].force_same_masks == 1:
                assert False  # TODO check it
                ICL_mask = torch.Tensor(self.peft_config['default'].ICL_mask).long()
                N_samples = (torch.max(ICL_mask) // 4).item()

                ## calc mean
                list_input_mask = []
                list_output_mask = []
                for i in range(N_samples):
                    list_input_mask.append(new_embeddings_weights[ICL_mask == 4 * i + 1])
                    list_output_mask.append(new_embeddings_weights[ICL_mask == 4 * i + 3])
                mean_input_mask = torch.stack(list_input_mask).mean(0)
                mean_output_mask = torch.stack(list_output_mask).mean(0)

                ## update mean
                for i in range(N_samples):
                    new_embeddings_weights[ICL_mask == 4 * i + 1] = mean_input_mask
                    new_embeddings_weights[ICL_mask == 4 * i + 3] = mean_output_mask

            self.prompt_encoder.default.embedding.weight.data = new_embeddings_weights.data

    def get_extended_labels(self, input_ids, labels, sample_mask, list_begin_of_example_in_ICL,
                            list_sample_order_in_ICL):
        device = input_ids.device
        batch_size = labels.shape[0]
        tensor_ICL_ids_batch = torch.Tensor(self.peft_config['default'].ICL_token_ids).view(1, -1).repeat(
            (batch_size, 1)).long().to(device=device)
        extended_labels = torch.cat((tensor_ICL_ids_batch, labels), dim=1)

        tensor_ICL_mask = torch.Tensor(self.peft_config['default'].ICL_mask).view(1, -1).repeat(
            (batch_size, 1)).long().to(device=device)
        shift_mask = tensor_ICL_mask.max()
        tensor_input_mask_shifted = sample_mask.clone().long().to(device=device)
        tensor_input_mask_shifted[tensor_input_mask_shifted != 0] += shift_mask

        extended_input_mask = torch.cat((tensor_ICL_mask, tensor_input_mask_shifted), dim=1)

        ## selct the loss pattern ##
        test_idx = extended_input_mask.max()
        if self.training == False:
            loss_mask = torch.zeros_like(extended_labels).bool().to(device=device)
        else:
            ICL_loss_pattern = self.peft_config['default'].ICL_loss_pattern
            if ICL_loss_pattern in ['all_labels']:
                if self.peft_config['default'].used_loss_tokens in ['answer']:
                    loss_mask = (extended_input_mask <= 0) | (torch.remainder(extended_input_mask, 4) != 0)
                elif self.peft_config['default'].used_loss_tokens in ['input_and_answer']:
                    loss_mask = (extended_input_mask <= 0) | ((torch.remainder(extended_input_mask, 4) != 0) & (
                            torch.remainder(extended_input_mask, 2) != 0))

            elif ICL_loss_pattern in ['last_label']:
                loss_mask = extended_input_mask != test_idx
            elif ICL_loss_pattern in ['last_and_random_label']:
                N_samples = extended_input_mask.max() // 4 - 1
                num = self.peft_config['default'].random_losses_num
                wanted_sample_loss = torch.randperm(N_samples)[:num]
                loss_mask = extended_input_mask != test_idx
                for i in wanted_sample_loss:
                    loss_mask = loss_mask & (extended_input_mask != (1 + i) * 4)
            else:
                raise Exception('ICL loss type unknown')

            # remove the sample that is used for the trainig from the loss calculation
            if len(list_begin_of_example_in_ICL) > 0:
                for i in range(batch_size):
                    sample_number = list_sample_order_in_ICL[i]
                    begin_idx, end_idx = list_begin_of_example_in_ICL[sample_number]
                    loss_mask[i, begin_idx:end_idx] = 1

        extended_labels[loss_mask] = -100

        return extended_labels, extended_input_mask


class IPTPeftModelForCausalLM(OursPeftModelForCausalLM):
    def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None:
        super().__init__(model, peft_config, adapter_name)
        if self.peft_config['default'].ICL_token_ids is not None:
            self.IPT_set_update_embeds()

    def IPT_set_update_embeds(self):
        mask = (torch.Tensor(self.peft_config['default'].ICL_mask) == -1).long().view(-1, 1)

        def backward_hook(grad):
            device = grad.device
            return (grad * mask.to(device))

        self.prompt_encoder.default.embedding.weight.register_hook(backward_hook)

    def ICL_projection(self):
        pass


def our_get_peft_model(
        model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
) -> PeftModel | PeftMixedModel:
    """
    Returns a Peft model object from a model and a config.

    Args:
        model ([`transformers.PreTrainedModel`]):
            Model to be wrapped.
        peft_config ([`PeftConfig`]):
            Configuration object containing the parameters of the Peft model.
        adapter_name (`str`, `optional`, defaults to `"default"`):
            The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
        mixed (`bool`, `optional`, defaults to `False`):
            Whether to allow mixing different (compatible) adapter types.
    """
    model_config = getattr(model, "config", {"model_type": "custom"})
    if hasattr(model_config, "to_dict"):
        model_config = model_config.to_dict()

    peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

    if mixed:
        return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

    if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
        return PeftModel(model, peft_config, adapter_name=adapter_name)

    if peft_config.is_prompt_learning:
        peft_config = _prepare_prompt_learning_config(peft_config, model_config)
    return OursPeftModelForCausalLM(model, peft_config, adapter_name=adapter_name)


def IPT_get_peft_model(
        model: PreTrainedModel, peft_config: PeftConfig, adapter_name: str = "default", mixed: bool = False
) -> PeftModel | PeftMixedModel:
    """
    Returns a Peft model object from a model and a config.

    Args:
        model ([`transformers.PreTrainedModel`]):
            Model to be wrapped.
        peft_config ([`PeftConfig`]):
            Configuration object containing the parameters of the Peft model.
        adapter_name (`str`, `optional`, defaults to `"default"`):
            The name of the adapter to be injected, if not provided, the default adapter name is used ("default").
        mixed (`bool`, `optional`, defaults to `False`):
            Whether to allow mixing different (compatible) adapter types.
    """
    model_config = getattr(model, "config", {"model_type": "custom"})
    if hasattr(model_config, "to_dict"):
        model_config = model_config.to_dict()

    peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None)

    if mixed:
        return PeftMixedModel(model, peft_config, adapter_name=adapter_name)

    if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys() and not peft_config.is_prompt_learning:
        return PeftModel(model, peft_config, adapter_name=adapter_name)

    if peft_config.is_prompt_learning:
        peft_config = _prepare_prompt_learning_config(peft_config, model_config)
    return IPTPeftModelForCausalLM(model, peft_config, adapter_name=adapter_name)


@dataclass
class OurPromptTuningConfig(PromptTuningConfig):
    ICL_token_ids: Optional[torch.TensorType] = field(default=None)
    ICL_mask: Optional[torch.TensorType] = field(default=None)
    update_tokens: Optional[str] = field(default=None)
    ICL_loss_pattern: Optional[bool] = field(default=None)
    random_losses_num: Optional[int] = field(default=1)
    ICL_projection_epsilon: Optional[float] = field(default=100.0)
    ICL_projection_epsilon_type: Optional[str] = field(default=None)
    remove_train_example_from_icl: Optional[int] = field(default=False)
    peft_weighted_loss_type: Optional[str] = field(default='none')
    peft_weighted_loss_decay_factor: Optional[float] = field(default=1.0)
    decay_projection: Optional[int] = field(default=0),
    decay_projection_base: Optional[float] = field(default=0.99),
    random_losses_num: Optional[int] = field(default=1),
    used_loss_tokens: Optional[str] = field(default='answer'),
    force_same_masks: Optional[int] = field(default=0),
    ICL_projection_format_epsilon: Optional[float] = field(default=1),
    ICL_projection_epsilon_multiplier: Optional[float] = field(default=1.0),
    ICL_projection_format_epsilon_multiplier: Optional[float] = field(default=1.0),
    number_of_learnable_tokens: Optional[int] = field(default=0),
    peft_lr: Optional[float] = field(default=1e-4),
