import copy
from dataclasses import field
from typing import Sequence, Literal

import transformers
from transformers import Trainer
from transformers.modeling_utils import *
from transformers.trainer import _is_peft_model
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from transformers.data.data_collator import DataCollator

from transformers.training_args import TrainingArguments
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import TrainerCallback
from transformers.trainer_utils import EvalPrediction
from torch.utils.data import Dataset, IterableDataset

from datasets import load_dataset
from peft import get_peft_model, HOFTConfig
import wandb

IGNORE_INDEX = -100
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
DEFAULT_UNK_TOKEN = "</s>"
PROMPT = (
    "Below is an instruction that describes a task. "
    "Write a response that appropriately completes the request.\n\n"
    "### Instruction:\n{instruction}\n\n### Response:"
)


class MyTrainer(Trainer):

    def __init__(
            self,
            model: Union[PreTrainedModel, nn.Module] = None,
            args: TrainingArguments = None,
            data_collator: Optional[DataCollator] = None,
            train_dataset: Optional[Union[Dataset, IterableDataset, "datasets.Dataset"]] = None,
            eval_dataset: Optional[Union[Dataset, Dict[str, Dataset], "datasets.Dataset"]] = None,
            tokenizer: Optional[PreTrainedTokenizerBase] = None,
            model_init: Optional[Callable[[], PreTrainedModel]] = None,
            compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
            callbacks: Optional[List[TrainerCallback]] = None,
            optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
            preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
            lamda: float = 1e-4
    ):
        super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init,
                         compute_metrics, callbacks, preprocess_logits_for_metrics=preprocess_logits_for_metrics)
        self.lamda = lamda

    def compute_loss(self, model, inputs, num_items_in_batch, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        # ------------------------------------------------------------------------------

        #for name, param in model.named_parameters():
        #    if 'oft_r' in name:
        #        device = param.device
        #        householder_U_norm = param / param.norm(dim=0)
        #        orth_loss = torch.norm(
        #            torch.eye(householder_U_norm.size(1), device=device) - householder_U_norm.t() @ householder_U_norm)
        #        print(self.lamda)
        #        loss = loss + self.lamda * orth_loss.to(loss.device)

        # ------------------------------------------------------------------------------

        return (loss, outputs) if return_outputs else loss


@dataclass
class TrainingArguments(transformers.TrainingArguments):
    run_name : Optional[str] = field(default="experiment")
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
    adapter_name_or_path: Optional[str] = field(default=None)
    data_path: str = field(default=None, metadata={"help": "Path to the training data."})
    dataset_split: str = field(
        default="train[:100000]", metadata={"help": "(`['train', 'test', 'eval']`):"}
    )
    dataset_field: List[str] = field(
        default=None, metadata={"help": "Fields of dataset input and output."}
    )
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(default=512, metadata={
        "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, )
    hrft_r: int = field(default=8, metadata={
        "help": "The rank of the adapter. When passing `None` and `adapter_name_or_path` is also `None`, full fine-tuning is used."})
    init_a: float = field(default=1e-4, metadata={"help": "The initial weights"})
    eps: float = field(default=1e-4, metadata={"help": "The control strength of COFT. The freedom of rotation."})
    lamda: float = field(default=1e-4, metadata={"help": "The control strength of regularity"})
    add_orth: str = field(default='none', metadata={"help": ""})
    peft_type: str = field(default='hoft', metadata={"help": "Use hoft or shoft methods"})
    init_weights: Literal[True, "pissa"] = field(
        default=True,
        metadata={
            "help": (
                "Passing True (default) results in the LoRA initialization."
                "Passing `pissa` results in PiSSA initialization."
            ),
        },
    )


def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
    """Collects the state dict and dump to disk."""
    state_dict = trainer.model.state_dict()
    if trainer.args.should_save:
        cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
        del state_dict
        trainer._save(output_dir, state_dict=cpu_state_dict)  # noqa


def smart_tokenizer_and_embedding_resize(
        special_tokens_dict: Dict,
        tokenizer: transformers.PreTrainedTokenizer,
        model: transformers.PreTrainedModel,
):
    """Resize tokenizer and embedding.

    Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
    """
    num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
    model.resize_token_embeddings(len(tokenizer))

    if num_new_tokens > 0:
        input_embeddings = model.get_input_embeddings().weight.data
        output_embeddings = model.get_output_embeddings().weight.data

        input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
        output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)

        input_embeddings[-num_new_tokens:] = input_embeddings_avg
        output_embeddings[-num_new_tokens:] = output_embeddings_avg


def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
    """Tokenize a list of strings."""
    tokenized_list = [
        tokenizer(
            text,
            return_tensors="pt",
            padding="longest",
            max_length=tokenizer.model_max_length,
            truncation=True,
        )
        for text in strings
    ]
    input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
    input_ids_lens = labels_lens = [
        tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
    ]
    return dict(
        input_ids=input_ids,
        labels=labels,
        input_ids_lens=input_ids_lens,
        labels_lens=labels_lens,
    )


def preprocess(
        sources: Sequence[str],
        targets: Sequence[str],
        tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
    """Preprocess the data by tokenizing."""
    examples = [s + t for s, t in zip(sources, targets)]
    examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
    input_ids = examples_tokenized["input_ids"]
    labels = copy.deepcopy(input_ids)
    for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
        label[:source_len] = IGNORE_INDEX
    return dict(input_ids=input_ids, labels=labels)


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
        input_ids = [torch.tensor(x) for x in input_ids]
        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        labels = [torch.tensor(x) for x in labels]
        labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
        return dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
        )


def train_tokenize_function(examples, tokenizer, query, response):
    sources = [PROMPT.format_map(dict(instruction=instruction)) for instruction in examples[query]]
    targets = [f"{output}{tokenizer.eos_token}" for output in examples[response]]
    data_dict = preprocess(sources, targets, tokenizer)
    return data_dict


def train():
    parser = transformers.HfArgumentParser(TrainingArguments)    
    script_args = parser.parse_args_into_dataclasses()[0]
    print(script_args)
    if 'wandb' in script_args.report_to:
        wandb.init(
            project="Mathematical reasoning",
            name=f"LLaMA2-{script_args.peft_type}",
        )
        
    model = transformers.AutoModelForCausalLM.from_pretrained(
        script_args.model_name_or_path,
        device_map="auto",
    )
    
    #if script_args.adapter_name_or_path is not None:
    #    print(f"Load {script_args.init_weights} from {script_args.adapter_name_or_path}: ", )
    #    model = PeftModel.from_pretrained(model, script_args.model_name_or_path,
    #                                      subfolder=script_args.adapter_name_or_path, is_trainable=True)
    if script_args.hrft_r is not None:
        #print(f"Initilized {script_args.init_weights} layers")

        hra_config = HOFTConfig(
            r=script_args.hrft_r,
            target_modules=["q_proj", "v_proj"],
            use_shoft=script_args.peft_type == 'shoft'
        )
        model = get_peft_model(model, hra_config)
        print(model)
    else:
        print("Full Parameter Fine-Tuning")

    # print(model)
    model.print_trainable_parameters()
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        script_args.model_name_or_path,
        model_max_length=script_args.model_max_length,
        padding_side="right",
        use_fast=True,
    )
    tokenizer.pad_token = tokenizer.eos_token
    #if tokenizer.pad_token is None:
    #    smart_tokenizer_and_embedding_resize(
    #        special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
    #        tokenizer=tokenizer,
    #        model=model,
    #    )
    #if "llama" in script_args.model_name_or_path:
    #    tokenizer.add_special_tokens(
    #        {
    #            "eos_token": DEFAULT_EOS_TOKEN,
    #            "bos_token": DEFAULT_BOS_TOKEN,
    #            "unk_token": DEFAULT_UNK_TOKEN,
    #        }
    #    )

    raw_train_datasets = load_dataset(script_args.data_path, split=script_args.dataset_split)
    train_dataset = raw_train_datasets.map(
        train_tokenize_function,
        batched=True,
        batch_size=3000,
        num_proc=4,
        remove_columns=raw_train_datasets.column_names,
        load_from_cache_file=True,
        desc="Running tokenizer on train dataset",
        fn_kwargs={"tokenizer": tokenizer, "query": script_args.dataset_field[0],
                   "response": script_args.dataset_field[1]}
    )

    data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
    #data_module = dict(train_dataset=train_dataset, data_collator=data_collator)
    #trainer = MyTrainer(model=model, tokenizer=tokenizer, lamda=script_args.lamda,  args=script_args, **data_module)

    train_args = TrainingArguments(
            f"models/llama2",
            learning_rate=script_args.learning_rate,
            per_device_train_batch_size=script_args.per_device_train_batch_size,
            save_total_limit=0,
            warmup_ratio=script_args.warmup_ratio,
            gradient_accumulation_steps=script_args.gradient_accumulation_steps,
            num_train_epochs=script_args.num_train_epochs,
            optim="adamw_torch",
            lr_scheduler_type= script_args.lr_scheduler_type,
            fp16=True,
            report_to=script_args.report_to,  # enable logging to W&B
            logging_steps=script_args.logging_steps,  # how often to log to W&B
        )

    trainer = Trainer(
            model,
            train_args,
            train_dataset=train_dataset,
            tokenizer=tokenizer,
            data_collator=data_collator,
        )
    
    model.config.use_cache = False
    trainer.train()
    

    # Merge model and save it
    model = model.merge_and_unload()
    model.save_pretrained(os.path.join(script_args.output_dir, 'ft'))
    tokenizer.save_pretrained(os.path.join(script_args.output_dir, 'ft'))

    if 'wandb' in script_args.report_to:
        wandb.finish()


if __name__ == "__main__":

    train()
