from typing import List, Union, Any, Dict, Mapping

import torch
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling
from transformers.data.data_collator import pad_without_fast_tokenizer_warning, _torch_collate_batch
import random


def train_peft_model(model, tokenizer, dataset, args):
    if args.method_boost_type in ['LORA']:
        learning_rate = args.peft_lr
        max_steps = args.peft_steps
    elif args.method_boost_type in ['prompt_tuning']:
        learning_rate = args.peft_lr
        max_steps = args.peft_steps
    elif args.method_boost_type in ['prefix_tuning']:
        learning_rate = args.peft_lr
        max_steps = args.peft_steps
    elif args.method_boost_type in ['our_prompt_tuning']:
        learning_rate = args.peft_lr
        max_steps = args.peft_steps
    elif args.method_boost_type in ['IPT']:
        learning_rate = args.peft_lr
        max_steps = args.peft_steps
    elif args.method_boost_type in ['none']:
        pass
    else:
        raise Exception('Unknown method boost type')

    if args.train_model:
        training_args = TrainingArguments(
            output_dir='../.',  # Where the model predictions and checkpoints will be written
            use_cpu=False,  # This is necessary for CPU clusters.
            auto_find_batch_size=False,  # Find a suitable batch size that will fit into memory automatically
            learning_rate=learning_rate,  # Higher learning rate than full Fine-Tuning
            max_steps=max_steps,
            logging_steps=2,
            per_device_train_batch_size=args.peft_batch_size,
            save_total_limit=1,
            remove_unused_columns=False,
            num_train_epochs=args.peft_epochs,
            fp16=True if args.fp16 == torch.float16 else False,
            save_strategy='no'
        )

        trainer = Trainer(
            model=model,  # We pass in the PEFT version of the foundation model, bloomz-560M
            args=training_args,  # The args for the training.
            train_dataset=dataset,  # The dataset used to tyrain the model.
            data_collator=OurDataCollatorForLanguageModeling(tokenizer, training=True, args=args, mlm=False)
        )

        trainer.train()

        if args.method_boost_type in ['our_prompt_tuning']:
            model.ICL_projection()

    return model


class OurDataCollatorForLanguageModeling(DataCollatorForLanguageModeling):
    def __init__(self, tokenizer, args, training=True, mlm=False):
        super().__init__(tokenizer, mlm=mlm)
        self.args = args
        self.training = training
        self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})  # mk check why needed

    def torch_call(self, examples: List[Union[List[int], Any, Dict[str, Any]]]) -> Dict[str, Any]:
        # Handle dict or lists with proper padding and conversion to tensor.
        list_sample_mask = []
        for i in range(len(examples)):
            if 'sample_mask' in examples[i].keys():
                list_sample_mask.append(examples[i].pop('sample_mask'))

        if isinstance(examples[0], Mapping):
            batch = pad_without_fast_tokenizer_warning(
                self.tokenizer, examples, return_tensors="pt", pad_to_multiple_of=self.pad_to_multiple_of
            )
        else:
            batch = {
                "input_ids": _torch_collate_batch(examples, self.tokenizer, pad_to_multiple_of=self.pad_to_multiple_of)
            }

        tensor_sample_mask = batch["input_ids"].clone().long()
        tensor_sample_mask[:, :] = 0
        for i in range(len(list_sample_mask)):
            tensor_sample_mask[i, :len(list_sample_mask[i])] = list_sample_mask[i]

        batch['labels'] = batch["input_ids"].clone()
        if self.args.method_boost_type in ['IPT', 'our_prompt_tuning'] or self.training == False:
            batch['sample_mask'] = tensor_sample_mask

        if self.training and self.args.method_boost_type in ['prefix_tuning', 'prompt_tuning', 'LORA']:
            for i, ICL_mask_i in enumerate(list_sample_mask):
                batch['labels'][i][ICL_mask_i != 4] = -100

        return batch
